new overhaul

This commit is contained in:
Dobromir Popov 2025-05-24 11:00:40 +03:00
parent b5ad023b16
commit 2f50ed920f
9 changed files with 2998 additions and 36 deletions

View File

@ -0,0 +1,377 @@
# Enhanced Multi-Modal Trading Architecture Guide
## Overview
This document describes the enhanced multi-modal trading system that implements sophisticated decision-making through coordinated CNN and RL modules. The system is designed to handle multi-timeframe analysis across multiple symbols (ETH, BTC) with continuous learning capabilities.
## Architecture Components
### 1. Enhanced Trading Orchestrator (`core/enhanced_orchestrator.py`)
The heart of the system that coordinates all components:
**Key Features:**
- **Multi-Symbol Coordination**: Makes decisions across ETH and BTC considering correlations
- **Timeframe Integration**: Combines predictions from multiple timeframes (1m, 5m, 15m, 1h, 4h, 1d)
- **Perfect Move Marking**: Identifies and marks optimal trading decisions for CNN training
- **RL Evaluation Loop**: Evaluates trading outcomes to train RL agents
**Data Structures:**
```python
@dataclass
class TimeframePrediction:
timeframe: str
action: str # 'BUY', 'SELL', 'HOLD'
confidence: float # 0.0 to 1.0
probabilities: Dict[str, float]
timestamp: datetime
market_features: Dict[str, float]
@dataclass
class TradingAction:
symbol: str
action: str
quantity: float
confidence: float
price: float
timestamp: datetime
reasoning: Dict[str, Any]
timeframe_analysis: List[TimeframePrediction]
```
**Decision Making Process:**
1. Gather market states for all symbols and timeframes
2. Get CNN predictions for each timeframe with confidence scores
3. Combine timeframe predictions using weighted averaging
4. Consider symbol correlations (ETH-BTC correlation ~0.85)
5. Apply confidence thresholds and risk management
6. Generate coordinated trading decisions
7. Queue actions for RL evaluation
### 2. Enhanced CNN Trainer (`training/enhanced_cnn_trainer.py`)
Implements supervised learning on marked perfect moves:
**Key Features:**
- **Perfect Move Dataset**: Trains on historically optimal decisions
- **Timeframe-Specific Heads**: Separate prediction heads for each timeframe
- **Confidence Prediction**: Predicts both action and confidence simultaneously
- **Multi-Loss Training**: Combines action classification and confidence regression
**Network Architecture:**
```python
# Convolutional feature extraction
Conv1D(features=5, filters=64, kernel=3) -> BatchNorm -> ReLU -> Dropout
Conv1D(filters=128, kernel=3) -> BatchNorm -> ReLU -> Dropout
Conv1D(filters=256, kernel=3) -> BatchNorm -> ReLU -> Dropout
AdaptiveAvgPool1d(1) # Global average pooling
# Timeframe-specific heads
for each timeframe:
Linear(256 -> 128) -> ReLU -> Dropout
Linear(128 -> 64) -> ReLU -> Dropout
# Action prediction
Linear(64 -> 3) # BUY, HOLD, SELL
# Confidence prediction
Linear(64 -> 32) -> ReLU -> Linear(32 -> 1) -> Sigmoid
```
**Training Process:**
1. Collect perfect moves from orchestrator with known outcomes
2. Create dataset with features, optimal actions, and target confidence
3. Train with combined loss: `action_loss + 0.5 * confidence_loss`
4. Use early stopping and model checkpointing
5. Generate comprehensive training reports and visualizations
### 3. Enhanced RL Trainer (`training/enhanced_rl_trainer.py`)
Implements continuous learning from trading evaluations:
**Key Features:**
- **Prioritized Experience Replay**: Learns from important experiences first
- **Market Regime Adaptation**: Adjusts confidence based on market conditions
- **Multi-Symbol Agents**: Separate RL agents for each trading symbol
- **Double DQN Architecture**: Reduces overestimation bias
**Agent Architecture:**
```python
# Main Network
Linear(state_size -> 256) -> ReLU -> Dropout
Linear(256 -> 256) -> ReLU -> Dropout
Linear(256 -> 128) -> ReLU -> Dropout
# Dueling heads
value_head = Linear(128 -> 1)
advantage_head = Linear(128 -> action_space)
# Q-values = V(s) + A(s,a) - mean(A(s,a))
```
**Learning Process:**
1. Store trading experiences with TD-error priorities
2. Sample batches using prioritized replay
3. Train with Double DQN to reduce overestimation
4. Update target networks periodically
5. Adapt exploration (epsilon) based on market regime stability
### 4. Market State and Feature Engineering
**Market State Components:**
```python
@dataclass
class MarketState:
symbol: str
timestamp: datetime
prices: Dict[str, float] # {timeframe: price}
features: Dict[str, np.ndarray] # {timeframe: feature_matrix}
volatility: float
volume: float
trend_strength: float
market_regime: str # 'trending', 'ranging', 'volatile'
```
**Feature Engineering:**
- **OHLCV Data**: Open, High, Low, Close, Volume for each timeframe
- **Technical Indicators**: RSI, MACD, Bollinger Bands, etc.
- **Market Regime Detection**: Automatic classification of market conditions
- **Volatility Analysis**: Real-time volatility calculations
- **Volume Analysis**: Volume ratio compared to historical averages
## System Workflow
### 1. Initialization Phase
```python
# Load configuration
config = get_config('config.yaml')
# Initialize components
data_provider = DataProvider(config)
orchestrator = EnhancedTradingOrchestrator(data_provider)
cnn_trainer = EnhancedCNNTrainer(config, orchestrator)
rl_trainer = EnhancedRLTrainer(config, orchestrator)
# Load existing models or create new ones
models = initialize_models(load_existing=True)
register_models_with_orchestrator(models)
```
### 2. Trading Loop
```python
while running:
# 1. Gather market data for all symbols and timeframes
market_states = await get_all_market_states()
# 2. Generate CNN predictions for each timeframe
for symbol in symbols:
for timeframe in timeframes:
prediction = cnn_model.predict_timeframe(features, timeframe)
# 3. Combine timeframe predictions with weights
combined_prediction = combine_timeframe_predictions(predictions)
# 4. Consider symbol correlations
coordinated_decision = coordinate_symbols(predictions, correlations)
# 5. Apply confidence thresholds and risk management
final_decision = apply_risk_management(coordinated_decision)
# 6. Execute trades (or log decisions)
execute_trading_decision(final_decision)
# 7. Queue for RL evaluation
queue_for_rl_evaluation(final_decision, market_state)
```
### 3. Continuous Learning Loop
```python
# RL Learning (every hour)
async def rl_learning_loop():
while running:
# Evaluate past trading actions
await evaluate_trading_outcomes()
# Train RL agents on new experiences
for symbol, agent in rl_agents.items():
agent.replay() # Learn from prioritized experiences
# Adapt to market regime changes
adapt_to_market_conditions()
await asyncio.sleep(3600) # Wait 1 hour
# CNN Learning (every 6 hours)
async def cnn_learning_loop():
while running:
# Check for sufficient perfect moves
perfect_moves = get_perfect_moves_for_training()
if len(perfect_moves) >= 200:
# Train CNN on perfect moves
training_report = train_cnn_on_perfect_moves(perfect_moves)
# Update registered model
update_model_registry(trained_model)
await asyncio.sleep(6 * 3600) # Wait 6 hours
```
## Key Algorithms
### 1. Timeframe Prediction Combination
```python
def combine_timeframe_predictions(timeframe_predictions, symbol):
action_scores = {'BUY': 0.0, 'SELL': 0.0, 'HOLD': 0.0}
total_weight = 0.0
timeframe_weights = {
'1m': 0.05, '5m': 0.10, '15m': 0.15,
'1h': 0.25, '4h': 0.25, '1d': 0.20
}
for pred in timeframe_predictions:
weight = timeframe_weights[pred.timeframe] * pred.confidence
action_scores[pred.action] += weight
total_weight += weight
# Normalize and select best action
best_action = max(action_scores, key=action_scores.get)
confidence = action_scores[best_action] / total_weight
return best_action, confidence
```
### 2. Perfect Move Marking
```python
def mark_perfect_move(action, initial_state, final_state, reward):
# Determine optimal action based on outcome
if reward > 0.02: # Significant positive outcome
optimal_action = action.action # Action was correct
optimal_confidence = min(0.95, abs(reward) * 10)
elif reward < -0.02: # Significant negative outcome
optimal_action = opposite_action(action.action) # Should have done opposite
optimal_confidence = min(0.95, abs(reward) * 10)
else: # Neutral outcome
optimal_action = 'HOLD' # Should have held
optimal_confidence = 0.3
# Create perfect move for CNN training
perfect_move = PerfectMove(
symbol=action.symbol,
timeframe=timeframe,
timestamp=action.timestamp,
optimal_action=optimal_action,
confidence_should_have_been=optimal_confidence,
market_state_before=initial_state,
market_state_after=final_state,
actual_outcome=reward
)
return perfect_move
```
### 3. RL Reward Calculation
```python
def calculate_reward(action, price_change, confidence):
base_reward = 0.0
# Reward based on action correctness
if action == 'BUY' and price_change > 0:
base_reward = price_change * 10 # Reward proportional to gain
elif action == 'SELL' and price_change < 0:
base_reward = abs(price_change) * 10 # Reward for avoiding loss
elif action == 'HOLD':
if abs(price_change) < 0.005: # Correct hold
base_reward = 0.01
else: # Missed opportunity
base_reward = -0.01
else:
base_reward = -abs(price_change) * 5 # Penalty for wrong actions
# Scale by confidence
confidence_multiplier = 0.5 + confidence # 0.5 to 1.5 range
return base_reward * confidence_multiplier
```
## Configuration and Deployment
### 1. Running the System
```bash
# Basic trading mode
python enhanced_trading_main.py --mode trade
# Training only mode
python enhanced_trading_main.py --mode train
# Fresh start without loading existing models
python enhanced_trading_main.py --mode trade --no-load-models
# Custom configuration
python enhanced_trading_main.py --config custom_config.yaml
```
### 2. Key Configuration Parameters
```yaml
# Enhanced Orchestrator Settings
orchestrator:
confidence_threshold: 0.6 # Higher threshold for enhanced system
decision_frequency: 30 # Faster decisions (30 seconds)
# CNN Configuration
cnn:
timeframes: ["1m", "5m", "15m", "1h", "4h", "1d"]
confidence_threshold: 0.6
model_dir: "models/enhanced_cnn"
# RL Configuration
rl:
hidden_size: 256
buffer_size: 10000
model_dir: "models/enhanced_rl"
market_regime_weights:
trending: 1.2
ranging: 0.8
volatile: 0.6
```
### 3. Memory Management
The system is designed to work within 8GB memory constraints:
- Total system limit: 8GB
- Per-model limit: 2GB
- Automatic memory cleanup every 30 minutes
- GPU memory management with dynamic allocation
### 4. Monitoring and Logging
- Comprehensive logging with component-specific levels
- TensorBoard integration for training visualization
- Performance metrics tracking
- Memory usage monitoring
- Real-time decision logging with full reasoning
## Performance Characteristics
### Expected Behavior:
1. **Decision Frequency**: 30-second intervals between decisions
2. **CNN Training**: Every 6 hours when sufficient perfect moves available
3. **RL Training**: Continuous learning every hour
4. **Memory Usage**: <8GB total system usage
5. **Confidence Thresholds**: 0.6+ for trading actions
### Key Metrics:
- **Decision Accuracy**: Tracked via RL reward system
- **Confidence Calibration**: CNN confidence vs actual outcomes
- **Symbol Correlation**: ETH-BTC coordination effectiveness
- **Training Progress**: Loss curves and validation accuracy
- **Market Adaptation**: Performance across different regimes
## Future Enhancements
1. **Additional Symbols**: Easy extension to support more trading pairs
2. **Advanced Features**: Sentiment analysis, news integration
3. **Risk Management**: Portfolio-level risk optimization
4. **Backtesting**: Historical performance evaluation
5. **Live Trading**: Real exchange integration
6. **Model Ensembles**: Multiple CNN/RL model combinations
This architecture provides a robust foundation for sophisticated algorithmic trading with continuous learning and adaptation capabilities.

View File

@ -1,6 +1,6 @@
# Trading System Configuration # Enhanced Multi-Modal Trading System Configuration
# Trading Symbols (extendable) # Trading Symbols (extendable/configurable)
symbols: symbols:
- "ETH/USDT" - "ETH/USDT"
- "BTC/USDT" - "BTC/USDT"
@ -22,22 +22,38 @@ data:
historical_limit: 1000 historical_limit: 1000
real_time_enabled: true real_time_enabled: true
websocket_reconnect: true websocket_reconnect: true
feature_engineering:
technical_indicators: true
market_regime_detection: true
volatility_analysis: true
# CNN Model Configuration # Enhanced CNN Configuration
cnn: cnn:
window_size: 20 window_size: 20
features: ["open", "high", "low", "close", "volume"] features: ["open", "high", "low", "close", "volume"]
hidden_layers: [64, 32, 16] timeframes: ["1m", "5m", "15m", "1h", "4h", "1d"]
hidden_layers: [64, 128, 256]
dropout: 0.2 dropout: 0.2
learning_rate: 0.001 learning_rate: 0.001
batch_size: 32 batch_size: 32
epochs: 100 epochs: 100
confidence_threshold: 0.6 confidence_threshold: 0.6
early_stopping_patience: 10
model_dir: "models/enhanced_cnn"
# Timeframe-specific model weights
timeframe_importance:
"1m": 0.05 # Noise filtering
"5m": 0.10 # Short-term momentum
"15m": 0.15 # Entry/exit timing
"1h": 0.25 # Medium-term trend
"4h": 0.25 # Stronger trend confirmation
"1d": 0.20 # Long-term direction
# RL Agent Configuration # Enhanced RL Agent Configuration
rl: rl:
state_size: 100 # Will be calculated dynamically state_size: 100 # Will be calculated dynamically based on features
action_space: 3 # BUY, HOLD, SELL action_space: 3 # BUY, HOLD, SELL
hidden_size: 256
epsilon: 1.0 epsilon: 1.0
epsilon_decay: 0.995 epsilon_decay: 0.995
epsilon_min: 0.01 epsilon_min: 0.01
@ -46,21 +62,78 @@ rl:
memory_size: 10000 memory_size: 10000
batch_size: 64 batch_size: 64
target_update_freq: 1000 target_update_freq: 1000
buffer_size: 10000
model_dir: "models/enhanced_rl"
# Market regime adaptation
market_regime_weights:
trending: 1.2 # Higher confidence in trending markets
ranging: 0.8 # Lower confidence in ranging markets
volatile: 0.6 # Much lower confidence in volatile markets
# Prioritized experience replay
replay_alpha: 0.6 # Priority exponent
replay_beta: 0.4 # Importance sampling exponent
# Orchestrator Settings # Enhanced Orchestrator Settings
orchestrator: orchestrator:
# Model weights for decision combination
cnn_weight: 0.7 # Weight for CNN predictions cnn_weight: 0.7 # Weight for CNN predictions
rl_weight: 0.3 # Weight for RL decisions rl_weight: 0.3 # Weight for RL decisions
confidence_threshold: 0.5 # Minimum confidence to act confidence_threshold: 0.6 # Increased for enhanced system
decision_frequency: 60 # Seconds between decisions decision_frequency: 30 # Seconds between decisions (faster)
# Multi-symbol coordination
symbol_correlation_matrix:
"ETH/USDT-BTC/USDT": 0.85 # ETH-BTC correlation
# Perfect move marking
perfect_move_threshold: 0.02 # 2% price change to mark as significant
perfect_move_buffer_size: 10000
# RL evaluation settings
evaluation_delay: 3600 # Evaluate actions after 1 hour
reward_calculation:
success_multiplier: 10 # Reward for correct predictions
failure_penalty: 5 # Penalty for wrong predictions
confidence_scaling: true # Scale rewards by confidence
# Training Configuration
training:
learning_rate: 0.001
batch_size: 32
epochs: 100
validation_split: 0.2
early_stopping_patience: 10
# CNN specific
cnn_training_interval: 21600 # Train every 6 hours
min_perfect_moves: 200 # Minimum moves before training
# RL specific
rl_training_interval: 3600 # Train every hour
min_experiences: 100 # Minimum experiences before training
training_steps_per_cycle: 10 # Training steps per cycle
# Trading Execution # Trading Execution
trading: trading:
max_position_size: 0.1 # Maximum position size (fraction of balance) max_position_size: 0.05 # Maximum position size (5% of balance)
stop_loss: 0.02 # 2% stop loss stop_loss: 0.02 # 2% stop loss
take_profit: 0.05 # 5% take profit take_profit: 0.05 # 5% take profit
trading_fee: 0.0002 # 0.02% trading fee trading_fee: 0.0002 # 0.02% trading fee
min_trade_interval: 60 # Minimum seconds between trades min_trade_interval: 30 # Minimum seconds between trades (faster)
# Risk management
max_daily_trades: 20 # Maximum trades per day
max_concurrent_positions: 2 # Max positions across symbols
position_sizing:
confidence_scaling: true # Scale position by confidence
base_size: 0.02 # 2% base position
max_size: 0.05 # 5% maximum position
# Memory Management
memory:
total_limit_gb: 8.0 # Total system memory limit
model_limit_gb: 2.0 # Per-model memory limit
cleanup_interval: 1800 # Memory cleanup every 30 minutes
# Web Dashboard # Web Dashboard
web: web:
@ -70,36 +143,54 @@ web:
update_interval: 1000 # Milliseconds update_interval: 1000 # Milliseconds
chart_history: 100 # Number of candles to show chart_history: 100 # Number of candles to show
# Enhanced dashboard features
show_timeframe_analysis: true
show_confidence_scores: true
show_perfect_moves: true
show_rl_metrics: true
# Logging # Logging
logging: logging:
level: "INFO" level: "INFO"
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
file: "logs/trading.log" file: "logs/enhanced_trading.log"
max_size: 10485760 # 10MB max_size: 10485760 # 10MB
backup_count: 5 backup_count: 5
# Component-specific logging
orchestrator_level: "INFO"
cnn_level: "INFO"
rl_level: "INFO"
training_level: "INFO"
# Model Directories
model_dir: "models"
data_dir: "data"
cache_dir: "cache"
logs_dir: "logs"
# GPU/Performance # GPU/Performance
performance: gpu:
use_gpu: true enabled: true
mixed_precision: true memory_fraction: 0.8 # Use 80% of GPU memory
num_workers: 4 allow_growth: true # Allow dynamic memory allocation
batch_size_multiplier: 1.0
# Paths # Monitoring and Alerting
paths: monitoring:
models: "models" tensorboard_enabled: true
data: "data" tensorboard_log_dir: "logs/tensorboard"
logs: "logs" metrics_interval: 300 # Log metrics every 5 minutes
cache: "cache" performance_alerts: true
plots: "plots"
# Training Configuration # Performance thresholds
training: min_confidence_threshold: 0.3
use_only_real_data: true # CRITICAL: Never use synthetic/generated data max_memory_usage: 0.9 # 90% of available memory
batch_size: 32 max_decision_latency: 10 # 10 seconds max per decision
learning_rate: 0.001
epochs: 100
validation_split: 0.2
early_stopping_patience: 10
# Directory paths # Backtesting (for future implementation)
backtesting:
start_date: "2024-01-01"
end_date: "2024-12-31"
initial_balance: 10000
commission: 0.0002
slippage: 0.0001

View File

@ -0,0 +1,698 @@
"""
Enhanced Trading Orchestrator - Advanced Multi-Modal Decision Making
This enhanced orchestrator implements:
1. Multi-timeframe CNN predictions with individual confidence scores
2. Advanced RL feedback loop for continuous learning
3. Multi-symbol (ETH, BTC) coordinated decision making
4. Perfect move marking for CNN backpropagation training
5. Market environment adaptation through RL evaluation
"""
import asyncio
import logging
import time
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any, Union
from dataclasses import dataclass, field
from collections import deque
import torch
from .config import get_config
from .data_provider import DataProvider
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface
logger = logging.getLogger(__name__)
@dataclass
class TimeframePrediction:
"""CNN prediction for a specific timeframe with confidence"""
timeframe: str
action: str # 'BUY', 'SELL', 'HOLD'
confidence: float # 0.0 to 1.0
probabilities: Dict[str, float] # Action probabilities
timestamp: datetime
market_features: Dict[str, float] = field(default_factory=dict) # Additional context
@dataclass
class EnhancedPrediction:
"""Enhanced prediction structure with timeframe breakdown"""
symbol: str
timeframe_predictions: List[TimeframePrediction]
overall_action: str
overall_confidence: float
model_name: str
timestamp: datetime
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class TradingAction:
"""Represents a trading action with full context"""
symbol: str
action: str # 'BUY', 'SELL', 'HOLD'
quantity: float
confidence: float
price: float
timestamp: datetime
reasoning: Dict[str, Any]
timeframe_analysis: List[TimeframePrediction]
@dataclass
class MarketState:
"""Complete market state for RL evaluation"""
symbol: str
timestamp: datetime
prices: Dict[str, float] # {timeframe: current_price}
features: Dict[str, np.ndarray] # {timeframe: feature_matrix}
volatility: float
volume: float
trend_strength: float
market_regime: str # 'trending', 'ranging', 'volatile'
@dataclass
class PerfectMove:
"""Marked perfect move for CNN training"""
symbol: str
timeframe: str
timestamp: datetime
optimal_action: str
actual_outcome: float # Price change percentage
market_state_before: MarketState
market_state_after: MarketState
confidence_should_have_been: float
class EnhancedTradingOrchestrator:
"""
Enhanced orchestrator with sophisticated multi-modal decision making
"""
def __init__(self, data_provider: DataProvider = None):
"""Initialize the enhanced orchestrator"""
self.config = get_config()
self.data_provider = data_provider or DataProvider()
self.model_registry = get_model_registry()
# Multi-symbol configuration
self.symbols = self.config.symbols
self.timeframes = self.config.timeframes
# Configuration
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.6)
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30)
# Enhanced weighting system
self.timeframe_weights = self._initialize_timeframe_weights()
self.symbol_correlation_matrix = self._initialize_correlation_matrix()
# State tracking for each symbol
self.symbol_states = {symbol: {} for symbol in self.symbols}
self.recent_actions = {symbol: deque(maxlen=100) for symbol in self.symbols}
self.market_states = {symbol: deque(maxlen=1000) for symbol in self.symbols}
# Perfect move tracking for CNN training
self.perfect_moves = deque(maxlen=10000)
self.performance_tracker = {}
# RL feedback system
self.rl_evaluation_queue = deque(maxlen=1000)
self.environment_adaptation_rate = 0.01
# Decision callbacks
self.decision_callbacks = []
self.learning_callbacks = []
logger.info("Enhanced TradingOrchestrator initialized")
logger.info(f"Symbols: {self.symbols}")
logger.info(f"Timeframes: {self.timeframes}")
logger.info(f"Enhanced confidence threshold: {self.confidence_threshold}")
def _initialize_timeframe_weights(self) -> Dict[str, float]:
"""Initialize weights for different timeframes"""
# Higher timeframes get more weight for trend direction
# Lower timeframes get more weight for entry/exit timing
base_weights = {
'1m': 0.05, # Noise filtering
'5m': 0.10, # Short-term momentum
'15m': 0.15, # Entry/exit timing
'1h': 0.25, # Medium-term trend
'4h': 0.25, # Stronger trend confirmation
'1d': 0.20 # Long-term direction
}
# Normalize weights for configured timeframes
configured_weights = {tf: base_weights.get(tf, 0.1) for tf in self.timeframes}
total = sum(configured_weights.values())
return {tf: w/total for tf, w in configured_weights.items()}
def _initialize_correlation_matrix(self) -> Dict[Tuple[str, str], float]:
"""Initialize correlation matrix between symbols"""
correlations = {}
for i, symbol1 in enumerate(self.symbols):
for j, symbol2 in enumerate(self.symbols):
if i != j:
# ETH and BTC are typically highly correlated
if 'ETH' in symbol1 and 'BTC' in symbol2:
correlations[(symbol1, symbol2)] = 0.85
elif 'BTC' in symbol1 and 'ETH' in symbol2:
correlations[(symbol1, symbol2)] = 0.85
else:
correlations[(symbol1, symbol2)] = 0.7 # Default correlation
return correlations
async def make_coordinated_decisions(self) -> Dict[str, Optional[TradingAction]]:
"""
Make coordinated trading decisions across all symbols
"""
decisions = {}
try:
# Get market states for all symbols
market_states = await self._get_all_market_states()
# Get enhanced predictions for all symbols
symbol_predictions = {}
for symbol in self.symbols:
if symbol in market_states:
predictions = await self._get_enhanced_predictions(symbol, market_states[symbol])
symbol_predictions[symbol] = predictions
# Coordinate decisions considering symbol correlations
for symbol in self.symbols:
if symbol in symbol_predictions:
decision = await self._make_coordinated_decision(
symbol,
symbol_predictions[symbol],
symbol_predictions,
market_states[symbol]
)
decisions[symbol] = decision
# Queue for RL evaluation
if decision and decision.action != 'HOLD':
self._queue_for_rl_evaluation(decision, market_states[symbol])
except Exception as e:
logger.error(f"Error in coordinated decision making: {e}")
return decisions
async def _get_all_market_states(self) -> Dict[str, MarketState]:
"""Get current market state for all symbols"""
market_states = {}
for symbol in self.symbols:
try:
# Get current market data for all timeframes
prices = {}
features = {}
for timeframe in self.timeframes:
# Get current price
current_price = self.data_provider.get_current_price(symbol)
if current_price:
prices[timeframe] = current_price
# Get feature matrix for this timeframe
feature_matrix = self.data_provider.get_feature_matrix(
symbol=symbol,
timeframes=[timeframe],
window_size=20 # Standard window
)
if feature_matrix is not None:
features[timeframe] = feature_matrix
if prices and features:
# Calculate market metrics
volatility = self._calculate_volatility(symbol)
volume = self._get_current_volume(symbol)
trend_strength = self._calculate_trend_strength(symbol)
market_regime = self._determine_market_regime(symbol)
market_state = MarketState(
symbol=symbol,
timestamp=datetime.now(),
prices=prices,
features=features,
volatility=volatility,
volume=volume,
trend_strength=trend_strength,
market_regime=market_regime
)
market_states[symbol] = market_state
# Store for historical tracking
self.market_states[symbol].append(market_state)
except Exception as e:
logger.error(f"Error getting market state for {symbol}: {e}")
return market_states
async def _get_enhanced_predictions(self, symbol: str, market_state: MarketState) -> List[EnhancedPrediction]:
"""Get enhanced predictions with timeframe breakdown"""
predictions = []
for model_name, model in self.model_registry.models.items():
try:
if isinstance(model, CNNModelInterface):
# Get CNN predictions for each timeframe
timeframe_predictions = []
for timeframe in self.timeframes:
if timeframe in market_state.features:
feature_matrix = market_state.features[timeframe]
# Get timeframe-specific prediction
action_probs, confidence = await self._get_timeframe_prediction(
model, feature_matrix, timeframe, market_state
)
if action_probs is not None:
action_names = ['SELL', 'HOLD', 'BUY']
best_action_idx = np.argmax(action_probs)
best_action = action_names[best_action_idx]
# Create timeframe prediction
tf_prediction = TimeframePrediction(
timeframe=timeframe,
action=best_action,
confidence=float(confidence),
probabilities={name: float(prob) for name, prob in zip(action_names, action_probs)},
timestamp=datetime.now(),
market_features={
'volatility': market_state.volatility,
'volume': market_state.volume,
'trend_strength': market_state.trend_strength
}
)
timeframe_predictions.append(tf_prediction)
if timeframe_predictions:
# Combine timeframe predictions into overall prediction
overall_action, overall_confidence = self._combine_timeframe_predictions(
timeframe_predictions, symbol
)
enhanced_pred = EnhancedPrediction(
symbol=symbol,
timeframe_predictions=timeframe_predictions,
overall_action=overall_action,
overall_confidence=overall_confidence,
model_name=model.name,
timestamp=datetime.now(),
metadata={
'market_regime': market_state.market_regime,
'symbol_correlation': self._get_symbol_correlation(symbol)
}
)
predictions.append(enhanced_pred)
except Exception as e:
logger.error(f"Error getting enhanced predictions from {model_name}: {e}")
return predictions
async def _get_timeframe_prediction(self, model: CNNModelInterface, feature_matrix: np.ndarray,
timeframe: str, market_state: MarketState) -> Tuple[Optional[np.ndarray], float]:
"""Get prediction for specific timeframe with enhanced context"""
try:
# Check if model supports timeframe-specific prediction
if hasattr(model, 'predict_timeframe'):
action_probs, confidence = model.predict_timeframe(feature_matrix, timeframe)
else:
action_probs, confidence = model.predict(feature_matrix)
if action_probs is not None and confidence is not None:
# Enhance confidence based on market conditions
enhanced_confidence = self._enhance_confidence_with_context(
confidence, timeframe, market_state
)
return action_probs, enhanced_confidence
except Exception as e:
logger.error(f"Error getting timeframe prediction for {timeframe}: {e}")
return None, 0.0
def _enhance_confidence_with_context(self, base_confidence: float, timeframe: str,
market_state: MarketState) -> float:
"""Enhance confidence score based on market context"""
enhanced = base_confidence
# Adjust based on market regime
if market_state.market_regime == 'trending':
enhanced *= 1.1 # More confident in trending markets
elif market_state.market_regime == 'volatile':
enhanced *= 0.8 # Less confident in volatile markets
# Adjust based on timeframe reliability
timeframe_reliability = {
'1m': 0.7, '5m': 0.8, '15m': 0.9, '1h': 1.0, '4h': 1.1, '1d': 1.2
}
enhanced *= timeframe_reliability.get(timeframe, 1.0)
# Adjust based on volume
if market_state.volume > 1.5: # High volume
enhanced *= 1.05
elif market_state.volume < 0.5: # Low volume
enhanced *= 0.9
return min(enhanced, 1.0) # Cap at 1.0
def _combine_timeframe_predictions(self, timeframe_predictions: List[TimeframePrediction],
symbol: str) -> Tuple[str, float]:
"""Combine predictions from multiple timeframes"""
action_scores = {'BUY': 0.0, 'SELL': 0.0, 'HOLD': 0.0}
total_weight = 0.0
for tf_pred in timeframe_predictions:
# Get timeframe weight
tf_weight = self.timeframe_weights.get(tf_pred.timeframe, 0.1)
# Weight by confidence and timeframe importance
weighted_confidence = tf_pred.confidence * tf_weight
# Add to action scores
action_scores[tf_pred.action] += weighted_confidence
total_weight += weighted_confidence
# Normalize scores
if total_weight > 0:
for action in action_scores:
action_scores[action] /= total_weight
# Get best action and confidence
best_action = max(action_scores, key=action_scores.get)
best_confidence = action_scores[best_action]
return best_action, best_confidence
async def _make_coordinated_decision(self, symbol: str, predictions: List[EnhancedPrediction],
all_predictions: Dict[str, List[EnhancedPrediction]],
market_state: MarketState) -> Optional[TradingAction]:
"""Make decision considering symbol correlations"""
if not predictions:
return None
try:
# Get primary prediction (highest confidence)
primary_pred = max(predictions, key=lambda p: p.overall_confidence)
# Consider correlated symbols
correlated_sentiment = self._get_correlated_sentiment(symbol, all_predictions)
# Adjust decision based on correlation
final_action = primary_pred.overall_action
final_confidence = primary_pred.overall_confidence
# If correlated symbols strongly disagree, reduce confidence
if correlated_sentiment['agreement'] < 0.5:
final_confidence *= 0.8
logger.info(f"Reduced confidence for {symbol} due to correlation disagreement")
# Apply confidence threshold
if final_confidence < self.confidence_threshold:
final_action = 'HOLD'
logger.info(f"Action for {symbol} changed to HOLD due to low confidence: {final_confidence:.3f}")
# Create trading action
if final_action != 'HOLD':
current_price = market_state.prices.get(self.timeframes[0], 0)
quantity = self._calculate_position_size(symbol, final_action, final_confidence)
action = TradingAction(
symbol=symbol,
action=final_action,
quantity=quantity,
confidence=final_confidence,
price=current_price,
timestamp=datetime.now(),
reasoning={
'primary_model': primary_pred.model_name,
'timeframe_breakdown': [(tf.timeframe, tf.action, tf.confidence)
for tf in primary_pred.timeframe_predictions],
'correlated_sentiment': correlated_sentiment,
'market_regime': market_state.market_regime
},
timeframe_analysis=primary_pred.timeframe_predictions
)
# Store recent action
self.recent_actions[symbol].append(action)
return action
except Exception as e:
logger.error(f"Error making coordinated decision for {symbol}: {e}")
return None
def _get_correlated_sentiment(self, symbol: str,
all_predictions: Dict[str, List[EnhancedPrediction]]) -> Dict[str, Any]:
"""Get sentiment from correlated symbols"""
correlated_actions = []
correlated_confidences = []
for other_symbol, predictions in all_predictions.items():
if other_symbol != symbol and predictions:
correlation = self.symbol_correlation_matrix.get((symbol, other_symbol), 0.0)
if correlation > 0.5: # Only consider significantly correlated symbols
best_pred = max(predictions, key=lambda p: p.overall_confidence)
correlated_actions.append(best_pred.overall_action)
correlated_confidences.append(best_pred.overall_confidence * correlation)
if not correlated_actions:
return {'agreement': 1.0, 'sentiment': 'NEUTRAL'}
# Calculate agreement
primary_pred = all_predictions[symbol][0] if all_predictions.get(symbol) else None
if primary_pred:
agreement_count = sum(1 for action in correlated_actions
if action == primary_pred.overall_action)
agreement = agreement_count / len(correlated_actions)
else:
agreement = 0.5
# Calculate overall sentiment
action_weights = {'BUY': 0.0, 'SELL': 0.0, 'HOLD': 0.0}
for action, confidence in zip(correlated_actions, correlated_confidences):
action_weights[action] += confidence
dominant_sentiment = max(action_weights, key=action_weights.get)
return {
'agreement': agreement,
'sentiment': dominant_sentiment,
'correlated_symbols': len(correlated_actions)
}
def _queue_for_rl_evaluation(self, action: TradingAction, market_state: MarketState):
"""Queue trading action for RL evaluation"""
evaluation_item = {
'action': action,
'market_state_before': market_state,
'timestamp': datetime.now(),
'evaluation_pending': True
}
self.rl_evaluation_queue.append(evaluation_item)
async def evaluate_actions_with_rl(self):
"""Evaluate recent actions using RL agents for continuous learning"""
if not self.rl_evaluation_queue:
return
current_time = datetime.now()
# Process actions that are ready for evaluation (e.g., 1 hour old)
for item in list(self.rl_evaluation_queue):
if item['evaluation_pending']:
time_since_action = (current_time - item['timestamp']).total_seconds()
# Evaluate after sufficient time has passed
if time_since_action >= 3600: # 1 hour
await self._evaluate_single_action(item)
item['evaluation_pending'] = False
async def _evaluate_single_action(self, evaluation_item: Dict[str, Any]):
"""Evaluate a single action using RL"""
try:
action = evaluation_item['action']
initial_state = evaluation_item['market_state_before']
# Get current market state for comparison
current_market_states = await self._get_all_market_states()
current_state = current_market_states.get(action.symbol)
if current_state:
# Calculate reward based on price movement
initial_price = initial_state.prices.get(self.timeframes[0], 0)
current_price = current_state.prices.get(self.timeframes[0], 0)
if initial_price > 0:
price_change = (current_price - initial_price) / initial_price
# Calculate reward based on action and price movement
reward = self._calculate_reward(action.action, price_change, action.confidence)
# Update RL agents
await self._update_rl_agents(action, initial_state, current_state, reward)
# Check if this was a perfect move for CNN training
if abs(reward) > 0.02: # Significant outcome
self._mark_perfect_move(action, initial_state, current_state, reward)
except Exception as e:
logger.error(f"Error evaluating action: {e}")
def _calculate_reward(self, action: str, price_change: float, confidence: float) -> float:
"""Calculate reward for RL training"""
base_reward = 0.0
if action == 'BUY' and price_change > 0:
base_reward = price_change * 10 # Reward proportional to gain
elif action == 'SELL' and price_change < 0:
base_reward = abs(price_change) * 10 # Reward for avoiding loss
elif action == 'HOLD':
base_reward = 0.01 if abs(price_change) < 0.005 else -0.01 # Small reward for correct holds
else:
base_reward = -abs(price_change) * 5 # Penalty for wrong actions
# Adjust reward based on confidence
confidence_multiplier = 0.5 + confidence # 0.5 to 1.5 range
return base_reward * confidence_multiplier
async def _update_rl_agents(self, action: TradingAction, initial_state: MarketState,
current_state: MarketState, reward: float):
"""Update RL agents with action evaluation"""
for model_name, model in self.model_registry.models.items():
if isinstance(model, RLAgentInterface):
try:
# Convert market states to RL state format
initial_rl_state = self._market_state_to_rl_state(initial_state)
current_rl_state = self._market_state_to_rl_state(current_state)
# Convert action to RL action index
action_idx = {'SELL': 0, 'HOLD': 1, 'BUY': 2}.get(action.action, 1)
# Store experience
model.remember(
state=initial_rl_state,
action=action_idx,
reward=reward,
next_state=current_rl_state,
done=False
)
# Trigger replay learning
loss = model.replay()
if loss is not None:
logger.info(f"RL agent {model_name} updated with loss: {loss:.4f}")
except Exception as e:
logger.error(f"Error updating RL agent {model_name}: {e}")
def _mark_perfect_move(self, action: TradingAction, initial_state: MarketState,
final_state: MarketState, reward: float):
"""Mark a perfect move for CNN training"""
try:
# Determine what the optimal action should have been
optimal_action = action.action if reward > 0 else ('HOLD' if action.action == 'HOLD' else
('SELL' if action.action == 'BUY' else 'BUY'))
# Calculate what confidence should have been
optimal_confidence = min(0.95, abs(reward) * 10) # Higher reward = higher confidence should have been
for tf_pred in action.timeframe_analysis:
perfect_move = PerfectMove(
symbol=action.symbol,
timeframe=tf_pred.timeframe,
timestamp=action.timestamp,
optimal_action=optimal_action,
actual_outcome=reward,
market_state_before=initial_state,
market_state_after=final_state,
confidence_should_have_been=optimal_confidence
)
self.perfect_moves.append(perfect_move)
logger.info(f"Marked perfect move for {action.symbol}: {optimal_action} with confidence {optimal_confidence:.3f}")
except Exception as e:
logger.error(f"Error marking perfect move: {e}")
def get_perfect_moves_for_training(self, symbol: str = None, timeframe: str = None,
limit: int = 1000) -> List[PerfectMove]:
"""Get perfect moves for CNN training"""
moves = list(self.perfect_moves)
if symbol:
moves = [m for m in moves if m.symbol == symbol]
if timeframe:
moves = [m for m in moves if m.timeframe == timeframe]
return moves[-limit:] if limit else moves
# Helper methods for market analysis
def _calculate_volatility(self, symbol: str) -> float:
"""Calculate current volatility for symbol"""
# Placeholder - implement based on your data provider
return 0.02 # 2% default volatility
def _get_current_volume(self, symbol: str) -> float:
"""Get current volume ratio compared to average"""
# Placeholder - implement based on your data provider
return 1.0 # Normal volume
def _calculate_trend_strength(self, symbol: str) -> float:
"""Calculate trend strength (0 = no trend, 1 = strong trend)"""
# Placeholder - implement based on your data provider
return 0.5 # Moderate trend
def _determine_market_regime(self, symbol: str) -> str:
"""Determine current market regime"""
# Placeholder - implement based on your analysis
return 'trending' # Default to trending
def _get_symbol_correlation(self, symbol: str) -> Dict[str, float]:
"""Get correlations with other symbols"""
correlations = {}
for other_symbol in self.symbols:
if other_symbol != symbol:
correlations[other_symbol] = self.symbol_correlation_matrix.get((symbol, other_symbol), 0.0)
return correlations
def _calculate_position_size(self, symbol: str, action: str, confidence: float) -> float:
"""Calculate position size based on confidence and risk management"""
base_size = 0.02 # 2% of portfolio
confidence_multiplier = confidence # Scale by confidence
max_size = 0.05 # 5% maximum
return min(base_size * confidence_multiplier, max_size)
def _market_state_to_rl_state(self, market_state: MarketState) -> np.ndarray:
"""Convert market state to RL state vector"""
# Combine features from all timeframes into a single state vector
state_components = []
# Add price features
state_components.extend([
market_state.volatility,
market_state.volume,
market_state.trend_strength
])
# Add flattened features from each timeframe
for timeframe in sorted(market_state.features.keys()):
features = market_state.features[timeframe]
if features is not None:
# Take the last row (most recent) and flatten
latest_features = features[-1] if len(features.shape) > 1 else features
state_components.extend(latest_features.flatten())
return np.array(state_components, dtype=np.float32)

370
enhanced_trading_main.py Normal file
View File

@ -0,0 +1,370 @@
"""
Enhanced Multi-Modal Trading System - Main Application
This is the main launcher for the sophisticated trading system featuring:
1. Enhanced orchestrator coordinating CNN and RL modules
2. Multi-timeframe, multi-symbol (ETH, BTC) trading decisions
3. Perfect move marking for CNN training with known outcomes
4. Continuous RL learning from trading action evaluations
5. Market environment adaptation and coordinated decision making
"""
import asyncio
import logging
import signal
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, Optional
import argparse
# Core components
from core.config import get_config
from core.data_provider import DataProvider
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
from models import get_model_registry
# Training components
from training.enhanced_cnn_trainer import EnhancedCNNTrainer, EnhancedCNNModel
from training.enhanced_rl_trainer import EnhancedRLTrainer, EnhancedDQNAgent
# Utilities
import torch
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('logs/enhanced_trading.log')
]
)
logger = logging.getLogger(__name__)
class EnhancedTradingSystem:
"""Main enhanced trading system coordinator"""
def __init__(self, config_path: str = None):
"""Initialize the enhanced trading system"""
self.config = get_config(config_path)
self.running = False
# Core components
self.data_provider = DataProvider(self.config)
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
self.model_registry = get_model_registry()
# Training components
self.cnn_trainer = EnhancedCNNTrainer(self.config, self.orchestrator)
self.rl_trainer = EnhancedRLTrainer(self.config, self.orchestrator)
# Models
self.cnn_models = {}
self.rl_agents = {}
# Performance tracking
self.performance_metrics = {
'decisions_made': 0,
'perfect_moves_marked': 0,
'rl_experiences_added': 0,
'training_sessions': 0
}
logger.info("Enhanced Trading System initialized")
logger.info(f"Symbols: {self.config.symbols}")
logger.info(f"Timeframes: {self.config.timeframes}")
async def initialize_models(self, load_existing: bool = True):
"""Initialize and register all models"""
logger.info("Initializing models...")
# Initialize CNN models
if load_existing:
# Try to load existing CNN model
if self.cnn_trainer.load_model('best_model.pt'):
logger.info("Loaded existing CNN model")
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
else:
logger.info("No existing CNN model found, using fresh model")
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
else:
logger.info("Creating fresh CNN model")
self.cnn_models['enhanced_cnn'] = self.cnn_trainer.get_model()
# Initialize RL agents
if load_existing:
# Try to load existing RL agents
if self.rl_trainer.load_models():
logger.info("Loaded existing RL models")
else:
logger.info("No existing RL models found, using fresh agents")
self.rl_agents = self.rl_trainer.get_agents()
# Register models with the orchestrator
for model_name, model in self.cnn_models.items():
if self.model_registry.register_model(model):
logger.info(f"Registered CNN model: {model_name}")
for symbol, agent in self.rl_agents.items():
if self.model_registry.register_model(agent):
logger.info(f"Registered RL agent for {symbol}")
# Display memory usage
memory_stats = self.model_registry.get_memory_stats()
logger.info(f"Total memory usage: {memory_stats['total_used_mb']:.1f}MB / "
f"{memory_stats['total_limit_mb']:.1f}MB "
f"({memory_stats['utilization_percent']:.1f}%)")
async def start_trading_loop(self):
"""Start the main trading decision loop"""
logger.info("Starting enhanced trading loop...")
self.running = True
decision_count = 0
while self.running:
try:
# Make coordinated decisions for all symbols
decisions = await self.orchestrator.make_coordinated_decisions()
# Process decisions
for symbol, decision in decisions.items():
if decision:
decision_count += 1
self.performance_metrics['decisions_made'] += 1
logger.info(f"Trading Decision #{decision_count}")
logger.info(f"Symbol: {symbol}")
logger.info(f"Action: {decision.action}")
logger.info(f"Confidence: {decision.confidence:.3f}")
logger.info(f"Price: ${decision.price:.2f}")
logger.info(f"Quantity: {decision.quantity:.6f}")
# Log timeframe analysis
for tf_pred in decision.timeframe_analysis:
logger.info(f" {tf_pred.timeframe}: {tf_pred.action} "
f"(conf: {tf_pred.confidence:.3f})")
# Here you would integrate with actual trading execution
# For now, we just log the decision
# Evaluate past actions with RL
await self.orchestrator.evaluate_actions_with_rl()
# Check for perfect moves to mark
perfect_moves = self.orchestrator.get_perfect_moves_for_training(limit=10)
if perfect_moves:
self.performance_metrics['perfect_moves_marked'] = len(perfect_moves)
# Log performance metrics every 10 decisions
if decision_count % 10 == 0 and decision_count > 0:
await self._log_performance_metrics()
# Wait before next decision cycle
await asyncio.sleep(self.orchestrator.decision_frequency)
except Exception as e:
logger.error(f"Error in trading loop: {e}")
await asyncio.sleep(30) # Wait 30 seconds on error
async def start_training_loops(self):
"""Start continuous training loops"""
logger.info("Starting continuous training loops...")
# Start RL continuous learning
rl_task = asyncio.create_task(self.rl_trainer.continuous_learning_loop())
# Start periodic CNN training
cnn_task = asyncio.create_task(self._periodic_cnn_training())
return rl_task, cnn_task
async def _periodic_cnn_training(self):
"""Periodic CNN training on accumulated perfect moves"""
while self.running:
try:
# Wait for 6 hours between training sessions
await asyncio.sleep(6 * 3600)
# Check if we have enough perfect moves for training
perfect_moves = []
for symbol in self.config.symbols:
symbol_moves = self.orchestrator.get_perfect_moves_for_training(symbol=symbol)
perfect_moves.extend(symbol_moves)
if len(perfect_moves) >= 200: # Minimum 200 perfect moves
logger.info(f"Starting CNN training on {len(perfect_moves)} perfect moves")
# Train the CNN model
training_report = self.cnn_trainer.train_on_perfect_moves(min_samples=200)
if training_report.get('training_completed'):
self.performance_metrics['training_sessions'] += 1
logger.info("CNN training completed successfully")
logger.info(f"Final validation accuracy: "
f"{training_report['final_metrics']['val_accuracy']:.4f}")
# Update the registered model
updated_model = self.cnn_trainer.get_model()
self.model_registry.unregister_model('enhanced_cnn')
self.model_registry.register_model(updated_model)
else:
logger.warning(f"CNN training failed: {training_report}")
else:
logger.info(f"Not enough perfect moves for training: {len(perfect_moves)} < 200")
except Exception as e:
logger.error(f"Error in periodic CNN training: {e}")
async def _log_performance_metrics(self):
"""Log system performance metrics"""
logger.info("=== SYSTEM PERFORMANCE METRICS ===")
logger.info(f"Decisions made: {self.performance_metrics['decisions_made']}")
logger.info(f"Perfect moves marked: {self.performance_metrics['perfect_moves_marked']}")
logger.info(f"Training sessions: {self.performance_metrics['training_sessions']}")
# Model registry stats
memory_stats = self.model_registry.get_memory_stats()
logger.info(f"Memory usage: {memory_stats['total_used_mb']:.1f}MB / "
f"{memory_stats['total_limit_mb']:.1f}MB")
# RL performance
rl_report = self.rl_trainer.get_performance_report()
for symbol, agent_data in rl_report['agents'].items():
logger.info(f"{symbol} RL: Epsilon={agent_data['epsilon']:.3f}, "
f"Experiences={agent_data['experiences_stored']}, "
f"Avg Reward={agent_data['avg_recent_reward']:.4f}")
# CNN model info
for model_name, model in self.cnn_models.items():
logger.info(f"{model_name}: Memory={model.get_memory_usage()}MB, "
f"Device={model.device}")
async def shutdown(self):
"""Graceful shutdown of the system"""
logger.info("Shutting down Enhanced Trading System...")
self.running = False
# Save models
logger.info("Saving models...")
self.cnn_trainer._save_model('shutdown_model.pt')
self.rl_trainer._save_all_models()
# Clean up memory
self.model_registry.cleanup_all_models()
# Generate final reports
logger.info("Generating final reports...")
# CNN training plots
if self.cnn_trainer.training_history['train_loss']:
self.cnn_trainer._plot_training_history()
# RL training plots
self.rl_trainer.plot_training_metrics()
logger.info("Enhanced Trading System shutdown complete")
def setup_signal_handlers(trading_system: EnhancedTradingSystem):
"""Setup signal handlers for graceful shutdown"""
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, initiating shutdown...")
asyncio.create_task(trading_system.shutdown())
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
async def main():
"""Main application entry point"""
parser = argparse.ArgumentParser(description='Enhanced Multi-Modal Trading System')
parser.add_argument('--config', type=str, help='Configuration file path')
parser.add_argument('--mode', type=str, choices=['trade', 'train', 'backtest'],
default='trade', help='Operation mode')
parser.add_argument('--load-models', action='store_true', default=True,
help='Load existing models')
parser.add_argument('--no-load-models', action='store_false', dest='load_models',
help="Don't load existing models")
args = parser.parse_args()
# Create logs directory
Path('logs').mkdir(exist_ok=True)
logger.info("=== ENHANCED MULTI-MODAL TRADING SYSTEM ===")
logger.info(f"Mode: {args.mode}")
logger.info(f"Load existing models: {args.load_models}")
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
# Initialize trading system
trading_system = EnhancedTradingSystem(args.config)
# Setup signal handlers
setup_signal_handlers(trading_system)
try:
# Initialize models
await trading_system.initialize_models(load_existing=args.load_models)
if args.mode == 'trade':
# Start training loops
rl_task, cnn_task = await trading_system.start_training_loops()
# Start main trading loop
trading_task = asyncio.create_task(trading_system.start_trading_loop())
# Wait for any task to complete (or error)
done, pending = await asyncio.wait(
[trading_task, rl_task, cnn_task],
return_when=asyncio.FIRST_COMPLETED
)
# Cancel remaining tasks
for task in pending:
task.cancel()
elif args.mode == 'train':
# Training-only mode
logger.info("Running in training-only mode...")
# Train CNN if we have perfect moves
perfect_moves = []
for symbol in trading_system.config.symbols:
symbol_moves = trading_system.orchestrator.get_perfect_moves_for_training(symbol=symbol)
perfect_moves.extend(symbol_moves)
if len(perfect_moves) >= 100:
logger.info(f"Training CNN on {len(perfect_moves)} perfect moves")
training_report = trading_system.cnn_trainer.train_on_perfect_moves(min_samples=100)
logger.info(f"CNN training report: {training_report}")
else:
logger.warning(f"Not enough perfect moves for training: {len(perfect_moves)}")
# Train RL agents if they have experiences
await trading_system.rl_trainer._train_all_agents()
elif args.mode == 'backtest':
# Backtesting mode
logger.info("Backtesting mode not implemented yet")
return
except KeyboardInterrupt:
logger.info("Received keyboard interrupt")
except Exception as e:
logger.error(f"Unexpected error: {e}", exc_info=True)
finally:
await trading_system.shutdown()
if __name__ == "__main__":
# Run the main application
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Application terminated by user")
except Exception as e:
logger.error(f"Fatal error: {e}", exc_info=True)
sys.exit(1)

View File

@ -6,3 +6,10 @@ numpy>=1.24.0
python-dotenv>=1.0.0 python-dotenv>=1.0.0
psutil>=5.9.0 psutil>=5.9.0
tensorboard>=2.15.0 tensorboard>=2.15.0
torch>=2.0.0
torchvision>=0.15.0
torchaudio>=2.0.0
scikit-learn>=1.3.0
matplotlib>=3.7.0
seaborn>=0.12.0
asyncio-compat>=0.1.2

168
run_enhanced_dashboard.py Normal file
View File

@ -0,0 +1,168 @@
#!/usr/bin/env python3
"""
Run Enhanced Trading Dashboard
This script starts the web dashboard with the enhanced trading system
for real-time monitoring and visualization.
"""
import logging
import asyncio
from threading import Thread
import time
from core.config import get_config, setup_logging
from core.data_provider import DataProvider
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
from web.dashboard import TradingDashboard
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EnhancedDashboardRunner:
"""Enhanced dashboard runner with mock trading simulation"""
def __init__(self):
"""Initialize the enhanced dashboard"""
self.config = get_config()
self.data_provider = DataProvider(self.config)
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
# Create dashboard with enhanced orchestrator
self.dashboard = TradingDashboard(
data_provider=self.data_provider,
orchestrator=self.orchestrator
)
# Simulation state
self.running = False
self.simulation_thread = None
logger.info("Enhanced dashboard runner initialized")
def start_simulation(self):
"""Start background simulation for demonstration"""
self.running = True
self.simulation_thread = Thread(target=self._simulation_loop, daemon=True)
self.simulation_thread.start()
logger.info("Started enhanced trading simulation")
def _simulation_loop(self):
"""Background simulation loop"""
import random
from datetime import datetime
from core.enhanced_orchestrator import TradingAction, TimeframePrediction
action_count = 0
while self.running:
try:
# Simulate trading decisions for demonstration
for symbol in self.config.symbols:
# Create mock timeframe predictions
timeframe_predictions = []
for timeframe in ['1h', '4h', '1d']:
# Random but realistic predictions
action_probs = [
random.uniform(0.1, 0.4), # SELL
random.uniform(0.3, 0.6), # HOLD
random.uniform(0.1, 0.4) # BUY
]
# Normalize probabilities
total = sum(action_probs)
action_probs = [p/total for p in action_probs]
best_action_idx = action_probs.index(max(action_probs))
actions = ['SELL', 'HOLD', 'BUY']
best_action = actions[best_action_idx]
tf_pred = TimeframePrediction(
timeframe=timeframe,
action=best_action,
confidence=random.uniform(0.5, 0.9),
probabilities={
'SELL': action_probs[0],
'HOLD': action_probs[1],
'BUY': action_probs[2]
},
timestamp=datetime.now(),
market_features={
'volatility': random.uniform(0.01, 0.05),
'volume': random.uniform(1000, 10000),
'trend_strength': random.uniform(0.3, 0.8)
}
)
timeframe_predictions.append(tf_pred)
# Create mock trading action
if random.random() > 0.7: # 30% chance of action
action_count += 1
mock_action = TradingAction(
symbol=symbol,
action=random.choice(['BUY', 'SELL']),
quantity=random.uniform(0.01, 0.1),
confidence=random.uniform(0.6, 0.9),
price=random.uniform(2000, 4000) if 'ETH' in symbol else random.uniform(40000, 70000),
timestamp=datetime.now(),
reasoning={
'model': 'Enhanced Multi-Modal',
'timeframe_consensus': 'Strong',
'market_regime': random.choice(['trending', 'ranging', 'volatile']),
'action_count': action_count
},
timeframe_analysis=timeframe_predictions
)
# Add to dashboard
self.dashboard.add_trading_decision(mock_action)
logger.info(f"Simulated {mock_action.action} for {symbol} "
f"(confidence: {mock_action.confidence:.2f})")
# Sleep for next iteration
time.sleep(10) # Update every 10 seconds
except Exception as e:
logger.error(f"Error in simulation loop: {e}")
time.sleep(5)
def run_dashboard(self, host='127.0.0.1', port=8050):
"""Run the enhanced dashboard"""
logger.info(f"Starting enhanced trading dashboard at http://{host}:{port}")
logger.info("Features:")
logger.info("- Multi-modal CNN + RL predictions")
logger.info("- Multi-timeframe analysis")
logger.info("- Real-time market regime detection")
logger.info("- Perfect move tracking for CNN training")
logger.info("- RL feedback loop evaluation")
# Start simulation
self.start_simulation()
# Run dashboard
try:
self.dashboard.run(host=host, port=port, debug=False)
except KeyboardInterrupt:
logger.info("Dashboard stopped by user")
finally:
self.running = False
if self.simulation_thread:
self.simulation_thread.join(timeout=2)
def main():
"""Main function"""
try:
logger.info("=== ENHANCED TRADING DASHBOARD ===")
# Create and run dashboard
runner = EnhancedDashboardRunner()
runner.run_dashboard()
except Exception as e:
logger.error(f"Fatal error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

60
test_enhanced_system.py Normal file
View File

@ -0,0 +1,60 @@
#!/usr/bin/env python3
"""
Simple test script for the enhanced trading system
Tests basic functionality without complex training loops
"""
import logging
import asyncio
from core.config import get_config, setup_logging
from core.data_provider import DataProvider
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def test_enhanced_system():
"""Test the enhanced trading system components"""
try:
logger.info("=== TESTING ENHANCED TRADING SYSTEM ===")
# Load configuration
config = get_config()
logger.info(f"Loaded config with symbols: {config.symbols}")
logger.info(f"Timeframes: {config.timeframes}")
# Initialize data provider
data_provider = DataProvider(config)
logger.info("Data provider initialized")
# Initialize enhanced orchestrator orchestrator = EnhancedTradingOrchestrator(data_provider) logger.info("Enhanced orchestrator initialized")
# Test basic functionality
logger.info("Testing orchestrator functionality...")
# Test market state creation
for symbol in config.symbols[:1]: # Test with first symbol only
logger.info(f"Testing with symbol: {symbol}")
# Test basic orchestrator methods logger.info("Testing timeframe weights...") weights = orchestrator._initialize_timeframe_weights() logger.info(f"Timeframe weights: {weights}") logger.info("Testing correlation matrix...") correlations = orchestrator._initialize_correlation_matrix() logger.info(f"Symbol correlations: {correlations}")
# Test basic functionality logger.info("Basic orchestrator functionality tested successfully")
break # Test with one symbol only
logger.info("=== ENHANCED SYSTEM TEST COMPLETED SUCCESSFULLY ===")
return True
except Exception as e:
logger.error(f"Test failed: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = asyncio.run(test_enhanced_system())
if success:
print("\n✅ Enhanced system test PASSED")
else:
print("\n❌ Enhanced system test FAILED")

View File

@ -0,0 +1,566 @@
"""
Enhanced CNN Trainer with Perfect Move Learning
This trainer implements:
1. Training on marked perfect moves with known outcomes
2. Multi-timeframe CNN model training with confidence scoring
3. Backpropagation on optimal moves when future outcomes are known
4. Progressive learning from real trading experience
5. Symbol-specific and timeframe-specific model fine-tuning
"""
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from core.config import get_config
from core.data_provider import DataProvider
from core.enhanced_orchestrator import PerfectMove, EnhancedTradingOrchestrator
from models import CNNModelInterface
import models
logger = logging.getLogger(__name__)
class PerfectMoveDataset(Dataset):
"""Dataset for training on perfect moves with known outcomes"""
def __init__(self, perfect_moves: List[PerfectMove], data_provider: DataProvider):
"""
Initialize dataset from perfect moves
Args:
perfect_moves: List of perfect moves with known outcomes
data_provider: Data provider to fetch additional context
"""
self.perfect_moves = perfect_moves
self.data_provider = data_provider
self.samples = []
self._prepare_samples()
def _prepare_samples(self):
"""Prepare training samples from perfect moves"""
logger.info(f"Preparing {len(self.perfect_moves)} perfect move samples")
for move in self.perfect_moves:
try:
# Get feature matrix at the time of the decision
feature_matrix = self.data_provider.get_feature_matrix(
symbol=move.symbol,
timeframes=[move.timeframe],
window_size=20,
end_time=move.timestamp
)
if feature_matrix is not None:
# Convert optimal action to label
action_to_label = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
label = action_to_label.get(move.optimal_action, 1)
# Create confidence target (what confidence should have been)
confidence_target = move.confidence_should_have_been
sample = {
'features': feature_matrix,
'action_label': label,
'confidence_target': confidence_target,
'symbol': move.symbol,
'timeframe': move.timeframe,
'outcome': move.actual_outcome,
'timestamp': move.timestamp
}
self.samples.append(sample)
except Exception as e:
logger.warning(f"Error preparing sample for perfect move: {e}")
logger.info(f"Prepared {len(self.samples)} valid training samples")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
# Convert to tensors
features = torch.FloatTensor(sample['features'])
action_label = torch.LongTensor([sample['action_label']])
confidence_target = torch.FloatTensor([sample['confidence_target']])
return {
'features': features,
'action_label': action_label,
'confidence_target': confidence_target,
'metadata': {
'symbol': sample['symbol'],
'timeframe': sample['timeframe'],
'outcome': sample['outcome'],
'timestamp': sample['timestamp']
}
}
class EnhancedCNNModel(nn.Module, CNNModelInterface):
"""Enhanced CNN model with timeframe-specific predictions and confidence scoring"""
def __init__(self, config: Dict[str, Any]):
nn.Module.__init__(self)
CNNModelInterface.__init__(self, config)
self.timeframes = config.get('timeframes', ['1h', '4h', '1d'])
self.n_features = len(config.get('features', ['open', 'high', 'low', 'close', 'volume']))
self.window_size = config.get('window_size', 20)
# Build the neural network
self._build_network()
# Initialize device
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.to(self.device)
# Training components
self.optimizer = optim.Adam(self.parameters(), lr=config.get('learning_rate', 0.001))
self.action_criterion = nn.CrossEntropyLoss()
self.confidence_criterion = nn.MSELoss()
logger.info(f"Enhanced CNN model initialized for {len(self.timeframes)} timeframes")
def _build_network(self):
"""Build the CNN architecture"""
# Convolutional feature extraction
self.conv_layers = nn.Sequential(
# First conv block
nn.Conv1d(self.n_features, 64, kernel_size=3, padding=1),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Dropout(0.2),
# Second conv block
nn.Conv1d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(0.2),
# Third conv block
nn.Conv1d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.2),
# Global average pooling
nn.AdaptiveAvgPool1d(1)
)
# Timeframe-specific heads
self.timeframe_heads = nn.ModuleDict()
for timeframe in self.timeframes:
self.timeframe_heads[timeframe] = nn.Sequential(
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(0.3)
)
# Action prediction heads (one per timeframe)
self.action_heads = nn.ModuleDict()
for timeframe in self.timeframes:
self.action_heads[timeframe] = nn.Linear(64, 3) # BUY, HOLD, SELL
# Confidence prediction heads (one per timeframe)
self.confidence_heads = nn.ModuleDict()
for timeframe in self.timeframes:
self.confidence_heads[timeframe] = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1),
nn.Sigmoid() # Output between 0 and 1
)
def forward(self, x, timeframe: str = None):
"""
Forward pass through the network
Args:
x: Input tensor [batch_size, window_size, features]
timeframe: Specific timeframe to predict for
Returns:
action_probs: Action probabilities
confidence: Confidence score
"""
# Reshape for conv1d: [batch, features, sequence]
x = x.transpose(1, 2)
# Extract features
features = self.conv_layers(x) # [batch, 256, 1]
features = features.squeeze(-1) # [batch, 256]
if timeframe and timeframe in self.timeframe_heads:
# Timeframe-specific prediction
tf_features = self.timeframe_heads[timeframe](features)
action_logits = self.action_heads[timeframe](tf_features)
confidence = self.confidence_heads[timeframe](tf_features)
action_probs = torch.softmax(action_logits, dim=1)
return action_probs, confidence.squeeze(-1)
else:
# Multi-timeframe prediction (average across timeframes)
all_action_probs = []
all_confidences = []
for tf in self.timeframes:
tf_features = self.timeframe_heads[tf](features)
action_logits = self.action_heads[tf](tf_features)
confidence = self.confidence_heads[tf](tf_features)
action_probs = torch.softmax(action_logits, dim=1)
all_action_probs.append(action_probs)
all_confidences.append(confidence.squeeze(-1))
# Average predictions across timeframes
avg_action_probs = torch.stack(all_action_probs).mean(dim=0)
avg_confidence = torch.stack(all_confidences).mean(dim=0)
return avg_action_probs, avg_confidence
def predict(self, features: np.ndarray) -> Tuple[np.ndarray, float]:
"""Predict action probabilities and confidence"""
self.eval()
with torch.no_grad():
x = torch.FloatTensor(features).to(self.device)
if len(x.shape) == 2:
x = x.unsqueeze(0) # Add batch dimension
action_probs, confidence = self.forward(x)
return action_probs[0].cpu().numpy(), confidence[0].cpu().item()
def predict_timeframe(self, features: np.ndarray, timeframe: str) -> Tuple[np.ndarray, float]:
"""Predict for specific timeframe"""
self.eval()
with torch.no_grad():
x = torch.FloatTensor(features).to(self.device)
if len(x.shape) == 2:
x = x.unsqueeze(0) # Add batch dimension
action_probs, confidence = self.forward(x, timeframe)
return action_probs[0].cpu().numpy(), confidence[0].cpu().item()
def get_memory_usage(self) -> int:
"""Get memory usage in MB"""
if torch.cuda.is_available():
return torch.cuda.memory_allocated(self.device) // (1024 * 1024)
else:
# Rough estimate for CPU
param_count = sum(p.numel() for p in self.parameters())
return (param_count * 4) // (1024 * 1024) # 4 bytes per float32
def train(self, training_data: Dict[str, Any]) -> Dict[str, Any]:
"""Train the model (placeholder for interface compatibility)"""
return {}
class EnhancedCNNTrainer:
"""Enhanced CNN trainer using perfect moves and real market outcomes"""
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
"""Initialize the enhanced trainer"""
self.config = config or get_config()
self.orchestrator = orchestrator
self.data_provider = DataProvider(self.config)
# Training parameters
self.learning_rate = self.config.training.get('learning_rate', 0.001)
self.batch_size = self.config.training.get('batch_size', 32)
self.epochs = self.config.training.get('epochs', 100)
self.patience = self.config.training.get('early_stopping_patience', 10)
# Model
self.model = EnhancedCNNModel(self.config.cnn)
# Training history
self.training_history = {
'train_loss': [],
'val_loss': [],
'train_accuracy': [],
'val_accuracy': [],
'confidence_accuracy': []
} # Create save directory models_path = self.config.cnn.get('model_dir', "models/enhanced_cnn") self.save_dir = Path(models_path) self.save_dir.mkdir(parents=True, exist_ok=True) logger.info("Enhanced CNN trainer initialized")
def train_on_perfect_moves(self, min_samples: int = 100) -> Dict[str, Any]:
"""Train the model on perfect moves from the orchestrator"""
if not self.orchestrator:
raise ValueError("Orchestrator required for perfect move training")
# Get perfect moves from orchestrator
perfect_moves = []
for symbol in self.config.symbols:
symbol_moves = self.orchestrator.get_perfect_moves_for_training(symbol=symbol)
perfect_moves.extend(symbol_moves)
if len(perfect_moves) < min_samples:
logger.warning(f"Not enough perfect moves for training: {len(perfect_moves)} < {min_samples}")
return {'error': 'insufficient_data', 'samples': len(perfect_moves)}
logger.info(f"Training on {len(perfect_moves)} perfect moves")
# Create dataset
dataset = PerfectMoveDataset(perfect_moves, self.data_provider)
# Split into train/validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
# Training loop
best_val_loss = float('inf')
patience_counter = 0
for epoch in range(self.epochs):
# Training phase
train_loss, train_acc = self._train_epoch(train_loader)
# Validation phase
val_loss, val_acc, conf_acc = self._validate_epoch(val_loader)
# Update history
self.training_history['train_loss'].append(train_loss)
self.training_history['val_loss'].append(val_loss)
self.training_history['train_accuracy'].append(train_acc)
self.training_history['val_accuracy'].append(val_acc)
self.training_history['confidence_accuracy'].append(conf_acc)
# Log progress
logger.info(f"Epoch {epoch+1}/{self.epochs}: "
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
f"Conf Acc: {conf_acc:.4f}")
# Early stopping
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
self._save_model('best_model.pt')
else:
patience_counter += 1
if patience_counter >= self.patience:
logger.info(f"Early stopping at epoch {epoch+1}")
break
# Save final model
self._save_model('final_model.pt')
# Generate training report
return self._generate_training_report()
def _train_epoch(self, train_loader: DataLoader) -> Tuple[float, float]:
"""Train for one epoch"""
self.model.train()
total_loss = 0.0
correct_predictions = 0
total_predictions = 0
for batch in train_loader:
features = batch['features'].to(self.model.device)
action_labels = batch['action_label'].to(self.model.device).squeeze(-1)
confidence_targets = batch['confidence_target'].to(self.model.device).squeeze(-1)
# Zero gradients
self.model.optimizer.zero_grad()
# Forward pass
action_probs, confidence_pred = self.model(features)
# Calculate losses
action_loss = self.model.action_criterion(action_probs, action_labels)
confidence_loss = self.model.confidence_criterion(confidence_pred, confidence_targets)
# Combined loss
total_loss_batch = action_loss + 0.5 * confidence_loss
# Backward pass
total_loss_batch.backward()
self.model.optimizer.step()
# Track metrics
total_loss += total_loss_batch.item()
predicted_actions = torch.argmax(action_probs, dim=1)
correct_predictions += (predicted_actions == action_labels).sum().item()
total_predictions += action_labels.size(0)
avg_loss = total_loss / len(train_loader)
accuracy = correct_predictions / total_predictions
return avg_loss, accuracy
def _validate_epoch(self, val_loader: DataLoader) -> Tuple[float, float, float]:
"""Validate for one epoch"""
self.model.eval()
total_loss = 0.0
correct_predictions = 0
total_predictions = 0
confidence_errors = []
with torch.no_grad():
for batch in val_loader:
features = batch['features'].to(self.model.device)
action_labels = batch['action_label'].to(self.model.device).squeeze(-1)
confidence_targets = batch['confidence_target'].to(self.model.device).squeeze(-1)
# Forward pass
action_probs, confidence_pred = self.model(features)
# Calculate losses
action_loss = self.model.action_criterion(action_probs, action_labels)
confidence_loss = self.model.confidence_criterion(confidence_pred, confidence_targets)
total_loss_batch = action_loss + 0.5 * confidence_loss
# Track metrics
total_loss += total_loss_batch.item()
predicted_actions = torch.argmax(action_probs, dim=1)
correct_predictions += (predicted_actions == action_labels).sum().item()
total_predictions += action_labels.size(0)
# Track confidence accuracy
conf_errors = torch.abs(confidence_pred - confidence_targets)
confidence_errors.extend(conf_errors.cpu().numpy())
avg_loss = total_loss / len(val_loader)
accuracy = correct_predictions / total_predictions
confidence_accuracy = 1.0 - np.mean(confidence_errors) # 1 - mean absolute error
return avg_loss, accuracy, confidence_accuracy
def _save_model(self, filename: str):
"""Save the model"""
save_path = self.save_dir / filename
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.model.optimizer.state_dict(),
'config': self.config.cnn,
'training_history': self.training_history
}, save_path)
logger.info(f"Model saved to {save_path}")
def load_model(self, filename: str) -> bool:
"""Load a saved model"""
load_path = self.save_dir / filename
if not load_path.exists():
logger.error(f"Model file not found: {load_path}")
return False
try:
checkpoint = torch.load(load_path, map_location=self.model.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.training_history = checkpoint.get('training_history', {})
logger.info(f"Model loaded from {load_path}")
return True
except Exception as e:
logger.error(f"Error loading model: {e}")
return False
def _generate_training_report(self) -> Dict[str, Any]:
"""Generate comprehensive training report"""
if not self.training_history['train_loss']:
return {'error': 'no_training_data'}
# Calculate final metrics
final_train_loss = self.training_history['train_loss'][-1]
final_val_loss = self.training_history['val_loss'][-1]
final_train_acc = self.training_history['train_accuracy'][-1]
final_val_acc = self.training_history['val_accuracy'][-1]
final_conf_acc = self.training_history['confidence_accuracy'][-1]
# Best metrics
best_val_loss = min(self.training_history['val_loss'])
best_val_acc = max(self.training_history['val_accuracy'])
best_conf_acc = max(self.training_history['confidence_accuracy'])
report = {
'training_completed': True,
'epochs_trained': len(self.training_history['train_loss']),
'final_metrics': {
'train_loss': final_train_loss,
'val_loss': final_val_loss,
'train_accuracy': final_train_acc,
'val_accuracy': final_val_acc,
'confidence_accuracy': final_conf_acc
},
'best_metrics': {
'val_loss': best_val_loss,
'val_accuracy': best_val_acc,
'confidence_accuracy': best_conf_acc
},
'model_info': {
'timeframes': self.model.timeframes,
'memory_usage_mb': self.model.get_memory_usage(),
'device': str(self.model.device)
}
}
# Generate plots
self._plot_training_history()
logger.info("Training completed successfully")
logger.info(f"Final validation accuracy: {final_val_acc:.4f}")
logger.info(f"Final confidence accuracy: {final_conf_acc:.4f}")
return report
def _plot_training_history(self):
"""Plot training history"""
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle('Enhanced CNN Training History')
# Loss plot
axes[0, 0].plot(self.training_history['train_loss'], label='Train Loss')
axes[0, 0].plot(self.training_history['val_loss'], label='Val Loss')
axes[0, 0].set_title('Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
# Accuracy plot
axes[0, 1].plot(self.training_history['train_accuracy'], label='Train Accuracy')
axes[0, 1].plot(self.training_history['val_accuracy'], label='Val Accuracy')
axes[0, 1].set_title('Action Accuracy')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
# Confidence accuracy plot
axes[1, 0].plot(self.training_history['confidence_accuracy'], label='Confidence Accuracy')
axes[1, 0].set_title('Confidence Prediction Accuracy')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Accuracy')
axes[1, 0].legend()
# Learning curves comparison
axes[1, 1].plot(self.training_history['val_loss'], label='Validation Loss')
axes[1, 1].plot(self.training_history['confidence_accuracy'], label='Confidence Accuracy')
axes[1, 1].set_title('Model Performance Overview')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].legend()
plt.tight_layout()
plt.savefig(self.save_dir / 'training_history.png', dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"Training plots saved to {self.save_dir / 'training_history.png'}")
def get_model(self) -> EnhancedCNNModel:
"""Get the trained model"""
return self.model

View File

@ -0,0 +1,625 @@
"""
Enhanced RL Trainer with Market Environment Adaptation
This trainer implements:
1. Continuous learning from orchestrator action evaluations
2. Environment adaptation based on market regime changes
3. Multi-symbol coordinated RL training
4. Experience replay with prioritized sampling
5. Dynamic reward shaping based on market conditions
"""
import asyncioimport asyncioimport loggingimport numpy as npimport torchimport torch.nn as nnimport torch.optim as optimfrom collections import deque, namedtupleimport randomfrom datetime import datetime, timedeltafrom typing import Dict, List, Optional, Tuple, Anyimport matplotlib.pyplot as pltfrom pathlib import Path
from core.config import get_config
from core.data_provider import DataProvider
from core.enhanced_orchestrator import EnhancedTradingOrchestrator, MarketState, TradingAction
from models import RLAgentInterface
import models
logger = logging.getLogger(__name__)
# Experience tuple for replay buffer
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done', 'priority'])
class PrioritizedReplayBuffer:
"""Prioritized experience replay buffer for RL training"""
def __init__(self, capacity: int = 10000, alpha: float = 0.6):
"""
Initialize prioritized replay buffer
Args:
capacity: Maximum number of experiences to store
alpha: Priority exponent (0 = uniform, 1 = fully prioritized)
"""
self.capacity = capacity
self.alpha = alpha
self.buffer = []
self.priorities = np.zeros(capacity, dtype=np.float32)
self.position = 0
self.size = 0
def add(self, experience: Experience):
"""Add experience to buffer with priority"""
max_priority = self.priorities[:self.size].max() if self.size > 0 else 1.0
if self.size < self.capacity:
self.buffer.append(experience)
self.size += 1
else:
self.buffer[self.position] = experience
self.priorities[self.position] = max_priority
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size: int, beta: float = 0.4) -> Tuple[List[Experience], np.ndarray, np.ndarray]:
"""Sample batch with prioritized sampling"""
if self.size == 0:
return [], np.array([]), np.array([])
# Calculate sampling probabilities
priorities = self.priorities[:self.size] ** self.alpha
probabilities = priorities / priorities.sum()
# Sample indices
indices = np.random.choice(self.size, batch_size, p=probabilities)
experiences = [self.buffer[i] for i in indices]
# Calculate importance sampling weights
weights = (self.size * probabilities[indices]) ** (-beta)
weights = weights / weights.max() # Normalize
return experiences, indices, weights
def update_priorities(self, indices: np.ndarray, priorities: np.ndarray):
"""Update priorities for sampled experiences"""
for idx, priority in zip(indices, priorities):
self.priorities[idx] = priority + 1e-6 # Small epsilon to avoid zero priority
def __len__(self):
return self.size
class EnhancedDQNAgent(nn.Module, RLAgentInterface):
"""Enhanced DQN agent with market environment adaptation"""
def __init__(self, config: Dict[str, Any]):
nn.Module.__init__(self)
RLAgentInterface.__init__(self, config)
# Network architecture
self.state_size = config.get('state_size', 100)
self.action_space = config.get('action_space', 3)
self.hidden_size = config.get('hidden_size', 256)
# Build networks
self._build_networks()
# Training parameters
self.learning_rate = config.get('learning_rate', 0.0001)
self.gamma = config.get('gamma', 0.99)
self.epsilon = config.get('epsilon', 1.0)
self.epsilon_decay = config.get('epsilon_decay', 0.995)
self.epsilon_min = config.get('epsilon_min', 0.01)
self.target_update_freq = config.get('target_update_freq', 1000)
# Initialize device and optimizer
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.to(self.device)
self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
# Experience replay
self.replay_buffer = PrioritizedReplayBuffer(config.get('buffer_size', 10000))
self.batch_size = config.get('batch_size', 64)
# Market adaptation
self.market_regime_weights = {
'trending': 1.2, # Higher confidence in trending markets
'ranging': 0.8, # Lower confidence in ranging markets
'volatile': 0.6 # Much lower confidence in volatile markets
}
# Training statistics
self.training_steps = 0
self.losses = []
self.rewards = []
self.epsilon_history = []
logger.info(f"Enhanced DQN agent initialized with state size: {self.state_size}")
def _build_networks(self):
"""Build main and target networks"""
# Main network
self.main_network = nn.Sequential(
nn.Linear(self.state_size, self.hidden_size),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(self.hidden_size, self.hidden_size),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(self.hidden_size, 128),
nn.ReLU(),
nn.Dropout(0.2)
)
# Dueling network heads
self.value_head = nn.Linear(128, 1)
self.advantage_head = nn.Linear(128, self.action_space)
# Target network (copy of main network)
self.target_network = nn.Sequential(
nn.Linear(self.state_size, self.hidden_size),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(self.hidden_size, self.hidden_size),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(self.hidden_size, 128),
nn.ReLU(),
nn.Dropout(0.2)
)
self.target_value_head = nn.Linear(128, 1)
self.target_advantage_head = nn.Linear(128, self.action_space)
# Initialize target network with same weights
self._update_target_network()
def forward(self, state, target: bool = False):
"""Forward pass through the network"""
if target:
features = self.target_network(state)
value = self.target_value_head(features)
advantage = self.target_advantage_head(features)
else:
features = self.main_network(state)
value = self.value_head(features)
advantage = self.advantage_head(features)
# Dueling architecture: Q(s,a) = V(s) + A(s,a) - mean(A(s,a))
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
return q_values
def act(self, state: np.ndarray) -> int:
"""Choose action using epsilon-greedy policy"""
if random.random() < self.epsilon:
return random.randint(0, self.action_space - 1)
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
q_values = self.forward(state_tensor)
return q_values.argmax().item()
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]:
"""Choose action with confidence score adapted to market regime"""
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
q_values = self.forward(state_tensor)
# Convert Q-values to probabilities
action_probs = torch.softmax(q_values, dim=1)
action = q_values.argmax().item()
base_confidence = action_probs[0, action].item()
# Adapt confidence based on market regime
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
adapted_confidence = min(base_confidence * regime_weight, 1.0)
return action, adapted_confidence
def remember(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool):
"""Store experience in replay buffer"""
# Calculate TD error for priority
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)
current_q = self.forward(state_tensor)[0, action]
next_q = self.forward(next_state_tensor, target=True).max(1)[0]
target_q = reward + (self.gamma * next_q * (1 - done))
td_error = abs(current_q.item() - target_q.item())
experience = Experience(state, action, reward, next_state, done, td_error)
self.replay_buffer.add(experience)
def replay(self) -> Optional[float]:
"""Train the network on a batch of experiences"""
if len(self.replay_buffer) < self.batch_size:
return None
# Sample batch
experiences, indices, weights = self.replay_buffer.sample(self.batch_size)
if not experiences:
return None
# Convert to tensors
states = torch.FloatTensor([e.state for e in experiences]).to(self.device)
actions = torch.LongTensor([e.action for e in experiences]).to(self.device)
rewards = torch.FloatTensor([e.reward for e in experiences]).to(self.device)
next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device)
dones = torch.BoolTensor([e.done for e in experiences]).to(self.device)
weights_tensor = torch.FloatTensor(weights).to(self.device)
# Current Q-values
current_q_values = self.forward(states).gather(1, actions.unsqueeze(1))
# Target Q-values (Double DQN)
with torch.no_grad():
# Use main network to select actions
next_actions = self.forward(next_states).argmax(1)
# Use target network to evaluate actions
next_q_values = self.forward(next_states, target=True).gather(1, next_actions.unsqueeze(1))
target_q_values = rewards.unsqueeze(1) + (self.gamma * next_q_values * ~dones.unsqueeze(1))
# Calculate weighted loss
td_errors = target_q_values - current_q_values
loss = (weights_tensor * (td_errors ** 2)).mean()
# Optimize
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
self.optimizer.step()
# Update priorities
new_priorities = torch.abs(td_errors).detach().cpu().numpy().flatten()
self.replay_buffer.update_priorities(indices, new_priorities)
# Update target network
self.training_steps += 1
if self.training_steps % self.target_update_freq == 0:
self._update_target_network()
# Decay epsilon
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
# Track statistics
self.losses.append(loss.item())
self.epsilon_history.append(self.epsilon)
return loss.item()
def _update_target_network(self):
"""Update target network with main network weights"""
self.target_network.load_state_dict(self.main_network.state_dict())
self.target_value_head.load_state_dict(self.value_head.state_dict())
self.target_advantage_head.load_state_dict(self.advantage_head.state_dict())
def predict(self, features: np.ndarray) -> Tuple[np.ndarray, float]: """Predict action probabilities and confidence (required by ModelInterface)""" action, confidence = self.act_with_confidence(features) # Convert action to probabilities action_probs = np.zeros(self.action_space) action_probs[action] = 1.0 return action_probs, confidence def get_memory_usage(self) -> int: """Get memory usage in MB""" if torch.cuda.is_available(): return torch.cuda.memory_allocated(self.device) // (1024 * 1024) else: param_count = sum(p.numel() for p in self.parameters()) buffer_size = len(self.replay_buffer) * self.state_size * 4 # Rough estimate return (param_count * 4 + buffer_size) // (1024 * 1024)
class EnhancedRLTrainer:
"""Enhanced RL trainer with continuous learning from market feedback"""
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
"""Initialize the enhanced RL trainer"""
self.config = config or get_config()
self.orchestrator = orchestrator
self.data_provider = DataProvider(self.config)
# Create RL agents for each symbol
self.agents = {}
for symbol in self.config.symbols:
agent_config = self.config.rl.copy()
agent_config['name'] = f'RL_{symbol}'
self.agents[symbol] = EnhancedDQNAgent(agent_config)
# Training parameters
self.training_interval = 3600 # Train every hour
self.evaluation_window = 24 * 3600 # Evaluate actions after 24 hours
self.min_experiences = 100 # Minimum experiences before training
# Performance tracking
self.performance_history = {symbol: [] for symbol in self.config.symbols}
self.training_metrics = {
'total_episodes': 0,
'total_rewards': {symbol: [] for symbol in self.config.symbols},
'losses': {symbol: [] for symbol in self.config.symbols},
'epsilon_values': {symbol: [] for symbol in self.config.symbols}
}
# Create save directory models_path = self.config.rl.get('model_dir', "models/enhanced_rl") self.save_dir = Path(models_path) self.save_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Enhanced RL trainer initialized for symbols: {self.config.symbols}")
async def continuous_learning_loop(self):
"""Main continuous learning loop"""
logger.info("Starting continuous RL learning loop")
while True:
try:
# Train agents with recent experiences
await self._train_all_agents()
# Evaluate recent actions
if self.orchestrator:
await self.orchestrator.evaluate_actions_with_rl()
# Adapt to market regime changes
await self._adapt_to_market_changes()
# Update performance metrics
self._update_performance_metrics()
# Save models periodically
if self.training_metrics['total_episodes'] % 100 == 0:
self._save_all_models()
# Wait before next training cycle
await asyncio.sleep(self.training_interval)
except Exception as e:
logger.error(f"Error in continuous learning loop: {e}")
await asyncio.sleep(60) # Wait 1 minute on error
async def _train_all_agents(self):
"""Train all RL agents with their experiences"""
for symbol, agent in self.agents.items():
try:
if len(agent.replay_buffer) >= self.min_experiences:
# Train for multiple steps
losses = []
for _ in range(10): # Train 10 steps per cycle
loss = agent.replay()
if loss is not None:
losses.append(loss)
if losses:
avg_loss = np.mean(losses)
self.training_metrics['losses'][symbol].append(avg_loss)
self.training_metrics['epsilon_values'][symbol].append(agent.epsilon)
logger.info(f"Trained {symbol} RL agent: Loss={avg_loss:.4f}, Epsilon={agent.epsilon:.4f}")
except Exception as e:
logger.error(f"Error training {symbol} agent: {e}")
async def _adapt_to_market_changes(self):
"""Adapt agents to market regime changes"""
if not self.orchestrator:
return
for symbol in self.config.symbols:
try:
# Get recent market states
recent_states = list(self.orchestrator.market_states[symbol])[-10:] # Last 10 states
if len(recent_states) < 5:
continue
# Analyze regime stability
regimes = [state.market_regime for state in recent_states]
regime_stability = len(set(regimes)) / len(regimes) # Lower = more stable
# Adjust learning parameters based on stability
agent = self.agents[symbol]
if regime_stability < 0.3: # Stable regime
agent.epsilon *= 0.99 # Faster epsilon decay
elif regime_stability > 0.7: # Unstable regime
agent.epsilon = min(agent.epsilon * 1.01, 0.5) # Increase exploration
logger.debug(f"{symbol} regime stability: {regime_stability:.3f}, epsilon: {agent.epsilon:.3f}")
except Exception as e:
logger.error(f"Error adapting {symbol} to market changes: {e}")
def add_trading_experience(self, symbol: str, action: TradingAction,
initial_state: MarketState, final_state: MarketState,
reward: float):
"""Add trading experience to the appropriate agent"""
if symbol not in self.agents:
logger.warning(f"No agent for symbol {symbol}")
return
try:
# Convert market states to RL state vectors
initial_rl_state = self._market_state_to_rl_state(initial_state)
final_rl_state = self._market_state_to_rl_state(final_state)
# Convert action to RL action index
action_mapping = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
action_idx = action_mapping.get(action.action, 1)
# Store experience
agent = self.agents[symbol]
agent.remember(
state=initial_rl_state,
action=action_idx,
reward=reward,
next_state=final_rl_state,
done=False
)
# Track reward
self.training_metrics['total_rewards'][symbol].append(reward)
logger.debug(f"Added experience for {symbol}: action={action.action}, reward={reward:.4f}")
except Exception as e:
logger.error(f"Error adding experience for {symbol}: {e}")
def _market_state_to_rl_state(self, market_state: MarketState) -> np.ndarray:
"""Convert market state to RL state vector"""
if hasattr(self.orchestrator, '_market_state_to_rl_state'):
return self.orchestrator._market_state_to_rl_state(market_state)
# Fallback implementation
state_components = [
market_state.volatility,
market_state.volume,
market_state.trend_strength
]
# Add price features
for timeframe in sorted(market_state.prices.keys()):
state_components.append(market_state.prices[timeframe])
# Pad or truncate to expected state size
expected_size = self.config.rl.get('state_size', 100)
if len(state_components) < expected_size:
state_components.extend([0.0] * (expected_size - len(state_components)))
else:
state_components = state_components[:expected_size]
return np.array(state_components, dtype=np.float32)
def _update_performance_metrics(self):
"""Update performance tracking metrics"""
self.training_metrics['total_episodes'] += 1
# Calculate recent performance for each agent
for symbol, agent in self.agents.items():
recent_rewards = self.training_metrics['total_rewards'][symbol][-100:] # Last 100 rewards
if recent_rewards:
avg_reward = np.mean(recent_rewards)
self.performance_history[symbol].append({
'timestamp': datetime.now(),
'avg_reward': avg_reward,
'epsilon': agent.epsilon,
'experiences': len(agent.replay_buffer)
})
def _save_all_models(self):
"""Save all RL models"""
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
for symbol, agent in self.agents.items():
filename = f"rl_agent_{symbol}_{timestamp}.pt"
filepath = self.save_dir / filename
torch.save({
'model_state_dict': agent.state_dict(),
'optimizer_state_dict': agent.optimizer.state_dict(),
'config': self.config.rl,
'training_metrics': self.training_metrics,
'symbol': symbol,
'epsilon': agent.epsilon,
'training_steps': agent.training_steps
}, filepath)
logger.info(f"Saved {symbol} RL agent to {filepath}")
def load_models(self, timestamp: str = None):
"""Load RL models from files"""
if timestamp is None:
# Find most recent models
model_files = list(self.save_dir.glob("rl_agent_*.pt"))
if not model_files:
logger.warning("No saved RL models found")
return False
# Group by timestamp and get most recent
timestamps = set(f.stem.split('_')[-2] + '_' + f.stem.split('_')[-1] for f in model_files)
timestamp = max(timestamps)
loaded_count = 0
for symbol in self.config.symbols:
filename = f"rl_agent_{symbol}_{timestamp}.pt"
filepath = self.save_dir / filename
if filepath.exists():
try:
checkpoint = torch.load(filepath, map_location=self.agents[symbol].device)
self.agents[symbol].load_state_dict(checkpoint['model_state_dict'])
self.agents[symbol].optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.agents[symbol].epsilon = checkpoint.get('epsilon', 0.1)
self.agents[symbol].training_steps = checkpoint.get('training_steps', 0)
logger.info(f"Loaded {symbol} RL agent from {filepath}")
loaded_count += 1
except Exception as e:
logger.error(f"Error loading {symbol} RL agent: {e}")
return loaded_count > 0
def get_performance_report(self) -> Dict[str, Any]:
"""Generate performance report for all agents"""
report = {
'total_episodes': self.training_metrics['total_episodes'],
'agents': {}
}
for symbol, agent in self.agents.items():
recent_rewards = self.training_metrics['total_rewards'][symbol][-100:]
recent_losses = self.training_metrics['losses'][symbol][-10:]
agent_report = {
'symbol': symbol,
'epsilon': agent.epsilon,
'training_steps': agent.training_steps,
'experiences_stored': len(agent.replay_buffer),
'memory_usage_mb': agent.get_memory_usage(),
'avg_recent_reward': np.mean(recent_rewards) if recent_rewards else 0.0,
'avg_recent_loss': np.mean(recent_losses) if recent_losses else 0.0,
'total_rewards': len(self.training_metrics['total_rewards'][symbol])
}
report['agents'][symbol] = agent_report
return report
def plot_training_metrics(self):
"""Plot training metrics for all agents"""
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Enhanced RL Training Metrics')
symbols = list(self.agents.keys())
colors = ['blue', 'red', 'green', 'orange'][:len(symbols)]
# Rewards plot
for i, symbol in enumerate(symbols):
rewards = self.training_metrics['total_rewards'][symbol]
if rewards:
# Moving average of rewards
window = min(100, len(rewards))
if len(rewards) >= window:
moving_avg = np.convolve(rewards, np.ones(window)/window, mode='valid')
axes[0, 0].plot(moving_avg, label=f'{symbol}', color=colors[i])
axes[0, 0].set_title('Average Rewards (Moving Average)')
axes[0, 0].set_xlabel('Episodes')
axes[0, 0].set_ylabel('Reward')
axes[0, 0].legend()
# Losses plot
for i, symbol in enumerate(symbols):
losses = self.training_metrics['losses'][symbol]
if losses:
axes[0, 1].plot(losses, label=f'{symbol}', color=colors[i])
axes[0, 1].set_title('Training Losses')
axes[0, 1].set_xlabel('Training Steps')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].legend()
# Epsilon values
for i, symbol in enumerate(symbols):
epsilon_values = self.training_metrics['epsilon_values'][symbol]
if epsilon_values:
axes[1, 0].plot(epsilon_values, label=f'{symbol}', color=colors[i])
axes[1, 0].set_title('Exploration Rate (Epsilon)')
axes[1, 0].set_xlabel('Training Steps')
axes[1, 0].set_ylabel('Epsilon')
axes[1, 0].legend()
# Experience buffer sizes
buffer_sizes = [len(agent.replay_buffer) for agent in self.agents.values()]
axes[1, 1].bar(symbols, buffer_sizes, color=colors[:len(symbols)])
axes[1, 1].set_title('Experience Buffer Sizes')
axes[1, 1].set_ylabel('Number of Experiences')
plt.tight_layout()
plt.savefig(self.save_dir / 'rl_training_metrics.png', dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"RL training plots saved to {self.save_dir / 'rl_training_metrics.png'}")
def get_agents(self) -> Dict[str, EnhancedDQNAgent]:
"""Get all RL agents"""
return self.agents