new overhaul
This commit is contained in:
parent
b5ad023b16
commit
2f50ed920f
377
ENHANCED_ARCHITECTURE_GUIDE.md
Normal file
377
ENHANCED_ARCHITECTURE_GUIDE.md
Normal file
@ -0,0 +1,377 @@
|
|||||||
|
# Enhanced Multi-Modal Trading Architecture Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This document describes the enhanced multi-modal trading system that implements sophisticated decision-making through coordinated CNN and RL modules. The system is designed to handle multi-timeframe analysis across multiple symbols (ETH, BTC) with continuous learning capabilities.
|
||||||
|
|
||||||
|
## Architecture Components
|
||||||
|
|
||||||
|
### 1. Enhanced Trading Orchestrator (`core/enhanced_orchestrator.py`)
|
||||||
|
|
||||||
|
The heart of the system that coordinates all components:
|
||||||
|
|
||||||
|
**Key Features:**
|
||||||
|
- **Multi-Symbol Coordination**: Makes decisions across ETH and BTC considering correlations
|
||||||
|
- **Timeframe Integration**: Combines predictions from multiple timeframes (1m, 5m, 15m, 1h, 4h, 1d)
|
||||||
|
- **Perfect Move Marking**: Identifies and marks optimal trading decisions for CNN training
|
||||||
|
- **RL Evaluation Loop**: Evaluates trading outcomes to train RL agents
|
||||||
|
|
||||||
|
**Data Structures:**
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class TimeframePrediction:
|
||||||
|
timeframe: str
|
||||||
|
action: str # 'BUY', 'SELL', 'HOLD'
|
||||||
|
confidence: float # 0.0 to 1.0
|
||||||
|
probabilities: Dict[str, float]
|
||||||
|
timestamp: datetime
|
||||||
|
market_features: Dict[str, float]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TradingAction:
|
||||||
|
symbol: str
|
||||||
|
action: str
|
||||||
|
quantity: float
|
||||||
|
confidence: float
|
||||||
|
price: float
|
||||||
|
timestamp: datetime
|
||||||
|
reasoning: Dict[str, Any]
|
||||||
|
timeframe_analysis: List[TimeframePrediction]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Decision Making Process:**
|
||||||
|
1. Gather market states for all symbols and timeframes
|
||||||
|
2. Get CNN predictions for each timeframe with confidence scores
|
||||||
|
3. Combine timeframe predictions using weighted averaging
|
||||||
|
4. Consider symbol correlations (ETH-BTC correlation ~0.85)
|
||||||
|
5. Apply confidence thresholds and risk management
|
||||||
|
6. Generate coordinated trading decisions
|
||||||
|
7. Queue actions for RL evaluation
|
||||||
|
|
||||||
|
### 2. Enhanced CNN Trainer (`training/enhanced_cnn_trainer.py`)
|
||||||
|
|
||||||
|
Implements supervised learning on marked perfect moves:
|
||||||
|
|
||||||
|
**Key Features:**
|
||||||
|
- **Perfect Move Dataset**: Trains on historically optimal decisions
|
||||||
|
- **Timeframe-Specific Heads**: Separate prediction heads for each timeframe
|
||||||
|
- **Confidence Prediction**: Predicts both action and confidence simultaneously
|
||||||
|
- **Multi-Loss Training**: Combines action classification and confidence regression
|
||||||
|
|
||||||
|
**Network Architecture:**
|
||||||
|
```python
|
||||||
|
# Convolutional feature extraction
|
||||||
|
Conv1D(features=5, filters=64, kernel=3) -> BatchNorm -> ReLU -> Dropout
|
||||||
|
Conv1D(filters=128, kernel=3) -> BatchNorm -> ReLU -> Dropout
|
||||||
|
Conv1D(filters=256, kernel=3) -> BatchNorm -> ReLU -> Dropout
|
||||||
|
AdaptiveAvgPool1d(1) # Global average pooling
|
||||||
|
|
||||||
|
# Timeframe-specific heads
|
||||||
|
for each timeframe:
|
||||||
|
Linear(256 -> 128) -> ReLU -> Dropout
|
||||||
|
Linear(128 -> 64) -> ReLU -> Dropout
|
||||||
|
|
||||||
|
# Action prediction
|
||||||
|
Linear(64 -> 3) # BUY, HOLD, SELL
|
||||||
|
|
||||||
|
# Confidence prediction
|
||||||
|
Linear(64 -> 32) -> ReLU -> Linear(32 -> 1) -> Sigmoid
|
||||||
|
```
|
||||||
|
|
||||||
|
**Training Process:**
|
||||||
|
1. Collect perfect moves from orchestrator with known outcomes
|
||||||
|
2. Create dataset with features, optimal actions, and target confidence
|
||||||
|
3. Train with combined loss: `action_loss + 0.5 * confidence_loss`
|
||||||
|
4. Use early stopping and model checkpointing
|
||||||
|
5. Generate comprehensive training reports and visualizations
|
||||||
|
|
||||||
|
### 3. Enhanced RL Trainer (`training/enhanced_rl_trainer.py`)
|
||||||
|
|
||||||
|
Implements continuous learning from trading evaluations:
|
||||||
|
|
||||||
|
**Key Features:**
|
||||||
|
- **Prioritized Experience Replay**: Learns from important experiences first
|
||||||
|
- **Market Regime Adaptation**: Adjusts confidence based on market conditions
|
||||||
|
- **Multi-Symbol Agents**: Separate RL agents for each trading symbol
|
||||||
|
- **Double DQN Architecture**: Reduces overestimation bias
|
||||||
|
|
||||||
|
**Agent Architecture:**
|
||||||
|
```python
|
||||||
|
# Main Network
|
||||||
|
Linear(state_size -> 256) -> ReLU -> Dropout
|
||||||
|
Linear(256 -> 256) -> ReLU -> Dropout
|
||||||
|
Linear(256 -> 128) -> ReLU -> Dropout
|
||||||
|
|
||||||
|
# Dueling heads
|
||||||
|
value_head = Linear(128 -> 1)
|
||||||
|
advantage_head = Linear(128 -> action_space)
|
||||||
|
|
||||||
|
# Q-values = V(s) + A(s,a) - mean(A(s,a))
|
||||||
|
```
|
||||||
|
|
||||||
|
**Learning Process:**
|
||||||
|
1. Store trading experiences with TD-error priorities
|
||||||
|
2. Sample batches using prioritized replay
|
||||||
|
3. Train with Double DQN to reduce overestimation
|
||||||
|
4. Update target networks periodically
|
||||||
|
5. Adapt exploration (epsilon) based on market regime stability
|
||||||
|
|
||||||
|
### 4. Market State and Feature Engineering
|
||||||
|
|
||||||
|
**Market State Components:**
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class MarketState:
|
||||||
|
symbol: str
|
||||||
|
timestamp: datetime
|
||||||
|
prices: Dict[str, float] # {timeframe: price}
|
||||||
|
features: Dict[str, np.ndarray] # {timeframe: feature_matrix}
|
||||||
|
volatility: float
|
||||||
|
volume: float
|
||||||
|
trend_strength: float
|
||||||
|
market_regime: str # 'trending', 'ranging', 'volatile'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Feature Engineering:**
|
||||||
|
- **OHLCV Data**: Open, High, Low, Close, Volume for each timeframe
|
||||||
|
- **Technical Indicators**: RSI, MACD, Bollinger Bands, etc.
|
||||||
|
- **Market Regime Detection**: Automatic classification of market conditions
|
||||||
|
- **Volatility Analysis**: Real-time volatility calculations
|
||||||
|
- **Volume Analysis**: Volume ratio compared to historical averages
|
||||||
|
|
||||||
|
## System Workflow
|
||||||
|
|
||||||
|
### 1. Initialization Phase
|
||||||
|
```python
|
||||||
|
# Load configuration
|
||||||
|
config = get_config('config.yaml')
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
data_provider = DataProvider(config)
|
||||||
|
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||||
|
cnn_trainer = EnhancedCNNTrainer(config, orchestrator)
|
||||||
|
rl_trainer = EnhancedRLTrainer(config, orchestrator)
|
||||||
|
|
||||||
|
# Load existing models or create new ones
|
||||||
|
models = initialize_models(load_existing=True)
|
||||||
|
register_models_with_orchestrator(models)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Trading Loop
|
||||||
|
```python
|
||||||
|
while running:
|
||||||
|
# 1. Gather market data for all symbols and timeframes
|
||||||
|
market_states = await get_all_market_states()
|
||||||
|
|
||||||
|
# 2. Generate CNN predictions for each timeframe
|
||||||
|
for symbol in symbols:
|
||||||
|
for timeframe in timeframes:
|
||||||
|
prediction = cnn_model.predict_timeframe(features, timeframe)
|
||||||
|
|
||||||
|
# 3. Combine timeframe predictions with weights
|
||||||
|
combined_prediction = combine_timeframe_predictions(predictions)
|
||||||
|
|
||||||
|
# 4. Consider symbol correlations
|
||||||
|
coordinated_decision = coordinate_symbols(predictions, correlations)
|
||||||
|
|
||||||
|
# 5. Apply confidence thresholds and risk management
|
||||||
|
final_decision = apply_risk_management(coordinated_decision)
|
||||||
|
|
||||||
|
# 6. Execute trades (or log decisions)
|
||||||
|
execute_trading_decision(final_decision)
|
||||||
|
|
||||||
|
# 7. Queue for RL evaluation
|
||||||
|
queue_for_rl_evaluation(final_decision, market_state)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Continuous Learning Loop
|
||||||
|
```python
|
||||||
|
# RL Learning (every hour)
|
||||||
|
async def rl_learning_loop():
|
||||||
|
while running:
|
||||||
|
# Evaluate past trading actions
|
||||||
|
await evaluate_trading_outcomes()
|
||||||
|
|
||||||
|
# Train RL agents on new experiences
|
||||||
|
for symbol, agent in rl_agents.items():
|
||||||
|
agent.replay() # Learn from prioritized experiences
|
||||||
|
|
||||||
|
# Adapt to market regime changes
|
||||||
|
adapt_to_market_conditions()
|
||||||
|
|
||||||
|
await asyncio.sleep(3600) # Wait 1 hour
|
||||||
|
|
||||||
|
# CNN Learning (every 6 hours)
|
||||||
|
async def cnn_learning_loop():
|
||||||
|
while running:
|
||||||
|
# Check for sufficient perfect moves
|
||||||
|
perfect_moves = get_perfect_moves_for_training()
|
||||||
|
|
||||||
|
if len(perfect_moves) >= 200:
|
||||||
|
# Train CNN on perfect moves
|
||||||
|
training_report = train_cnn_on_perfect_moves(perfect_moves)
|
||||||
|
|
||||||
|
# Update registered model
|
||||||
|
update_model_registry(trained_model)
|
||||||
|
|
||||||
|
await asyncio.sleep(6 * 3600) # Wait 6 hours
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key Algorithms
|
||||||
|
|
||||||
|
### 1. Timeframe Prediction Combination
|
||||||
|
```python
|
||||||
|
def combine_timeframe_predictions(timeframe_predictions, symbol):
|
||||||
|
action_scores = {'BUY': 0.0, 'SELL': 0.0, 'HOLD': 0.0}
|
||||||
|
total_weight = 0.0
|
||||||
|
|
||||||
|
timeframe_weights = {
|
||||||
|
'1m': 0.05, '5m': 0.10, '15m': 0.15,
|
||||||
|
'1h': 0.25, '4h': 0.25, '1d': 0.20
|
||||||
|
}
|
||||||
|
|
||||||
|
for pred in timeframe_predictions:
|
||||||
|
weight = timeframe_weights[pred.timeframe] * pred.confidence
|
||||||
|
action_scores[pred.action] += weight
|
||||||
|
total_weight += weight
|
||||||
|
|
||||||
|
# Normalize and select best action
|
||||||
|
best_action = max(action_scores, key=action_scores.get)
|
||||||
|
confidence = action_scores[best_action] / total_weight
|
||||||
|
|
||||||
|
return best_action, confidence
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Perfect Move Marking
|
||||||
|
```python
|
||||||
|
def mark_perfect_move(action, initial_state, final_state, reward):
|
||||||
|
# Determine optimal action based on outcome
|
||||||
|
if reward > 0.02: # Significant positive outcome
|
||||||
|
optimal_action = action.action # Action was correct
|
||||||
|
optimal_confidence = min(0.95, abs(reward) * 10)
|
||||||
|
elif reward < -0.02: # Significant negative outcome
|
||||||
|
optimal_action = opposite_action(action.action) # Should have done opposite
|
||||||
|
optimal_confidence = min(0.95, abs(reward) * 10)
|
||||||
|
else: # Neutral outcome
|
||||||
|
optimal_action = 'HOLD' # Should have held
|
||||||
|
optimal_confidence = 0.3
|
||||||
|
|
||||||
|
# Create perfect move for CNN training
|
||||||
|
perfect_move = PerfectMove(
|
||||||
|
symbol=action.symbol,
|
||||||
|
timeframe=timeframe,
|
||||||
|
timestamp=action.timestamp,
|
||||||
|
optimal_action=optimal_action,
|
||||||
|
confidence_should_have_been=optimal_confidence,
|
||||||
|
market_state_before=initial_state,
|
||||||
|
market_state_after=final_state,
|
||||||
|
actual_outcome=reward
|
||||||
|
)
|
||||||
|
|
||||||
|
return perfect_move
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. RL Reward Calculation
|
||||||
|
```python
|
||||||
|
def calculate_reward(action, price_change, confidence):
|
||||||
|
base_reward = 0.0
|
||||||
|
|
||||||
|
# Reward based on action correctness
|
||||||
|
if action == 'BUY' and price_change > 0:
|
||||||
|
base_reward = price_change * 10 # Reward proportional to gain
|
||||||
|
elif action == 'SELL' and price_change < 0:
|
||||||
|
base_reward = abs(price_change) * 10 # Reward for avoiding loss
|
||||||
|
elif action == 'HOLD':
|
||||||
|
if abs(price_change) < 0.005: # Correct hold
|
||||||
|
base_reward = 0.01
|
||||||
|
else: # Missed opportunity
|
||||||
|
base_reward = -0.01
|
||||||
|
else:
|
||||||
|
base_reward = -abs(price_change) * 5 # Penalty for wrong actions
|
||||||
|
|
||||||
|
# Scale by confidence
|
||||||
|
confidence_multiplier = 0.5 + confidence # 0.5 to 1.5 range
|
||||||
|
return base_reward * confidence_multiplier
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration and Deployment
|
||||||
|
|
||||||
|
### 1. Running the System
|
||||||
|
```bash
|
||||||
|
# Basic trading mode
|
||||||
|
python enhanced_trading_main.py --mode trade
|
||||||
|
|
||||||
|
# Training only mode
|
||||||
|
python enhanced_trading_main.py --mode train
|
||||||
|
|
||||||
|
# Fresh start without loading existing models
|
||||||
|
python enhanced_trading_main.py --mode trade --no-load-models
|
||||||
|
|
||||||
|
# Custom configuration
|
||||||
|
python enhanced_trading_main.py --config custom_config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Key Configuration Parameters
|
||||||
|
```yaml
|
||||||
|
# Enhanced Orchestrator Settings
|
||||||
|
orchestrator:
|
||||||
|
confidence_threshold: 0.6 # Higher threshold for enhanced system
|
||||||
|
decision_frequency: 30 # Faster decisions (30 seconds)
|
||||||
|
|
||||||
|
# CNN Configuration
|
||||||
|
cnn:
|
||||||
|
timeframes: ["1m", "5m", "15m", "1h", "4h", "1d"]
|
||||||
|
confidence_threshold: 0.6
|
||||||
|
model_dir: "models/enhanced_cnn"
|
||||||
|
|
||||||
|
# RL Configuration
|
||||||
|
rl:
|
||||||
|
hidden_size: 256
|
||||||
|
buffer_size: 10000
|
||||||
|
model_dir: "models/enhanced_rl"
|
||||||
|
market_regime_weights:
|
||||||
|
trending: 1.2
|
||||||
|
ranging: 0.8
|
||||||
|
volatile: 0.6
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Memory Management
|
||||||
|
The system is designed to work within 8GB memory constraints:
|
||||||
|
- Total system limit: 8GB
|
||||||
|
- Per-model limit: 2GB
|
||||||
|
- Automatic memory cleanup every 30 minutes
|
||||||
|
- GPU memory management with dynamic allocation
|
||||||
|
|
||||||
|
### 4. Monitoring and Logging
|
||||||
|
- Comprehensive logging with component-specific levels
|
||||||
|
- TensorBoard integration for training visualization
|
||||||
|
- Performance metrics tracking
|
||||||
|
- Memory usage monitoring
|
||||||
|
- Real-time decision logging with full reasoning
|
||||||
|
|
||||||
|
## Performance Characteristics
|
||||||
|
|
||||||
|
### Expected Behavior:
|
||||||
|
1. **Decision Frequency**: 30-second intervals between decisions
|
||||||
|
2. **CNN Training**: Every 6 hours when sufficient perfect moves available
|
||||||
|
3. **RL Training**: Continuous learning every hour
|
||||||
|
4. **Memory Usage**: <8GB total system usage
|
||||||
|
5. **Confidence Thresholds**: 0.6+ for trading actions
|
||||||
|
|
||||||
|
### Key Metrics:
|
||||||
|
- **Decision Accuracy**: Tracked via RL reward system
|
||||||
|
- **Confidence Calibration**: CNN confidence vs actual outcomes
|
||||||
|
- **Symbol Correlation**: ETH-BTC coordination effectiveness
|
||||||
|
- **Training Progress**: Loss curves and validation accuracy
|
||||||
|
- **Market Adaptation**: Performance across different regimes
|
||||||
|
|
||||||
|
## Future Enhancements
|
||||||
|
|
||||||
|
1. **Additional Symbols**: Easy extension to support more trading pairs
|
||||||
|
2. **Advanced Features**: Sentiment analysis, news integration
|
||||||
|
3. **Risk Management**: Portfolio-level risk optimization
|
||||||
|
4. **Backtesting**: Historical performance evaluation
|
||||||
|
5. **Live Trading**: Real exchange integration
|
||||||
|
6. **Model Ensembles**: Multiple CNN/RL model combinations
|
||||||
|
|
||||||
|
This architecture provides a robust foundation for sophisticated algorithmic trading with continuous learning and adaptation capabilities.
|
161
config.yaml
161
config.yaml
@ -1,6 +1,6 @@
|
|||||||
# Trading System Configuration
|
# Enhanced Multi-Modal Trading System Configuration
|
||||||
|
|
||||||
# Trading Symbols (extendable)
|
# Trading Symbols (extendable/configurable)
|
||||||
symbols:
|
symbols:
|
||||||
- "ETH/USDT"
|
- "ETH/USDT"
|
||||||
- "BTC/USDT"
|
- "BTC/USDT"
|
||||||
@ -22,22 +22,38 @@ data:
|
|||||||
historical_limit: 1000
|
historical_limit: 1000
|
||||||
real_time_enabled: true
|
real_time_enabled: true
|
||||||
websocket_reconnect: true
|
websocket_reconnect: true
|
||||||
|
feature_engineering:
|
||||||
|
technical_indicators: true
|
||||||
|
market_regime_detection: true
|
||||||
|
volatility_analysis: true
|
||||||
|
|
||||||
# CNN Model Configuration
|
# Enhanced CNN Configuration
|
||||||
cnn:
|
cnn:
|
||||||
window_size: 20
|
window_size: 20
|
||||||
features: ["open", "high", "low", "close", "volume"]
|
features: ["open", "high", "low", "close", "volume"]
|
||||||
hidden_layers: [64, 32, 16]
|
timeframes: ["1m", "5m", "15m", "1h", "4h", "1d"]
|
||||||
|
hidden_layers: [64, 128, 256]
|
||||||
dropout: 0.2
|
dropout: 0.2
|
||||||
learning_rate: 0.001
|
learning_rate: 0.001
|
||||||
batch_size: 32
|
batch_size: 32
|
||||||
epochs: 100
|
epochs: 100
|
||||||
confidence_threshold: 0.6
|
confidence_threshold: 0.6
|
||||||
|
early_stopping_patience: 10
|
||||||
|
model_dir: "models/enhanced_cnn"
|
||||||
|
# Timeframe-specific model weights
|
||||||
|
timeframe_importance:
|
||||||
|
"1m": 0.05 # Noise filtering
|
||||||
|
"5m": 0.10 # Short-term momentum
|
||||||
|
"15m": 0.15 # Entry/exit timing
|
||||||
|
"1h": 0.25 # Medium-term trend
|
||||||
|
"4h": 0.25 # Stronger trend confirmation
|
||||||
|
"1d": 0.20 # Long-term direction
|
||||||
|
|
||||||
# RL Agent Configuration
|
# Enhanced RL Agent Configuration
|
||||||
rl:
|
rl:
|
||||||
state_size: 100 # Will be calculated dynamically
|
state_size: 100 # Will be calculated dynamically based on features
|
||||||
action_space: 3 # BUY, HOLD, SELL
|
action_space: 3 # BUY, HOLD, SELL
|
||||||
|
hidden_size: 256
|
||||||
epsilon: 1.0
|
epsilon: 1.0
|
||||||
epsilon_decay: 0.995
|
epsilon_decay: 0.995
|
||||||
epsilon_min: 0.01
|
epsilon_min: 0.01
|
||||||
@ -46,22 +62,79 @@ rl:
|
|||||||
memory_size: 10000
|
memory_size: 10000
|
||||||
batch_size: 64
|
batch_size: 64
|
||||||
target_update_freq: 1000
|
target_update_freq: 1000
|
||||||
|
buffer_size: 10000
|
||||||
|
model_dir: "models/enhanced_rl"
|
||||||
|
# Market regime adaptation
|
||||||
|
market_regime_weights:
|
||||||
|
trending: 1.2 # Higher confidence in trending markets
|
||||||
|
ranging: 0.8 # Lower confidence in ranging markets
|
||||||
|
volatile: 0.6 # Much lower confidence in volatile markets
|
||||||
|
# Prioritized experience replay
|
||||||
|
replay_alpha: 0.6 # Priority exponent
|
||||||
|
replay_beta: 0.4 # Importance sampling exponent
|
||||||
|
|
||||||
# Orchestrator Settings
|
# Enhanced Orchestrator Settings
|
||||||
orchestrator:
|
orchestrator:
|
||||||
|
# Model weights for decision combination
|
||||||
cnn_weight: 0.7 # Weight for CNN predictions
|
cnn_weight: 0.7 # Weight for CNN predictions
|
||||||
rl_weight: 0.3 # Weight for RL decisions
|
rl_weight: 0.3 # Weight for RL decisions
|
||||||
confidence_threshold: 0.5 # Minimum confidence to act
|
confidence_threshold: 0.6 # Increased for enhanced system
|
||||||
decision_frequency: 60 # Seconds between decisions
|
decision_frequency: 30 # Seconds between decisions (faster)
|
||||||
|
|
||||||
|
# Multi-symbol coordination
|
||||||
|
symbol_correlation_matrix:
|
||||||
|
"ETH/USDT-BTC/USDT": 0.85 # ETH-BTC correlation
|
||||||
|
|
||||||
|
# Perfect move marking
|
||||||
|
perfect_move_threshold: 0.02 # 2% price change to mark as significant
|
||||||
|
perfect_move_buffer_size: 10000
|
||||||
|
|
||||||
|
# RL evaluation settings
|
||||||
|
evaluation_delay: 3600 # Evaluate actions after 1 hour
|
||||||
|
reward_calculation:
|
||||||
|
success_multiplier: 10 # Reward for correct predictions
|
||||||
|
failure_penalty: 5 # Penalty for wrong predictions
|
||||||
|
confidence_scaling: true # Scale rewards by confidence
|
||||||
|
|
||||||
|
# Training Configuration
|
||||||
|
training:
|
||||||
|
learning_rate: 0.001
|
||||||
|
batch_size: 32
|
||||||
|
epochs: 100
|
||||||
|
validation_split: 0.2
|
||||||
|
early_stopping_patience: 10
|
||||||
|
|
||||||
|
# CNN specific
|
||||||
|
cnn_training_interval: 21600 # Train every 6 hours
|
||||||
|
min_perfect_moves: 200 # Minimum moves before training
|
||||||
|
|
||||||
|
# RL specific
|
||||||
|
rl_training_interval: 3600 # Train every hour
|
||||||
|
min_experiences: 100 # Minimum experiences before training
|
||||||
|
training_steps_per_cycle: 10 # Training steps per cycle
|
||||||
|
|
||||||
# Trading Execution
|
# Trading Execution
|
||||||
trading:
|
trading:
|
||||||
max_position_size: 0.1 # Maximum position size (fraction of balance)
|
max_position_size: 0.05 # Maximum position size (5% of balance)
|
||||||
stop_loss: 0.02 # 2% stop loss
|
stop_loss: 0.02 # 2% stop loss
|
||||||
take_profit: 0.05 # 5% take profit
|
take_profit: 0.05 # 5% take profit
|
||||||
trading_fee: 0.0002 # 0.02% trading fee
|
trading_fee: 0.0002 # 0.02% trading fee
|
||||||
min_trade_interval: 60 # Minimum seconds between trades
|
min_trade_interval: 30 # Minimum seconds between trades (faster)
|
||||||
|
|
||||||
|
# Risk management
|
||||||
|
max_daily_trades: 20 # Maximum trades per day
|
||||||
|
max_concurrent_positions: 2 # Max positions across symbols
|
||||||
|
position_sizing:
|
||||||
|
confidence_scaling: true # Scale position by confidence
|
||||||
|
base_size: 0.02 # 2% base position
|
||||||
|
max_size: 0.05 # 5% maximum position
|
||||||
|
|
||||||
|
# Memory Management
|
||||||
|
memory:
|
||||||
|
total_limit_gb: 8.0 # Total system memory limit
|
||||||
|
model_limit_gb: 2.0 # Per-model memory limit
|
||||||
|
cleanup_interval: 1800 # Memory cleanup every 30 minutes
|
||||||
|
|
||||||
# Web Dashboard
|
# Web Dashboard
|
||||||
web:
|
web:
|
||||||
host: "127.0.0.1"
|
host: "127.0.0.1"
|
||||||
@ -69,37 +142,55 @@ web:
|
|||||||
debug: false
|
debug: false
|
||||||
update_interval: 1000 # Milliseconds
|
update_interval: 1000 # Milliseconds
|
||||||
chart_history: 100 # Number of candles to show
|
chart_history: 100 # Number of candles to show
|
||||||
|
|
||||||
|
# Enhanced dashboard features
|
||||||
|
show_timeframe_analysis: true
|
||||||
|
show_confidence_scores: true
|
||||||
|
show_perfect_moves: true
|
||||||
|
show_rl_metrics: true
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
logging:
|
logging:
|
||||||
level: "INFO"
|
level: "INFO"
|
||||||
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
file: "logs/trading.log"
|
file: "logs/enhanced_trading.log"
|
||||||
max_size: 10485760 # 10MB
|
max_size: 10485760 # 10MB
|
||||||
backup_count: 5
|
backup_count: 5
|
||||||
|
|
||||||
|
# Component-specific logging
|
||||||
|
orchestrator_level: "INFO"
|
||||||
|
cnn_level: "INFO"
|
||||||
|
rl_level: "INFO"
|
||||||
|
training_level: "INFO"
|
||||||
|
|
||||||
|
# Model Directories
|
||||||
|
model_dir: "models"
|
||||||
|
data_dir: "data"
|
||||||
|
cache_dir: "cache"
|
||||||
|
logs_dir: "logs"
|
||||||
|
|
||||||
# GPU/Performance
|
# GPU/Performance
|
||||||
performance:
|
gpu:
|
||||||
use_gpu: true
|
enabled: true
|
||||||
mixed_precision: true
|
memory_fraction: 0.8 # Use 80% of GPU memory
|
||||||
num_workers: 4
|
allow_growth: true # Allow dynamic memory allocation
|
||||||
batch_size_multiplier: 1.0
|
|
||||||
|
# Monitoring and Alerting
|
||||||
|
monitoring:
|
||||||
|
tensorboard_enabled: true
|
||||||
|
tensorboard_log_dir: "logs/tensorboard"
|
||||||
|
metrics_interval: 300 # Log metrics every 5 minutes
|
||||||
|
performance_alerts: true
|
||||||
|
|
||||||
|
# Performance thresholds
|
||||||
|
min_confidence_threshold: 0.3
|
||||||
|
max_memory_usage: 0.9 # 90% of available memory
|
||||||
|
max_decision_latency: 10 # 10 seconds max per decision
|
||||||
|
|
||||||
# Paths
|
# Backtesting (for future implementation)
|
||||||
paths:
|
backtesting:
|
||||||
models: "models"
|
start_date: "2024-01-01"
|
||||||
data: "data"
|
end_date: "2024-12-31"
|
||||||
logs: "logs"
|
initial_balance: 10000
|
||||||
cache: "cache"
|
commission: 0.0002
|
||||||
plots: "plots"
|
slippage: 0.0001
|
||||||
|
|
||||||
# Training Configuration
|
|
||||||
training:
|
|
||||||
use_only_real_data: true # CRITICAL: Never use synthetic/generated data
|
|
||||||
batch_size: 32
|
|
||||||
learning_rate: 0.001
|
|
||||||
epochs: 100
|
|
||||||
validation_split: 0.2
|
|
||||||
early_stopping_patience: 10
|
|
||||||
|
|
||||||
# Directory paths
|
|
698
core/enhanced_orchestrator.py
Normal file
698
core/enhanced_orchestrator.py
Normal file
@ -0,0 +1,698 @@
|
|||||||
|
"""
|
||||||
|
Enhanced Trading Orchestrator - Advanced Multi-Modal Decision Making
|
||||||
|
|
||||||
|
This enhanced orchestrator implements:
|
||||||
|
1. Multi-timeframe CNN predictions with individual confidence scores
|
||||||
|
2. Advanced RL feedback loop for continuous learning
|
||||||
|
3. Multi-symbol (ETH, BTC) coordinated decision making
|
||||||
|
4. Perfect move marking for CNN backpropagation training
|
||||||
|
5. Market environment adaptation through RL evaluation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from collections import deque
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .config import get_config
|
||||||
|
from .data_provider import DataProvider
|
||||||
|
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TimeframePrediction:
|
||||||
|
"""CNN prediction for a specific timeframe with confidence"""
|
||||||
|
timeframe: str
|
||||||
|
action: str # 'BUY', 'SELL', 'HOLD'
|
||||||
|
confidence: float # 0.0 to 1.0
|
||||||
|
probabilities: Dict[str, float] # Action probabilities
|
||||||
|
timestamp: datetime
|
||||||
|
market_features: Dict[str, float] = field(default_factory=dict) # Additional context
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EnhancedPrediction:
|
||||||
|
"""Enhanced prediction structure with timeframe breakdown"""
|
||||||
|
symbol: str
|
||||||
|
timeframe_predictions: List[TimeframePrediction]
|
||||||
|
overall_action: str
|
||||||
|
overall_confidence: float
|
||||||
|
model_name: str
|
||||||
|
timestamp: datetime
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TradingAction:
|
||||||
|
"""Represents a trading action with full context"""
|
||||||
|
symbol: str
|
||||||
|
action: str # 'BUY', 'SELL', 'HOLD'
|
||||||
|
quantity: float
|
||||||
|
confidence: float
|
||||||
|
price: float
|
||||||
|
timestamp: datetime
|
||||||
|
reasoning: Dict[str, Any]
|
||||||
|
timeframe_analysis: List[TimeframePrediction]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MarketState:
|
||||||
|
"""Complete market state for RL evaluation"""
|
||||||
|
symbol: str
|
||||||
|
timestamp: datetime
|
||||||
|
prices: Dict[str, float] # {timeframe: current_price}
|
||||||
|
features: Dict[str, np.ndarray] # {timeframe: feature_matrix}
|
||||||
|
volatility: float
|
||||||
|
volume: float
|
||||||
|
trend_strength: float
|
||||||
|
market_regime: str # 'trending', 'ranging', 'volatile'
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PerfectMove:
|
||||||
|
"""Marked perfect move for CNN training"""
|
||||||
|
symbol: str
|
||||||
|
timeframe: str
|
||||||
|
timestamp: datetime
|
||||||
|
optimal_action: str
|
||||||
|
actual_outcome: float # Price change percentage
|
||||||
|
market_state_before: MarketState
|
||||||
|
market_state_after: MarketState
|
||||||
|
confidence_should_have_been: float
|
||||||
|
|
||||||
|
class EnhancedTradingOrchestrator:
|
||||||
|
"""
|
||||||
|
Enhanced orchestrator with sophisticated multi-modal decision making
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_provider: DataProvider = None):
|
||||||
|
"""Initialize the enhanced orchestrator"""
|
||||||
|
self.config = get_config()
|
||||||
|
self.data_provider = data_provider or DataProvider()
|
||||||
|
self.model_registry = get_model_registry()
|
||||||
|
|
||||||
|
# Multi-symbol configuration
|
||||||
|
self.symbols = self.config.symbols
|
||||||
|
self.timeframes = self.config.timeframes
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.6)
|
||||||
|
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30)
|
||||||
|
|
||||||
|
# Enhanced weighting system
|
||||||
|
self.timeframe_weights = self._initialize_timeframe_weights()
|
||||||
|
self.symbol_correlation_matrix = self._initialize_correlation_matrix()
|
||||||
|
|
||||||
|
# State tracking for each symbol
|
||||||
|
self.symbol_states = {symbol: {} for symbol in self.symbols}
|
||||||
|
self.recent_actions = {symbol: deque(maxlen=100) for symbol in self.symbols}
|
||||||
|
self.market_states = {symbol: deque(maxlen=1000) for symbol in self.symbols}
|
||||||
|
|
||||||
|
# Perfect move tracking for CNN training
|
||||||
|
self.perfect_moves = deque(maxlen=10000)
|
||||||
|
self.performance_tracker = {}
|
||||||
|
|
||||||
|
# RL feedback system
|
||||||
|
self.rl_evaluation_queue = deque(maxlen=1000)
|
||||||
|
self.environment_adaptation_rate = 0.01
|
||||||
|
|
||||||
|
# Decision callbacks
|
||||||
|
self.decision_callbacks = []
|
||||||
|
self.learning_callbacks = []
|
||||||
|
|
||||||
|
logger.info("Enhanced TradingOrchestrator initialized")
|
||||||
|
logger.info(f"Symbols: {self.symbols}")
|
||||||
|
logger.info(f"Timeframes: {self.timeframes}")
|
||||||
|
logger.info(f"Enhanced confidence threshold: {self.confidence_threshold}")
|
||||||
|
|
||||||
|
def _initialize_timeframe_weights(self) -> Dict[str, float]:
|
||||||
|
"""Initialize weights for different timeframes"""
|
||||||
|
# Higher timeframes get more weight for trend direction
|
||||||
|
# Lower timeframes get more weight for entry/exit timing
|
||||||
|
base_weights = {
|
||||||
|
'1m': 0.05, # Noise filtering
|
||||||
|
'5m': 0.10, # Short-term momentum
|
||||||
|
'15m': 0.15, # Entry/exit timing
|
||||||
|
'1h': 0.25, # Medium-term trend
|
||||||
|
'4h': 0.25, # Stronger trend confirmation
|
||||||
|
'1d': 0.20 # Long-term direction
|
||||||
|
}
|
||||||
|
|
||||||
|
# Normalize weights for configured timeframes
|
||||||
|
configured_weights = {tf: base_weights.get(tf, 0.1) for tf in self.timeframes}
|
||||||
|
total = sum(configured_weights.values())
|
||||||
|
return {tf: w/total for tf, w in configured_weights.items()}
|
||||||
|
|
||||||
|
def _initialize_correlation_matrix(self) -> Dict[Tuple[str, str], float]:
|
||||||
|
"""Initialize correlation matrix between symbols"""
|
||||||
|
correlations = {}
|
||||||
|
for i, symbol1 in enumerate(self.symbols):
|
||||||
|
for j, symbol2 in enumerate(self.symbols):
|
||||||
|
if i != j:
|
||||||
|
# ETH and BTC are typically highly correlated
|
||||||
|
if 'ETH' in symbol1 and 'BTC' in symbol2:
|
||||||
|
correlations[(symbol1, symbol2)] = 0.85
|
||||||
|
elif 'BTC' in symbol1 and 'ETH' in symbol2:
|
||||||
|
correlations[(symbol1, symbol2)] = 0.85
|
||||||
|
else:
|
||||||
|
correlations[(symbol1, symbol2)] = 0.7 # Default correlation
|
||||||
|
return correlations
|
||||||
|
|
||||||
|
async def make_coordinated_decisions(self) -> Dict[str, Optional[TradingAction]]:
|
||||||
|
"""
|
||||||
|
Make coordinated trading decisions across all symbols
|
||||||
|
"""
|
||||||
|
decisions = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get market states for all symbols
|
||||||
|
market_states = await self._get_all_market_states()
|
||||||
|
|
||||||
|
# Get enhanced predictions for all symbols
|
||||||
|
symbol_predictions = {}
|
||||||
|
for symbol in self.symbols:
|
||||||
|
if symbol in market_states:
|
||||||
|
predictions = await self._get_enhanced_predictions(symbol, market_states[symbol])
|
||||||
|
symbol_predictions[symbol] = predictions
|
||||||
|
|
||||||
|
# Coordinate decisions considering symbol correlations
|
||||||
|
for symbol in self.symbols:
|
||||||
|
if symbol in symbol_predictions:
|
||||||
|
decision = await self._make_coordinated_decision(
|
||||||
|
symbol,
|
||||||
|
symbol_predictions[symbol],
|
||||||
|
symbol_predictions,
|
||||||
|
market_states[symbol]
|
||||||
|
)
|
||||||
|
decisions[symbol] = decision
|
||||||
|
|
||||||
|
# Queue for RL evaluation
|
||||||
|
if decision and decision.action != 'HOLD':
|
||||||
|
self._queue_for_rl_evaluation(decision, market_states[symbol])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in coordinated decision making: {e}")
|
||||||
|
|
||||||
|
return decisions
|
||||||
|
|
||||||
|
async def _get_all_market_states(self) -> Dict[str, MarketState]:
|
||||||
|
"""Get current market state for all symbols"""
|
||||||
|
market_states = {}
|
||||||
|
|
||||||
|
for symbol in self.symbols:
|
||||||
|
try:
|
||||||
|
# Get current market data for all timeframes
|
||||||
|
prices = {}
|
||||||
|
features = {}
|
||||||
|
|
||||||
|
for timeframe in self.timeframes:
|
||||||
|
# Get current price
|
||||||
|
current_price = self.data_provider.get_current_price(symbol)
|
||||||
|
if current_price:
|
||||||
|
prices[timeframe] = current_price
|
||||||
|
|
||||||
|
# Get feature matrix for this timeframe
|
||||||
|
feature_matrix = self.data_provider.get_feature_matrix(
|
||||||
|
symbol=symbol,
|
||||||
|
timeframes=[timeframe],
|
||||||
|
window_size=20 # Standard window
|
||||||
|
)
|
||||||
|
if feature_matrix is not None:
|
||||||
|
features[timeframe] = feature_matrix
|
||||||
|
|
||||||
|
if prices and features:
|
||||||
|
# Calculate market metrics
|
||||||
|
volatility = self._calculate_volatility(symbol)
|
||||||
|
volume = self._get_current_volume(symbol)
|
||||||
|
trend_strength = self._calculate_trend_strength(symbol)
|
||||||
|
market_regime = self._determine_market_regime(symbol)
|
||||||
|
|
||||||
|
market_state = MarketState(
|
||||||
|
symbol=symbol,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
prices=prices,
|
||||||
|
features=features,
|
||||||
|
volatility=volatility,
|
||||||
|
volume=volume,
|
||||||
|
trend_strength=trend_strength,
|
||||||
|
market_regime=market_regime
|
||||||
|
)
|
||||||
|
|
||||||
|
market_states[symbol] = market_state
|
||||||
|
|
||||||
|
# Store for historical tracking
|
||||||
|
self.market_states[symbol].append(market_state)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting market state for {symbol}: {e}")
|
||||||
|
|
||||||
|
return market_states
|
||||||
|
|
||||||
|
async def _get_enhanced_predictions(self, symbol: str, market_state: MarketState) -> List[EnhancedPrediction]:
|
||||||
|
"""Get enhanced predictions with timeframe breakdown"""
|
||||||
|
predictions = []
|
||||||
|
|
||||||
|
for model_name, model in self.model_registry.models.items():
|
||||||
|
try:
|
||||||
|
if isinstance(model, CNNModelInterface):
|
||||||
|
# Get CNN predictions for each timeframe
|
||||||
|
timeframe_predictions = []
|
||||||
|
|
||||||
|
for timeframe in self.timeframes:
|
||||||
|
if timeframe in market_state.features:
|
||||||
|
feature_matrix = market_state.features[timeframe]
|
||||||
|
|
||||||
|
# Get timeframe-specific prediction
|
||||||
|
action_probs, confidence = await self._get_timeframe_prediction(
|
||||||
|
model, feature_matrix, timeframe, market_state
|
||||||
|
)
|
||||||
|
|
||||||
|
if action_probs is not None:
|
||||||
|
action_names = ['SELL', 'HOLD', 'BUY']
|
||||||
|
best_action_idx = np.argmax(action_probs)
|
||||||
|
best_action = action_names[best_action_idx]
|
||||||
|
|
||||||
|
# Create timeframe prediction
|
||||||
|
tf_prediction = TimeframePrediction(
|
||||||
|
timeframe=timeframe,
|
||||||
|
action=best_action,
|
||||||
|
confidence=float(confidence),
|
||||||
|
probabilities={name: float(prob) for name, prob in zip(action_names, action_probs)},
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
market_features={
|
||||||
|
'volatility': market_state.volatility,
|
||||||
|
'volume': market_state.volume,
|
||||||
|
'trend_strength': market_state.trend_strength
|
||||||
|
}
|
||||||
|
)
|
||||||
|
timeframe_predictions.append(tf_prediction)
|
||||||
|
|
||||||
|
if timeframe_predictions:
|
||||||
|
# Combine timeframe predictions into overall prediction
|
||||||
|
overall_action, overall_confidence = self._combine_timeframe_predictions(
|
||||||
|
timeframe_predictions, symbol
|
||||||
|
)
|
||||||
|
|
||||||
|
enhanced_pred = EnhancedPrediction(
|
||||||
|
symbol=symbol,
|
||||||
|
timeframe_predictions=timeframe_predictions,
|
||||||
|
overall_action=overall_action,
|
||||||
|
overall_confidence=overall_confidence,
|
||||||
|
model_name=model.name,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
metadata={
|
||||||
|
'market_regime': market_state.market_regime,
|
||||||
|
'symbol_correlation': self._get_symbol_correlation(symbol)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
predictions.append(enhanced_pred)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting enhanced predictions from {model_name}: {e}")
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
async def _get_timeframe_prediction(self, model: CNNModelInterface, feature_matrix: np.ndarray,
|
||||||
|
timeframe: str, market_state: MarketState) -> Tuple[Optional[np.ndarray], float]:
|
||||||
|
"""Get prediction for specific timeframe with enhanced context"""
|
||||||
|
try:
|
||||||
|
# Check if model supports timeframe-specific prediction
|
||||||
|
if hasattr(model, 'predict_timeframe'):
|
||||||
|
action_probs, confidence = model.predict_timeframe(feature_matrix, timeframe)
|
||||||
|
else:
|
||||||
|
action_probs, confidence = model.predict(feature_matrix)
|
||||||
|
|
||||||
|
if action_probs is not None and confidence is not None:
|
||||||
|
# Enhance confidence based on market conditions
|
||||||
|
enhanced_confidence = self._enhance_confidence_with_context(
|
||||||
|
confidence, timeframe, market_state
|
||||||
|
)
|
||||||
|
return action_probs, enhanced_confidence
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting timeframe prediction for {timeframe}: {e}")
|
||||||
|
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
def _enhance_confidence_with_context(self, base_confidence: float, timeframe: str,
|
||||||
|
market_state: MarketState) -> float:
|
||||||
|
"""Enhance confidence score based on market context"""
|
||||||
|
enhanced = base_confidence
|
||||||
|
|
||||||
|
# Adjust based on market regime
|
||||||
|
if market_state.market_regime == 'trending':
|
||||||
|
enhanced *= 1.1 # More confident in trending markets
|
||||||
|
elif market_state.market_regime == 'volatile':
|
||||||
|
enhanced *= 0.8 # Less confident in volatile markets
|
||||||
|
|
||||||
|
# Adjust based on timeframe reliability
|
||||||
|
timeframe_reliability = {
|
||||||
|
'1m': 0.7, '5m': 0.8, '15m': 0.9, '1h': 1.0, '4h': 1.1, '1d': 1.2
|
||||||
|
}
|
||||||
|
enhanced *= timeframe_reliability.get(timeframe, 1.0)
|
||||||
|
|
||||||
|
# Adjust based on volume
|
||||||
|
if market_state.volume > 1.5: # High volume
|
||||||
|
enhanced *= 1.05
|
||||||
|
elif market_state.volume < 0.5: # Low volume
|
||||||
|
enhanced *= 0.9
|
||||||
|
|
||||||
|
return min(enhanced, 1.0) # Cap at 1.0
|
||||||
|
|
||||||
|
def _combine_timeframe_predictions(self, timeframe_predictions: List[TimeframePrediction],
|
||||||
|
symbol: str) -> Tuple[str, float]:
|
||||||
|
"""Combine predictions from multiple timeframes"""
|
||||||
|
action_scores = {'BUY': 0.0, 'SELL': 0.0, 'HOLD': 0.0}
|
||||||
|
total_weight = 0.0
|
||||||
|
|
||||||
|
for tf_pred in timeframe_predictions:
|
||||||
|
# Get timeframe weight
|
||||||
|
tf_weight = self.timeframe_weights.get(tf_pred.timeframe, 0.1)
|
||||||
|
|
||||||
|
# Weight by confidence and timeframe importance
|
||||||
|
weighted_confidence = tf_pred.confidence * tf_weight
|
||||||
|
|
||||||
|
# Add to action scores
|
||||||
|
action_scores[tf_pred.action] += weighted_confidence
|
||||||
|
total_weight += weighted_confidence
|
||||||
|
|
||||||
|
# Normalize scores
|
||||||
|
if total_weight > 0:
|
||||||
|
for action in action_scores:
|
||||||
|
action_scores[action] /= total_weight
|
||||||
|
|
||||||
|
# Get best action and confidence
|
||||||
|
best_action = max(action_scores, key=action_scores.get)
|
||||||
|
best_confidence = action_scores[best_action]
|
||||||
|
|
||||||
|
return best_action, best_confidence
|
||||||
|
|
||||||
|
async def _make_coordinated_decision(self, symbol: str, predictions: List[EnhancedPrediction],
|
||||||
|
all_predictions: Dict[str, List[EnhancedPrediction]],
|
||||||
|
market_state: MarketState) -> Optional[TradingAction]:
|
||||||
|
"""Make decision considering symbol correlations"""
|
||||||
|
if not predictions:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get primary prediction (highest confidence)
|
||||||
|
primary_pred = max(predictions, key=lambda p: p.overall_confidence)
|
||||||
|
|
||||||
|
# Consider correlated symbols
|
||||||
|
correlated_sentiment = self._get_correlated_sentiment(symbol, all_predictions)
|
||||||
|
|
||||||
|
# Adjust decision based on correlation
|
||||||
|
final_action = primary_pred.overall_action
|
||||||
|
final_confidence = primary_pred.overall_confidence
|
||||||
|
|
||||||
|
# If correlated symbols strongly disagree, reduce confidence
|
||||||
|
if correlated_sentiment['agreement'] < 0.5:
|
||||||
|
final_confidence *= 0.8
|
||||||
|
logger.info(f"Reduced confidence for {symbol} due to correlation disagreement")
|
||||||
|
|
||||||
|
# Apply confidence threshold
|
||||||
|
if final_confidence < self.confidence_threshold:
|
||||||
|
final_action = 'HOLD'
|
||||||
|
logger.info(f"Action for {symbol} changed to HOLD due to low confidence: {final_confidence:.3f}")
|
||||||
|
|
||||||
|
# Create trading action
|
||||||
|
if final_action != 'HOLD':
|
||||||
|
current_price = market_state.prices.get(self.timeframes[0], 0)
|
||||||
|
quantity = self._calculate_position_size(symbol, final_action, final_confidence)
|
||||||
|
|
||||||
|
action = TradingAction(
|
||||||
|
symbol=symbol,
|
||||||
|
action=final_action,
|
||||||
|
quantity=quantity,
|
||||||
|
confidence=final_confidence,
|
||||||
|
price=current_price,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
reasoning={
|
||||||
|
'primary_model': primary_pred.model_name,
|
||||||
|
'timeframe_breakdown': [(tf.timeframe, tf.action, tf.confidence)
|
||||||
|
for tf in primary_pred.timeframe_predictions],
|
||||||
|
'correlated_sentiment': correlated_sentiment,
|
||||||
|
'market_regime': market_state.market_regime
|
||||||
|
},
|
||||||
|
timeframe_analysis=primary_pred.timeframe_predictions
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store recent action
|
||||||
|
self.recent_actions[symbol].append(action)
|
||||||
|
|
||||||
|
return action
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error making coordinated decision for {symbol}: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_correlated_sentiment(self, symbol: str,
|
||||||
|
all_predictions: Dict[str, List[EnhancedPrediction]]) -> Dict[str, Any]:
|
||||||
|
"""Get sentiment from correlated symbols"""
|
||||||
|
correlated_actions = []
|
||||||
|
correlated_confidences = []
|
||||||
|
|
||||||
|
for other_symbol, predictions in all_predictions.items():
|
||||||
|
if other_symbol != symbol and predictions:
|
||||||
|
correlation = self.symbol_correlation_matrix.get((symbol, other_symbol), 0.0)
|
||||||
|
|
||||||
|
if correlation > 0.5: # Only consider significantly correlated symbols
|
||||||
|
best_pred = max(predictions, key=lambda p: p.overall_confidence)
|
||||||
|
correlated_actions.append(best_pred.overall_action)
|
||||||
|
correlated_confidences.append(best_pred.overall_confidence * correlation)
|
||||||
|
|
||||||
|
if not correlated_actions:
|
||||||
|
return {'agreement': 1.0, 'sentiment': 'NEUTRAL'}
|
||||||
|
|
||||||
|
# Calculate agreement
|
||||||
|
primary_pred = all_predictions[symbol][0] if all_predictions.get(symbol) else None
|
||||||
|
if primary_pred:
|
||||||
|
agreement_count = sum(1 for action in correlated_actions
|
||||||
|
if action == primary_pred.overall_action)
|
||||||
|
agreement = agreement_count / len(correlated_actions)
|
||||||
|
else:
|
||||||
|
agreement = 0.5
|
||||||
|
|
||||||
|
# Calculate overall sentiment
|
||||||
|
action_weights = {'BUY': 0.0, 'SELL': 0.0, 'HOLD': 0.0}
|
||||||
|
for action, confidence in zip(correlated_actions, correlated_confidences):
|
||||||
|
action_weights[action] += confidence
|
||||||
|
|
||||||
|
dominant_sentiment = max(action_weights, key=action_weights.get)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'agreement': agreement,
|
||||||
|
'sentiment': dominant_sentiment,
|
||||||
|
'correlated_symbols': len(correlated_actions)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _queue_for_rl_evaluation(self, action: TradingAction, market_state: MarketState):
|
||||||
|
"""Queue trading action for RL evaluation"""
|
||||||
|
evaluation_item = {
|
||||||
|
'action': action,
|
||||||
|
'market_state_before': market_state,
|
||||||
|
'timestamp': datetime.now(),
|
||||||
|
'evaluation_pending': True
|
||||||
|
}
|
||||||
|
self.rl_evaluation_queue.append(evaluation_item)
|
||||||
|
|
||||||
|
async def evaluate_actions_with_rl(self):
|
||||||
|
"""Evaluate recent actions using RL agents for continuous learning"""
|
||||||
|
if not self.rl_evaluation_queue:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_time = datetime.now()
|
||||||
|
|
||||||
|
# Process actions that are ready for evaluation (e.g., 1 hour old)
|
||||||
|
for item in list(self.rl_evaluation_queue):
|
||||||
|
if item['evaluation_pending']:
|
||||||
|
time_since_action = (current_time - item['timestamp']).total_seconds()
|
||||||
|
|
||||||
|
# Evaluate after sufficient time has passed
|
||||||
|
if time_since_action >= 3600: # 1 hour
|
||||||
|
await self._evaluate_single_action(item)
|
||||||
|
item['evaluation_pending'] = False
|
||||||
|
|
||||||
|
async def _evaluate_single_action(self, evaluation_item: Dict[str, Any]):
|
||||||
|
"""Evaluate a single action using RL"""
|
||||||
|
try:
|
||||||
|
action = evaluation_item['action']
|
||||||
|
initial_state = evaluation_item['market_state_before']
|
||||||
|
|
||||||
|
# Get current market state for comparison
|
||||||
|
current_market_states = await self._get_all_market_states()
|
||||||
|
current_state = current_market_states.get(action.symbol)
|
||||||
|
|
||||||
|
if current_state:
|
||||||
|
# Calculate reward based on price movement
|
||||||
|
initial_price = initial_state.prices.get(self.timeframes[0], 0)
|
||||||
|
current_price = current_state.prices.get(self.timeframes[0], 0)
|
||||||
|
|
||||||
|
if initial_price > 0:
|
||||||
|
price_change = (current_price - initial_price) / initial_price
|
||||||
|
|
||||||
|
# Calculate reward based on action and price movement
|
||||||
|
reward = self._calculate_reward(action.action, price_change, action.confidence)
|
||||||
|
|
||||||
|
# Update RL agents
|
||||||
|
await self._update_rl_agents(action, initial_state, current_state, reward)
|
||||||
|
|
||||||
|
# Check if this was a perfect move for CNN training
|
||||||
|
if abs(reward) > 0.02: # Significant outcome
|
||||||
|
self._mark_perfect_move(action, initial_state, current_state, reward)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error evaluating action: {e}")
|
||||||
|
|
||||||
|
def _calculate_reward(self, action: str, price_change: float, confidence: float) -> float:
|
||||||
|
"""Calculate reward for RL training"""
|
||||||
|
base_reward = 0.0
|
||||||
|
|
||||||
|
if action == 'BUY' and price_change > 0:
|
||||||
|
base_reward = price_change * 10 # Reward proportional to gain
|
||||||
|
elif action == 'SELL' and price_change < 0:
|
||||||
|
base_reward = abs(price_change) * 10 # Reward for avoiding loss
|
||||||
|
elif action == 'HOLD':
|
||||||
|
base_reward = 0.01 if abs(price_change) < 0.005 else -0.01 # Small reward for correct holds
|
||||||
|
else:
|
||||||
|
base_reward = -abs(price_change) * 5 # Penalty for wrong actions
|
||||||
|
|
||||||
|
# Adjust reward based on confidence
|
||||||
|
confidence_multiplier = 0.5 + confidence # 0.5 to 1.5 range
|
||||||
|
|
||||||
|
return base_reward * confidence_multiplier
|
||||||
|
|
||||||
|
async def _update_rl_agents(self, action: TradingAction, initial_state: MarketState,
|
||||||
|
current_state: MarketState, reward: float):
|
||||||
|
"""Update RL agents with action evaluation"""
|
||||||
|
for model_name, model in self.model_registry.models.items():
|
||||||
|
if isinstance(model, RLAgentInterface):
|
||||||
|
try:
|
||||||
|
# Convert market states to RL state format
|
||||||
|
initial_rl_state = self._market_state_to_rl_state(initial_state)
|
||||||
|
current_rl_state = self._market_state_to_rl_state(current_state)
|
||||||
|
|
||||||
|
# Convert action to RL action index
|
||||||
|
action_idx = {'SELL': 0, 'HOLD': 1, 'BUY': 2}.get(action.action, 1)
|
||||||
|
|
||||||
|
# Store experience
|
||||||
|
model.remember(
|
||||||
|
state=initial_rl_state,
|
||||||
|
action=action_idx,
|
||||||
|
reward=reward,
|
||||||
|
next_state=current_rl_state,
|
||||||
|
done=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger replay learning
|
||||||
|
loss = model.replay()
|
||||||
|
if loss is not None:
|
||||||
|
logger.info(f"RL agent {model_name} updated with loss: {loss:.4f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating RL agent {model_name}: {e}")
|
||||||
|
|
||||||
|
def _mark_perfect_move(self, action: TradingAction, initial_state: MarketState,
|
||||||
|
final_state: MarketState, reward: float):
|
||||||
|
"""Mark a perfect move for CNN training"""
|
||||||
|
try:
|
||||||
|
# Determine what the optimal action should have been
|
||||||
|
optimal_action = action.action if reward > 0 else ('HOLD' if action.action == 'HOLD' else
|
||||||
|
('SELL' if action.action == 'BUY' else 'BUY'))
|
||||||
|
|
||||||
|
# Calculate what confidence should have been
|
||||||
|
optimal_confidence = min(0.95, abs(reward) * 10) # Higher reward = higher confidence should have been
|
||||||
|
|
||||||
|
for tf_pred in action.timeframe_analysis:
|
||||||
|
perfect_move = PerfectMove(
|
||||||
|
symbol=action.symbol,
|
||||||
|
timeframe=tf_pred.timeframe,
|
||||||
|
timestamp=action.timestamp,
|
||||||
|
optimal_action=optimal_action,
|
||||||
|
actual_outcome=reward,
|
||||||
|
market_state_before=initial_state,
|
||||||
|
market_state_after=final_state,
|
||||||
|
confidence_should_have_been=optimal_confidence
|
||||||
|
)
|
||||||
|
self.perfect_moves.append(perfect_move)
|
||||||
|
|
||||||
|
logger.info(f"Marked perfect move for {action.symbol}: {optimal_action} with confidence {optimal_confidence:.3f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error marking perfect move: {e}")
|
||||||
|
|
||||||
|
def get_perfect_moves_for_training(self, symbol: str = None, timeframe: str = None,
|
||||||
|
limit: int = 1000) -> List[PerfectMove]:
|
||||||
|
"""Get perfect moves for CNN training"""
|
||||||
|
moves = list(self.perfect_moves)
|
||||||
|
|
||||||
|
if symbol:
|
||||||
|
moves = [m for m in moves if m.symbol == symbol]
|
||||||
|
|
||||||
|
if timeframe:
|
||||||
|
moves = [m for m in moves if m.timeframe == timeframe]
|
||||||
|
|
||||||
|
return moves[-limit:] if limit else moves
|
||||||
|
|
||||||
|
# Helper methods for market analysis
|
||||||
|
def _calculate_volatility(self, symbol: str) -> float:
|
||||||
|
"""Calculate current volatility for symbol"""
|
||||||
|
# Placeholder - implement based on your data provider
|
||||||
|
return 0.02 # 2% default volatility
|
||||||
|
|
||||||
|
def _get_current_volume(self, symbol: str) -> float:
|
||||||
|
"""Get current volume ratio compared to average"""
|
||||||
|
# Placeholder - implement based on your data provider
|
||||||
|
return 1.0 # Normal volume
|
||||||
|
|
||||||
|
def _calculate_trend_strength(self, symbol: str) -> float:
|
||||||
|
"""Calculate trend strength (0 = no trend, 1 = strong trend)"""
|
||||||
|
# Placeholder - implement based on your data provider
|
||||||
|
return 0.5 # Moderate trend
|
||||||
|
|
||||||
|
def _determine_market_regime(self, symbol: str) -> str:
|
||||||
|
"""Determine current market regime"""
|
||||||
|
# Placeholder - implement based on your analysis
|
||||||
|
return 'trending' # Default to trending
|
||||||
|
|
||||||
|
def _get_symbol_correlation(self, symbol: str) -> Dict[str, float]:
|
||||||
|
"""Get correlations with other symbols"""
|
||||||
|
correlations = {}
|
||||||
|
for other_symbol in self.symbols:
|
||||||
|
if other_symbol != symbol:
|
||||||
|
correlations[other_symbol] = self.symbol_correlation_matrix.get((symbol, other_symbol), 0.0)
|
||||||
|
return correlations
|
||||||
|
|
||||||
|
def _calculate_position_size(self, symbol: str, action: str, confidence: float) -> float:
|
||||||
|
"""Calculate position size based on confidence and risk management"""
|
||||||
|
base_size = 0.02 # 2% of portfolio
|
||||||
|
confidence_multiplier = confidence # Scale by confidence
|
||||||
|
max_size = 0.05 # 5% maximum
|
||||||
|
|
||||||
|
return min(base_size * confidence_multiplier, max_size)
|
||||||
|
|
||||||
|
def _market_state_to_rl_state(self, market_state: MarketState) -> np.ndarray:
|
||||||
|
"""Convert market state to RL state vector"""
|
||||||
|
# Combine features from all timeframes into a single state vector
|
||||||
|
state_components = []
|
||||||
|
|
||||||
|
# Add price features
|
||||||
|
state_components.extend([
|
||||||
|
market_state.volatility,
|
||||||
|
market_state.volume,
|
||||||
|
market_state.trend_strength
|
||||||
|
])
|
||||||
|
|
||||||
|
# Add flattened features from each timeframe
|
||||||
|
for timeframe in sorted(market_state.features.keys()):
|
||||||
|
features = market_state.features[timeframe]
|
||||||
|
if features is not None:
|
||||||
|
# Take the last row (most recent) and flatten
|
||||||
|
latest_features = features[-1] if len(features.shape) > 1 else features
|
||||||
|
state_components.extend(latest_features.flatten())
|
||||||
|
|
||||||
|
return np.array(state_components, dtype=np.float32)
|
370
enhanced_trading_main.py
Normal file
370
enhanced_trading_main.py
Normal file
@ -0,0 +1,370 @@
|
|||||||
|
"""
|
||||||
|
Enhanced Multi-Modal Trading System - Main Application
|
||||||
|
|
||||||
|
This is the main launcher for the sophisticated trading system featuring:
|
||||||
|
1. Enhanced orchestrator coordinating CNN and RL modules
|
||||||
|
2. Multi-timeframe, multi-symbol (ETH, BTC) trading decisions
|
||||||
|
3. Perfect move marking for CNN training with known outcomes
|
||||||
|
4. Continuous RL learning from trading action evaluations
|
||||||
|
5. Market environment adaptation and coordinated decision making
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
# Core components
|
||||||
|
from core.config import get_config
|
||||||
|
from core.data_provider import DataProvider
|
||||||
|
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||||
|
from models import get_model_registry
|
||||||
|
|
||||||
|
# Training components
|
||||||
|
from training.enhanced_cnn_trainer import EnhancedCNNTrainer, EnhancedCNNModel
|
||||||
|
from training.enhanced_rl_trainer import EnhancedRLTrainer, EnhancedDQNAgent
|
||||||
|
|
||||||
|
# Utilities
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(),
|
||||||
|
logging.FileHandler('logs/enhanced_trading.log')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class EnhancedTradingSystem:
|
||||||
|
"""Main enhanced trading system coordinator"""
|
||||||
|
|
||||||
|
def __init__(self, config_path: str = None):
|
||||||
|
"""Initialize the enhanced trading system"""
|
||||||
|
self.config = get_config(config_path)
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
# Core components
|
||||||
|
self.data_provider = DataProvider(self.config)
|
||||||
|
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
||||||
|
self.model_registry = get_model_registry()
|
||||||
|
|
||||||
|
# Training components
|
||||||
|
self.cnn_trainer = EnhancedCNNTrainer(self.config, self.orchestrator)
|
||||||
|
self.rl_trainer = EnhancedRLTrainer(self.config, self.orchestrator)
|
||||||
|
|
||||||
|
# Models
|
||||||
|
self.cnn_models = {}
|
||||||
|
self.rl_agents = {}
|
||||||
|
|
||||||
|
# Performance tracking
|
||||||
|
self.performance_metrics = {
|
||||||
|
'decisions_made': 0,
|
||||||
|
'perfect_moves_marked': 0,
|
||||||
|
'rl_experiences_added': 0,
|
||||||
|
'training_sessions': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("Enhanced Trading System initialized")
|
||||||
|
logger.info(f"Symbols: {self.config.symbols}")
|
||||||
|
logger.info(f"Timeframes: {self.config.timeframes}")
|
||||||
|
|
||||||
|
async def initialize_models(self, load_existing: bool = True):
|
||||||
|
"""Initialize and register all models"""
|
||||||
|
logger.info("Initializing models...")
|
||||||
|
|
||||||
|
# Initialize CNN models
|
||||||
|
if load_existing:
|
||||||
|
# Try to load existing CNN model
|
||||||
|
if self.cnn_trainer.load_model('best_model.pt'):
|
||||||
|
logger.info("Loaded existing CNN model")
|
||||||
|
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
|
||||||
|
else:
|
||||||
|
logger.info("No existing CNN model found, using fresh model")
|
||||||
|
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
|
||||||
|
else:
|
||||||
|
logger.info("Creating fresh CNN model")
|
||||||
|
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
|
||||||
|
|
||||||
|
# Initialize RL agents
|
||||||
|
if load_existing:
|
||||||
|
# Try to load existing RL agents
|
||||||
|
if self.rl_trainer.load_models():
|
||||||
|
logger.info("Loaded existing RL models")
|
||||||
|
else:
|
||||||
|
logger.info("No existing RL models found, using fresh agents")
|
||||||
|
|
||||||
|
self.rl_agents = self.rl_trainer.get_agents()
|
||||||
|
|
||||||
|
# Register models with the orchestrator
|
||||||
|
for model_name, model in self.cnn_models.items():
|
||||||
|
if self.model_registry.register_model(model):
|
||||||
|
logger.info(f"Registered CNN model: {model_name}")
|
||||||
|
|
||||||
|
for symbol, agent in self.rl_agents.items():
|
||||||
|
if self.model_registry.register_model(agent):
|
||||||
|
logger.info(f"Registered RL agent for {symbol}")
|
||||||
|
|
||||||
|
# Display memory usage
|
||||||
|
memory_stats = self.model_registry.get_memory_stats()
|
||||||
|
logger.info(f"Total memory usage: {memory_stats['total_used_mb']:.1f}MB / "
|
||||||
|
f"{memory_stats['total_limit_mb']:.1f}MB "
|
||||||
|
f"({memory_stats['utilization_percent']:.1f}%)")
|
||||||
|
|
||||||
|
async def start_trading_loop(self):
|
||||||
|
"""Start the main trading decision loop"""
|
||||||
|
logger.info("Starting enhanced trading loop...")
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
decision_count = 0
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
# Make coordinated decisions for all symbols
|
||||||
|
decisions = await self.orchestrator.make_coordinated_decisions()
|
||||||
|
|
||||||
|
# Process decisions
|
||||||
|
for symbol, decision in decisions.items():
|
||||||
|
if decision:
|
||||||
|
decision_count += 1
|
||||||
|
self.performance_metrics['decisions_made'] += 1
|
||||||
|
|
||||||
|
logger.info(f"Trading Decision #{decision_count}")
|
||||||
|
logger.info(f"Symbol: {symbol}")
|
||||||
|
logger.info(f"Action: {decision.action}")
|
||||||
|
logger.info(f"Confidence: {decision.confidence:.3f}")
|
||||||
|
logger.info(f"Price: ${decision.price:.2f}")
|
||||||
|
logger.info(f"Quantity: {decision.quantity:.6f}")
|
||||||
|
|
||||||
|
# Log timeframe analysis
|
||||||
|
for tf_pred in decision.timeframe_analysis:
|
||||||
|
logger.info(f" {tf_pred.timeframe}: {tf_pred.action} "
|
||||||
|
f"(conf: {tf_pred.confidence:.3f})")
|
||||||
|
|
||||||
|
# Here you would integrate with actual trading execution
|
||||||
|
# For now, we just log the decision
|
||||||
|
|
||||||
|
# Evaluate past actions with RL
|
||||||
|
await self.orchestrator.evaluate_actions_with_rl()
|
||||||
|
|
||||||
|
# Check for perfect moves to mark
|
||||||
|
perfect_moves = self.orchestrator.get_perfect_moves_for_training(limit=10)
|
||||||
|
if perfect_moves:
|
||||||
|
self.performance_metrics['perfect_moves_marked'] = len(perfect_moves)
|
||||||
|
|
||||||
|
# Log performance metrics every 10 decisions
|
||||||
|
if decision_count % 10 == 0 and decision_count > 0:
|
||||||
|
await self._log_performance_metrics()
|
||||||
|
|
||||||
|
# Wait before next decision cycle
|
||||||
|
await asyncio.sleep(self.orchestrator.decision_frequency)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in trading loop: {e}")
|
||||||
|
await asyncio.sleep(30) # Wait 30 seconds on error
|
||||||
|
|
||||||
|
async def start_training_loops(self):
|
||||||
|
"""Start continuous training loops"""
|
||||||
|
logger.info("Starting continuous training loops...")
|
||||||
|
|
||||||
|
# Start RL continuous learning
|
||||||
|
rl_task = asyncio.create_task(self.rl_trainer.continuous_learning_loop())
|
||||||
|
|
||||||
|
# Start periodic CNN training
|
||||||
|
cnn_task = asyncio.create_task(self._periodic_cnn_training())
|
||||||
|
|
||||||
|
return rl_task, cnn_task
|
||||||
|
|
||||||
|
async def _periodic_cnn_training(self):
|
||||||
|
"""Periodic CNN training on accumulated perfect moves"""
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
# Wait for 6 hours between training sessions
|
||||||
|
await asyncio.sleep(6 * 3600)
|
||||||
|
|
||||||
|
# Check if we have enough perfect moves for training
|
||||||
|
perfect_moves = []
|
||||||
|
for symbol in self.config.symbols:
|
||||||
|
symbol_moves = self.orchestrator.get_perfect_moves_for_training(symbol=symbol)
|
||||||
|
perfect_moves.extend(symbol_moves)
|
||||||
|
|
||||||
|
if len(perfect_moves) >= 200: # Minimum 200 perfect moves
|
||||||
|
logger.info(f"Starting CNN training on {len(perfect_moves)} perfect moves")
|
||||||
|
|
||||||
|
# Train the CNN model
|
||||||
|
training_report = self.cnn_trainer.train_on_perfect_moves(min_samples=200)
|
||||||
|
|
||||||
|
if training_report.get('training_completed'):
|
||||||
|
self.performance_metrics['training_sessions'] += 1
|
||||||
|
logger.info("CNN training completed successfully")
|
||||||
|
logger.info(f"Final validation accuracy: "
|
||||||
|
f"{training_report['final_metrics']['val_accuracy']:.4f}")
|
||||||
|
|
||||||
|
# Update the registered model
|
||||||
|
updated_model = self.cnn_trainer.get_model()
|
||||||
|
self.model_registry.unregister_model('enhanced_cnn')
|
||||||
|
self.model_registry.register_model(updated_model)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f"CNN training failed: {training_report}")
|
||||||
|
else:
|
||||||
|
logger.info(f"Not enough perfect moves for training: {len(perfect_moves)} < 200")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in periodic CNN training: {e}")
|
||||||
|
|
||||||
|
async def _log_performance_metrics(self):
|
||||||
|
"""Log system performance metrics"""
|
||||||
|
logger.info("=== SYSTEM PERFORMANCE METRICS ===")
|
||||||
|
logger.info(f"Decisions made: {self.performance_metrics['decisions_made']}")
|
||||||
|
logger.info(f"Perfect moves marked: {self.performance_metrics['perfect_moves_marked']}")
|
||||||
|
logger.info(f"Training sessions: {self.performance_metrics['training_sessions']}")
|
||||||
|
|
||||||
|
# Model registry stats
|
||||||
|
memory_stats = self.model_registry.get_memory_stats()
|
||||||
|
logger.info(f"Memory usage: {memory_stats['total_used_mb']:.1f}MB / "
|
||||||
|
f"{memory_stats['total_limit_mb']:.1f}MB")
|
||||||
|
|
||||||
|
# RL performance
|
||||||
|
rl_report = self.rl_trainer.get_performance_report()
|
||||||
|
for symbol, agent_data in rl_report['agents'].items():
|
||||||
|
logger.info(f"{symbol} RL: Epsilon={agent_data['epsilon']:.3f}, "
|
||||||
|
f"Experiences={agent_data['experiences_stored']}, "
|
||||||
|
f"Avg Reward={agent_data['avg_recent_reward']:.4f}")
|
||||||
|
|
||||||
|
# CNN model info
|
||||||
|
for model_name, model in self.cnn_models.items():
|
||||||
|
logger.info(f"{model_name}: Memory={model.get_memory_usage()}MB, "
|
||||||
|
f"Device={model.device}")
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""Graceful shutdown of the system"""
|
||||||
|
logger.info("Shutting down Enhanced Trading System...")
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
# Save models
|
||||||
|
logger.info("Saving models...")
|
||||||
|
self.cnn_trainer._save_model('shutdown_model.pt')
|
||||||
|
self.rl_trainer._save_all_models()
|
||||||
|
|
||||||
|
# Clean up memory
|
||||||
|
self.model_registry.cleanup_all_models()
|
||||||
|
|
||||||
|
# Generate final reports
|
||||||
|
logger.info("Generating final reports...")
|
||||||
|
|
||||||
|
# CNN training plots
|
||||||
|
if self.cnn_trainer.training_history['train_loss']:
|
||||||
|
self.cnn_trainer._plot_training_history()
|
||||||
|
|
||||||
|
# RL training plots
|
||||||
|
self.rl_trainer.plot_training_metrics()
|
||||||
|
|
||||||
|
logger.info("Enhanced Trading System shutdown complete")
|
||||||
|
|
||||||
|
def setup_signal_handlers(trading_system: EnhancedTradingSystem):
|
||||||
|
"""Setup signal handlers for graceful shutdown"""
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
logger.info(f"Received signal {signum}, initiating shutdown...")
|
||||||
|
asyncio.create_task(trading_system.shutdown())
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main application entry point"""
|
||||||
|
parser = argparse.ArgumentParser(description='Enhanced Multi-Modal Trading System')
|
||||||
|
parser.add_argument('--config', type=str, help='Configuration file path')
|
||||||
|
parser.add_argument('--mode', type=str, choices=['trade', 'train', 'backtest'],
|
||||||
|
default='trade', help='Operation mode')
|
||||||
|
parser.add_argument('--load-models', action='store_true', default=True,
|
||||||
|
help='Load existing models')
|
||||||
|
parser.add_argument('--no-load-models', action='store_false', dest='load_models',
|
||||||
|
help="Don't load existing models")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create logs directory
|
||||||
|
Path('logs').mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
logger.info("=== ENHANCED MULTI-MODAL TRADING SYSTEM ===")
|
||||||
|
logger.info(f"Mode: {args.mode}")
|
||||||
|
logger.info(f"Load existing models: {args.load_models}")
|
||||||
|
logger.info(f"PyTorch version: {torch.__version__}")
|
||||||
|
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
||||||
|
|
||||||
|
# Initialize trading system
|
||||||
|
trading_system = EnhancedTradingSystem(args.config)
|
||||||
|
|
||||||
|
# Setup signal handlers
|
||||||
|
setup_signal_handlers(trading_system)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize models
|
||||||
|
await trading_system.initialize_models(load_existing=args.load_models)
|
||||||
|
|
||||||
|
if args.mode == 'trade':
|
||||||
|
# Start training loops
|
||||||
|
rl_task, cnn_task = await trading_system.start_training_loops()
|
||||||
|
|
||||||
|
# Start main trading loop
|
||||||
|
trading_task = asyncio.create_task(trading_system.start_trading_loop())
|
||||||
|
|
||||||
|
# Wait for any task to complete (or error)
|
||||||
|
done, pending = await asyncio.wait(
|
||||||
|
[trading_task, rl_task, cnn_task],
|
||||||
|
return_when=asyncio.FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cancel remaining tasks
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
elif args.mode == 'train':
|
||||||
|
# Training-only mode
|
||||||
|
logger.info("Running in training-only mode...")
|
||||||
|
|
||||||
|
# Train CNN if we have perfect moves
|
||||||
|
perfect_moves = []
|
||||||
|
for symbol in trading_system.config.symbols:
|
||||||
|
symbol_moves = trading_system.orchestrator.get_perfect_moves_for_training(symbol=symbol)
|
||||||
|
perfect_moves.extend(symbol_moves)
|
||||||
|
|
||||||
|
if len(perfect_moves) >= 100:
|
||||||
|
logger.info(f"Training CNN on {len(perfect_moves)} perfect moves")
|
||||||
|
training_report = trading_system.cnn_trainer.train_on_perfect_moves(min_samples=100)
|
||||||
|
logger.info(f"CNN training report: {training_report}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Not enough perfect moves for training: {len(perfect_moves)}")
|
||||||
|
|
||||||
|
# Train RL agents if they have experiences
|
||||||
|
await trading_system.rl_trainer._train_all_agents()
|
||||||
|
|
||||||
|
elif args.mode == 'backtest':
|
||||||
|
# Backtesting mode
|
||||||
|
logger.info("Backtesting mode not implemented yet")
|
||||||
|
return
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Received keyboard interrupt")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
await trading_system.shutdown()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run the main application
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Application terminated by user")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Fatal error: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
@ -5,4 +5,11 @@ pandas>=2.0.0
|
|||||||
numpy>=1.24.0
|
numpy>=1.24.0
|
||||||
python-dotenv>=1.0.0
|
python-dotenv>=1.0.0
|
||||||
psutil>=5.9.0
|
psutil>=5.9.0
|
||||||
tensorboard>=2.15.0
|
tensorboard>=2.15.0
|
||||||
|
torch>=2.0.0
|
||||||
|
torchvision>=0.15.0
|
||||||
|
torchaudio>=2.0.0
|
||||||
|
scikit-learn>=1.3.0
|
||||||
|
matplotlib>=3.7.0
|
||||||
|
seaborn>=0.12.0
|
||||||
|
asyncio-compat>=0.1.2
|
168
run_enhanced_dashboard.py
Normal file
168
run_enhanced_dashboard.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Run Enhanced Trading Dashboard
|
||||||
|
|
||||||
|
This script starts the web dashboard with the enhanced trading system
|
||||||
|
for real-time monitoring and visualization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from threading import Thread
|
||||||
|
import time
|
||||||
|
|
||||||
|
from core.config import get_config, setup_logging
|
||||||
|
from core.data_provider import DataProvider
|
||||||
|
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||||
|
from web.dashboard import TradingDashboard
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class EnhancedDashboardRunner:
|
||||||
|
"""Enhanced dashboard runner with mock trading simulation"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the enhanced dashboard"""
|
||||||
|
self.config = get_config()
|
||||||
|
self.data_provider = DataProvider(self.config)
|
||||||
|
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
||||||
|
|
||||||
|
# Create dashboard with enhanced orchestrator
|
||||||
|
self.dashboard = TradingDashboard(
|
||||||
|
data_provider=self.data_provider,
|
||||||
|
orchestrator=self.orchestrator
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simulation state
|
||||||
|
self.running = False
|
||||||
|
self.simulation_thread = None
|
||||||
|
|
||||||
|
logger.info("Enhanced dashboard runner initialized")
|
||||||
|
|
||||||
|
def start_simulation(self):
|
||||||
|
"""Start background simulation for demonstration"""
|
||||||
|
self.running = True
|
||||||
|
self.simulation_thread = Thread(target=self._simulation_loop, daemon=True)
|
||||||
|
self.simulation_thread.start()
|
||||||
|
logger.info("Started enhanced trading simulation")
|
||||||
|
|
||||||
|
def _simulation_loop(self):
|
||||||
|
"""Background simulation loop"""
|
||||||
|
import random
|
||||||
|
from datetime import datetime
|
||||||
|
from core.enhanced_orchestrator import TradingAction, TimeframePrediction
|
||||||
|
|
||||||
|
action_count = 0
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
# Simulate trading decisions for demonstration
|
||||||
|
for symbol in self.config.symbols:
|
||||||
|
# Create mock timeframe predictions
|
||||||
|
timeframe_predictions = []
|
||||||
|
for timeframe in ['1h', '4h', '1d']:
|
||||||
|
# Random but realistic predictions
|
||||||
|
action_probs = [
|
||||||
|
random.uniform(0.1, 0.4), # SELL
|
||||||
|
random.uniform(0.3, 0.6), # HOLD
|
||||||
|
random.uniform(0.1, 0.4) # BUY
|
||||||
|
]
|
||||||
|
# Normalize probabilities
|
||||||
|
total = sum(action_probs)
|
||||||
|
action_probs = [p/total for p in action_probs]
|
||||||
|
|
||||||
|
best_action_idx = action_probs.index(max(action_probs))
|
||||||
|
actions = ['SELL', 'HOLD', 'BUY']
|
||||||
|
best_action = actions[best_action_idx]
|
||||||
|
|
||||||
|
tf_pred = TimeframePrediction(
|
||||||
|
timeframe=timeframe,
|
||||||
|
action=best_action,
|
||||||
|
confidence=random.uniform(0.5, 0.9),
|
||||||
|
probabilities={
|
||||||
|
'SELL': action_probs[0],
|
||||||
|
'HOLD': action_probs[1],
|
||||||
|
'BUY': action_probs[2]
|
||||||
|
},
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
market_features={
|
||||||
|
'volatility': random.uniform(0.01, 0.05),
|
||||||
|
'volume': random.uniform(1000, 10000),
|
||||||
|
'trend_strength': random.uniform(0.3, 0.8)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
timeframe_predictions.append(tf_pred)
|
||||||
|
|
||||||
|
# Create mock trading action
|
||||||
|
if random.random() > 0.7: # 30% chance of action
|
||||||
|
action_count += 1
|
||||||
|
mock_action = TradingAction(
|
||||||
|
symbol=symbol,
|
||||||
|
action=random.choice(['BUY', 'SELL']),
|
||||||
|
quantity=random.uniform(0.01, 0.1),
|
||||||
|
confidence=random.uniform(0.6, 0.9),
|
||||||
|
price=random.uniform(2000, 4000) if 'ETH' in symbol else random.uniform(40000, 70000),
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
reasoning={
|
||||||
|
'model': 'Enhanced Multi-Modal',
|
||||||
|
'timeframe_consensus': 'Strong',
|
||||||
|
'market_regime': random.choice(['trending', 'ranging', 'volatile']),
|
||||||
|
'action_count': action_count
|
||||||
|
},
|
||||||
|
timeframe_analysis=timeframe_predictions
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to dashboard
|
||||||
|
self.dashboard.add_trading_decision(mock_action)
|
||||||
|
|
||||||
|
logger.info(f"Simulated {mock_action.action} for {symbol} "
|
||||||
|
f"(confidence: {mock_action.confidence:.2f})")
|
||||||
|
|
||||||
|
# Sleep for next iteration
|
||||||
|
time.sleep(10) # Update every 10 seconds
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in simulation loop: {e}")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
def run_dashboard(self, host='127.0.0.1', port=8050):
|
||||||
|
"""Run the enhanced dashboard"""
|
||||||
|
logger.info(f"Starting enhanced trading dashboard at http://{host}:{port}")
|
||||||
|
logger.info("Features:")
|
||||||
|
logger.info("- Multi-modal CNN + RL predictions")
|
||||||
|
logger.info("- Multi-timeframe analysis")
|
||||||
|
logger.info("- Real-time market regime detection")
|
||||||
|
logger.info("- Perfect move tracking for CNN training")
|
||||||
|
logger.info("- RL feedback loop evaluation")
|
||||||
|
|
||||||
|
# Start simulation
|
||||||
|
self.start_simulation()
|
||||||
|
|
||||||
|
# Run dashboard
|
||||||
|
try:
|
||||||
|
self.dashboard.run(host=host, port=port, debug=False)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Dashboard stopped by user")
|
||||||
|
finally:
|
||||||
|
self.running = False
|
||||||
|
if self.simulation_thread:
|
||||||
|
self.simulation_thread.join(timeout=2)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function"""
|
||||||
|
try:
|
||||||
|
logger.info("=== ENHANCED TRADING DASHBOARD ===")
|
||||||
|
|
||||||
|
# Create and run dashboard
|
||||||
|
runner = EnhancedDashboardRunner()
|
||||||
|
runner.run_dashboard()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Fatal error: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
60
test_enhanced_system.py
Normal file
60
test_enhanced_system.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Simple test script for the enhanced trading system
|
||||||
|
Tests basic functionality without complex training loops
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from core.config import get_config, setup_logging
|
||||||
|
from core.data_provider import DataProvider
|
||||||
|
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def test_enhanced_system():
|
||||||
|
"""Test the enhanced trading system components"""
|
||||||
|
try:
|
||||||
|
logger.info("=== TESTING ENHANCED TRADING SYSTEM ===")
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
config = get_config()
|
||||||
|
logger.info(f"Loaded config with symbols: {config.symbols}")
|
||||||
|
logger.info(f"Timeframes: {config.timeframes}")
|
||||||
|
|
||||||
|
# Initialize data provider
|
||||||
|
data_provider = DataProvider(config)
|
||||||
|
logger.info("Data provider initialized")
|
||||||
|
|
||||||
|
# Initialize enhanced orchestrator orchestrator = EnhancedTradingOrchestrator(data_provider) logger.info("Enhanced orchestrator initialized")
|
||||||
|
|
||||||
|
# Test basic functionality
|
||||||
|
logger.info("Testing orchestrator functionality...")
|
||||||
|
|
||||||
|
# Test market state creation
|
||||||
|
for symbol in config.symbols[:1]: # Test with first symbol only
|
||||||
|
logger.info(f"Testing with symbol: {symbol}")
|
||||||
|
|
||||||
|
# Test basic orchestrator methods logger.info("Testing timeframe weights...") weights = orchestrator._initialize_timeframe_weights() logger.info(f"Timeframe weights: {weights}") logger.info("Testing correlation matrix...") correlations = orchestrator._initialize_correlation_matrix() logger.info(f"Symbol correlations: {correlations}")
|
||||||
|
|
||||||
|
# Test basic functionality logger.info("Basic orchestrator functionality tested successfully")
|
||||||
|
|
||||||
|
break # Test with one symbol only
|
||||||
|
|
||||||
|
logger.info("=== ENHANCED SYSTEM TEST COMPLETED SUCCESSFULLY ===")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Test failed: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = asyncio.run(test_enhanced_system())
|
||||||
|
if success:
|
||||||
|
print("\n✅ Enhanced system test PASSED")
|
||||||
|
else:
|
||||||
|
print("\n❌ Enhanced system test FAILED")
|
566
training/enhanced_cnn_trainer.py
Normal file
566
training/enhanced_cnn_trainer.py
Normal file
@ -0,0 +1,566 @@
|
|||||||
|
"""
|
||||||
|
Enhanced CNN Trainer with Perfect Move Learning
|
||||||
|
|
||||||
|
This trainer implements:
|
||||||
|
1. Training on marked perfect moves with known outcomes
|
||||||
|
2. Multi-timeframe CNN model training with confidence scoring
|
||||||
|
3. Backpropagation on optimal moves when future outcomes are known
|
||||||
|
4. Progressive learning from real trading experience
|
||||||
|
5. Symbol-specific and timeframe-specific model fine-tuning
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Dict, List, Optional, Tuple, Any
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from core.config import get_config
|
||||||
|
from core.data_provider import DataProvider
|
||||||
|
from core.enhanced_orchestrator import PerfectMove, EnhancedTradingOrchestrator
|
||||||
|
from models import CNNModelInterface
|
||||||
|
import models
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class PerfectMoveDataset(Dataset):
|
||||||
|
"""Dataset for training on perfect moves with known outcomes"""
|
||||||
|
|
||||||
|
def __init__(self, perfect_moves: List[PerfectMove], data_provider: DataProvider):
|
||||||
|
"""
|
||||||
|
Initialize dataset from perfect moves
|
||||||
|
|
||||||
|
Args:
|
||||||
|
perfect_moves: List of perfect moves with known outcomes
|
||||||
|
data_provider: Data provider to fetch additional context
|
||||||
|
"""
|
||||||
|
self.perfect_moves = perfect_moves
|
||||||
|
self.data_provider = data_provider
|
||||||
|
self.samples = []
|
||||||
|
self._prepare_samples()
|
||||||
|
|
||||||
|
def _prepare_samples(self):
|
||||||
|
"""Prepare training samples from perfect moves"""
|
||||||
|
logger.info(f"Preparing {len(self.perfect_moves)} perfect move samples")
|
||||||
|
|
||||||
|
for move in self.perfect_moves:
|
||||||
|
try:
|
||||||
|
# Get feature matrix at the time of the decision
|
||||||
|
feature_matrix = self.data_provider.get_feature_matrix(
|
||||||
|
symbol=move.symbol,
|
||||||
|
timeframes=[move.timeframe],
|
||||||
|
window_size=20,
|
||||||
|
end_time=move.timestamp
|
||||||
|
)
|
||||||
|
|
||||||
|
if feature_matrix is not None:
|
||||||
|
# Convert optimal action to label
|
||||||
|
action_to_label = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
|
||||||
|
label = action_to_label.get(move.optimal_action, 1)
|
||||||
|
|
||||||
|
# Create confidence target (what confidence should have been)
|
||||||
|
confidence_target = move.confidence_should_have_been
|
||||||
|
|
||||||
|
sample = {
|
||||||
|
'features': feature_matrix,
|
||||||
|
'action_label': label,
|
||||||
|
'confidence_target': confidence_target,
|
||||||
|
'symbol': move.symbol,
|
||||||
|
'timeframe': move.timeframe,
|
||||||
|
'outcome': move.actual_outcome,
|
||||||
|
'timestamp': move.timestamp
|
||||||
|
}
|
||||||
|
self.samples.append(sample)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error preparing sample for perfect move: {e}")
|
||||||
|
|
||||||
|
logger.info(f"Prepared {len(self.samples)} valid training samples")
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
sample = self.samples[idx]
|
||||||
|
|
||||||
|
# Convert to tensors
|
||||||
|
features = torch.FloatTensor(sample['features'])
|
||||||
|
action_label = torch.LongTensor([sample['action_label']])
|
||||||
|
confidence_target = torch.FloatTensor([sample['confidence_target']])
|
||||||
|
|
||||||
|
return {
|
||||||
|
'features': features,
|
||||||
|
'action_label': action_label,
|
||||||
|
'confidence_target': confidence_target,
|
||||||
|
'metadata': {
|
||||||
|
'symbol': sample['symbol'],
|
||||||
|
'timeframe': sample['timeframe'],
|
||||||
|
'outcome': sample['outcome'],
|
||||||
|
'timestamp': sample['timestamp']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class EnhancedCNNModel(nn.Module, CNNModelInterface):
|
||||||
|
"""Enhanced CNN model with timeframe-specific predictions and confidence scoring"""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
CNNModelInterface.__init__(self, config)
|
||||||
|
|
||||||
|
self.timeframes = config.get('timeframes', ['1h', '4h', '1d'])
|
||||||
|
self.n_features = len(config.get('features', ['open', 'high', 'low', 'close', 'volume']))
|
||||||
|
self.window_size = config.get('window_size', 20)
|
||||||
|
|
||||||
|
# Build the neural network
|
||||||
|
self._build_network()
|
||||||
|
|
||||||
|
# Initialize device
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
self.to(self.device)
|
||||||
|
|
||||||
|
# Training components
|
||||||
|
self.optimizer = optim.Adam(self.parameters(), lr=config.get('learning_rate', 0.001))
|
||||||
|
self.action_criterion = nn.CrossEntropyLoss()
|
||||||
|
self.confidence_criterion = nn.MSELoss()
|
||||||
|
|
||||||
|
logger.info(f"Enhanced CNN model initialized for {len(self.timeframes)} timeframes")
|
||||||
|
|
||||||
|
def _build_network(self):
|
||||||
|
"""Build the CNN architecture"""
|
||||||
|
# Convolutional feature extraction
|
||||||
|
self.conv_layers = nn.Sequential(
|
||||||
|
# First conv block
|
||||||
|
nn.Conv1d(self.n_features, 64, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm1d(64),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
|
||||||
|
# Second conv block
|
||||||
|
nn.Conv1d(64, 128, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm1d(128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
|
||||||
|
# Third conv block
|
||||||
|
nn.Conv1d(128, 256, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm1d(256),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
|
||||||
|
# Global average pooling
|
||||||
|
nn.AdaptiveAvgPool1d(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Timeframe-specific heads
|
||||||
|
self.timeframe_heads = nn.ModuleDict()
|
||||||
|
for timeframe in self.timeframes:
|
||||||
|
self.timeframe_heads[timeframe] = nn.Sequential(
|
||||||
|
nn.Linear(256, 128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(128, 64),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Action prediction heads (one per timeframe)
|
||||||
|
self.action_heads = nn.ModuleDict()
|
||||||
|
for timeframe in self.timeframes:
|
||||||
|
self.action_heads[timeframe] = nn.Linear(64, 3) # BUY, HOLD, SELL
|
||||||
|
|
||||||
|
# Confidence prediction heads (one per timeframe)
|
||||||
|
self.confidence_heads = nn.ModuleDict()
|
||||||
|
for timeframe in self.timeframes:
|
||||||
|
self.confidence_heads[timeframe] = nn.Sequential(
|
||||||
|
nn.Linear(64, 32),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(32, 1),
|
||||||
|
nn.Sigmoid() # Output between 0 and 1
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, timeframe: str = None):
|
||||||
|
"""
|
||||||
|
Forward pass through the network
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor [batch_size, window_size, features]
|
||||||
|
timeframe: Specific timeframe to predict for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
action_probs: Action probabilities
|
||||||
|
confidence: Confidence score
|
||||||
|
"""
|
||||||
|
# Reshape for conv1d: [batch, features, sequence]
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
|
||||||
|
# Extract features
|
||||||
|
features = self.conv_layers(x) # [batch, 256, 1]
|
||||||
|
features = features.squeeze(-1) # [batch, 256]
|
||||||
|
|
||||||
|
if timeframe and timeframe in self.timeframe_heads:
|
||||||
|
# Timeframe-specific prediction
|
||||||
|
tf_features = self.timeframe_heads[timeframe](features)
|
||||||
|
action_logits = self.action_heads[timeframe](tf_features)
|
||||||
|
confidence = self.confidence_heads[timeframe](tf_features)
|
||||||
|
|
||||||
|
action_probs = torch.softmax(action_logits, dim=1)
|
||||||
|
return action_probs, confidence.squeeze(-1)
|
||||||
|
else:
|
||||||
|
# Multi-timeframe prediction (average across timeframes)
|
||||||
|
all_action_probs = []
|
||||||
|
all_confidences = []
|
||||||
|
|
||||||
|
for tf in self.timeframes:
|
||||||
|
tf_features = self.timeframe_heads[tf](features)
|
||||||
|
action_logits = self.action_heads[tf](tf_features)
|
||||||
|
confidence = self.confidence_heads[tf](tf_features)
|
||||||
|
|
||||||
|
action_probs = torch.softmax(action_logits, dim=1)
|
||||||
|
all_action_probs.append(action_probs)
|
||||||
|
all_confidences.append(confidence.squeeze(-1))
|
||||||
|
|
||||||
|
# Average predictions across timeframes
|
||||||
|
avg_action_probs = torch.stack(all_action_probs).mean(dim=0)
|
||||||
|
avg_confidence = torch.stack(all_confidences).mean(dim=0)
|
||||||
|
|
||||||
|
return avg_action_probs, avg_confidence
|
||||||
|
|
||||||
|
def predict(self, features: np.ndarray) -> Tuple[np.ndarray, float]:
|
||||||
|
"""Predict action probabilities and confidence"""
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
x = torch.FloatTensor(features).to(self.device)
|
||||||
|
if len(x.shape) == 2:
|
||||||
|
x = x.unsqueeze(0) # Add batch dimension
|
||||||
|
|
||||||
|
action_probs, confidence = self.forward(x)
|
||||||
|
|
||||||
|
return action_probs[0].cpu().numpy(), confidence[0].cpu().item()
|
||||||
|
|
||||||
|
def predict_timeframe(self, features: np.ndarray, timeframe: str) -> Tuple[np.ndarray, float]:
|
||||||
|
"""Predict for specific timeframe"""
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
x = torch.FloatTensor(features).to(self.device)
|
||||||
|
if len(x.shape) == 2:
|
||||||
|
x = x.unsqueeze(0) # Add batch dimension
|
||||||
|
|
||||||
|
action_probs, confidence = self.forward(x, timeframe)
|
||||||
|
|
||||||
|
return action_probs[0].cpu().numpy(), confidence[0].cpu().item()
|
||||||
|
|
||||||
|
def get_memory_usage(self) -> int:
|
||||||
|
"""Get memory usage in MB"""
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.cuda.memory_allocated(self.device) // (1024 * 1024)
|
||||||
|
else:
|
||||||
|
# Rough estimate for CPU
|
||||||
|
param_count = sum(p.numel() for p in self.parameters())
|
||||||
|
return (param_count * 4) // (1024 * 1024) # 4 bytes per float32
|
||||||
|
|
||||||
|
def train(self, training_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Train the model (placeholder for interface compatibility)"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
class EnhancedCNNTrainer:
|
||||||
|
"""Enhanced CNN trainer using perfect moves and real market outcomes"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
|
||||||
|
"""Initialize the enhanced trainer"""
|
||||||
|
self.config = config or get_config()
|
||||||
|
self.orchestrator = orchestrator
|
||||||
|
self.data_provider = DataProvider(self.config)
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
self.learning_rate = self.config.training.get('learning_rate', 0.001)
|
||||||
|
self.batch_size = self.config.training.get('batch_size', 32)
|
||||||
|
self.epochs = self.config.training.get('epochs', 100)
|
||||||
|
self.patience = self.config.training.get('early_stopping_patience', 10)
|
||||||
|
|
||||||
|
# Model
|
||||||
|
self.model = EnhancedCNNModel(self.config.cnn)
|
||||||
|
|
||||||
|
# Training history
|
||||||
|
self.training_history = {
|
||||||
|
'train_loss': [],
|
||||||
|
'val_loss': [],
|
||||||
|
'train_accuracy': [],
|
||||||
|
'val_accuracy': [],
|
||||||
|
'confidence_accuracy': []
|
||||||
|
} # Create save directory models_path = self.config.cnn.get('model_dir', "models/enhanced_cnn") self.save_dir = Path(models_path) self.save_dir.mkdir(parents=True, exist_ok=True) logger.info("Enhanced CNN trainer initialized")
|
||||||
|
|
||||||
|
def train_on_perfect_moves(self, min_samples: int = 100) -> Dict[str, Any]:
|
||||||
|
"""Train the model on perfect moves from the orchestrator"""
|
||||||
|
if not self.orchestrator:
|
||||||
|
raise ValueError("Orchestrator required for perfect move training")
|
||||||
|
|
||||||
|
# Get perfect moves from orchestrator
|
||||||
|
perfect_moves = []
|
||||||
|
for symbol in self.config.symbols:
|
||||||
|
symbol_moves = self.orchestrator.get_perfect_moves_for_training(symbol=symbol)
|
||||||
|
perfect_moves.extend(symbol_moves)
|
||||||
|
|
||||||
|
if len(perfect_moves) < min_samples:
|
||||||
|
logger.warning(f"Not enough perfect moves for training: {len(perfect_moves)} < {min_samples}")
|
||||||
|
return {'error': 'insufficient_data', 'samples': len(perfect_moves)}
|
||||||
|
|
||||||
|
logger.info(f"Training on {len(perfect_moves)} perfect moves")
|
||||||
|
|
||||||
|
# Create dataset
|
||||||
|
dataset = PerfectMoveDataset(perfect_moves, self.data_provider)
|
||||||
|
|
||||||
|
# Split into train/validation
|
||||||
|
train_size = int(0.8 * len(dataset))
|
||||||
|
val_size = len(dataset) - train_size
|
||||||
|
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
||||||
|
|
||||||
|
# Create data loaders
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
patience_counter = 0
|
||||||
|
|
||||||
|
for epoch in range(self.epochs):
|
||||||
|
# Training phase
|
||||||
|
train_loss, train_acc = self._train_epoch(train_loader)
|
||||||
|
|
||||||
|
# Validation phase
|
||||||
|
val_loss, val_acc, conf_acc = self._validate_epoch(val_loader)
|
||||||
|
|
||||||
|
# Update history
|
||||||
|
self.training_history['train_loss'].append(train_loss)
|
||||||
|
self.training_history['val_loss'].append(val_loss)
|
||||||
|
self.training_history['train_accuracy'].append(train_acc)
|
||||||
|
self.training_history['val_accuracy'].append(val_acc)
|
||||||
|
self.training_history['confidence_accuracy'].append(conf_acc)
|
||||||
|
|
||||||
|
# Log progress
|
||||||
|
logger.info(f"Epoch {epoch+1}/{self.epochs}: "
|
||||||
|
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
|
||||||
|
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
|
||||||
|
f"Conf Acc: {conf_acc:.4f}")
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if val_loss < best_val_loss:
|
||||||
|
best_val_loss = val_loss
|
||||||
|
patience_counter = 0
|
||||||
|
self._save_model('best_model.pt')
|
||||||
|
else:
|
||||||
|
patience_counter += 1
|
||||||
|
if patience_counter >= self.patience:
|
||||||
|
logger.info(f"Early stopping at epoch {epoch+1}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Save final model
|
||||||
|
self._save_model('final_model.pt')
|
||||||
|
|
||||||
|
# Generate training report
|
||||||
|
return self._generate_training_report()
|
||||||
|
|
||||||
|
def _train_epoch(self, train_loader: DataLoader) -> Tuple[float, float]:
|
||||||
|
"""Train for one epoch"""
|
||||||
|
self.model.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
correct_predictions = 0
|
||||||
|
total_predictions = 0
|
||||||
|
|
||||||
|
for batch in train_loader:
|
||||||
|
features = batch['features'].to(self.model.device)
|
||||||
|
action_labels = batch['action_label'].to(self.model.device).squeeze(-1)
|
||||||
|
confidence_targets = batch['confidence_target'].to(self.model.device).squeeze(-1)
|
||||||
|
|
||||||
|
# Zero gradients
|
||||||
|
self.model.optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
action_probs, confidence_pred = self.model(features)
|
||||||
|
|
||||||
|
# Calculate losses
|
||||||
|
action_loss = self.model.action_criterion(action_probs, action_labels)
|
||||||
|
confidence_loss = self.model.confidence_criterion(confidence_pred, confidence_targets)
|
||||||
|
|
||||||
|
# Combined loss
|
||||||
|
total_loss_batch = action_loss + 0.5 * confidence_loss
|
||||||
|
|
||||||
|
# Backward pass
|
||||||
|
total_loss_batch.backward()
|
||||||
|
self.model.optimizer.step()
|
||||||
|
|
||||||
|
# Track metrics
|
||||||
|
total_loss += total_loss_batch.item()
|
||||||
|
predicted_actions = torch.argmax(action_probs, dim=1)
|
||||||
|
correct_predictions += (predicted_actions == action_labels).sum().item()
|
||||||
|
total_predictions += action_labels.size(0)
|
||||||
|
|
||||||
|
avg_loss = total_loss / len(train_loader)
|
||||||
|
accuracy = correct_predictions / total_predictions
|
||||||
|
|
||||||
|
return avg_loss, accuracy
|
||||||
|
|
||||||
|
def _validate_epoch(self, val_loader: DataLoader) -> Tuple[float, float, float]:
|
||||||
|
"""Validate for one epoch"""
|
||||||
|
self.model.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
correct_predictions = 0
|
||||||
|
total_predictions = 0
|
||||||
|
confidence_errors = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in val_loader:
|
||||||
|
features = batch['features'].to(self.model.device)
|
||||||
|
action_labels = batch['action_label'].to(self.model.device).squeeze(-1)
|
||||||
|
confidence_targets = batch['confidence_target'].to(self.model.device).squeeze(-1)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
action_probs, confidence_pred = self.model(features)
|
||||||
|
|
||||||
|
# Calculate losses
|
||||||
|
action_loss = self.model.action_criterion(action_probs, action_labels)
|
||||||
|
confidence_loss = self.model.confidence_criterion(confidence_pred, confidence_targets)
|
||||||
|
total_loss_batch = action_loss + 0.5 * confidence_loss
|
||||||
|
|
||||||
|
# Track metrics
|
||||||
|
total_loss += total_loss_batch.item()
|
||||||
|
predicted_actions = torch.argmax(action_probs, dim=1)
|
||||||
|
correct_predictions += (predicted_actions == action_labels).sum().item()
|
||||||
|
total_predictions += action_labels.size(0)
|
||||||
|
|
||||||
|
# Track confidence accuracy
|
||||||
|
conf_errors = torch.abs(confidence_pred - confidence_targets)
|
||||||
|
confidence_errors.extend(conf_errors.cpu().numpy())
|
||||||
|
|
||||||
|
avg_loss = total_loss / len(val_loader)
|
||||||
|
accuracy = correct_predictions / total_predictions
|
||||||
|
confidence_accuracy = 1.0 - np.mean(confidence_errors) # 1 - mean absolute error
|
||||||
|
|
||||||
|
return avg_loss, accuracy, confidence_accuracy
|
||||||
|
|
||||||
|
def _save_model(self, filename: str):
|
||||||
|
"""Save the model"""
|
||||||
|
save_path = self.save_dir / filename
|
||||||
|
torch.save({
|
||||||
|
'model_state_dict': self.model.state_dict(),
|
||||||
|
'optimizer_state_dict': self.model.optimizer.state_dict(),
|
||||||
|
'config': self.config.cnn,
|
||||||
|
'training_history': self.training_history
|
||||||
|
}, save_path)
|
||||||
|
logger.info(f"Model saved to {save_path}")
|
||||||
|
|
||||||
|
def load_model(self, filename: str) -> bool:
|
||||||
|
"""Load a saved model"""
|
||||||
|
load_path = self.save_dir / filename
|
||||||
|
if not load_path.exists():
|
||||||
|
logger.error(f"Model file not found: {load_path}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
checkpoint = torch.load(load_path, map_location=self.model.device)
|
||||||
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
self.model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
self.training_history = checkpoint.get('training_history', {})
|
||||||
|
logger.info(f"Model loaded from {load_path}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading model: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _generate_training_report(self) -> Dict[str, Any]:
|
||||||
|
"""Generate comprehensive training report"""
|
||||||
|
if not self.training_history['train_loss']:
|
||||||
|
return {'error': 'no_training_data'}
|
||||||
|
|
||||||
|
# Calculate final metrics
|
||||||
|
final_train_loss = self.training_history['train_loss'][-1]
|
||||||
|
final_val_loss = self.training_history['val_loss'][-1]
|
||||||
|
final_train_acc = self.training_history['train_accuracy'][-1]
|
||||||
|
final_val_acc = self.training_history['val_accuracy'][-1]
|
||||||
|
final_conf_acc = self.training_history['confidence_accuracy'][-1]
|
||||||
|
|
||||||
|
# Best metrics
|
||||||
|
best_val_loss = min(self.training_history['val_loss'])
|
||||||
|
best_val_acc = max(self.training_history['val_accuracy'])
|
||||||
|
best_conf_acc = max(self.training_history['confidence_accuracy'])
|
||||||
|
|
||||||
|
report = {
|
||||||
|
'training_completed': True,
|
||||||
|
'epochs_trained': len(self.training_history['train_loss']),
|
||||||
|
'final_metrics': {
|
||||||
|
'train_loss': final_train_loss,
|
||||||
|
'val_loss': final_val_loss,
|
||||||
|
'train_accuracy': final_train_acc,
|
||||||
|
'val_accuracy': final_val_acc,
|
||||||
|
'confidence_accuracy': final_conf_acc
|
||||||
|
},
|
||||||
|
'best_metrics': {
|
||||||
|
'val_loss': best_val_loss,
|
||||||
|
'val_accuracy': best_val_acc,
|
||||||
|
'confidence_accuracy': best_conf_acc
|
||||||
|
},
|
||||||
|
'model_info': {
|
||||||
|
'timeframes': self.model.timeframes,
|
||||||
|
'memory_usage_mb': self.model.get_memory_usage(),
|
||||||
|
'device': str(self.model.device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generate plots
|
||||||
|
self._plot_training_history()
|
||||||
|
|
||||||
|
logger.info("Training completed successfully")
|
||||||
|
logger.info(f"Final validation accuracy: {final_val_acc:.4f}")
|
||||||
|
logger.info(f"Final confidence accuracy: {final_conf_acc:.4f}")
|
||||||
|
|
||||||
|
return report
|
||||||
|
|
||||||
|
def _plot_training_history(self):
|
||||||
|
"""Plot training history"""
|
||||||
|
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||||
|
fig.suptitle('Enhanced CNN Training History')
|
||||||
|
|
||||||
|
# Loss plot
|
||||||
|
axes[0, 0].plot(self.training_history['train_loss'], label='Train Loss')
|
||||||
|
axes[0, 0].plot(self.training_history['val_loss'], label='Val Loss')
|
||||||
|
axes[0, 0].set_title('Loss')
|
||||||
|
axes[0, 0].set_xlabel('Epoch')
|
||||||
|
axes[0, 0].set_ylabel('Loss')
|
||||||
|
axes[0, 0].legend()
|
||||||
|
|
||||||
|
# Accuracy plot
|
||||||
|
axes[0, 1].plot(self.training_history['train_accuracy'], label='Train Accuracy')
|
||||||
|
axes[0, 1].plot(self.training_history['val_accuracy'], label='Val Accuracy')
|
||||||
|
axes[0, 1].set_title('Action Accuracy')
|
||||||
|
axes[0, 1].set_xlabel('Epoch')
|
||||||
|
axes[0, 1].set_ylabel('Accuracy')
|
||||||
|
axes[0, 1].legend()
|
||||||
|
|
||||||
|
# Confidence accuracy plot
|
||||||
|
axes[1, 0].plot(self.training_history['confidence_accuracy'], label='Confidence Accuracy')
|
||||||
|
axes[1, 0].set_title('Confidence Prediction Accuracy')
|
||||||
|
axes[1, 0].set_xlabel('Epoch')
|
||||||
|
axes[1, 0].set_ylabel('Accuracy')
|
||||||
|
axes[1, 0].legend()
|
||||||
|
|
||||||
|
# Learning curves comparison
|
||||||
|
axes[1, 1].plot(self.training_history['val_loss'], label='Validation Loss')
|
||||||
|
axes[1, 1].plot(self.training_history['confidence_accuracy'], label='Confidence Accuracy')
|
||||||
|
axes[1, 1].set_title('Model Performance Overview')
|
||||||
|
axes[1, 1].set_xlabel('Epoch')
|
||||||
|
axes[1, 1].legend()
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(self.save_dir / 'training_history.png', dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
logger.info(f"Training plots saved to {self.save_dir / 'training_history.png'}")
|
||||||
|
|
||||||
|
def get_model(self) -> EnhancedCNNModel:
|
||||||
|
"""Get the trained model"""
|
||||||
|
return self.model
|
625
training/enhanced_rl_trainer.py
Normal file
625
training/enhanced_rl_trainer.py
Normal file
@ -0,0 +1,625 @@
|
|||||||
|
"""
|
||||||
|
Enhanced RL Trainer with Market Environment Adaptation
|
||||||
|
|
||||||
|
This trainer implements:
|
||||||
|
1. Continuous learning from orchestrator action evaluations
|
||||||
|
2. Environment adaptation based on market regime changes
|
||||||
|
3. Multi-symbol coordinated RL training
|
||||||
|
4. Experience replay with prioritized sampling
|
||||||
|
5. Dynamic reward shaping based on market conditions
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncioimport asyncioimport loggingimport numpy as npimport torchimport torch.nn as nnimport torch.optim as optimfrom collections import deque, namedtupleimport randomfrom datetime import datetime, timedeltafrom typing import Dict, List, Optional, Tuple, Anyimport matplotlib.pyplot as pltfrom pathlib import Path
|
||||||
|
|
||||||
|
from core.config import get_config
|
||||||
|
from core.data_provider import DataProvider
|
||||||
|
from core.enhanced_orchestrator import EnhancedTradingOrchestrator, MarketState, TradingAction
|
||||||
|
from models import RLAgentInterface
|
||||||
|
import models
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Experience tuple for replay buffer
|
||||||
|
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done', 'priority'])
|
||||||
|
|
||||||
|
class PrioritizedReplayBuffer:
|
||||||
|
"""Prioritized experience replay buffer for RL training"""
|
||||||
|
|
||||||
|
def __init__(self, capacity: int = 10000, alpha: float = 0.6):
|
||||||
|
"""
|
||||||
|
Initialize prioritized replay buffer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
capacity: Maximum number of experiences to store
|
||||||
|
alpha: Priority exponent (0 = uniform, 1 = fully prioritized)
|
||||||
|
"""
|
||||||
|
self.capacity = capacity
|
||||||
|
self.alpha = alpha
|
||||||
|
self.buffer = []
|
||||||
|
self.priorities = np.zeros(capacity, dtype=np.float32)
|
||||||
|
self.position = 0
|
||||||
|
self.size = 0
|
||||||
|
|
||||||
|
def add(self, experience: Experience):
|
||||||
|
"""Add experience to buffer with priority"""
|
||||||
|
max_priority = self.priorities[:self.size].max() if self.size > 0 else 1.0
|
||||||
|
|
||||||
|
if self.size < self.capacity:
|
||||||
|
self.buffer.append(experience)
|
||||||
|
self.size += 1
|
||||||
|
else:
|
||||||
|
self.buffer[self.position] = experience
|
||||||
|
|
||||||
|
self.priorities[self.position] = max_priority
|
||||||
|
self.position = (self.position + 1) % self.capacity
|
||||||
|
|
||||||
|
def sample(self, batch_size: int, beta: float = 0.4) -> Tuple[List[Experience], np.ndarray, np.ndarray]:
|
||||||
|
"""Sample batch with prioritized sampling"""
|
||||||
|
if self.size == 0:
|
||||||
|
return [], np.array([]), np.array([])
|
||||||
|
|
||||||
|
# Calculate sampling probabilities
|
||||||
|
priorities = self.priorities[:self.size] ** self.alpha
|
||||||
|
probabilities = priorities / priorities.sum()
|
||||||
|
|
||||||
|
# Sample indices
|
||||||
|
indices = np.random.choice(self.size, batch_size, p=probabilities)
|
||||||
|
experiences = [self.buffer[i] for i in indices]
|
||||||
|
|
||||||
|
# Calculate importance sampling weights
|
||||||
|
weights = (self.size * probabilities[indices]) ** (-beta)
|
||||||
|
weights = weights / weights.max() # Normalize
|
||||||
|
|
||||||
|
return experiences, indices, weights
|
||||||
|
|
||||||
|
def update_priorities(self, indices: np.ndarray, priorities: np.ndarray):
|
||||||
|
"""Update priorities for sampled experiences"""
|
||||||
|
for idx, priority in zip(indices, priorities):
|
||||||
|
self.priorities[idx] = priority + 1e-6 # Small epsilon to avoid zero priority
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.size
|
||||||
|
|
||||||
|
class EnhancedDQNAgent(nn.Module, RLAgentInterface):
|
||||||
|
"""Enhanced DQN agent with market environment adaptation"""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
RLAgentInterface.__init__(self, config)
|
||||||
|
|
||||||
|
# Network architecture
|
||||||
|
self.state_size = config.get('state_size', 100)
|
||||||
|
self.action_space = config.get('action_space', 3)
|
||||||
|
self.hidden_size = config.get('hidden_size', 256)
|
||||||
|
|
||||||
|
# Build networks
|
||||||
|
self._build_networks()
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
self.learning_rate = config.get('learning_rate', 0.0001)
|
||||||
|
self.gamma = config.get('gamma', 0.99)
|
||||||
|
self.epsilon = config.get('epsilon', 1.0)
|
||||||
|
self.epsilon_decay = config.get('epsilon_decay', 0.995)
|
||||||
|
self.epsilon_min = config.get('epsilon_min', 0.01)
|
||||||
|
self.target_update_freq = config.get('target_update_freq', 1000)
|
||||||
|
|
||||||
|
# Initialize device and optimizer
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
self.to(self.device)
|
||||||
|
self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||||
|
|
||||||
|
# Experience replay
|
||||||
|
self.replay_buffer = PrioritizedReplayBuffer(config.get('buffer_size', 10000))
|
||||||
|
self.batch_size = config.get('batch_size', 64)
|
||||||
|
|
||||||
|
# Market adaptation
|
||||||
|
self.market_regime_weights = {
|
||||||
|
'trending': 1.2, # Higher confidence in trending markets
|
||||||
|
'ranging': 0.8, # Lower confidence in ranging markets
|
||||||
|
'volatile': 0.6 # Much lower confidence in volatile markets
|
||||||
|
}
|
||||||
|
|
||||||
|
# Training statistics
|
||||||
|
self.training_steps = 0
|
||||||
|
self.losses = []
|
||||||
|
self.rewards = []
|
||||||
|
self.epsilon_history = []
|
||||||
|
|
||||||
|
logger.info(f"Enhanced DQN agent initialized with state size: {self.state_size}")
|
||||||
|
|
||||||
|
def _build_networks(self):
|
||||||
|
"""Build main and target networks"""
|
||||||
|
# Main network
|
||||||
|
self.main_network = nn.Sequential(
|
||||||
|
nn.Linear(self.state_size, self.hidden_size),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(self.hidden_size, self.hidden_size),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(self.hidden_size, 128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.2)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dueling network heads
|
||||||
|
self.value_head = nn.Linear(128, 1)
|
||||||
|
self.advantage_head = nn.Linear(128, self.action_space)
|
||||||
|
|
||||||
|
# Target network (copy of main network)
|
||||||
|
self.target_network = nn.Sequential(
|
||||||
|
nn.Linear(self.state_size, self.hidden_size),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(self.hidden_size, self.hidden_size),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
nn.Linear(self.hidden_size, 128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.2)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.target_value_head = nn.Linear(128, 1)
|
||||||
|
self.target_advantage_head = nn.Linear(128, self.action_space)
|
||||||
|
|
||||||
|
# Initialize target network with same weights
|
||||||
|
self._update_target_network()
|
||||||
|
|
||||||
|
def forward(self, state, target: bool = False):
|
||||||
|
"""Forward pass through the network"""
|
||||||
|
if target:
|
||||||
|
features = self.target_network(state)
|
||||||
|
value = self.target_value_head(features)
|
||||||
|
advantage = self.target_advantage_head(features)
|
||||||
|
else:
|
||||||
|
features = self.main_network(state)
|
||||||
|
value = self.value_head(features)
|
||||||
|
advantage = self.advantage_head(features)
|
||||||
|
|
||||||
|
# Dueling architecture: Q(s,a) = V(s) + A(s,a) - mean(A(s,a))
|
||||||
|
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
|
||||||
|
|
||||||
|
return q_values
|
||||||
|
|
||||||
|
def act(self, state: np.ndarray) -> int:
|
||||||
|
"""Choose action using epsilon-greedy policy"""
|
||||||
|
if random.random() < self.epsilon:
|
||||||
|
return random.randint(0, self.action_space - 1)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||||
|
q_values = self.forward(state_tensor)
|
||||||
|
return q_values.argmax().item()
|
||||||
|
|
||||||
|
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]:
|
||||||
|
"""Choose action with confidence score adapted to market regime"""
|
||||||
|
with torch.no_grad():
|
||||||
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||||
|
q_values = self.forward(state_tensor)
|
||||||
|
|
||||||
|
# Convert Q-values to probabilities
|
||||||
|
action_probs = torch.softmax(q_values, dim=1)
|
||||||
|
action = q_values.argmax().item()
|
||||||
|
base_confidence = action_probs[0, action].item()
|
||||||
|
|
||||||
|
# Adapt confidence based on market regime
|
||||||
|
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
||||||
|
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
||||||
|
|
||||||
|
return action, adapted_confidence
|
||||||
|
|
||||||
|
def remember(self, state: np.ndarray, action: int, reward: float,
|
||||||
|
next_state: np.ndarray, done: bool):
|
||||||
|
"""Store experience in replay buffer"""
|
||||||
|
# Calculate TD error for priority
|
||||||
|
with torch.no_grad():
|
||||||
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||||
|
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
current_q = self.forward(state_tensor)[0, action]
|
||||||
|
next_q = self.forward(next_state_tensor, target=True).max(1)[0]
|
||||||
|
target_q = reward + (self.gamma * next_q * (1 - done))
|
||||||
|
|
||||||
|
td_error = abs(current_q.item() - target_q.item())
|
||||||
|
|
||||||
|
experience = Experience(state, action, reward, next_state, done, td_error)
|
||||||
|
self.replay_buffer.add(experience)
|
||||||
|
|
||||||
|
def replay(self) -> Optional[float]:
|
||||||
|
"""Train the network on a batch of experiences"""
|
||||||
|
if len(self.replay_buffer) < self.batch_size:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Sample batch
|
||||||
|
experiences, indices, weights = self.replay_buffer.sample(self.batch_size)
|
||||||
|
|
||||||
|
if not experiences:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert to tensors
|
||||||
|
states = torch.FloatTensor([e.state for e in experiences]).to(self.device)
|
||||||
|
actions = torch.LongTensor([e.action for e in experiences]).to(self.device)
|
||||||
|
rewards = torch.FloatTensor([e.reward for e in experiences]).to(self.device)
|
||||||
|
next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device)
|
||||||
|
dones = torch.BoolTensor([e.done for e in experiences]).to(self.device)
|
||||||
|
weights_tensor = torch.FloatTensor(weights).to(self.device)
|
||||||
|
|
||||||
|
# Current Q-values
|
||||||
|
current_q_values = self.forward(states).gather(1, actions.unsqueeze(1))
|
||||||
|
|
||||||
|
# Target Q-values (Double DQN)
|
||||||
|
with torch.no_grad():
|
||||||
|
# Use main network to select actions
|
||||||
|
next_actions = self.forward(next_states).argmax(1)
|
||||||
|
# Use target network to evaluate actions
|
||||||
|
next_q_values = self.forward(next_states, target=True).gather(1, next_actions.unsqueeze(1))
|
||||||
|
target_q_values = rewards.unsqueeze(1) + (self.gamma * next_q_values * ~dones.unsqueeze(1))
|
||||||
|
|
||||||
|
# Calculate weighted loss
|
||||||
|
td_errors = target_q_values - current_q_values
|
||||||
|
loss = (weights_tensor * (td_errors ** 2)).mean()
|
||||||
|
|
||||||
|
# Optimize
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
# Update priorities
|
||||||
|
new_priorities = torch.abs(td_errors).detach().cpu().numpy().flatten()
|
||||||
|
self.replay_buffer.update_priorities(indices, new_priorities)
|
||||||
|
|
||||||
|
# Update target network
|
||||||
|
self.training_steps += 1
|
||||||
|
if self.training_steps % self.target_update_freq == 0:
|
||||||
|
self._update_target_network()
|
||||||
|
|
||||||
|
# Decay epsilon
|
||||||
|
if self.epsilon > self.epsilon_min:
|
||||||
|
self.epsilon *= self.epsilon_decay
|
||||||
|
|
||||||
|
# Track statistics
|
||||||
|
self.losses.append(loss.item())
|
||||||
|
self.epsilon_history.append(self.epsilon)
|
||||||
|
|
||||||
|
return loss.item()
|
||||||
|
|
||||||
|
def _update_target_network(self):
|
||||||
|
"""Update target network with main network weights"""
|
||||||
|
self.target_network.load_state_dict(self.main_network.state_dict())
|
||||||
|
self.target_value_head.load_state_dict(self.value_head.state_dict())
|
||||||
|
self.target_advantage_head.load_state_dict(self.advantage_head.state_dict())
|
||||||
|
|
||||||
|
def predict(self, features: np.ndarray) -> Tuple[np.ndarray, float]: """Predict action probabilities and confidence (required by ModelInterface)""" action, confidence = self.act_with_confidence(features) # Convert action to probabilities action_probs = np.zeros(self.action_space) action_probs[action] = 1.0 return action_probs, confidence def get_memory_usage(self) -> int: """Get memory usage in MB""" if torch.cuda.is_available(): return torch.cuda.memory_allocated(self.device) // (1024 * 1024) else: param_count = sum(p.numel() for p in self.parameters()) buffer_size = len(self.replay_buffer) * self.state_size * 4 # Rough estimate return (param_count * 4 + buffer_size) // (1024 * 1024)
|
||||||
|
|
||||||
|
class EnhancedRLTrainer:
|
||||||
|
"""Enhanced RL trainer with continuous learning from market feedback"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
|
||||||
|
"""Initialize the enhanced RL trainer"""
|
||||||
|
self.config = config or get_config()
|
||||||
|
self.orchestrator = orchestrator
|
||||||
|
self.data_provider = DataProvider(self.config)
|
||||||
|
|
||||||
|
# Create RL agents for each symbol
|
||||||
|
self.agents = {}
|
||||||
|
for symbol in self.config.symbols:
|
||||||
|
agent_config = self.config.rl.copy()
|
||||||
|
agent_config['name'] = f'RL_{symbol}'
|
||||||
|
self.agents[symbol] = EnhancedDQNAgent(agent_config)
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
|
self.training_interval = 3600 # Train every hour
|
||||||
|
self.evaluation_window = 24 * 3600 # Evaluate actions after 24 hours
|
||||||
|
self.min_experiences = 100 # Minimum experiences before training
|
||||||
|
|
||||||
|
# Performance tracking
|
||||||
|
self.performance_history = {symbol: [] for symbol in self.config.symbols}
|
||||||
|
self.training_metrics = {
|
||||||
|
'total_episodes': 0,
|
||||||
|
'total_rewards': {symbol: [] for symbol in self.config.symbols},
|
||||||
|
'losses': {symbol: [] for symbol in self.config.symbols},
|
||||||
|
'epsilon_values': {symbol: [] for symbol in self.config.symbols}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create save directory models_path = self.config.rl.get('model_dir', "models/enhanced_rl") self.save_dir = Path(models_path) self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
logger.info(f"Enhanced RL trainer initialized for symbols: {self.config.symbols}")
|
||||||
|
|
||||||
|
async def continuous_learning_loop(self):
|
||||||
|
"""Main continuous learning loop"""
|
||||||
|
logger.info("Starting continuous RL learning loop")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Train agents with recent experiences
|
||||||
|
await self._train_all_agents()
|
||||||
|
|
||||||
|
# Evaluate recent actions
|
||||||
|
if self.orchestrator:
|
||||||
|
await self.orchestrator.evaluate_actions_with_rl()
|
||||||
|
|
||||||
|
# Adapt to market regime changes
|
||||||
|
await self._adapt_to_market_changes()
|
||||||
|
|
||||||
|
# Update performance metrics
|
||||||
|
self._update_performance_metrics()
|
||||||
|
|
||||||
|
# Save models periodically
|
||||||
|
if self.training_metrics['total_episodes'] % 100 == 0:
|
||||||
|
self._save_all_models()
|
||||||
|
|
||||||
|
# Wait before next training cycle
|
||||||
|
await asyncio.sleep(self.training_interval)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in continuous learning loop: {e}")
|
||||||
|
await asyncio.sleep(60) # Wait 1 minute on error
|
||||||
|
|
||||||
|
async def _train_all_agents(self):
|
||||||
|
"""Train all RL agents with their experiences"""
|
||||||
|
for symbol, agent in self.agents.items():
|
||||||
|
try:
|
||||||
|
if len(agent.replay_buffer) >= self.min_experiences:
|
||||||
|
# Train for multiple steps
|
||||||
|
losses = []
|
||||||
|
for _ in range(10): # Train 10 steps per cycle
|
||||||
|
loss = agent.replay()
|
||||||
|
if loss is not None:
|
||||||
|
losses.append(loss)
|
||||||
|
|
||||||
|
if losses:
|
||||||
|
avg_loss = np.mean(losses)
|
||||||
|
self.training_metrics['losses'][symbol].append(avg_loss)
|
||||||
|
self.training_metrics['epsilon_values'][symbol].append(agent.epsilon)
|
||||||
|
|
||||||
|
logger.info(f"Trained {symbol} RL agent: Loss={avg_loss:.4f}, Epsilon={agent.epsilon:.4f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error training {symbol} agent: {e}")
|
||||||
|
|
||||||
|
async def _adapt_to_market_changes(self):
|
||||||
|
"""Adapt agents to market regime changes"""
|
||||||
|
if not self.orchestrator:
|
||||||
|
return
|
||||||
|
|
||||||
|
for symbol in self.config.symbols:
|
||||||
|
try:
|
||||||
|
# Get recent market states
|
||||||
|
recent_states = list(self.orchestrator.market_states[symbol])[-10:] # Last 10 states
|
||||||
|
|
||||||
|
if len(recent_states) < 5:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Analyze regime stability
|
||||||
|
regimes = [state.market_regime for state in recent_states]
|
||||||
|
regime_stability = len(set(regimes)) / len(regimes) # Lower = more stable
|
||||||
|
|
||||||
|
# Adjust learning parameters based on stability
|
||||||
|
agent = self.agents[symbol]
|
||||||
|
if regime_stability < 0.3: # Stable regime
|
||||||
|
agent.epsilon *= 0.99 # Faster epsilon decay
|
||||||
|
elif regime_stability > 0.7: # Unstable regime
|
||||||
|
agent.epsilon = min(agent.epsilon * 1.01, 0.5) # Increase exploration
|
||||||
|
|
||||||
|
logger.debug(f"{symbol} regime stability: {regime_stability:.3f}, epsilon: {agent.epsilon:.3f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adapting {symbol} to market changes: {e}")
|
||||||
|
|
||||||
|
def add_trading_experience(self, symbol: str, action: TradingAction,
|
||||||
|
initial_state: MarketState, final_state: MarketState,
|
||||||
|
reward: float):
|
||||||
|
"""Add trading experience to the appropriate agent"""
|
||||||
|
if symbol not in self.agents:
|
||||||
|
logger.warning(f"No agent for symbol {symbol}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert market states to RL state vectors
|
||||||
|
initial_rl_state = self._market_state_to_rl_state(initial_state)
|
||||||
|
final_rl_state = self._market_state_to_rl_state(final_state)
|
||||||
|
|
||||||
|
# Convert action to RL action index
|
||||||
|
action_mapping = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
|
||||||
|
action_idx = action_mapping.get(action.action, 1)
|
||||||
|
|
||||||
|
# Store experience
|
||||||
|
agent = self.agents[symbol]
|
||||||
|
agent.remember(
|
||||||
|
state=initial_rl_state,
|
||||||
|
action=action_idx,
|
||||||
|
reward=reward,
|
||||||
|
next_state=final_rl_state,
|
||||||
|
done=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track reward
|
||||||
|
self.training_metrics['total_rewards'][symbol].append(reward)
|
||||||
|
|
||||||
|
logger.debug(f"Added experience for {symbol}: action={action.action}, reward={reward:.4f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding experience for {symbol}: {e}")
|
||||||
|
|
||||||
|
def _market_state_to_rl_state(self, market_state: MarketState) -> np.ndarray:
|
||||||
|
"""Convert market state to RL state vector"""
|
||||||
|
if hasattr(self.orchestrator, '_market_state_to_rl_state'):
|
||||||
|
return self.orchestrator._market_state_to_rl_state(market_state)
|
||||||
|
|
||||||
|
# Fallback implementation
|
||||||
|
state_components = [
|
||||||
|
market_state.volatility,
|
||||||
|
market_state.volume,
|
||||||
|
market_state.trend_strength
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add price features
|
||||||
|
for timeframe in sorted(market_state.prices.keys()):
|
||||||
|
state_components.append(market_state.prices[timeframe])
|
||||||
|
|
||||||
|
# Pad or truncate to expected state size
|
||||||
|
expected_size = self.config.rl.get('state_size', 100)
|
||||||
|
if len(state_components) < expected_size:
|
||||||
|
state_components.extend([0.0] * (expected_size - len(state_components)))
|
||||||
|
else:
|
||||||
|
state_components = state_components[:expected_size]
|
||||||
|
|
||||||
|
return np.array(state_components, dtype=np.float32)
|
||||||
|
|
||||||
|
def _update_performance_metrics(self):
|
||||||
|
"""Update performance tracking metrics"""
|
||||||
|
self.training_metrics['total_episodes'] += 1
|
||||||
|
|
||||||
|
# Calculate recent performance for each agent
|
||||||
|
for symbol, agent in self.agents.items():
|
||||||
|
recent_rewards = self.training_metrics['total_rewards'][symbol][-100:] # Last 100 rewards
|
||||||
|
if recent_rewards:
|
||||||
|
avg_reward = np.mean(recent_rewards)
|
||||||
|
self.performance_history[symbol].append({
|
||||||
|
'timestamp': datetime.now(),
|
||||||
|
'avg_reward': avg_reward,
|
||||||
|
'epsilon': agent.epsilon,
|
||||||
|
'experiences': len(agent.replay_buffer)
|
||||||
|
})
|
||||||
|
|
||||||
|
def _save_all_models(self):
|
||||||
|
"""Save all RL models"""
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
|
||||||
|
for symbol, agent in self.agents.items():
|
||||||
|
filename = f"rl_agent_{symbol}_{timestamp}.pt"
|
||||||
|
filepath = self.save_dir / filename
|
||||||
|
|
||||||
|
torch.save({
|
||||||
|
'model_state_dict': agent.state_dict(),
|
||||||
|
'optimizer_state_dict': agent.optimizer.state_dict(),
|
||||||
|
'config': self.config.rl,
|
||||||
|
'training_metrics': self.training_metrics,
|
||||||
|
'symbol': symbol,
|
||||||
|
'epsilon': agent.epsilon,
|
||||||
|
'training_steps': agent.training_steps
|
||||||
|
}, filepath)
|
||||||
|
|
||||||
|
logger.info(f"Saved {symbol} RL agent to {filepath}")
|
||||||
|
|
||||||
|
def load_models(self, timestamp: str = None):
|
||||||
|
"""Load RL models from files"""
|
||||||
|
if timestamp is None:
|
||||||
|
# Find most recent models
|
||||||
|
model_files = list(self.save_dir.glob("rl_agent_*.pt"))
|
||||||
|
if not model_files:
|
||||||
|
logger.warning("No saved RL models found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Group by timestamp and get most recent
|
||||||
|
timestamps = set(f.stem.split('_')[-2] + '_' + f.stem.split('_')[-1] for f in model_files)
|
||||||
|
timestamp = max(timestamps)
|
||||||
|
|
||||||
|
loaded_count = 0
|
||||||
|
for symbol in self.config.symbols:
|
||||||
|
filename = f"rl_agent_{symbol}_{timestamp}.pt"
|
||||||
|
filepath = self.save_dir / filename
|
||||||
|
|
||||||
|
if filepath.exists():
|
||||||
|
try:
|
||||||
|
checkpoint = torch.load(filepath, map_location=self.agents[symbol].device)
|
||||||
|
self.agents[symbol].load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
self.agents[symbol].optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
self.agents[symbol].epsilon = checkpoint.get('epsilon', 0.1)
|
||||||
|
self.agents[symbol].training_steps = checkpoint.get('training_steps', 0)
|
||||||
|
|
||||||
|
logger.info(f"Loaded {symbol} RL agent from {filepath}")
|
||||||
|
loaded_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading {symbol} RL agent: {e}")
|
||||||
|
|
||||||
|
return loaded_count > 0
|
||||||
|
|
||||||
|
def get_performance_report(self) -> Dict[str, Any]:
|
||||||
|
"""Generate performance report for all agents"""
|
||||||
|
report = {
|
||||||
|
'total_episodes': self.training_metrics['total_episodes'],
|
||||||
|
'agents': {}
|
||||||
|
}
|
||||||
|
|
||||||
|
for symbol, agent in self.agents.items():
|
||||||
|
recent_rewards = self.training_metrics['total_rewards'][symbol][-100:]
|
||||||
|
recent_losses = self.training_metrics['losses'][symbol][-10:]
|
||||||
|
|
||||||
|
agent_report = {
|
||||||
|
'symbol': symbol,
|
||||||
|
'epsilon': agent.epsilon,
|
||||||
|
'training_steps': agent.training_steps,
|
||||||
|
'experiences_stored': len(agent.replay_buffer),
|
||||||
|
'memory_usage_mb': agent.get_memory_usage(),
|
||||||
|
'avg_recent_reward': np.mean(recent_rewards) if recent_rewards else 0.0,
|
||||||
|
'avg_recent_loss': np.mean(recent_losses) if recent_losses else 0.0,
|
||||||
|
'total_rewards': len(self.training_metrics['total_rewards'][symbol])
|
||||||
|
}
|
||||||
|
|
||||||
|
report['agents'][symbol] = agent_report
|
||||||
|
|
||||||
|
return report
|
||||||
|
|
||||||
|
def plot_training_metrics(self):
|
||||||
|
"""Plot training metrics for all agents"""
|
||||||
|
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
|
||||||
|
fig.suptitle('Enhanced RL Training Metrics')
|
||||||
|
|
||||||
|
symbols = list(self.agents.keys())
|
||||||
|
colors = ['blue', 'red', 'green', 'orange'][:len(symbols)]
|
||||||
|
|
||||||
|
# Rewards plot
|
||||||
|
for i, symbol in enumerate(symbols):
|
||||||
|
rewards = self.training_metrics['total_rewards'][symbol]
|
||||||
|
if rewards:
|
||||||
|
# Moving average of rewards
|
||||||
|
window = min(100, len(rewards))
|
||||||
|
if len(rewards) >= window:
|
||||||
|
moving_avg = np.convolve(rewards, np.ones(window)/window, mode='valid')
|
||||||
|
axes[0, 0].plot(moving_avg, label=f'{symbol}', color=colors[i])
|
||||||
|
|
||||||
|
axes[0, 0].set_title('Average Rewards (Moving Average)')
|
||||||
|
axes[0, 0].set_xlabel('Episodes')
|
||||||
|
axes[0, 0].set_ylabel('Reward')
|
||||||
|
axes[0, 0].legend()
|
||||||
|
|
||||||
|
# Losses plot
|
||||||
|
for i, symbol in enumerate(symbols):
|
||||||
|
losses = self.training_metrics['losses'][symbol]
|
||||||
|
if losses:
|
||||||
|
axes[0, 1].plot(losses, label=f'{symbol}', color=colors[i])
|
||||||
|
|
||||||
|
axes[0, 1].set_title('Training Losses')
|
||||||
|
axes[0, 1].set_xlabel('Training Steps')
|
||||||
|
axes[0, 1].set_ylabel('Loss')
|
||||||
|
axes[0, 1].legend()
|
||||||
|
|
||||||
|
# Epsilon values
|
||||||
|
for i, symbol in enumerate(symbols):
|
||||||
|
epsilon_values = self.training_metrics['epsilon_values'][symbol]
|
||||||
|
if epsilon_values:
|
||||||
|
axes[1, 0].plot(epsilon_values, label=f'{symbol}', color=colors[i])
|
||||||
|
|
||||||
|
axes[1, 0].set_title('Exploration Rate (Epsilon)')
|
||||||
|
axes[1, 0].set_xlabel('Training Steps')
|
||||||
|
axes[1, 0].set_ylabel('Epsilon')
|
||||||
|
axes[1, 0].legend()
|
||||||
|
|
||||||
|
# Experience buffer sizes
|
||||||
|
buffer_sizes = [len(agent.replay_buffer) for agent in self.agents.values()]
|
||||||
|
axes[1, 1].bar(symbols, buffer_sizes, color=colors[:len(symbols)])
|
||||||
|
axes[1, 1].set_title('Experience Buffer Sizes')
|
||||||
|
axes[1, 1].set_ylabel('Number of Experiences')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(self.save_dir / 'rl_training_metrics.png', dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
logger.info(f"RL training plots saved to {self.save_dir / 'rl_training_metrics.png'}")
|
||||||
|
|
||||||
|
def get_agents(self) -> Dict[str, EnhancedDQNAgent]:
|
||||||
|
"""Get all RL agents"""
|
||||||
|
return self.agents
|
Loading…
x
Reference in New Issue
Block a user