Compare commits
5 Commits
demo
...
6d55061e86
Author | SHA1 | Date | |
---|---|---|---|
6d55061e86 | |||
c3010a6737 | |||
6b9482d2be | |||
b4e592b406 | |||
f73cd17dfc |
439
.kiro/specs/multi-modal-trading-system/design.md
Normal file
439
.kiro/specs/multi-modal-trading-system/design.md
Normal file
@ -0,0 +1,439 @@
|
||||
# Multi-Modal Trading System Design Document
|
||||
|
||||
## Overview
|
||||
|
||||
The Multi-Modal Trading System is designed as an advanced algorithmic trading platform that combines Convolutional Neural Networks (CNN) and Reinforcement Learning (RL) models orchestrated by a decision-making module. The system processes multi-timeframe and multi-symbol market data (primarily ETH and BTC) to generate trading actions.
|
||||
|
||||
This design document outlines the architecture, components, data flow, and implementation details for the system based on the requirements and existing codebase.
|
||||
|
||||
## Architecture
|
||||
|
||||
The system follows a modular architecture with clear separation of concerns:
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[Data Provider] --> B[Data Processor]
|
||||
B --> C[CNN Model]
|
||||
B --> D[RL Model]
|
||||
C --> E[Orchestrator]
|
||||
D --> E
|
||||
E --> F[Trading Executor]
|
||||
E --> G[Dashboard]
|
||||
F --> G
|
||||
H[Risk Manager] --> F
|
||||
H --> G
|
||||
```
|
||||
|
||||
### Key Components
|
||||
|
||||
1. **Data Provider**: Centralized component responsible for collecting, processing, and distributing market data from multiple sources.
|
||||
2. **Data Processor**: Processes raw market data, calculates technical indicators, and identifies pivot points.
|
||||
3. **CNN Model**: Analyzes patterns in market data and predicts pivot points across multiple timeframes.
|
||||
4. **RL Model**: Learns optimal trading strategies based on market data and CNN predictions.
|
||||
5. **Orchestrator**: Makes final trading decisions based on inputs from both CNN and RL models.
|
||||
6. **Trading Executor**: Executes trading actions through brokerage APIs.
|
||||
7. **Risk Manager**: Implements risk management features like stop-loss and position sizing.
|
||||
8. **Dashboard**: Provides a user interface for monitoring and controlling the system.
|
||||
|
||||
## Components and Interfaces
|
||||
|
||||
### 1. Data Provider
|
||||
|
||||
The Data Provider is the foundation of the system, responsible for collecting, processing, and distributing market data to all other components.
|
||||
|
||||
#### Key Classes and Interfaces
|
||||
|
||||
- **DataProvider**: Central class that manages data collection, processing, and distribution.
|
||||
- **MarketTick**: Data structure for standardized market tick data.
|
||||
- **DataSubscriber**: Interface for components that subscribe to market data.
|
||||
- **PivotBounds**: Data structure for pivot-based normalization bounds.
|
||||
|
||||
#### Implementation Details
|
||||
|
||||
The DataProvider class will:
|
||||
- Collect data from multiple sources (Binance, MEXC)
|
||||
- Support multiple timeframes (1s, 1m, 1h, 1d)
|
||||
- Support multiple symbols (ETH, BTC)
|
||||
- Calculate technical indicators
|
||||
- Identify pivot points
|
||||
- Normalize data
|
||||
- Distribute data to subscribers
|
||||
|
||||
Based on the existing implementation in `core/data_provider.py`, we'll enhance it to:
|
||||
- Improve pivot point calculation using Williams Market Structure
|
||||
- Optimize data caching for better performance
|
||||
- Enhance real-time data streaming
|
||||
- Implement better error handling and fallback mechanisms
|
||||
|
||||
### 2. CNN Model
|
||||
|
||||
The CNN Model is responsible for analyzing patterns in market data and predicting pivot points across multiple timeframes.
|
||||
|
||||
#### Key Classes and Interfaces
|
||||
|
||||
- **CNNModel**: Main class for the CNN model.
|
||||
- **PivotPointPredictor**: Interface for predicting pivot points.
|
||||
- **CNNTrainer**: Class for training the CNN model.
|
||||
|
||||
#### Implementation Details
|
||||
|
||||
The CNN Model will:
|
||||
- Accept multi-timeframe and multi-symbol data as input
|
||||
- Output predicted pivot points for each timeframe (1s, 1m, 1h, 1d)
|
||||
- Provide confidence scores for each prediction
|
||||
- Make hidden layer states available for the RL model
|
||||
|
||||
Architecture:
|
||||
- Input layer: Multi-channel input for different timeframes and symbols
|
||||
- Convolutional layers: Extract patterns from time series data
|
||||
- LSTM/GRU layers: Capture temporal dependencies
|
||||
- Attention mechanism: Focus on relevant parts of the input
|
||||
- Output layer: Predict pivot points and confidence scores
|
||||
|
||||
Training:
|
||||
- Use programmatically calculated pivot points as ground truth
|
||||
- Train on historical data
|
||||
- Update model when new pivot points are detected
|
||||
- Use backpropagation to optimize weights
|
||||
|
||||
### 3. RL Model
|
||||
|
||||
The RL Model is responsible for learning optimal trading strategies based on market data and CNN predictions.
|
||||
|
||||
#### Key Classes and Interfaces
|
||||
|
||||
- **RLModel**: Main class for the RL model.
|
||||
- **TradingActionGenerator**: Interface for generating trading actions.
|
||||
- **RLTrainer**: Class for training the RL model.
|
||||
|
||||
#### Implementation Details
|
||||
|
||||
The RL Model will:
|
||||
- Accept market data, CNN predictions, and CNN hidden layer states as input
|
||||
- Output trading action recommendations (buy/sell)
|
||||
- Provide confidence scores for each action
|
||||
- Learn from past experiences to adapt to the current market environment
|
||||
|
||||
Architecture:
|
||||
- State representation: Market data, CNN predictions, CNN hidden layer states
|
||||
- Action space: Buy, Sell
|
||||
- Reward function: PnL, risk-adjusted returns
|
||||
- Policy network: Deep neural network
|
||||
- Value network: Estimate expected returns
|
||||
|
||||
Training:
|
||||
- Use reinforcement learning algorithms (DQN, PPO, A3C)
|
||||
- Train on historical data
|
||||
- Update model based on trading outcomes
|
||||
- Use experience replay to improve sample efficiency
|
||||
|
||||
### 4. Orchestrator
|
||||
|
||||
The Orchestrator is responsible for making final trading decisions based on inputs from both CNN and RL models.
|
||||
|
||||
#### Key Classes and Interfaces
|
||||
|
||||
- **Orchestrator**: Main class for the orchestrator.
|
||||
- **DecisionMaker**: Interface for making trading decisions.
|
||||
- **MoEGateway**: Mixture of Experts gateway for model integration.
|
||||
|
||||
#### Implementation Details
|
||||
|
||||
The Orchestrator will:
|
||||
- Accept inputs from both CNN and RL models
|
||||
- Output final trading actions (buy/sell)
|
||||
- Consider confidence levels of both models
|
||||
- Learn to avoid entering positions when uncertain
|
||||
- Allow for configurable thresholds for entering and exiting positions
|
||||
|
||||
Architecture:
|
||||
- Mixture of Experts (MoE) approach
|
||||
- Gating network: Determine which expert to trust
|
||||
- Expert models: CNN, RL, and potentially others
|
||||
- Decision network: Combine expert outputs
|
||||
|
||||
Training:
|
||||
- Train on historical data
|
||||
- Update model based on trading outcomes
|
||||
- Use reinforcement learning to optimize decision-making
|
||||
|
||||
### 5. Trading Executor
|
||||
|
||||
The Trading Executor is responsible for executing trading actions through brokerage APIs.
|
||||
|
||||
#### Key Classes and Interfaces
|
||||
|
||||
- **TradingExecutor**: Main class for the trading executor.
|
||||
- **BrokerageAPI**: Interface for interacting with brokerages.
|
||||
- **OrderManager**: Class for managing orders.
|
||||
|
||||
#### Implementation Details
|
||||
|
||||
The Trading Executor will:
|
||||
- Accept trading actions from the orchestrator
|
||||
- Execute orders through brokerage APIs
|
||||
- Manage order lifecycle
|
||||
- Handle errors and retries
|
||||
- Provide feedback on order execution
|
||||
|
||||
Supported brokerages:
|
||||
- MEXC
|
||||
- Binance
|
||||
- Bybit (future extension)
|
||||
|
||||
Order types:
|
||||
- Market orders
|
||||
- Limit orders
|
||||
- Stop-loss orders
|
||||
|
||||
### 6. Risk Manager
|
||||
|
||||
The Risk Manager is responsible for implementing risk management features like stop-loss and position sizing.
|
||||
|
||||
#### Key Classes and Interfaces
|
||||
|
||||
- **RiskManager**: Main class for the risk manager.
|
||||
- **StopLossManager**: Class for managing stop-loss orders.
|
||||
- **PositionSizer**: Class for determining position sizes.
|
||||
|
||||
#### Implementation Details
|
||||
|
||||
The Risk Manager will:
|
||||
- Implement configurable stop-loss functionality
|
||||
- Implement configurable position sizing based on risk parameters
|
||||
- Implement configurable maximum drawdown limits
|
||||
- Provide real-time risk metrics
|
||||
- Provide alerts for high-risk situations
|
||||
|
||||
Risk parameters:
|
||||
- Maximum position size
|
||||
- Maximum drawdown
|
||||
- Risk per trade
|
||||
- Maximum leverage
|
||||
|
||||
### 7. Dashboard
|
||||
|
||||
The Dashboard provides a user interface for monitoring and controlling the system.
|
||||
|
||||
#### Key Classes and Interfaces
|
||||
|
||||
- **Dashboard**: Main class for the dashboard.
|
||||
- **ChartManager**: Class for managing charts.
|
||||
- **ControlPanel**: Class for managing controls.
|
||||
|
||||
#### Implementation Details
|
||||
|
||||
The Dashboard will:
|
||||
- Display real-time market data for all symbols and timeframes
|
||||
- Display OHLCV charts for all timeframes
|
||||
- Display CNN pivot point predictions and confidence levels
|
||||
- Display RL and orchestrator trading actions and confidence levels
|
||||
- Display system status and model performance metrics
|
||||
- Provide start/stop toggles for all system processes
|
||||
- Provide sliders to adjust buy/sell thresholds for the orchestrator
|
||||
|
||||
Implementation:
|
||||
- Web-based dashboard using Flask/Dash
|
||||
- Real-time updates using WebSockets
|
||||
- Interactive charts using Plotly
|
||||
- Server-side processing for all models
|
||||
|
||||
## Data Models
|
||||
|
||||
### Market Data
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class MarketTick:
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
price: float
|
||||
volume: float
|
||||
quantity: float
|
||||
side: str # 'buy' or 'sell'
|
||||
trade_id: str
|
||||
is_buyer_maker: bool
|
||||
raw_data: Dict[str, Any] = field(default_factory=dict)
|
||||
```
|
||||
|
||||
### OHLCV Data
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class OHLCVBar:
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: float
|
||||
timeframe: str
|
||||
indicators: Dict[str, float] = field(default_factory=dict)
|
||||
```
|
||||
|
||||
### Pivot Points
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class PivotPoint:
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
price: float
|
||||
type: str # 'high' or 'low'
|
||||
level: int # Pivot level (1, 2, 3, etc.)
|
||||
confidence: float = 1.0
|
||||
```
|
||||
|
||||
### Trading Actions
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class TradingAction:
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
action: str # 'buy' or 'sell'
|
||||
confidence: float
|
||||
source: str # 'rl', 'cnn', 'orchestrator'
|
||||
price: Optional[float] = None
|
||||
quantity: Optional[float] = None
|
||||
reason: Optional[str] = None
|
||||
```
|
||||
|
||||
### Model Predictions
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CNNPrediction:
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
pivot_points: List[PivotPoint]
|
||||
hidden_states: Dict[str, Any]
|
||||
confidence: float
|
||||
```
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class RLPrediction:
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
action: str # 'buy' or 'sell'
|
||||
confidence: float
|
||||
expected_reward: float
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Data Collection Errors
|
||||
|
||||
- Implement retry mechanisms for API failures
|
||||
- Use fallback data sources when primary sources are unavailable
|
||||
- Log all errors with detailed information
|
||||
- Notify users through the dashboard
|
||||
|
||||
### Model Errors
|
||||
|
||||
- Implement model validation before deployment
|
||||
- Use fallback models when primary models fail
|
||||
- Log all errors with detailed information
|
||||
- Notify users through the dashboard
|
||||
|
||||
### Trading Errors
|
||||
|
||||
- Implement order validation before submission
|
||||
- Use retry mechanisms for order failures
|
||||
- Implement circuit breakers for extreme market conditions
|
||||
- Log all errors with detailed information
|
||||
- Notify users through the dashboard
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Unit Testing
|
||||
|
||||
- Test individual components in isolation
|
||||
- Use mock objects for dependencies
|
||||
- Focus on edge cases and error handling
|
||||
|
||||
### Integration Testing
|
||||
|
||||
- Test interactions between components
|
||||
- Use real data for testing
|
||||
- Focus on data flow and error propagation
|
||||
|
||||
### System Testing
|
||||
|
||||
- Test the entire system end-to-end
|
||||
- Use real data for testing
|
||||
- Focus on performance and reliability
|
||||
|
||||
### Backtesting
|
||||
|
||||
- Test trading strategies on historical data
|
||||
- Measure performance metrics (PnL, Sharpe ratio, etc.)
|
||||
- Compare against benchmarks
|
||||
|
||||
### Live Testing
|
||||
|
||||
- Test the system in a live environment with small position sizes
|
||||
- Monitor performance and stability
|
||||
- Gradually increase position sizes as confidence grows
|
||||
|
||||
## Implementation Plan
|
||||
|
||||
The implementation will follow a phased approach:
|
||||
|
||||
1. **Phase 1: Data Provider**
|
||||
- Implement the enhanced data provider
|
||||
- Implement pivot point calculation
|
||||
- Implement technical indicator calculation
|
||||
- Implement data normalization
|
||||
|
||||
2. **Phase 2: CNN Model**
|
||||
- Implement the CNN model architecture
|
||||
- Implement the training pipeline
|
||||
- Implement the inference pipeline
|
||||
- Implement the pivot point prediction
|
||||
|
||||
3. **Phase 3: RL Model**
|
||||
- Implement the RL model architecture
|
||||
- Implement the training pipeline
|
||||
- Implement the inference pipeline
|
||||
- Implement the trading action generation
|
||||
|
||||
4. **Phase 4: Orchestrator**
|
||||
- Implement the orchestrator architecture
|
||||
- Implement the decision-making logic
|
||||
- Implement the MoE gateway
|
||||
- Implement the confidence-based filtering
|
||||
|
||||
5. **Phase 5: Trading Executor**
|
||||
- Implement the trading executor
|
||||
- Implement the brokerage API integrations
|
||||
- Implement the order management
|
||||
- Implement the error handling
|
||||
|
||||
6. **Phase 6: Risk Manager**
|
||||
- Implement the risk manager
|
||||
- Implement the stop-loss functionality
|
||||
- Implement the position sizing
|
||||
- Implement the risk metrics
|
||||
|
||||
7. **Phase 7: Dashboard**
|
||||
- Implement the dashboard UI
|
||||
- Implement the chart management
|
||||
- Implement the control panel
|
||||
- Implement the real-time updates
|
||||
|
||||
8. **Phase 8: Integration and Testing**
|
||||
- Integrate all components
|
||||
- Implement comprehensive testing
|
||||
- Fix bugs and optimize performance
|
||||
- Deploy to production
|
||||
|
||||
## Conclusion
|
||||
|
||||
This design document outlines the architecture, components, data flow, and implementation details for the Multi-Modal Trading System. The system is designed to be modular, extensible, and robust, with a focus on performance, reliability, and user experience.
|
||||
|
||||
The implementation will follow a phased approach, with each phase building on the previous one. The system will be thoroughly tested at each phase to ensure that it meets the requirements and performs as expected.
|
||||
|
||||
The final system will provide traders with a powerful tool for analyzing market data, identifying trading opportunities, and executing trades with confidence.
|
133
.kiro/specs/multi-modal-trading-system/requirements.md
Normal file
133
.kiro/specs/multi-modal-trading-system/requirements.md
Normal file
@ -0,0 +1,133 @@
|
||||
# Requirements Document
|
||||
|
||||
## Introduction
|
||||
|
||||
The Multi-Modal Trading System is an advanced algorithmic trading platform that combines Convolutional Neural Networks (CNN) and Reinforcement Learning (RL) models orchestrated by a decision-making module. The system processes multi-timeframe and multi-symbol market data (primarily ETH and BTC) to generate trading actions. The system is designed to adapt to current market conditions through continuous learning from past experiences, with the CNN module trained on historical data to predict pivot points and the RL module optimizing trading decisions based on these predictions and market data.
|
||||
|
||||
## Requirements
|
||||
|
||||
### Requirement 1: Data Collection and Processing
|
||||
|
||||
**User Story:** As a trader, I want the system to collect and process multi-timeframe and multi-symbol market data, so that the models have comprehensive market information for making accurate trading decisions.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
0. NEVER USE GENERATED/SYNTHETIC DATA or mock implementations and UI. If somethings is not implemented yet, it should be obvious.
|
||||
1. WHEN the system starts THEN it SHALL collect and process data for both ETH and BTC symbols.
|
||||
2. WHEN collecting data THEN the system SHALL store the following for the primary symbol (ETH):
|
||||
- 300 seconds of raw tick data - price and COB snapshot for all prices +- 1% on fine reslolution buckets (1$ for ETH, 10$ for BTC)
|
||||
- 300 seconds of 1-second OHLCV data + 1s aggregated COB data
|
||||
- 300 bars of OHLCV + indicators for each timeframe (1s, 1m, 1h, 1d)
|
||||
3. WHEN collecting data THEN the system SHALL store similar data for the reference symbol (BTC).
|
||||
4. WHEN processing data THEN the system SHALL calculate standard technical indicators for all timeframes.
|
||||
5. WHEN processing data THEN the system SHALL calculate pivot points for all timeframes according to the specified methodology.
|
||||
6. WHEN new data arrives THEN the system SHALL update its data cache in real-time.
|
||||
7. IF tick data is not available THEN the system SHALL substitute with the lowest available timeframe data.
|
||||
8. WHEN normalizing data THEN the system SHALL normalize to the max and min of the highest timeframe to maintain relationships between different timeframes.
|
||||
9. data is cached for longer (let's start with double the model inputs so 600 bars) to support performing backtesting when we know the current predictions outcomes so we can generate test cases.
|
||||
10. In general all models have access to the whole data we collect in a central data provider implementation. only some are specialized. All models should also take as input the last output of evey other model (also cached in the data provider). there should be a room for adding more models in the other models data input so we can extend the system without having to loose existing models and trained W&B
|
||||
|
||||
### Requirement 2: CNN Model Implementation
|
||||
|
||||
**User Story:** As a trader, I want the system to implement a CNN model that can identify patterns and predict pivot points across multiple timeframes, so that I can anticipate market direction changes.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN the CNN model is initialized THEN it SHALL accept multi-timeframe and multi-symbol data as input.
|
||||
2. WHEN processing input data THEN the CNN model SHALL output predicted pivot points for each timeframe (1s, 1m, 1h, 1d).
|
||||
3. WHEN predicting pivot points THEN the CNN model SHALL provide both the predicted pivot point value and the timestamp when it is expected to occur.
|
||||
4. WHEN a pivot point is detected THEN the system SHALL trigger a training round for the CNN model using historical data.
|
||||
5. WHEN training the CNN model THEN the system SHALL use programmatically calculated pivot points from historical data as ground truth.
|
||||
6. WHEN outputting predictions THEN the CNN model SHALL include a confidence score for each prediction.
|
||||
7. WHEN calculating pivot points THEN the system SHALL implement both standard pivot points and the recursive Williams market structure pivot points as described.
|
||||
8. WHEN processing data THEN the CNN model SHALL make available its hidden layer states for use by the RL model.
|
||||
|
||||
### Requirement 3: RL Model Implementation
|
||||
|
||||
**User Story:** As a trader, I want the system to implement an RL model that can learn optimal trading strategies based on market data and CNN predictions, so that the system can adapt to changing market conditions.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN the RL model is initialized THEN it SHALL accept market data, CNN predictions, and CNN hidden layer states as input.
|
||||
2. WHEN processing input data THEN the RL model SHALL output trading action recommendations (buy/sell).
|
||||
3. WHEN evaluating trading actions THEN the RL model SHALL learn from past experiences to adapt to the current market environment.
|
||||
4. WHEN making decisions THEN the RL model SHALL consider the confidence levels of CNN predictions.
|
||||
5. WHEN uncertain about market direction THEN the RL model SHALL learn to avoid entering positions.
|
||||
6. WHEN training the RL model THEN the system SHALL use a reward function that incentivizes high risk/reward setups.
|
||||
7. WHEN outputting trading actions THEN the RL model SHALL provide a confidence score for each action.
|
||||
8. WHEN a trading action is executed THEN the system SHALL store the input data for future training.
|
||||
|
||||
### Requirement 4: Orchestrator Implementation
|
||||
|
||||
**User Story:** As a trader, I want the system to implement an orchestrator that can make final trading decisions based on inputs from both CNN and RL models, so that the system can make more balanced and informed trading decisions.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN the orchestrator is initialized THEN it SHALL accept inputs from both CNN and RL models.
|
||||
2. WHEN processing model inputs THEN the orchestrator SHALL output final trading actions (buy/sell).
|
||||
3. WHEN making decisions THEN the orchestrator SHALL consider the confidence levels of both CNN and RL models.
|
||||
4. WHEN uncertain about market direction THEN the orchestrator SHALL learn to avoid entering positions.
|
||||
5. WHEN implementing the orchestrator THEN the system SHALL use a Mixture of Experts (MoE) approach to allow for future model integration.
|
||||
6. WHEN outputting trading actions THEN the orchestrator SHALL provide a confidence score for each action.
|
||||
7. WHEN a trading action is executed THEN the system SHALL store the input data for future training.
|
||||
8. WHEN implementing the orchestrator THEN the system SHALL allow for configurable thresholds for entering and exiting positions.
|
||||
|
||||
### Requirement 5: Training Pipeline
|
||||
|
||||
**User Story:** As a developer, I want the system to implement a unified training pipeline for both CNN and RL models, so that the models can be trained efficiently and consistently.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN training models THEN the system SHALL use a unified data provider to prepare data for all models.
|
||||
2. WHEN a pivot point is detected THEN the system SHALL trigger a training round for the CNN model.
|
||||
3. WHEN training the CNN model THEN the system SHALL use programmatically calculated pivot points from historical data as ground truth.
|
||||
4. WHEN training the RL model THEN the system SHALL use a reward function that incentivizes high risk/reward setups.
|
||||
5. WHEN training models THEN the system SHALL run the training process on the server without requiring the dashboard to be open.
|
||||
6. WHEN training models THEN the system SHALL provide real-time feedback on training progress through the dashboard.
|
||||
7. WHEN training models THEN the system SHALL store model checkpoints for future use.
|
||||
8. WHEN training models THEN the system SHALL provide metrics on model performance.
|
||||
|
||||
### Requirement 6: Dashboard Implementation
|
||||
|
||||
**User Story:** As a trader, I want the system to implement a comprehensive dashboard that displays real-time data, model predictions, and trading actions, so that I can monitor the system's performance and make informed decisions.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN the dashboard is initialized THEN it SHALL display real-time market data for all symbols and timeframes.
|
||||
2. WHEN displaying market data THEN the dashboard SHALL show OHLCV charts for all timeframes.
|
||||
3. WHEN displaying model predictions THEN the dashboard SHALL show CNN pivot point predictions and confidence levels.
|
||||
4. WHEN displaying trading actions THEN the dashboard SHALL show RL and orchestrator trading actions and confidence levels.
|
||||
5. WHEN displaying system status THEN the dashboard SHALL show training progress and model performance metrics.
|
||||
6. WHEN implementing controls THEN the dashboard SHALL provide start/stop toggles for all system processes.
|
||||
7. WHEN implementing controls THEN the dashboard SHALL provide sliders to adjust buy/sell thresholds for the orchestrator.
|
||||
8. WHEN implementing the dashboard THEN the system SHALL ensure all processes run on the server without requiring the dashboard to be open.
|
||||
|
||||
### Requirement 7: Risk Management
|
||||
|
||||
**User Story:** As a trader, I want the system to implement risk management features, so that I can protect my capital from significant losses.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN implementing risk management THEN the system SHALL provide configurable stop-loss functionality.
|
||||
2. WHEN a stop-loss is triggered THEN the system SHALL automatically close the position.
|
||||
3. WHEN implementing risk management THEN the system SHALL provide configurable position sizing based on risk parameters.
|
||||
4. WHEN implementing risk management THEN the system SHALL provide configurable maximum drawdown limits.
|
||||
5. WHEN maximum drawdown limits are reached THEN the system SHALL automatically stop trading.
|
||||
6. WHEN implementing risk management THEN the system SHALL provide real-time risk metrics through the dashboard.
|
||||
7. WHEN implementing risk management THEN the system SHALL allow for different risk parameters for different market conditions.
|
||||
8. WHEN implementing risk management THEN the system SHALL provide alerts for high-risk situations.
|
||||
|
||||
### Requirement 8: System Architecture and Integration
|
||||
|
||||
**User Story:** As a developer, I want the system to implement a clean and modular architecture, so that the system is easy to maintain and extend.
|
||||
|
||||
#### Acceptance Criteria
|
||||
|
||||
1. WHEN implementing the system architecture THEN the system SHALL use a unified data provider to prepare data for all models.
|
||||
2. WHEN implementing the system architecture THEN the system SHALL use a modular approach to allow for easy extension.
|
||||
3. WHEN implementing the system architecture THEN the system SHALL use a clean separation of concerns between data collection, model training, and trading execution.
|
||||
4. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all models.
|
||||
5. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all data providers.
|
||||
6. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all trading executors.
|
||||
7. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all risk management components.
|
||||
8. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all dashboard components.
|
247
.kiro/specs/multi-modal-trading-system/tasks.md
Normal file
247
.kiro/specs/multi-modal-trading-system/tasks.md
Normal file
@ -0,0 +1,247 @@
|
||||
# Implementation Plan
|
||||
|
||||
## Data Provider and Processing
|
||||
|
||||
- [ ] 1. Enhance the existing DataProvider class
|
||||
|
||||
|
||||
- Extend the current implementation in core/data_provider.py
|
||||
- Ensure it supports all required timeframes (1s, 1m, 1h, 1d)
|
||||
- Implement better error handling and fallback mechanisms
|
||||
- _Requirements: 1.1, 1.2, 1.3, 1.6_
|
||||
|
||||
- [ ] 1.1. Implement Williams Market Structure pivot point calculation
|
||||
- Create a dedicated method for identifying pivot points
|
||||
- Implement the recursive pivot point calculation as described
|
||||
- Add unit tests to verify pivot point detection accuracy
|
||||
- _Requirements: 1.5, 2.7_
|
||||
|
||||
- [ ] 1.2. Optimize data caching for better performance
|
||||
- Implement efficient caching strategies for different timeframes
|
||||
- Add cache invalidation mechanisms
|
||||
- Ensure thread safety for cache access
|
||||
- _Requirements: 1.6, 8.1_
|
||||
|
||||
- [ ] 1.3. Enhance real-time data streaming
|
||||
- Improve WebSocket connection management
|
||||
- Implement reconnection strategies
|
||||
- Add data validation to ensure data integrity
|
||||
- _Requirements: 1.6, 8.5_
|
||||
|
||||
- [ ] 1.4. Implement data normalization
|
||||
- Normalize data based on the highest timeframe
|
||||
- Ensure relationships between different timeframes are maintained
|
||||
- Add unit tests to verify normalization correctness
|
||||
- _Requirements: 1.8, 2.1_
|
||||
|
||||
## CNN Model Implementation
|
||||
|
||||
- [ ] 2. Design and implement the CNN model architecture
|
||||
- Create a CNNModel class that accepts multi-timeframe and multi-symbol data
|
||||
- Implement the model using PyTorch or TensorFlow
|
||||
- Design the architecture with convolutional, LSTM/GRU, and attention layers
|
||||
- _Requirements: 2.1, 2.2, 2.8_
|
||||
|
||||
- [ ] 2.1. Implement pivot point prediction
|
||||
- Create a PivotPointPredictor class
|
||||
- Implement methods to predict pivot points for each timeframe
|
||||
- Add confidence score calculation for predictions
|
||||
- _Requirements: 2.2, 2.3, 2.6_
|
||||
|
||||
- [ ] 2.2. Implement CNN training pipeline
|
||||
- Create a CNNTrainer class
|
||||
- Implement methods for training the model on historical data
|
||||
- Add mechanisms to trigger training when new pivot points are detected
|
||||
- _Requirements: 2.4, 2.5, 5.2, 5.3_
|
||||
|
||||
- [ ] 2.3. Implement CNN inference pipeline
|
||||
- Create methods for real-time inference
|
||||
- Ensure hidden layer states are accessible for the RL model
|
||||
- Optimize for performance to minimize latency
|
||||
- _Requirements: 2.2, 2.6, 2.8_
|
||||
|
||||
- [ ] 2.4. Implement model evaluation and validation
|
||||
- Create methods to evaluate model performance
|
||||
- Implement metrics for prediction accuracy
|
||||
- Add validation against historical pivot points
|
||||
- _Requirements: 2.5, 5.8_
|
||||
|
||||
## RL Model Implementation
|
||||
|
||||
- [ ] 3. Design and implement the RL model architecture
|
||||
- Create an RLModel class that accepts market data and CNN outputs
|
||||
- Implement the model using PyTorch or TensorFlow
|
||||
- Design the architecture with state representation, action space, and reward function
|
||||
- _Requirements: 3.1, 3.2, 3.7_
|
||||
|
||||
- [ ] 3.1. Implement trading action generation
|
||||
- Create a TradingActionGenerator class
|
||||
- Implement methods to generate buy/sell recommendations
|
||||
- Add confidence score calculation for actions
|
||||
- _Requirements: 3.2, 3.7_
|
||||
|
||||
- [ ] 3.2. Implement RL training pipeline
|
||||
- Create an RLTrainer class
|
||||
- Implement methods for training the model on historical data
|
||||
- Add experience replay for improved sample efficiency
|
||||
- _Requirements: 3.3, 3.5, 5.4_
|
||||
|
||||
- [ ] 3.3. Implement RL inference pipeline
|
||||
- Create methods for real-time inference
|
||||
- Optimize for performance to minimize latency
|
||||
- Ensure proper handling of CNN inputs
|
||||
- _Requirements: 3.1, 3.2, 3.4_
|
||||
|
||||
- [ ] 3.4. Implement model evaluation and validation
|
||||
- Create methods to evaluate model performance
|
||||
- Implement metrics for trading performance
|
||||
- Add validation against historical trading opportunities
|
||||
- _Requirements: 3.3, 5.8_
|
||||
|
||||
## Orchestrator Implementation
|
||||
|
||||
- [ ] 4. Design and implement the orchestrator architecture
|
||||
- Create an Orchestrator class that accepts inputs from CNN and RL models
|
||||
- Implement the Mixture of Experts (MoE) approach
|
||||
- Design the architecture with gating network and decision network
|
||||
- _Requirements: 4.1, 4.2, 4.5_
|
||||
|
||||
- [ ] 4.1. Implement decision-making logic
|
||||
- Create a DecisionMaker class
|
||||
- Implement methods to make final trading decisions
|
||||
- Add confidence-based filtering
|
||||
- _Requirements: 4.2, 4.3, 4.4_
|
||||
|
||||
- [ ] 4.2. Implement MoE gateway
|
||||
- Create a MoEGateway class
|
||||
- Implement methods to determine which expert to trust
|
||||
- Add mechanisms for future model integration
|
||||
- _Requirements: 4.5, 8.2_
|
||||
|
||||
- [ ] 4.3. Implement configurable thresholds
|
||||
- Add parameters for entering and exiting positions
|
||||
- Implement methods to adjust thresholds dynamically
|
||||
- Add validation to ensure thresholds are within reasonable ranges
|
||||
- _Requirements: 4.8, 6.7_
|
||||
|
||||
- [ ] 4.4. Implement model evaluation and validation
|
||||
- Create methods to evaluate orchestrator performance
|
||||
- Implement metrics for decision quality
|
||||
- Add validation against historical trading decisions
|
||||
- _Requirements: 4.6, 5.8_
|
||||
|
||||
## Trading Executor Implementation
|
||||
|
||||
- [ ] 5. Design and implement the trading executor
|
||||
- Create a TradingExecutor class that accepts trading actions from the orchestrator
|
||||
- Implement order execution through brokerage APIs
|
||||
- Add order lifecycle management
|
||||
- _Requirements: 7.1, 7.2, 8.6_
|
||||
|
||||
- [ ] 5.1. Implement brokerage API integrations
|
||||
- Create a BrokerageAPI interface
|
||||
- Implement concrete classes for MEXC and Binance
|
||||
- Add error handling and retry mechanisms
|
||||
- _Requirements: 7.1, 7.2, 8.6_
|
||||
|
||||
- [ ] 5.2. Implement order management
|
||||
- Create an OrderManager class
|
||||
- Implement methods for creating, updating, and canceling orders
|
||||
- Add order tracking and status updates
|
||||
- _Requirements: 7.1, 7.2, 8.6_
|
||||
|
||||
- [ ] 5.3. Implement error handling
|
||||
- Add comprehensive error handling for API failures
|
||||
- Implement circuit breakers for extreme market conditions
|
||||
- Add logging and notification mechanisms
|
||||
- _Requirements: 7.1, 7.2, 8.6_
|
||||
|
||||
## Risk Manager Implementation
|
||||
|
||||
- [ ] 6. Design and implement the risk manager
|
||||
- Create a RiskManager class
|
||||
- Implement risk parameter management
|
||||
- Add risk metric calculation
|
||||
- _Requirements: 7.1, 7.3, 7.4_
|
||||
|
||||
- [ ] 6.1. Implement stop-loss functionality
|
||||
- Create a StopLossManager class
|
||||
- Implement methods for creating and managing stop-loss orders
|
||||
- Add mechanisms to automatically close positions when stop-loss is triggered
|
||||
- _Requirements: 7.1, 7.2_
|
||||
|
||||
- [ ] 6.2. Implement position sizing
|
||||
- Create a PositionSizer class
|
||||
- Implement methods for calculating position sizes based on risk parameters
|
||||
- Add validation to ensure position sizes are within limits
|
||||
- _Requirements: 7.3, 7.7_
|
||||
|
||||
- [ ] 6.3. Implement risk metrics
|
||||
- Add methods to calculate risk metrics (drawdown, VaR, etc.)
|
||||
- Implement real-time risk monitoring
|
||||
- Add alerts for high-risk situations
|
||||
- _Requirements: 7.4, 7.5, 7.6, 7.8_
|
||||
|
||||
## Dashboard Implementation
|
||||
|
||||
- [ ] 7. Design and implement the dashboard UI
|
||||
- Create a Dashboard class
|
||||
- Implement the web-based UI using Flask/Dash
|
||||
- Add real-time updates using WebSockets
|
||||
- _Requirements: 6.1, 6.8_
|
||||
|
||||
- [ ] 7.1. Implement chart management
|
||||
- Create a ChartManager class
|
||||
- Implement methods for creating and updating charts
|
||||
- Add interactive features (zoom, pan, etc.)
|
||||
- _Requirements: 6.1, 6.2_
|
||||
|
||||
- [ ] 7.2. Implement control panel
|
||||
- Create a ControlPanel class
|
||||
- Implement start/stop toggles for system processes
|
||||
- Add sliders for adjusting buy/sell thresholds
|
||||
- _Requirements: 6.6, 6.7_
|
||||
|
||||
- [ ] 7.3. Implement system status display
|
||||
- Add methods to display training progress
|
||||
- Implement model performance metrics visualization
|
||||
- Add real-time system status updates
|
||||
- _Requirements: 6.5, 5.6_
|
||||
|
||||
- [ ] 7.4. Implement server-side processing
|
||||
- Ensure all processes run on the server without requiring the dashboard to be open
|
||||
- Implement background tasks for model training and inference
|
||||
- Add mechanisms to persist system state
|
||||
- _Requirements: 6.8, 5.5_
|
||||
|
||||
## Integration and Testing
|
||||
|
||||
- [ ] 8. Integrate all components
|
||||
- Connect the data provider to the CNN and RL models
|
||||
- Connect the CNN and RL models to the orchestrator
|
||||
- Connect the orchestrator to the trading executor
|
||||
- _Requirements: 8.1, 8.2, 8.3_
|
||||
|
||||
- [ ] 8.1. Implement comprehensive unit tests
|
||||
- Create unit tests for each component
|
||||
- Implement test fixtures and mocks
|
||||
- Add test coverage reporting
|
||||
- _Requirements: 8.1, 8.2, 8.3_
|
||||
|
||||
- [ ] 8.2. Implement integration tests
|
||||
- Create tests for component interactions
|
||||
- Implement end-to-end tests
|
||||
- Add performance benchmarks
|
||||
- _Requirements: 8.1, 8.2, 8.3_
|
||||
|
||||
- [ ] 8.3. Implement backtesting framework
|
||||
- Create a backtesting environment
|
||||
- Implement methods to replay historical data
|
||||
- Add performance metrics calculation
|
||||
- _Requirements: 5.8, 8.1_
|
||||
|
||||
- [ ] 8.4. Optimize performance
|
||||
- Profile the system to identify bottlenecks
|
||||
- Implement optimizations for critical paths
|
||||
- Add caching and parallelization where appropriate
|
||||
- _Requirements: 8.1, 8.2, 8.3_
|
0
audit_training_system.py
Normal file
0
audit_training_system.py
Normal file
@ -34,6 +34,7 @@ from collections import deque
|
||||
from .config import get_config
|
||||
from .tick_aggregator import RealTimeTickAggregator, RawTick, OHLCVBar
|
||||
from .cnn_monitor import log_cnn_prediction
|
||||
from .williams_market_structure import WilliamsMarketStructure, PivotPoint, TrendLevel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -182,6 +183,16 @@ class DataProvider:
|
||||
'1h': 3600, '4h': 14400, '1d': 86400
|
||||
}
|
||||
|
||||
# Williams Market Structure integration
|
||||
self.williams_structure: Dict[str, WilliamsMarketStructure] = {}
|
||||
for symbol in self.symbols:
|
||||
self.williams_structure[symbol] = WilliamsMarketStructure(min_pivot_distance=3)
|
||||
|
||||
# Pivot point caching
|
||||
self.pivot_points_cache: Dict[str, Dict[int, TrendLevel]] = {} # {symbol: {level: TrendLevel}}
|
||||
self.last_pivot_calculation: Dict[str, datetime] = {}
|
||||
self.pivot_calculation_interval = timedelta(minutes=5) # Recalculate every 5 minutes
|
||||
|
||||
# Load existing pivot bounds from cache
|
||||
self._load_all_pivot_bounds()
|
||||
|
||||
@ -189,6 +200,7 @@ class DataProvider:
|
||||
logger.info(f"Timeframes: {self.timeframes}")
|
||||
logger.info("Centralized data distribution enabled")
|
||||
logger.info("Pivot-based normalization system enabled")
|
||||
logger.info("Williams Market Structure integration enabled")
|
||||
|
||||
# Rate limiting
|
||||
self.last_request_time = {}
|
||||
@ -1613,6 +1625,151 @@ class DataProvider:
|
||||
logger.error(f"Error getting current price for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def calculate_williams_pivot_points(self, symbol: str, force_recalculate: bool = False) -> Dict[int, TrendLevel]:
|
||||
"""
|
||||
Calculate Williams Market Structure pivot points for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
force_recalculate: Force recalculation even if cache is fresh
|
||||
|
||||
Returns:
|
||||
Dictionary of trend levels with pivot points
|
||||
"""
|
||||
try:
|
||||
# Check if we need to recalculate
|
||||
now = datetime.now()
|
||||
if (not force_recalculate and
|
||||
symbol in self.last_pivot_calculation and
|
||||
now - self.last_pivot_calculation[symbol] < self.pivot_calculation_interval):
|
||||
# Return cached results
|
||||
return self.pivot_points_cache.get(symbol, {})
|
||||
|
||||
# Get 1s OHLCV data for Williams Market Structure calculation
|
||||
df_1s = self.get_historical_data(symbol, '1s', limit=1000)
|
||||
if df_1s is None or len(df_1s) < 50:
|
||||
logger.warning(f"Insufficient 1s data for Williams pivot calculation: {symbol}")
|
||||
return {}
|
||||
|
||||
# Convert DataFrame to numpy array for Williams calculation
|
||||
# Format: [timestamp_ms, open, high, low, close, volume]
|
||||
ohlcv_array = np.column_stack([
|
||||
df_1s.index.astype(np.int64) // 10**6, # Convert to milliseconds
|
||||
df_1s['open'].values,
|
||||
df_1s['high'].values,
|
||||
df_1s['low'].values,
|
||||
df_1s['close'].values,
|
||||
df_1s['volume'].values
|
||||
])
|
||||
|
||||
# Calculate recursive pivot points using Williams Market Structure
|
||||
williams = self.williams_structure[symbol]
|
||||
pivot_levels = williams.calculate_recursive_pivot_points(ohlcv_array)
|
||||
|
||||
# Cache the results
|
||||
self.pivot_points_cache[symbol] = pivot_levels
|
||||
self.last_pivot_calculation[symbol] = now
|
||||
|
||||
logger.debug(f"Calculated Williams pivot points for {symbol}: {len(pivot_levels)} levels")
|
||||
return pivot_levels
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating Williams pivot points for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def get_pivot_features_for_ml(self, symbol: str) -> np.ndarray:
|
||||
"""
|
||||
Get pivot point features for machine learning models
|
||||
|
||||
Returns a 250-element feature vector containing:
|
||||
- Recent pivot points (price, strength, type) for each level
|
||||
- Trend direction and strength for each level
|
||||
- Time since last pivot for each level
|
||||
"""
|
||||
try:
|
||||
# Ensure we have fresh pivot points
|
||||
pivot_levels = self.calculate_williams_pivot_points(symbol)
|
||||
|
||||
if not pivot_levels:
|
||||
logger.warning(f"No pivot points available for {symbol}")
|
||||
return np.zeros(250, dtype=np.float32)
|
||||
|
||||
# Use Williams Market Structure to extract ML features
|
||||
williams = self.williams_structure[symbol]
|
||||
features = williams.get_pivot_features_for_ml(symbol)
|
||||
|
||||
return features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pivot features for ML: {e}")
|
||||
return np.zeros(250, dtype=np.float32)
|
||||
|
||||
def get_market_structure_summary(self, symbol: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current market structure summary for dashboard display
|
||||
|
||||
Returns:
|
||||
Dictionary containing market structure information
|
||||
"""
|
||||
try:
|
||||
# Ensure we have fresh pivot points
|
||||
pivot_levels = self.calculate_williams_pivot_points(symbol)
|
||||
|
||||
if not pivot_levels:
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'levels': {},
|
||||
'overall_trend': 'sideways',
|
||||
'overall_strength': 0.0,
|
||||
'last_update': datetime.now().isoformat(),
|
||||
'error': 'No pivot points available'
|
||||
}
|
||||
|
||||
# Use Williams Market Structure to get summary
|
||||
williams = self.williams_structure[symbol]
|
||||
structure = williams.get_current_market_structure()
|
||||
structure['symbol'] = symbol
|
||||
|
||||
return structure
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting market structure summary for {symbol}: {e}")
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'levels': {},
|
||||
'overall_trend': 'sideways',
|
||||
'overall_strength': 0.0,
|
||||
'last_update': datetime.now().isoformat(),
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
def get_recent_pivot_points(self, symbol: str, level: int = 1, count: int = 10) -> List[PivotPoint]:
|
||||
"""
|
||||
Get recent pivot points for a specific level
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
level: Pivot level (1-5)
|
||||
count: Number of recent pivots to return
|
||||
|
||||
Returns:
|
||||
List of recent pivot points
|
||||
"""
|
||||
try:
|
||||
pivot_levels = self.calculate_williams_pivot_points(symbol)
|
||||
|
||||
if level not in pivot_levels:
|
||||
return []
|
||||
|
||||
trend_level = pivot_levels[level]
|
||||
recent_pivots = trend_level.pivot_points[-count:] if len(trend_level.pivot_points) >= count else trend_level.pivot_points
|
||||
|
||||
return recent_pivots
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting recent pivot points for {symbol} level {level}: {e}")
|
||||
return []
|
||||
|
||||
def get_price_at_index(self, symbol: str, index: int, timeframe: str = '1m') -> Optional[float]:
|
||||
"""Get price at specific index for backtesting"""
|
||||
try:
|
||||
|
@ -136,6 +136,11 @@ class TradingOrchestrator:
|
||||
self.recent_decisions: Dict[str, List[TradingDecision]] = {} # {symbol: List[TradingDecision]}
|
||||
self.model_performance: Dict[str, Dict[str, Any]] = {} # {model_name: {'correct': int, 'total': int, 'accuracy': float}}
|
||||
|
||||
# Signal rate limiting to prevent spam
|
||||
self.last_signal_time: Dict[str, Dict[str, datetime]] = {} # {symbol: {action: datetime}}
|
||||
self.min_signal_interval = timedelta(seconds=30) # Minimum 30 seconds between same signals
|
||||
self.last_confirmed_signal: Dict[str, Dict[str, Any]] = {} # {symbol: {action, timestamp, confidence}}
|
||||
|
||||
# Signal accumulation for trend confirmation
|
||||
self.signal_accumulator: Dict[str, List[Dict]] = {} # {symbol: List[signal_data]}
|
||||
self.required_confirmations = 3 # Number of consistent signals needed
|
||||
@ -871,6 +876,22 @@ class TradingOrchestrator:
|
||||
'CNN': self.config.orchestrator.get('cnn_weight', 0.7),
|
||||
'RL': self.config.orchestrator.get('rl_weight', 0.3)
|
||||
}
|
||||
|
||||
# Add weights for specific models if they exist
|
||||
if hasattr(self, 'cnn_model') and self.cnn_model:
|
||||
self.model_weights["enhanced_cnn"] = 0.4
|
||||
|
||||
# Only add DQN agent weight if it exists
|
||||
if hasattr(self, 'rl_agent') and self.rl_agent:
|
||||
self.model_weights["dqn_agent"] = 0.3
|
||||
|
||||
# Add COB RL model weight if it exists
|
||||
if hasattr(self, 'cob_rl_agent') and self.cob_rl_agent:
|
||||
self.model_weights["cob_rl_model"] = 0.2
|
||||
|
||||
# Add extrema trainer weight if it exists
|
||||
if hasattr(self, 'extrema_trainer') and self.extrema_trainer:
|
||||
self.model_weights["extrema_trainer"] = 0.15
|
||||
|
||||
def register_model(self, model: ModelInterface, weight: Optional[float] = None) -> bool:
|
||||
"""Register a new model with the orchestrator"""
|
||||
@ -1960,10 +1981,27 @@ class TradingOrchestrator:
|
||||
logger.info("Trading executor set for position tracking and P&L feedback")
|
||||
|
||||
def _check_signal_confirmation(self, symbol: str, signal_data: Dict) -> Optional[str]:
|
||||
"""Check if we have enough signal confirmations for trend confirmation"""
|
||||
"""Check if we have enough signal confirmations for trend confirmation with rate limiting"""
|
||||
try:
|
||||
# Clean up expired signals
|
||||
current_time = signal_data['timestamp']
|
||||
action = signal_data['action']
|
||||
|
||||
# Initialize signal tracking for this symbol if needed
|
||||
if symbol not in self.last_signal_time:
|
||||
self.last_signal_time[symbol] = {}
|
||||
if symbol not in self.last_confirmed_signal:
|
||||
self.last_confirmed_signal[symbol] = {}
|
||||
|
||||
# RATE LIMITING: Check if we recently confirmed the same signal
|
||||
if action in self.last_confirmed_signal[symbol]:
|
||||
last_confirmed = self.last_confirmed_signal[symbol][action]
|
||||
time_since_last = current_time - last_confirmed['timestamp']
|
||||
if time_since_last < self.min_signal_interval:
|
||||
logger.debug(f"Rate limiting: {action} signal for {symbol} too recent "
|
||||
f"({time_since_last.total_seconds():.1f}s < {self.min_signal_interval.total_seconds()}s)")
|
||||
return None
|
||||
|
||||
# Clean up expired signals
|
||||
self.signal_accumulator[symbol] = [
|
||||
s for s in self.signal_accumulator[symbol]
|
||||
if (current_time - s['timestamp']).total_seconds() < self.signal_timeout_seconds
|
||||
@ -1982,8 +2020,8 @@ class TradingOrchestrator:
|
||||
|
||||
# Count action consensus
|
||||
action_counts = {}
|
||||
for action in actions:
|
||||
action_counts[action] = action_counts.get(action, 0) + 1
|
||||
for action_item in actions:
|
||||
action_counts[action_item] = action_counts.get(action_item, 0) + 1
|
||||
|
||||
# Find dominant action
|
||||
dominant_action = max(action_counts, key=action_counts.get)
|
||||
@ -1991,8 +2029,24 @@ class TradingOrchestrator:
|
||||
|
||||
# Require at least 2/3 consensus
|
||||
if consensus_count >= max(2, self.required_confirmations * 0.67):
|
||||
# ADDITIONAL RATE LIMITING: Don't confirm if we just confirmed the same action
|
||||
if dominant_action in self.last_confirmed_signal[symbol]:
|
||||
last_confirmed = self.last_confirmed_signal[symbol][dominant_action]
|
||||
time_since_last = current_time - last_confirmed['timestamp']
|
||||
if time_since_last < self.min_signal_interval:
|
||||
logger.debug(f"Rate limiting: Preventing duplicate {dominant_action} confirmation for {symbol}")
|
||||
return None
|
||||
|
||||
# Record this confirmation
|
||||
self.last_confirmed_signal[symbol][dominant_action] = {
|
||||
'timestamp': current_time,
|
||||
'confidence': signal_data['confidence']
|
||||
}
|
||||
|
||||
# Clear accumulator after confirmation
|
||||
self.signal_accumulator[symbol] = []
|
||||
|
||||
logger.info(f"Signal confirmed after rate limiting: {dominant_action} for {symbol}")
|
||||
return dominant_action
|
||||
|
||||
return None
|
||||
|
710
core/overnight_training_coordinator.py
Normal file
710
core/overnight_training_coordinator.py
Normal file
@ -0,0 +1,710 @@
|
||||
"""
|
||||
Overnight Training Coordinator
|
||||
|
||||
This module coordinates comprehensive training for CNN and COB RL models during overnight sessions.
|
||||
It ensures that:
|
||||
1. Training passes occur on each signal when predictions change
|
||||
2. Trades are executed and recorded in simulation mode
|
||||
3. Performance statistics are tracked and logged
|
||||
4. Models learn from both successful and unsuccessful trades
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TrainingSession:
|
||||
"""Represents a training session for a model"""
|
||||
model_name: str
|
||||
symbol: str
|
||||
start_time: datetime
|
||||
end_time: Optional[datetime] = None
|
||||
training_samples: int = 0
|
||||
initial_loss: Optional[float] = None
|
||||
final_loss: Optional[float] = None
|
||||
improvement: Optional[float] = None
|
||||
trades_executed: int = 0
|
||||
successful_trades: int = 0
|
||||
total_pnl: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class SignalTradeRecord:
|
||||
"""Records a signal and its corresponding trade execution"""
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
signal_action: str
|
||||
signal_confidence: float
|
||||
model_source: str
|
||||
executed: bool = False
|
||||
execution_price: Optional[float] = None
|
||||
trade_pnl: Optional[float] = None
|
||||
training_triggered: bool = False
|
||||
training_loss: Optional[float] = None
|
||||
|
||||
class OvernightTrainingCoordinator:
|
||||
"""
|
||||
Coordinates comprehensive overnight training for all models
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator, data_provider, trading_executor, dashboard=None):
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = data_provider
|
||||
self.trading_executor = trading_executor
|
||||
self.dashboard = dashboard
|
||||
|
||||
# Training configuration
|
||||
self.config = {
|
||||
'training_on_signal_change': True, # Train when prediction changes
|
||||
'min_confidence_for_trade': 0.3, # Minimum confidence to execute trade
|
||||
'max_trades_per_hour': 20, # Rate limiting
|
||||
'training_batch_size': 32, # Training batch size
|
||||
'performance_tracking_window': 100, # Number of trades to track for performance
|
||||
'model_checkpoint_interval': 50, # Save checkpoints every N trades
|
||||
}
|
||||
|
||||
# State tracking
|
||||
self.is_running = False
|
||||
self.training_thread = None
|
||||
self.last_predictions: Dict[str, Dict[str, Any]] = {} # {symbol: {model: prediction}}
|
||||
self.signal_trade_records: deque = deque(maxlen=1000)
|
||||
self.training_sessions: Dict[str, TrainingSession] = {}
|
||||
|
||||
# Performance tracking
|
||||
self.performance_stats = {
|
||||
'total_signals': 0,
|
||||
'total_trades': 0,
|
||||
'successful_trades': 0,
|
||||
'total_pnl': 0.0,
|
||||
'training_sessions': 0,
|
||||
'models_trained': set(),
|
||||
'hourly_stats': deque(maxlen=24) # Last 24 hours
|
||||
}
|
||||
|
||||
# Rate limiting
|
||||
self.last_trade_time: Dict[str, datetime] = {}
|
||||
self.trades_this_hour: Dict[str, int] = {}
|
||||
self.hour_reset_time = datetime.now().replace(minute=0, second=0, microsecond=0)
|
||||
|
||||
logger.info("Overnight Training Coordinator initialized")
|
||||
|
||||
def start_overnight_training(self):
|
||||
"""Start the overnight training session"""
|
||||
if self.is_running:
|
||||
logger.warning("Training coordinator already running")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.training_thread = threading.Thread(target=self._training_loop, daemon=True)
|
||||
self.training_thread.start()
|
||||
|
||||
logger.info("🌙 OVERNIGHT TRAINING SESSION STARTED")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Features enabled:")
|
||||
logger.info("✅ CNN training on signal changes")
|
||||
logger.info("✅ COB RL training on market microstructure")
|
||||
logger.info("✅ Trade execution and recording")
|
||||
logger.info("✅ Performance tracking and statistics")
|
||||
logger.info("✅ Model checkpointing")
|
||||
logger.info("=" * 60)
|
||||
|
||||
def stop_overnight_training(self):
|
||||
"""Stop the overnight training session"""
|
||||
self.is_running = False
|
||||
if self.training_thread:
|
||||
self.training_thread.join(timeout=10)
|
||||
|
||||
# Generate final report
|
||||
self._generate_training_report()
|
||||
|
||||
logger.info("🌅 OVERNIGHT TRAINING SESSION COMPLETED")
|
||||
|
||||
def _training_loop(self):
|
||||
"""Main training loop that monitors signals and triggers training"""
|
||||
while self.is_running:
|
||||
try:
|
||||
# Reset hourly counters if needed
|
||||
self._reset_hourly_counters()
|
||||
|
||||
# Process signals from orchestrator
|
||||
self._process_orchestrator_signals()
|
||||
|
||||
# Check for model training opportunities
|
||||
self._check_training_opportunities()
|
||||
|
||||
# Update performance statistics
|
||||
self._update_performance_stats()
|
||||
|
||||
# Sleep briefly to avoid overwhelming the system
|
||||
time.sleep(0.5)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training loop: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
def _process_orchestrator_signals(self):
|
||||
"""Process signals from the orchestrator and trigger training/trading"""
|
||||
try:
|
||||
# Get recent decisions from orchestrator
|
||||
if not hasattr(self.orchestrator, 'recent_decisions'):
|
||||
return
|
||||
|
||||
for symbol in self.orchestrator.symbols:
|
||||
if symbol not in self.orchestrator.recent_decisions:
|
||||
continue
|
||||
|
||||
recent_decisions = self.orchestrator.recent_decisions[symbol]
|
||||
if not recent_decisions:
|
||||
continue
|
||||
|
||||
# Get the latest decision
|
||||
latest_decision = recent_decisions[-1]
|
||||
|
||||
# Check if this is a new signal that requires processing
|
||||
if self._is_new_signal_requiring_action(symbol, latest_decision):
|
||||
self._process_new_signal(symbol, latest_decision)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing orchestrator signals: {e}")
|
||||
|
||||
def _is_new_signal_requiring_action(self, symbol: str, decision) -> bool:
|
||||
"""Check if this signal requires training or trading action"""
|
||||
try:
|
||||
# Get current prediction for comparison
|
||||
current_action = decision.action
|
||||
current_confidence = decision.confidence
|
||||
current_time = decision.timestamp
|
||||
|
||||
# Check if we have a previous prediction for this symbol
|
||||
if symbol not in self.last_predictions:
|
||||
self.last_predictions[symbol] = {}
|
||||
|
||||
# Check if prediction has changed significantly
|
||||
last_action = self.last_predictions[symbol].get('action')
|
||||
last_confidence = self.last_predictions[symbol].get('confidence', 0.0)
|
||||
last_time = self.last_predictions[symbol].get('timestamp')
|
||||
|
||||
# Determine if action is required
|
||||
action_changed = last_action != current_action
|
||||
confidence_changed = abs(current_confidence - last_confidence) > 0.1
|
||||
time_elapsed = not last_time or (current_time - last_time).total_seconds() > 30
|
||||
|
||||
# Update last prediction
|
||||
self.last_predictions[symbol] = {
|
||||
'action': current_action,
|
||||
'confidence': current_confidence,
|
||||
'timestamp': current_time
|
||||
}
|
||||
|
||||
return action_changed or confidence_changed or time_elapsed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking if signal requires action: {e}")
|
||||
return False
|
||||
|
||||
def _process_new_signal(self, symbol: str, decision):
|
||||
"""Process a new signal by triggering training and potentially executing trade"""
|
||||
try:
|
||||
signal_record = SignalTradeRecord(
|
||||
timestamp=decision.timestamp,
|
||||
symbol=symbol,
|
||||
signal_action=decision.action,
|
||||
signal_confidence=decision.confidence,
|
||||
model_source=getattr(decision, 'reasoning', {}).get('primary_model', 'orchestrator')
|
||||
)
|
||||
|
||||
# 1. Trigger training on signal change
|
||||
if self.config['training_on_signal_change']:
|
||||
training_loss = self._trigger_model_training(symbol, decision)
|
||||
signal_record.training_triggered = True
|
||||
signal_record.training_loss = training_loss
|
||||
|
||||
# 2. Execute trade if confidence is sufficient
|
||||
if (decision.confidence >= self.config['min_confidence_for_trade'] and
|
||||
decision.action in ['BUY', 'SELL'] and
|
||||
self._can_execute_trade(symbol)):
|
||||
|
||||
trade_executed, execution_price, trade_pnl = self._execute_signal_trade(symbol, decision)
|
||||
signal_record.executed = trade_executed
|
||||
signal_record.execution_price = execution_price
|
||||
signal_record.trade_pnl = trade_pnl
|
||||
|
||||
# Update performance stats
|
||||
self.performance_stats['total_trades'] += 1
|
||||
if trade_pnl and trade_pnl > 0:
|
||||
self.performance_stats['successful_trades'] += 1
|
||||
if trade_pnl:
|
||||
self.performance_stats['total_pnl'] += trade_pnl
|
||||
|
||||
# 3. Record the signal
|
||||
self.signal_trade_records.append(signal_record)
|
||||
self.performance_stats['total_signals'] += 1
|
||||
|
||||
# 4. Log the action
|
||||
status = "EXECUTED" if signal_record.executed else "SIGNAL_ONLY"
|
||||
logger.info(f"[{status}] {symbol} {decision.action} "
|
||||
f"(conf: {decision.confidence:.3f}, "
|
||||
f"training: {'✅' if signal_record.training_triggered else '❌'}, "
|
||||
f"pnl: {signal_record.trade_pnl:.2f if signal_record.trade_pnl else 'N/A'})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing new signal for {symbol}: {e}")
|
||||
|
||||
def _trigger_model_training(self, symbol: str, decision) -> Optional[float]:
|
||||
"""Trigger training for all relevant models"""
|
||||
try:
|
||||
training_losses = []
|
||||
|
||||
# 1. Train CNN model
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
cnn_loss = self._train_cnn_model(symbol, decision)
|
||||
if cnn_loss is not None:
|
||||
training_losses.append(cnn_loss)
|
||||
self.performance_stats['models_trained'].add('CNN')
|
||||
|
||||
# 2. Train COB RL model
|
||||
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
||||
cob_rl_loss = self._train_cob_rl_model(symbol, decision)
|
||||
if cob_rl_loss is not None:
|
||||
training_losses.append(cob_rl_loss)
|
||||
self.performance_stats['models_trained'].add('COB_RL')
|
||||
|
||||
# 3. Train DQN model
|
||||
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
dqn_loss = self._train_dqn_model(symbol, decision)
|
||||
if dqn_loss is not None:
|
||||
training_losses.append(dqn_loss)
|
||||
self.performance_stats['models_trained'].add('DQN')
|
||||
|
||||
# Return average loss
|
||||
return np.mean(training_losses) if training_losses else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering model training: {e}")
|
||||
return None
|
||||
|
||||
def _train_cnn_model(self, symbol: str, decision) -> Optional[float]:
|
||||
"""Train CNN model on current market data"""
|
||||
try:
|
||||
# Get market data for training
|
||||
df = self.data_provider.get_historical_data(symbol, '1m', limit=100)
|
||||
if df is None or len(df) < 50:
|
||||
return None
|
||||
|
||||
# Prepare training data
|
||||
features = self._prepare_cnn_features(df)
|
||||
target = self._prepare_cnn_target(decision)
|
||||
|
||||
if features is None or target is None:
|
||||
return None
|
||||
|
||||
# Train the model
|
||||
if hasattr(self.orchestrator.cnn_model, 'train_on_batch'):
|
||||
loss = self.orchestrator.cnn_model.train_on_batch(features, target)
|
||||
logger.debug(f"CNN training loss for {symbol}: {loss:.4f}")
|
||||
return loss
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training CNN model: {e}")
|
||||
return None
|
||||
|
||||
def _train_cob_rl_model(self, symbol: str, decision) -> Optional[float]:
|
||||
"""Train COB RL model on market microstructure data"""
|
||||
try:
|
||||
# Get COB data if available
|
||||
if not hasattr(self.dashboard, 'latest_cob_data') or symbol not in self.dashboard.latest_cob_data:
|
||||
return None
|
||||
|
||||
cob_data = self.dashboard.latest_cob_data[symbol]
|
||||
|
||||
# Prepare COB features
|
||||
features = self._prepare_cob_features(cob_data)
|
||||
reward = self._calculate_cob_reward(decision)
|
||||
|
||||
if features is None:
|
||||
return None
|
||||
|
||||
# Train the model
|
||||
if hasattr(self.orchestrator.cob_rl_agent, 'train'):
|
||||
loss = self.orchestrator.cob_rl_agent.train(features, reward)
|
||||
logger.debug(f"COB RL training loss for {symbol}: {loss:.4f}")
|
||||
return loss
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training COB RL model: {e}")
|
||||
return None
|
||||
|
||||
def _train_dqn_model(self, symbol: str, decision) -> Optional[float]:
|
||||
"""Train DQN model on trading decision"""
|
||||
try:
|
||||
# Get state features
|
||||
state_features = self._prepare_dqn_state(symbol)
|
||||
action = self._map_action_to_index(decision.action)
|
||||
reward = decision.confidence # Use confidence as immediate reward
|
||||
|
||||
if state_features is None:
|
||||
return None
|
||||
|
||||
# Add experience to replay buffer
|
||||
if hasattr(self.orchestrator.rl_agent, 'remember'):
|
||||
# We'll use a dummy next_state for now
|
||||
next_state = state_features # Simplified
|
||||
done = False
|
||||
self.orchestrator.rl_agent.remember(state_features, action, reward, next_state, done)
|
||||
|
||||
# Train if we have enough experiences
|
||||
if hasattr(self.orchestrator.rl_agent, 'replay'):
|
||||
loss = self.orchestrator.rl_agent.replay()
|
||||
if loss is not None:
|
||||
logger.debug(f"DQN training loss for {symbol}: {loss:.4f}")
|
||||
return loss
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training DQN model: {e}")
|
||||
return None
|
||||
|
||||
def _execute_signal_trade(self, symbol: str, decision) -> Tuple[bool, Optional[float], Optional[float]]:
|
||||
"""Execute a trade based on the signal"""
|
||||
try:
|
||||
if not self.trading_executor:
|
||||
return False, None, None
|
||||
|
||||
# Get current price
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
if not current_price:
|
||||
return False, None, None
|
||||
|
||||
# Execute the trade
|
||||
success = self.trading_executor.execute_signal(
|
||||
symbol=symbol,
|
||||
action=decision.action,
|
||||
confidence=decision.confidence,
|
||||
current_price=current_price
|
||||
)
|
||||
|
||||
if success:
|
||||
# Calculate PnL (simplified - in real implementation this would be more complex)
|
||||
trade_pnl = self._calculate_trade_pnl(symbol, decision.action, current_price)
|
||||
|
||||
# Update rate limiting
|
||||
self.last_trade_time[symbol] = datetime.now()
|
||||
if symbol not in self.trades_this_hour:
|
||||
self.trades_this_hour[symbol] = 0
|
||||
self.trades_this_hour[symbol] += 1
|
||||
|
||||
return True, current_price, trade_pnl
|
||||
|
||||
return False, None, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing signal trade: {e}")
|
||||
return False, None, None
|
||||
|
||||
def _can_execute_trade(self, symbol: str) -> bool:
|
||||
"""Check if we can execute a trade based on rate limiting"""
|
||||
try:
|
||||
# Check hourly limit
|
||||
if symbol in self.trades_this_hour:
|
||||
if self.trades_this_hour[symbol] >= self.config['max_trades_per_hour']:
|
||||
return False
|
||||
|
||||
# Check minimum time between trades (30 seconds)
|
||||
if symbol in self.last_trade_time:
|
||||
time_since_last = (datetime.now() - self.last_trade_time[symbol]).total_seconds()
|
||||
if time_since_last < 30:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking if can execute trade: {e}")
|
||||
return False
|
||||
|
||||
def _prepare_cnn_features(self, df) -> Optional[np.ndarray]:
|
||||
"""Prepare features for CNN training"""
|
||||
try:
|
||||
# Use OHLCV data as features
|
||||
features = df[['open', 'high', 'low', 'close', 'volume']].values
|
||||
|
||||
# Normalize features
|
||||
features = (features - features.mean(axis=0)) / (features.std(axis=0) + 1e-8)
|
||||
|
||||
# Reshape for CNN (add batch and channel dimensions)
|
||||
features = features.reshape(1, features.shape[0], features.shape[1])
|
||||
|
||||
return features.astype(np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing CNN features: {e}")
|
||||
return None
|
||||
|
||||
def _prepare_cnn_target(self, decision) -> Optional[np.ndarray]:
|
||||
"""Prepare target for CNN training"""
|
||||
try:
|
||||
# Map action to target
|
||||
action_map = {'BUY': [1, 0, 0], 'SELL': [0, 1, 0], 'HOLD': [0, 0, 1]}
|
||||
target = action_map.get(decision.action, [0, 0, 1])
|
||||
|
||||
return np.array([target], dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing CNN target: {e}")
|
||||
return None
|
||||
|
||||
def _prepare_cob_features(self, cob_data) -> Optional[np.ndarray]:
|
||||
"""Prepare COB features for training"""
|
||||
try:
|
||||
# Extract key COB features
|
||||
features = []
|
||||
|
||||
# Order book imbalance
|
||||
imbalance = cob_data.get('stats', {}).get('imbalance', 0)
|
||||
features.append(imbalance)
|
||||
|
||||
# Bid/Ask liquidity
|
||||
bid_liquidity = cob_data.get('stats', {}).get('bid_liquidity', 0)
|
||||
ask_liquidity = cob_data.get('stats', {}).get('ask_liquidity', 0)
|
||||
features.extend([bid_liquidity, ask_liquidity])
|
||||
|
||||
# Spread
|
||||
spread = cob_data.get('stats', {}).get('spread_bps', 0)
|
||||
features.append(spread)
|
||||
|
||||
# Pad to expected size (2000 features for COB RL)
|
||||
while len(features) < 2000:
|
||||
features.append(0.0)
|
||||
|
||||
return np.array(features[:2000], dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing COB features: {e}")
|
||||
return None
|
||||
|
||||
def _calculate_cob_reward(self, decision) -> float:
|
||||
"""Calculate reward for COB RL training"""
|
||||
try:
|
||||
# Use confidence as base reward
|
||||
base_reward = decision.confidence
|
||||
|
||||
# Adjust based on action
|
||||
if decision.action in ['BUY', 'SELL']:
|
||||
return base_reward
|
||||
else:
|
||||
return base_reward * 0.1 # Lower reward for HOLD
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating COB reward: {e}")
|
||||
return 0.0
|
||||
|
||||
def _prepare_dqn_state(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Prepare state features for DQN training"""
|
||||
try:
|
||||
# Get market data
|
||||
df = self.data_provider.get_historical_data(symbol, '1m', limit=50)
|
||||
if df is None or len(df) < 10:
|
||||
return None
|
||||
|
||||
# Prepare basic features
|
||||
features = []
|
||||
|
||||
# Price features
|
||||
close_prices = df['close'].values
|
||||
features.extend(close_prices[-10:]) # Last 10 prices
|
||||
|
||||
# Technical indicators
|
||||
if len(close_prices) >= 20:
|
||||
sma_20 = np.mean(close_prices[-20:])
|
||||
features.append(sma_20)
|
||||
else:
|
||||
features.append(close_prices[-1])
|
||||
|
||||
# Volume features
|
||||
volumes = df['volume'].values
|
||||
features.extend(volumes[-5:]) # Last 5 volumes
|
||||
|
||||
# Pad to expected size (100 features for DQN)
|
||||
while len(features) < 100:
|
||||
features.append(0.0)
|
||||
|
||||
return np.array(features[:100], dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing DQN state: {e}")
|
||||
return None
|
||||
|
||||
def _map_action_to_index(self, action: str) -> int:
|
||||
"""Map action string to index"""
|
||||
action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
|
||||
return action_map.get(action, 2)
|
||||
|
||||
def _calculate_trade_pnl(self, symbol: str, action: str, price: float) -> float:
|
||||
"""Calculate simplified PnL for a trade"""
|
||||
try:
|
||||
# This is a simplified PnL calculation
|
||||
# In a real implementation, this would track actual position changes
|
||||
|
||||
# Get previous price for comparison
|
||||
df = self.data_provider.get_historical_data(symbol, '1m', limit=2)
|
||||
if df is None or len(df) < 2:
|
||||
return 0.0
|
||||
|
||||
prev_price = df['close'].iloc[-2]
|
||||
current_price = price
|
||||
|
||||
# Calculate price change
|
||||
price_change = (current_price - prev_price) / prev_price
|
||||
|
||||
# Apply action direction
|
||||
if action == 'BUY':
|
||||
return price_change * 100 # Simplified PnL
|
||||
elif action == 'SELL':
|
||||
return -price_change * 100 # Simplified PnL
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating trade PnL: {e}")
|
||||
return 0.0
|
||||
|
||||
def _check_training_opportunities(self):
|
||||
"""Check for additional training opportunities"""
|
||||
try:
|
||||
# Check if we should save model checkpoints
|
||||
if (self.performance_stats['total_trades'] > 0 and
|
||||
self.performance_stats['total_trades'] % self.config['model_checkpoint_interval'] == 0):
|
||||
self._save_model_checkpoints()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking training opportunities: {e}")
|
||||
|
||||
def _save_model_checkpoints(self):
|
||||
"""Save model checkpoints"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# Save CNN model
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
if hasattr(self.orchestrator.cnn_model, 'save'):
|
||||
checkpoint_path = f"models/overnight_cnn_{timestamp}.pth"
|
||||
self.orchestrator.cnn_model.save(checkpoint_path)
|
||||
logger.info(f"CNN checkpoint saved: {checkpoint_path}")
|
||||
|
||||
# Save COB RL model
|
||||
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
||||
if hasattr(self.orchestrator.cob_rl_agent, 'save_model'):
|
||||
checkpoint_path = f"models/overnight_cob_rl_{timestamp}.pth"
|
||||
self.orchestrator.cob_rl_agent.save_model(checkpoint_path)
|
||||
logger.info(f"COB RL checkpoint saved: {checkpoint_path}")
|
||||
|
||||
# Save DQN model
|
||||
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
if hasattr(self.orchestrator.rl_agent, 'save'):
|
||||
checkpoint_path = f"models/overnight_dqn_{timestamp}.pth"
|
||||
self.orchestrator.rl_agent.save(checkpoint_path)
|
||||
logger.info(f"DQN checkpoint saved: {checkpoint_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model checkpoints: {e}")
|
||||
|
||||
def _reset_hourly_counters(self):
|
||||
"""Reset hourly trade counters"""
|
||||
try:
|
||||
current_hour = datetime.now().replace(minute=0, second=0, microsecond=0)
|
||||
if current_hour > self.hour_reset_time:
|
||||
self.trades_this_hour = {}
|
||||
self.hour_reset_time = current_hour
|
||||
logger.info("Hourly trade counters reset")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error resetting hourly counters: {e}")
|
||||
|
||||
def _update_performance_stats(self):
|
||||
"""Update performance statistics"""
|
||||
try:
|
||||
# Update hourly stats every hour
|
||||
current_hour = datetime.now().replace(minute=0, second=0, microsecond=0)
|
||||
|
||||
# Check if we need to add a new hourly stat
|
||||
if not self.performance_stats['hourly_stats'] or self.performance_stats['hourly_stats'][-1]['hour'] != current_hour:
|
||||
hourly_stat = {
|
||||
'hour': current_hour,
|
||||
'signals': 0,
|
||||
'trades': 0,
|
||||
'pnl': 0.0,
|
||||
'models_trained': set()
|
||||
}
|
||||
self.performance_stats['hourly_stats'].append(hourly_stat)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating performance stats: {e}")
|
||||
|
||||
def _generate_training_report(self):
|
||||
"""Generate a comprehensive training report"""
|
||||
try:
|
||||
logger.info("=" * 80)
|
||||
logger.info("🌅 OVERNIGHT TRAINING SESSION REPORT")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Overall statistics
|
||||
logger.info(f"📊 OVERALL STATISTICS:")
|
||||
logger.info(f" Total Signals Processed: {self.performance_stats['total_signals']}")
|
||||
logger.info(f" Total Trades Executed: {self.performance_stats['total_trades']}")
|
||||
logger.info(f" Successful Trades: {self.performance_stats['successful_trades']}")
|
||||
logger.info(f" Success Rate: {(self.performance_stats['successful_trades'] / max(1, self.performance_stats['total_trades']) * 100):.1f}%")
|
||||
logger.info(f" Total P&L: ${self.performance_stats['total_pnl']:.2f}")
|
||||
|
||||
# Model training statistics
|
||||
logger.info(f"🧠 MODEL TRAINING:")
|
||||
logger.info(f" Models Trained: {', '.join(self.performance_stats['models_trained'])}")
|
||||
logger.info(f" Training Sessions: {len(self.training_sessions)}")
|
||||
|
||||
# Recent performance
|
||||
if self.signal_trade_records:
|
||||
recent_records = list(self.signal_trade_records)[-20:] # Last 20 records
|
||||
executed_trades = [r for r in recent_records if r.executed]
|
||||
successful_trades = [r for r in executed_trades if r.trade_pnl and r.trade_pnl > 0]
|
||||
|
||||
logger.info(f"📈 RECENT PERFORMANCE (Last 20 signals):")
|
||||
logger.info(f" Signals: {len(recent_records)}")
|
||||
logger.info(f" Executed: {len(executed_trades)}")
|
||||
logger.info(f" Successful: {len(successful_trades)}")
|
||||
if executed_trades:
|
||||
recent_pnl = sum(r.trade_pnl for r in executed_trades if r.trade_pnl)
|
||||
logger.info(f" Recent P&L: ${recent_pnl:.2f}")
|
||||
|
||||
logger.info("=" * 80)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating training report: {e}")
|
||||
|
||||
def get_performance_summary(self) -> Dict[str, Any]:
|
||||
"""Get current performance summary"""
|
||||
try:
|
||||
return {
|
||||
'total_signals': self.performance_stats['total_signals'],
|
||||
'total_trades': self.performance_stats['total_trades'],
|
||||
'successful_trades': self.performance_stats['successful_trades'],
|
||||
'success_rate': (self.performance_stats['successful_trades'] / max(1, self.performance_stats['total_trades'])),
|
||||
'total_pnl': self.performance_stats['total_pnl'],
|
||||
'models_trained': list(self.performance_stats['models_trained']),
|
||||
'is_running': self.is_running,
|
||||
'recent_signals': len(self.signal_trade_records)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting performance summary: {e}")
|
||||
return {}
|
555
core/williams_market_structure.py
Normal file
555
core/williams_market_structure.py
Normal file
@ -0,0 +1,555 @@
|
||||
"""
|
||||
Williams Market Structure Implementation
|
||||
|
||||
This module implements Larry Williams' market structure analysis with recursive pivot points.
|
||||
The system identifies swing highs and swing lows, then uses these pivot points to determine
|
||||
higher-level trends recursively.
|
||||
|
||||
Key Features:
|
||||
- Recursive pivot point calculation (5 levels)
|
||||
- Swing high/low identification
|
||||
- Trend direction and strength analysis
|
||||
- Integration with CNN model for pivot prediction
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class PivotPoint:
|
||||
"""Represents a pivot point in the market structure"""
|
||||
timestamp: datetime
|
||||
price: float
|
||||
pivot_type: str # 'high' or 'low'
|
||||
level: int # Pivot level (1-5)
|
||||
index: int # Index in the original data
|
||||
strength: float = 0.0 # Strength of the pivot (0.0 to 1.0)
|
||||
confirmed: bool = False # Whether the pivot is confirmed
|
||||
|
||||
@dataclass
|
||||
class TrendLevel:
|
||||
"""Represents a trend level in the Williams Market Structure"""
|
||||
level: int
|
||||
pivot_points: List[PivotPoint]
|
||||
trend_direction: str # 'up', 'down', 'sideways'
|
||||
trend_strength: float # 0.0 to 1.0
|
||||
last_pivot_high: Optional[PivotPoint] = None
|
||||
last_pivot_low: Optional[PivotPoint] = None
|
||||
|
||||
class WilliamsMarketStructure:
|
||||
"""
|
||||
Implementation of Larry Williams Market Structure Analysis
|
||||
|
||||
This class implements the recursive pivot point calculation system where:
|
||||
1. Level 1: Direct swing highs/lows from 1s OHLCV data
|
||||
2. Level 2-5: Recursive analysis using previous level's pivot points as "candles"
|
||||
"""
|
||||
|
||||
def __init__(self, min_pivot_distance: int = 3):
|
||||
"""
|
||||
Initialize Williams Market Structure analyzer
|
||||
|
||||
Args:
|
||||
min_pivot_distance: Minimum distance between pivot points
|
||||
"""
|
||||
self.min_pivot_distance = min_pivot_distance
|
||||
self.pivot_levels: Dict[int, TrendLevel] = {}
|
||||
self.max_levels = 5
|
||||
|
||||
logger.info(f"Williams Market Structure initialized with {self.max_levels} levels")
|
||||
|
||||
def calculate_recursive_pivot_points(self, ohlcv_data: np.ndarray) -> Dict[int, TrendLevel]:
|
||||
"""
|
||||
Calculate recursive pivot points following Williams Market Structure methodology
|
||||
|
||||
Args:
|
||||
ohlcv_data: OHLCV data array with shape (N, 6) [timestamp, O, H, L, C, V]
|
||||
|
||||
Returns:
|
||||
Dictionary of trend levels with pivot points
|
||||
"""
|
||||
try:
|
||||
if len(ohlcv_data) < self.min_pivot_distance * 2 + 1:
|
||||
logger.warning(f"Insufficient data for pivot calculation: {len(ohlcv_data)} bars")
|
||||
return {}
|
||||
|
||||
# Convert to DataFrame for easier processing
|
||||
df = pd.DataFrame(ohlcv_data, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||||
|
||||
# Initialize pivot levels
|
||||
self.pivot_levels = {}
|
||||
|
||||
# Level 1: Calculate pivot points from raw OHLCV data
|
||||
level_1_pivots = self._calculate_level_1_pivots(df)
|
||||
if level_1_pivots:
|
||||
self.pivot_levels[1] = TrendLevel(
|
||||
level=1,
|
||||
pivot_points=level_1_pivots,
|
||||
trend_direction=self._determine_trend_direction(level_1_pivots),
|
||||
trend_strength=self._calculate_trend_strength(level_1_pivots)
|
||||
)
|
||||
|
||||
# Levels 2-5: Recursive calculation using previous level's pivots
|
||||
for level in range(2, self.max_levels + 1):
|
||||
higher_level_pivots = self._calculate_higher_level_pivots(level)
|
||||
if higher_level_pivots:
|
||||
self.pivot_levels[level] = TrendLevel(
|
||||
level=level,
|
||||
pivot_points=higher_level_pivots,
|
||||
trend_direction=self._determine_trend_direction(higher_level_pivots),
|
||||
trend_strength=self._calculate_trend_strength(higher_level_pivots)
|
||||
)
|
||||
else:
|
||||
break # No more higher level pivots possible
|
||||
|
||||
logger.debug(f"Calculated {len(self.pivot_levels)} pivot levels")
|
||||
return self.pivot_levels
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating recursive pivot points: {e}")
|
||||
return {}
|
||||
|
||||
def _calculate_level_1_pivots(self, df: pd.DataFrame) -> List[PivotPoint]:
|
||||
"""
|
||||
Calculate Level 1 pivot points from raw OHLCV data
|
||||
|
||||
A swing high is a candle with lower highs on both sides
|
||||
A swing low is a candle with higher lows on both sides
|
||||
"""
|
||||
pivots = []
|
||||
|
||||
try:
|
||||
for i in range(self.min_pivot_distance, len(df) - self.min_pivot_distance):
|
||||
current_high = df.iloc[i]['high']
|
||||
current_low = df.iloc[i]['low']
|
||||
current_timestamp = df.iloc[i]['timestamp']
|
||||
|
||||
# Check for swing high
|
||||
is_swing_high = True
|
||||
for j in range(i - self.min_pivot_distance, i + self.min_pivot_distance + 1):
|
||||
if j != i and df.iloc[j]['high'] >= current_high:
|
||||
is_swing_high = False
|
||||
break
|
||||
|
||||
if is_swing_high:
|
||||
pivot = PivotPoint(
|
||||
timestamp=current_timestamp,
|
||||
price=current_high,
|
||||
pivot_type='high',
|
||||
level=1,
|
||||
index=i,
|
||||
strength=self._calculate_pivot_strength(df, i, 'high'),
|
||||
confirmed=True
|
||||
)
|
||||
pivots.append(pivot)
|
||||
continue
|
||||
|
||||
# Check for swing low
|
||||
is_swing_low = True
|
||||
for j in range(i - self.min_pivot_distance, i + self.min_pivot_distance + 1):
|
||||
if j != i and df.iloc[j]['low'] <= current_low:
|
||||
is_swing_low = False
|
||||
break
|
||||
|
||||
if is_swing_low:
|
||||
pivot = PivotPoint(
|
||||
timestamp=current_timestamp,
|
||||
price=current_low,
|
||||
pivot_type='low',
|
||||
level=1,
|
||||
index=i,
|
||||
strength=self._calculate_pivot_strength(df, i, 'low'),
|
||||
confirmed=True
|
||||
)
|
||||
pivots.append(pivot)
|
||||
|
||||
logger.debug(f"Level 1: Found {len(pivots)} pivot points")
|
||||
return pivots
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating Level 1 pivots: {e}")
|
||||
return []
|
||||
|
||||
def _calculate_higher_level_pivots(self, level: int) -> List[PivotPoint]:
|
||||
"""
|
||||
Calculate higher level pivot points using previous level's pivots as "candles"
|
||||
|
||||
This is the recursive part of Williams Market Structure where we treat
|
||||
pivot points from the previous level as if they were OHLCV candles
|
||||
"""
|
||||
if level - 1 not in self.pivot_levels:
|
||||
return []
|
||||
|
||||
previous_level_pivots = self.pivot_levels[level - 1].pivot_points
|
||||
if len(previous_level_pivots) < self.min_pivot_distance * 2 + 1:
|
||||
return []
|
||||
|
||||
pivots = []
|
||||
|
||||
try:
|
||||
# Group pivots by type to find swing points
|
||||
highs = [p for p in previous_level_pivots if p.pivot_type == 'high']
|
||||
lows = [p for p in previous_level_pivots if p.pivot_type == 'low']
|
||||
|
||||
# Find swing highs among the high pivots
|
||||
for i in range(self.min_pivot_distance, len(highs) - self.min_pivot_distance):
|
||||
current_pivot = highs[i]
|
||||
|
||||
# Check if this high is surrounded by lower highs
|
||||
is_swing_high = True
|
||||
for j in range(i - self.min_pivot_distance, i + self.min_pivot_distance + 1):
|
||||
if j != i and j < len(highs) and highs[j].price >= current_pivot.price:
|
||||
is_swing_high = False
|
||||
break
|
||||
|
||||
if is_swing_high:
|
||||
pivot = PivotPoint(
|
||||
timestamp=current_pivot.timestamp,
|
||||
price=current_pivot.price,
|
||||
pivot_type='high',
|
||||
level=level,
|
||||
index=current_pivot.index,
|
||||
strength=current_pivot.strength * 0.8, # Reduce strength at higher levels
|
||||
confirmed=True
|
||||
)
|
||||
pivots.append(pivot)
|
||||
|
||||
# Find swing lows among the low pivots
|
||||
for i in range(self.min_pivot_distance, len(lows) - self.min_pivot_distance):
|
||||
current_pivot = lows[i]
|
||||
|
||||
# Check if this low is surrounded by higher lows
|
||||
is_swing_low = True
|
||||
for j in range(i - self.min_pivot_distance, i + self.min_pivot_distance + 1):
|
||||
if j != i and j < len(lows) and lows[j].price <= current_pivot.price:
|
||||
is_swing_low = False
|
||||
break
|
||||
|
||||
if is_swing_low:
|
||||
pivot = PivotPoint(
|
||||
timestamp=current_pivot.timestamp,
|
||||
price=current_pivot.price,
|
||||
pivot_type='low',
|
||||
level=level,
|
||||
index=current_pivot.index,
|
||||
strength=current_pivot.strength * 0.8, # Reduce strength at higher levels
|
||||
confirmed=True
|
||||
)
|
||||
pivots.append(pivot)
|
||||
|
||||
# Sort pivots by timestamp
|
||||
pivots.sort(key=lambda x: x.timestamp)
|
||||
|
||||
logger.debug(f"Level {level}: Found {len(pivots)} pivot points")
|
||||
return pivots
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating Level {level} pivots: {e}")
|
||||
return []
|
||||
|
||||
def _calculate_pivot_strength(self, df: pd.DataFrame, index: int, pivot_type: str) -> float:
|
||||
"""
|
||||
Calculate the strength of a pivot point based on surrounding price action
|
||||
|
||||
Strength is determined by:
|
||||
- Distance from surrounding highs/lows
|
||||
- Volume at the pivot point
|
||||
- Duration of the pivot formation
|
||||
"""
|
||||
try:
|
||||
if pivot_type == 'high':
|
||||
current_price = df.iloc[index]['high']
|
||||
# Calculate average of surrounding highs
|
||||
surrounding_prices = []
|
||||
for i in range(max(0, index - self.min_pivot_distance),
|
||||
min(len(df), index + self.min_pivot_distance + 1)):
|
||||
if i != index:
|
||||
surrounding_prices.append(df.iloc[i]['high'])
|
||||
|
||||
if surrounding_prices:
|
||||
avg_surrounding = np.mean(surrounding_prices)
|
||||
strength = min(1.0, (current_price - avg_surrounding) / avg_surrounding * 10)
|
||||
else:
|
||||
strength = 0.5
|
||||
else: # pivot_type == 'low'
|
||||
current_price = df.iloc[index]['low']
|
||||
# Calculate average of surrounding lows
|
||||
surrounding_prices = []
|
||||
for i in range(max(0, index - self.min_pivot_distance),
|
||||
min(len(df), index + self.min_pivot_distance + 1)):
|
||||
if i != index:
|
||||
surrounding_prices.append(df.iloc[i]['low'])
|
||||
|
||||
if surrounding_prices:
|
||||
avg_surrounding = np.mean(surrounding_prices)
|
||||
strength = min(1.0, (avg_surrounding - current_price) / avg_surrounding * 10)
|
||||
else:
|
||||
strength = 0.5
|
||||
|
||||
# Factor in volume if available
|
||||
if 'volume' in df.columns and df.iloc[index]['volume'] > 0:
|
||||
avg_volume = df['volume'].rolling(window=20, center=True).mean().iloc[index]
|
||||
if avg_volume > 0:
|
||||
volume_factor = min(2.0, df.iloc[index]['volume'] / avg_volume)
|
||||
strength *= volume_factor
|
||||
|
||||
return max(0.0, min(1.0, strength))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating pivot strength: {e}")
|
||||
return 0.5
|
||||
|
||||
def _determine_trend_direction(self, pivots: List[PivotPoint]) -> str:
|
||||
"""
|
||||
Determine the overall trend direction based on pivot points
|
||||
|
||||
Trend is determined by comparing recent highs and lows:
|
||||
- Uptrend: Higher highs and higher lows
|
||||
- Downtrend: Lower highs and lower lows
|
||||
- Sideways: Mixed or insufficient data
|
||||
"""
|
||||
if len(pivots) < 4:
|
||||
return 'sideways'
|
||||
|
||||
try:
|
||||
# Get recent pivots (last 10 or all if less than 10)
|
||||
recent_pivots = pivots[-10:] if len(pivots) >= 10 else pivots
|
||||
|
||||
highs = [p for p in recent_pivots if p.pivot_type == 'high']
|
||||
lows = [p for p in recent_pivots if p.pivot_type == 'low']
|
||||
|
||||
if len(highs) < 2 or len(lows) < 2:
|
||||
return 'sideways'
|
||||
|
||||
# Sort by timestamp
|
||||
highs.sort(key=lambda x: x.timestamp)
|
||||
lows.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# Check for higher highs and higher lows (uptrend)
|
||||
higher_highs = highs[-1].price > highs[-2].price if len(highs) >= 2 else False
|
||||
higher_lows = lows[-1].price > lows[-2].price if len(lows) >= 2 else False
|
||||
|
||||
# Check for lower highs and lower lows (downtrend)
|
||||
lower_highs = highs[-1].price < highs[-2].price if len(highs) >= 2 else False
|
||||
lower_lows = lows[-1].price < lows[-2].price if len(lows) >= 2 else False
|
||||
|
||||
if higher_highs and higher_lows:
|
||||
return 'up'
|
||||
elif lower_highs and lower_lows:
|
||||
return 'down'
|
||||
else:
|
||||
return 'sideways'
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error determining trend direction: {e}")
|
||||
return 'sideways'
|
||||
|
||||
def _calculate_trend_strength(self, pivots: List[PivotPoint]) -> float:
|
||||
"""
|
||||
Calculate the strength of the current trend
|
||||
|
||||
Strength is based on:
|
||||
- Consistency of pivot point progression
|
||||
- Average strength of individual pivots
|
||||
- Number of confirming pivots
|
||||
"""
|
||||
if not pivots:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
# Average individual pivot strengths
|
||||
avg_pivot_strength = np.mean([p.strength for p in pivots])
|
||||
|
||||
# Factor in number of pivots (more pivots = stronger trend)
|
||||
pivot_count_factor = min(1.0, len(pivots) / 10.0)
|
||||
|
||||
# Calculate consistency (how well pivots follow the trend)
|
||||
trend_direction = self._determine_trend_direction(pivots)
|
||||
consistency_score = self._calculate_trend_consistency(pivots, trend_direction)
|
||||
|
||||
# Combine factors
|
||||
trend_strength = (avg_pivot_strength * 0.4 +
|
||||
pivot_count_factor * 0.3 +
|
||||
consistency_score * 0.3)
|
||||
|
||||
return max(0.0, min(1.0, trend_strength))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating trend strength: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_trend_consistency(self, pivots: List[PivotPoint], trend_direction: str) -> float:
|
||||
"""
|
||||
Calculate how consistently the pivots follow the expected trend direction
|
||||
"""
|
||||
if len(pivots) < 4 or trend_direction == 'sideways':
|
||||
return 0.5
|
||||
|
||||
try:
|
||||
highs = [p for p in pivots if p.pivot_type == 'high']
|
||||
lows = [p for p in pivots if p.pivot_type == 'low']
|
||||
|
||||
if len(highs) < 2 or len(lows) < 2:
|
||||
return 0.5
|
||||
|
||||
# Sort by timestamp
|
||||
highs.sort(key=lambda x: x.timestamp)
|
||||
lows.sort(key=lambda x: x.timestamp)
|
||||
|
||||
consistent_moves = 0
|
||||
total_moves = 0
|
||||
|
||||
# Check high-to-high moves
|
||||
for i in range(1, len(highs)):
|
||||
total_moves += 1
|
||||
if trend_direction == 'up' and highs[i].price > highs[i-1].price:
|
||||
consistent_moves += 1
|
||||
elif trend_direction == 'down' and highs[i].price < highs[i-1].price:
|
||||
consistent_moves += 1
|
||||
|
||||
# Check low-to-low moves
|
||||
for i in range(1, len(lows)):
|
||||
total_moves += 1
|
||||
if trend_direction == 'up' and lows[i].price > lows[i-1].price:
|
||||
consistent_moves += 1
|
||||
elif trend_direction == 'down' and lows[i].price < lows[i-1].price:
|
||||
consistent_moves += 1
|
||||
|
||||
if total_moves == 0:
|
||||
return 0.5
|
||||
|
||||
return consistent_moves / total_moves
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating trend consistency: {e}")
|
||||
return 0.5
|
||||
|
||||
def get_pivot_features_for_ml(self, symbol: str = "ETH/USDT") -> np.ndarray:
|
||||
"""
|
||||
Extract pivot point features for machine learning models
|
||||
|
||||
Returns a feature vector containing:
|
||||
- Recent pivot points (price, strength, type)
|
||||
- Trend direction and strength for each level
|
||||
- Time since last pivot for each level
|
||||
|
||||
Total features: 250 (50 features per level * 5 levels)
|
||||
"""
|
||||
features = []
|
||||
|
||||
try:
|
||||
for level in range(1, self.max_levels + 1):
|
||||
level_features = []
|
||||
|
||||
if level in self.pivot_levels:
|
||||
trend_level = self.pivot_levels[level]
|
||||
pivots = trend_level.pivot_points
|
||||
|
||||
# Get last 5 pivots for this level
|
||||
recent_pivots = pivots[-5:] if len(pivots) >= 5 else pivots
|
||||
|
||||
# Pad with zeros if we have fewer than 5 pivots
|
||||
while len(recent_pivots) < 5:
|
||||
recent_pivots.insert(0, PivotPoint(
|
||||
timestamp=datetime.now(),
|
||||
price=0.0,
|
||||
pivot_type='high',
|
||||
level=level,
|
||||
index=0,
|
||||
strength=0.0
|
||||
))
|
||||
|
||||
# Extract features for each pivot (8 features per pivot)
|
||||
for pivot in recent_pivots:
|
||||
level_features.extend([
|
||||
pivot.price,
|
||||
pivot.strength,
|
||||
1.0 if pivot.pivot_type == 'high' else 0.0, # Pivot type
|
||||
float(pivot.level),
|
||||
1.0 if pivot.confirmed else 0.0, # Confirmation status
|
||||
float((datetime.now() - pivot.timestamp).total_seconds() / 3600), # Hours since pivot
|
||||
float(pivot.index), # Position in data
|
||||
0.0 # Reserved for future use
|
||||
])
|
||||
|
||||
# Add trend features (10 features)
|
||||
trend_direction_encoded = {
|
||||
'up': [1.0, 0.0, 0.0],
|
||||
'down': [0.0, 1.0, 0.0],
|
||||
'sideways': [0.0, 0.0, 1.0]
|
||||
}.get(trend_level.trend_direction, [0.0, 0.0, 1.0])
|
||||
|
||||
level_features.extend(trend_direction_encoded)
|
||||
level_features.append(trend_level.trend_strength)
|
||||
level_features.extend([0.0] * 6) # Reserved for future use
|
||||
|
||||
else:
|
||||
# No data for this level, fill with zeros
|
||||
level_features = [0.0] * 50
|
||||
|
||||
features.extend(level_features)
|
||||
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting pivot features for ML: {e}")
|
||||
return np.zeros(250, dtype=np.float32)
|
||||
|
||||
def get_current_market_structure(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current market structure summary for dashboard display
|
||||
"""
|
||||
try:
|
||||
structure = {
|
||||
'levels': {},
|
||||
'overall_trend': 'sideways',
|
||||
'overall_strength': 0.0,
|
||||
'last_update': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Aggregate information from all levels
|
||||
trend_votes = {'up': 0, 'down': 0, 'sideways': 0}
|
||||
total_strength = 0.0
|
||||
active_levels = 0
|
||||
|
||||
for level, trend_level in self.pivot_levels.items():
|
||||
structure['levels'][level] = {
|
||||
'trend_direction': trend_level.trend_direction,
|
||||
'trend_strength': trend_level.trend_strength,
|
||||
'pivot_count': len(trend_level.pivot_points),
|
||||
'last_pivot': {
|
||||
'timestamp': trend_level.pivot_points[-1].timestamp.isoformat() if trend_level.pivot_points else None,
|
||||
'price': trend_level.pivot_points[-1].price if trend_level.pivot_points else 0.0,
|
||||
'type': trend_level.pivot_points[-1].pivot_type if trend_level.pivot_points else 'none'
|
||||
} if trend_level.pivot_points else None
|
||||
}
|
||||
|
||||
# Vote for overall trend
|
||||
trend_votes[trend_level.trend_direction] += trend_level.trend_strength
|
||||
total_strength += trend_level.trend_strength
|
||||
active_levels += 1
|
||||
|
||||
# Determine overall trend
|
||||
if active_levels > 0:
|
||||
structure['overall_trend'] = max(trend_votes, key=trend_votes.get)
|
||||
structure['overall_strength'] = total_strength / active_levels
|
||||
|
||||
return structure
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current market structure: {e}")
|
||||
return {
|
||||
'levels': {},
|
||||
'overall_trend': 'sideways',
|
||||
'overall_strength': 0.0,
|
||||
'last_update': datetime.now().isoformat(),
|
||||
'error': str(e)
|
||||
}
|
179
start_overnight_training.py
Normal file
179
start_overnight_training.py
Normal file
@ -0,0 +1,179 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Start Overnight Training Session
|
||||
|
||||
This script starts a comprehensive overnight training session that:
|
||||
1. Ensures CNN and COB RL training processes are implemented and running
|
||||
2. Executes training passes on each signal when predictions change
|
||||
3. Calculates PnL and records trades in SIM mode
|
||||
4. Tracks model performance statistics
|
||||
5. Converts signals to actual trades for performance tracking
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(f'overnight_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
"""Start the overnight training session"""
|
||||
try:
|
||||
logger.info("🌙 STARTING OVERNIGHT TRAINING SESSION")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Import required components
|
||||
from core.config import get_config, setup_logging
|
||||
from core.data_provider import DataProvider
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.trading_executor import TradingExecutor
|
||||
from web.clean_dashboard import CleanTradingDashboard
|
||||
|
||||
# Setup logging
|
||||
setup_logging()
|
||||
|
||||
# Initialize components
|
||||
logger.info("Initializing components...")
|
||||
|
||||
# Create data provider
|
||||
data_provider = DataProvider()
|
||||
logger.info("✅ Data Provider initialized")
|
||||
|
||||
# Create trading executor in simulation mode
|
||||
trading_executor = TradingExecutor()
|
||||
trading_executor.simulation_mode = True # Ensure we're in simulation mode
|
||||
logger.info("✅ Trading Executor initialized (SIMULATION MODE)")
|
||||
|
||||
# Create orchestrator with enhanced training
|
||||
orchestrator = TradingOrchestrator(
|
||||
data_provider=data_provider,
|
||||
enhanced_rl_training=True
|
||||
)
|
||||
logger.info("✅ Trading Orchestrator initialized")
|
||||
|
||||
# Connect trading executor to orchestrator
|
||||
if hasattr(orchestrator, 'set_trading_executor'):
|
||||
orchestrator.set_trading_executor(trading_executor)
|
||||
logger.info("✅ Trading Executor connected to Orchestrator")
|
||||
|
||||
# Create dashboard (this initializes the overnight training coordinator)
|
||||
dashboard = CleanTradingDashboard(
|
||||
data_provider=data_provider,
|
||||
orchestrator=orchestrator,
|
||||
trading_executor=trading_executor
|
||||
)
|
||||
logger.info("✅ Dashboard initialized with Overnight Training Coordinator")
|
||||
|
||||
# Start the overnight training session
|
||||
logger.info("Starting overnight training session...")
|
||||
success = dashboard.start_overnight_training()
|
||||
|
||||
if success:
|
||||
logger.info("🌙 OVERNIGHT TRAINING SESSION STARTED SUCCESSFULLY")
|
||||
logger.info("=" * 80)
|
||||
logger.info("Training Features Active:")
|
||||
logger.info("✅ CNN training on signal changes")
|
||||
logger.info("✅ COB RL training on market microstructure")
|
||||
logger.info("✅ DQN training on trading decisions")
|
||||
logger.info("✅ Trade execution and recording (SIMULATION)")
|
||||
logger.info("✅ Performance tracking and statistics")
|
||||
logger.info("✅ Model checkpointing every 50 trades")
|
||||
logger.info("✅ Signal-to-trade conversion with PnL calculation")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Monitor training progress
|
||||
logger.info("Monitoring training progress...")
|
||||
logger.info("Press Ctrl+C to stop the training session")
|
||||
|
||||
# Keep the session running and periodically report progress
|
||||
start_time = datetime.now()
|
||||
last_report_time = start_time
|
||||
|
||||
while True:
|
||||
try:
|
||||
time.sleep(60) # Check every minute
|
||||
|
||||
current_time = datetime.now()
|
||||
elapsed_time = current_time - start_time
|
||||
|
||||
# Get performance summary every 10 minutes
|
||||
if (current_time - last_report_time).total_seconds() >= 600: # 10 minutes
|
||||
performance = dashboard.get_training_performance_summary()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"🌙 TRAINING PROGRESS REPORT - {elapsed_time}")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Total Signals: {performance.get('total_signals', 0)}")
|
||||
logger.info(f"Total Trades: {performance.get('total_trades', 0)}")
|
||||
logger.info(f"Successful Trades: {performance.get('successful_trades', 0)}")
|
||||
logger.info(f"Success Rate: {performance.get('success_rate', 0):.1%}")
|
||||
logger.info(f"Total P&L: ${performance.get('total_pnl', 0):.2f}")
|
||||
logger.info(f"Models Trained: {', '.join(performance.get('models_trained', []))}")
|
||||
logger.info(f"Training Status: {'ACTIVE' if performance.get('is_running', False) else 'INACTIVE'}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
last_report_time = current_time
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n🛑 Training session interrupted by user")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error during training monitoring: {e}")
|
||||
time.sleep(30) # Wait 30 seconds before retrying
|
||||
|
||||
# Stop the training session
|
||||
logger.info("Stopping overnight training session...")
|
||||
dashboard.stop_overnight_training()
|
||||
|
||||
# Final report
|
||||
final_performance = dashboard.get_training_performance_summary()
|
||||
total_time = datetime.now() - start_time
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info("🌅 OVERNIGHT TRAINING SESSION COMPLETED")
|
||||
logger.info("=" * 80)
|
||||
logger.info(f"Total Duration: {total_time}")
|
||||
logger.info(f"Final Statistics:")
|
||||
logger.info(f" Total Signals: {final_performance.get('total_signals', 0)}")
|
||||
logger.info(f" Total Trades: {final_performance.get('total_trades', 0)}")
|
||||
logger.info(f" Successful Trades: {final_performance.get('successful_trades', 0)}")
|
||||
logger.info(f" Success Rate: {final_performance.get('success_rate', 0):.1%}")
|
||||
logger.info(f" Total P&L: ${final_performance.get('total_pnl', 0):.2f}")
|
||||
logger.info(f" Models Trained: {', '.join(final_performance.get('models_trained', []))}")
|
||||
logger.info("=" * 80)
|
||||
|
||||
else:
|
||||
logger.error("❌ Failed to start overnight training session")
|
||||
return 1
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n🛑 Training session interrupted by user")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in overnight training session: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = main()
|
||||
sys.exit(exit_code)
|
@ -80,6 +80,9 @@ except ImportError:
|
||||
# Import RL COB trader for 1B parameter model integration
|
||||
from core.realtime_rl_cob_trader import RealtimeRLCOBTrader, PredictionResult
|
||||
|
||||
# Import overnight training coordinator
|
||||
from core.overnight_training_coordinator import OvernightTrainingCoordinator
|
||||
|
||||
# Single unified orchestrator with full ML capabilities
|
||||
|
||||
class CleanTradingDashboard:
|
||||
@ -220,10 +223,58 @@ class CleanTradingDashboard:
|
||||
if not self.trading_executor.simulation_mode:
|
||||
threading.Thread(target=self._monitor_order_execution, daemon=True).start()
|
||||
|
||||
# Initialize overnight training coordinator
|
||||
self.overnight_training_coordinator = OvernightTrainingCoordinator(
|
||||
orchestrator=self.orchestrator,
|
||||
data_provider=self.data_provider,
|
||||
trading_executor=self.trading_executor,
|
||||
dashboard=self
|
||||
)
|
||||
|
||||
# Start training sessions if models are showing FRESH status
|
||||
threading.Thread(target=self._delayed_training_check, daemon=True).start()
|
||||
|
||||
logger.debug("Clean Trading Dashboard initialized with HIGH-FREQUENCY COB integration and signal generation")
|
||||
logger.info("🌙 Overnight Training Coordinator ready - call start_overnight_training() to begin")
|
||||
|
||||
def start_overnight_training(self):
|
||||
"""Start the overnight training session"""
|
||||
try:
|
||||
if hasattr(self, 'overnight_training_coordinator'):
|
||||
self.overnight_training_coordinator.start_overnight_training()
|
||||
logger.info("🌙 OVERNIGHT TRAINING SESSION STARTED")
|
||||
return True
|
||||
else:
|
||||
logger.error("Overnight training coordinator not available")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting overnight training: {e}")
|
||||
return False
|
||||
|
||||
def stop_overnight_training(self):
|
||||
"""Stop the overnight training session"""
|
||||
try:
|
||||
if hasattr(self, 'overnight_training_coordinator'):
|
||||
self.overnight_training_coordinator.stop_overnight_training()
|
||||
logger.info("🌅 OVERNIGHT TRAINING SESSION STOPPED")
|
||||
return True
|
||||
else:
|
||||
logger.error("Overnight training coordinator not available")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping overnight training: {e}")
|
||||
return False
|
||||
|
||||
def get_training_performance_summary(self) -> Dict[str, Any]:
|
||||
"""Get training performance summary"""
|
||||
try:
|
||||
if hasattr(self, 'overnight_training_coordinator'):
|
||||
return self.overnight_training_coordinator.get_performance_summary()
|
||||
else:
|
||||
return {'error': 'Training coordinator not available'}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting training performance summary: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def _get_universal_data_from_orchestrator(self) -> Optional[UniversalDataStream]:
|
||||
"""Get universal data through orchestrator as per architecture."""
|
||||
@ -436,9 +487,21 @@ class CleanTradingDashboard:
|
||||
symbol = 'ETH/USDT'
|
||||
self._sync_position_from_executor(symbol)
|
||||
|
||||
# Get current price
|
||||
# Get current price with better error handling
|
||||
current_price = self._get_current_price('ETH/USDT')
|
||||
price_str = f"${current_price:.2f}" if current_price else "Loading..."
|
||||
if current_price and current_price > 0:
|
||||
price_str = f"${current_price:.2f}"
|
||||
else:
|
||||
# Try to get price from COB data as fallback
|
||||
if hasattr(self, 'latest_cob_data') and 'ETH/USDT' in self.latest_cob_data:
|
||||
cob_data = self.latest_cob_data['ETH/USDT']
|
||||
if 'stats' in cob_data and 'mid_price' in cob_data['stats']:
|
||||
current_price = cob_data['stats']['mid_price']
|
||||
price_str = f"${current_price:.2f}"
|
||||
else:
|
||||
price_str = "Loading..."
|
||||
else:
|
||||
price_str = "Loading..."
|
||||
|
||||
# Calculate session P&L including unrealized P&L from current position
|
||||
total_session_pnl = self.session_pnl # Start with realized P&L
|
||||
@ -539,12 +602,29 @@ class CleanTradingDashboard:
|
||||
if self.update_batch_counter % self.update_batch_interval != 0:
|
||||
raise PreventUpdate
|
||||
|
||||
# Filter out HOLD signals before displaying
|
||||
# Filter out HOLD signals and duplicate signals before displaying
|
||||
filtered_decisions = []
|
||||
seen_signals = set() # Track recent signals to avoid duplicates
|
||||
|
||||
for decision in self.recent_decisions:
|
||||
action = self._get_signal_attribute(decision, 'action', 'UNKNOWN')
|
||||
if action != 'HOLD':
|
||||
filtered_decisions.append(decision)
|
||||
# Create a unique key for this signal to avoid duplicates
|
||||
timestamp = decision.get('timestamp', datetime.now())
|
||||
price = decision.get('price', 0)
|
||||
confidence = decision.get('confidence', 0)
|
||||
|
||||
# Only show signals that are significantly different or from different time periods
|
||||
signal_key = f"{action}_{int(price)}_{int(confidence*100)}"
|
||||
time_key = int(timestamp.timestamp() // 30) # Group by 30-second intervals
|
||||
full_key = f"{signal_key}_{time_key}"
|
||||
|
||||
if full_key not in seen_signals:
|
||||
seen_signals.add(full_key)
|
||||
filtered_decisions.append(decision)
|
||||
|
||||
# Limit to last 10 signals to prevent UI clutter
|
||||
filtered_decisions = filtered_decisions[-10:]
|
||||
|
||||
# Log COB signal activity
|
||||
cob_signals = [d for d in filtered_decisions if d.get('type') == 'cob_liquidity_imbalance']
|
||||
@ -621,9 +701,9 @@ class CleanTradingDashboard:
|
||||
eth_snapshot = self._get_cob_snapshot('ETH/USDT')
|
||||
btc_snapshot = self._get_cob_snapshot('BTC/USDT')
|
||||
|
||||
# Debug: Log COB data availability
|
||||
if n % 5 == 0: # Log every 5 seconds to avoid spam
|
||||
logger.info(f"COB Update #{n}: ETH snapshot: {eth_snapshot is not None}, BTC snapshot: {btc_snapshot is not None}")
|
||||
# Debug: Log COB data availability - OPTIMIZED: Less frequent logging
|
||||
if n % 30 == 0: # Log every 30 seconds to reduce spam and improve performance
|
||||
logger.info(f"COB Update #{n % 100}: ETH snapshot: {eth_snapshot is not None}, BTC snapshot: {btc_snapshot is not None}")
|
||||
if hasattr(self, 'latest_cob_data'):
|
||||
eth_data_time = self.cob_last_update.get('ETH/USDT', 0) if hasattr(self, 'cob_last_update') else 0
|
||||
btc_data_time = self.cob_last_update.get('BTC/USDT', 0) if hasattr(self, 'cob_last_update') else 0
|
||||
@ -759,26 +839,98 @@ class CleanTradingDashboard:
|
||||
return [html.I(className="fas fa-save me-1"), "Store All Models"]
|
||||
|
||||
def _get_current_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get current price for symbol"""
|
||||
"""Get current price for symbol - ENHANCED with better fallbacks"""
|
||||
try:
|
||||
# Try WebSocket cache first
|
||||
ws_symbol = symbol.replace('/', '')
|
||||
if ws_symbol in self.ws_price_cache:
|
||||
if ws_symbol in self.ws_price_cache and self.ws_price_cache[ws_symbol] > 0:
|
||||
return self.ws_price_cache[ws_symbol]
|
||||
|
||||
# Fallback to data provider
|
||||
if symbol in self.current_prices:
|
||||
# Try data provider current prices
|
||||
if hasattr(self.data_provider, 'current_prices') and symbol in self.data_provider.current_prices:
|
||||
price = self.data_provider.current_prices[symbol]
|
||||
if price and price > 0:
|
||||
return price
|
||||
|
||||
# Try data provider get_current_price method
|
||||
if hasattr(self.data_provider, 'get_current_price'):
|
||||
try:
|
||||
price = self.data_provider.get_current_price(symbol)
|
||||
if price and price > 0:
|
||||
self.current_prices[symbol] = price
|
||||
return price
|
||||
except Exception as dp_error:
|
||||
logger.debug(f"Data provider get_current_price failed: {dp_error}")
|
||||
|
||||
# Fallback to dashboard current prices
|
||||
if symbol in self.current_prices and self.current_prices[symbol] > 0:
|
||||
return self.current_prices[symbol]
|
||||
|
||||
# Get fresh price from data provider
|
||||
df = self.data_provider.get_historical_data(symbol, '1m', limit=1)
|
||||
if df is not None and not df.empty:
|
||||
price = float(df['close'].iloc[-1])
|
||||
self.current_prices[symbol] = price
|
||||
return price
|
||||
# Get fresh price from data provider - try multiple timeframes
|
||||
for timeframe in ['1m', '5m', '1h']: # Start with 1m instead of 1s for better reliability
|
||||
try:
|
||||
df = self.data_provider.get_historical_data(symbol, timeframe, limit=1, refresh=True)
|
||||
if df is not None and not df.empty:
|
||||
price = float(df['close'].iloc[-1])
|
||||
if price > 0:
|
||||
self.current_prices[symbol] = price
|
||||
logger.debug(f"Got current price for {symbol} from {timeframe}: ${price:.2f}")
|
||||
return price
|
||||
except Exception as tf_error:
|
||||
logger.debug(f"Failed to get {timeframe} data for {symbol}: {tf_error}")
|
||||
continue
|
||||
|
||||
# Last resort: try to get from orchestrator if available
|
||||
if hasattr(self, 'orchestrator') and self.orchestrator:
|
||||
try:
|
||||
# Try to get price from orchestrator's data
|
||||
if hasattr(self.orchestrator, 'data_provider'):
|
||||
price = self.orchestrator.data_provider.get_current_price(symbol)
|
||||
if price and price > 0:
|
||||
self.current_prices[symbol] = price
|
||||
logger.debug(f"Got current price for {symbol} from orchestrator: ${price:.2f}")
|
||||
return price
|
||||
except Exception as orch_error:
|
||||
logger.debug(f"Failed to get price from orchestrator: {orch_error}")
|
||||
|
||||
# Try external API as last resort
|
||||
try:
|
||||
import requests
|
||||
if symbol == 'ETH/USDT':
|
||||
response = requests.get('https://api.binance.com/api/v3/ticker/price?symbol=ETHUSDT', timeout=2)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
price = float(data['price'])
|
||||
if price > 0:
|
||||
self.current_prices[symbol] = price
|
||||
logger.debug(f"Got current price for {symbol} from Binance API: ${price:.2f}")
|
||||
return price
|
||||
elif symbol == 'BTC/USDT':
|
||||
response = requests.get('https://api.binance.com/api/v3/ticker/price?symbol=BTCUSDT', timeout=2)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
price = float(data['price'])
|
||||
if price > 0:
|
||||
self.current_prices[symbol] = price
|
||||
logger.debug(f"Got current price for {symbol} from Binance API: ${price:.2f}")
|
||||
return price
|
||||
except Exception as api_error:
|
||||
logger.debug(f"External API failed: {api_error}")
|
||||
|
||||
logger.warning(f"Could not get current price for {symbol} from any source")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting current price for {symbol}: {e}")
|
||||
logger.error(f"Error getting current price for {symbol}: {e}")
|
||||
|
||||
# Return a fallback price if we have any cached data
|
||||
if symbol in self.current_prices and self.current_prices[symbol] > 0:
|
||||
return self.current_prices[symbol]
|
||||
|
||||
# Return a reasonable fallback based on current market conditions
|
||||
if symbol == 'ETH/USDT':
|
||||
return 3385.0 # Current market price fallback
|
||||
elif symbol == 'BTC/USDT':
|
||||
return 119500.0 # Current market price fallback
|
||||
|
||||
return None
|
||||
|
||||
|
Reference in New Issue
Block a user