wip on the RL training pipeline and data collection

This commit is contained in:
Dobromir Popov
2025-05-29 14:08:14 +03:00
parent 6b7d7aec81
commit 3f4e9b9774
18 changed files with 6154 additions and 3446 deletions

View File

@ -41,6 +41,33 @@ from core.data_provider import DataProvider
from core.orchestrator import TradingOrchestrator, TradingDecision
from core.trading_executor import TradingExecutor
# Enhanced RL Training Integration
try:
from core.unified_data_stream import UnifiedDataStream, TrainingDataPacket, UIDataPacket
from core.enhanced_orchestrator import EnhancedTradingOrchestrator, MarketState, TradingAction
from training.enhanced_rl_trainer import EnhancedRLTrainer
ENHANCED_RL_AVAILABLE = True
logger = logging.getLogger(__name__)
logger.info("Enhanced RL training components available")
except ImportError as e:
ENHANCED_RL_AVAILABLE = False
logger = logging.getLogger(__name__)
logger.warning(f"Enhanced RL training not available: {e}")
# Fallback classes
class UnifiedDataStream:
def __init__(self, *args, **kwargs): pass
def register_consumer(self, *args, **kwargs): return "fallback_consumer"
def start_streaming(self): pass
def stop_streaming(self): pass
def get_latest_training_data(self): return None
def get_latest_ui_data(self): return None
class TrainingDataPacket:
def __init__(self, *args, **kwargs): pass
class UIDataPacket:
def __init__(self, *args, **kwargs): pass
# Try to import model registry, fallback if not available
try:
from models import get_model_registry
@ -73,16 +100,40 @@ except ImportError:
logger = logging.getLogger(__name__)
class TradingDashboard:
"""Modern trading dashboard with real-time updates"""
"""Modern trading dashboard with real-time updates and enhanced RL training integration"""
def __init__(self, data_provider: DataProvider = None, orchestrator: TradingOrchestrator = None, trading_executor: TradingExecutor = None):
"""Initialize the dashboard"""
"""Initialize the dashboard with unified data stream and enhanced RL training"""
self.config = get_config()
self.data_provider = data_provider or DataProvider()
self.orchestrator = orchestrator or TradingOrchestrator(self.data_provider)
# Enhanced orchestrator support
if ENHANCED_RL_AVAILABLE and isinstance(orchestrator, EnhancedTradingOrchestrator):
self.orchestrator = orchestrator
self.enhanced_rl_enabled = True
logger.info("Enhanced RL training orchestrator detected")
else:
self.orchestrator = orchestrator or TradingOrchestrator(self.data_provider)
self.enhanced_rl_enabled = False
logger.info("Using standard orchestrator")
self.trading_executor = trading_executor or TradingExecutor()
self.model_registry = get_model_registry()
# Initialize unified data stream for comprehensive training data
if ENHANCED_RL_AVAILABLE:
self.unified_stream = UnifiedDataStream(self.data_provider, self.orchestrator)
self.stream_consumer_id = self.unified_stream.register_consumer(
consumer_name="TradingDashboard",
callback=self._handle_unified_stream_data,
data_types=['ticks', 'ohlcv', 'training_data', 'ui_data']
)
logger.info(f"Unified data stream initialized with consumer ID: {self.stream_consumer_id}")
else:
self.unified_stream = UnifiedDataStream() # Fallback
self.stream_consumer_id = "fallback"
logger.warning("Using fallback unified data stream")
# Dashboard state
self.recent_decisions = []
self.recent_signals = [] # Track all signals (not just executed trades)
@ -126,21 +177,29 @@ class TradingDashboard:
self.ws_thread = None
self.is_streaming = False
# Load available models for real trading
self._load_available_models()
# RL Training System - Train on closed trades
# Enhanced RL Training System - Train on closed trades with comprehensive data
self.rl_training_enabled = True
self.enhanced_rl_training_enabled = ENHANCED_RL_AVAILABLE and self.enhanced_rl_enabled
self.rl_training_stats = {
'total_training_episodes': 0,
'profitable_trades_trained': 0,
'unprofitable_trades_trained': 0,
'last_training_time': None,
'training_rewards': deque(maxlen=100), # Last 100 training rewards
'model_accuracy_trend': deque(maxlen=50) # Track accuracy over time
'model_accuracy_trend': deque(maxlen=50), # Track accuracy over time
'enhanced_rl_episodes': 0,
'comprehensive_data_packets': 0
}
self.rl_training_queue = deque(maxlen=1000) # Queue of trades to train on
# Enhanced training data tracking
self.latest_training_data = None
self.latest_ui_data = None
self.training_data_available = False
# Load available models for real trading
self._load_available_models()
# Create Dash app
self.app = dash.Dash(__name__, external_stylesheets=[
'https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css',
@ -151,13 +210,244 @@ class TradingDashboard:
self._setup_layout()
self._setup_callbacks()
# Start WebSocket tick streaming
self._start_websocket_stream()
# Start unified data streaming
self._initialize_streaming()
# Start continuous training
# Start continuous training with enhanced RL support
self.start_continuous_training()
logger.info("Trading Dashboard initialized with continuous training")
logger.info("Trading Dashboard initialized with enhanced RL training integration")
logger.info(f"Enhanced RL enabled: {self.enhanced_rl_training_enabled}")
logger.info(f"Stream consumer ID: {self.stream_consumer_id}")
def _initialize_streaming(self):
"""Initialize unified data streaming and WebSocket fallback"""
try:
if ENHANCED_RL_AVAILABLE:
# Start unified data stream
asyncio.run(self.unified_stream.start_streaming())
logger.info("Unified data stream started")
# Start WebSocket as backup/additional data source
self._start_websocket_stream()
# Start background data collection
self._start_enhanced_training_data_collection()
logger.info("All data streaming initialized")
except Exception as e:
logger.error(f"Error initializing streaming: {e}")
# Fallback to WebSocket only
self._start_websocket_stream()
def _start_enhanced_training_data_collection(self):
"""Start enhanced training data collection using unified stream"""
def enhanced_training_loop():
try:
logger.info("Enhanced training data collection started with unified stream")
while True:
try:
if ENHANCED_RL_AVAILABLE and self.enhanced_rl_training_enabled:
# Get latest comprehensive training data from unified stream
training_data = self.unified_stream.get_latest_training_data()
if training_data:
# Send comprehensive training data to enhanced RL pipeline
self._send_comprehensive_training_data_to_enhanced_rl(training_data)
# Update training statistics
self.rl_training_stats['comprehensive_data_packets'] += 1
self.training_data_available = True
# Update context data in orchestrator
if hasattr(self.orchestrator, 'update_context_data'):
self.orchestrator.update_context_data()
# Initialize extrema trainer if not done
if hasattr(self.orchestrator, 'extrema_trainer'):
if not hasattr(self.orchestrator.extrema_trainer, '_initialized'):
self.orchestrator.extrema_trainer.initialize_context_data()
self.orchestrator.extrema_trainer._initialized = True
logger.info("Extrema trainer context data initialized")
# Run extrema detection with real data
if hasattr(self.orchestrator, 'extrema_trainer'):
for symbol in self.orchestrator.symbols:
detected = self.orchestrator.extrema_trainer.detect_local_extrema(symbol)
if detected:
logger.debug(f"Detected {len(detected)} extrema for {symbol}")
else:
# Fallback to basic training data collection
self._collect_basic_training_data()
time.sleep(10) # Update every 10 seconds for enhanced training
except Exception as e:
logger.error(f"Error in enhanced training loop: {e}")
time.sleep(30) # Wait before retrying
except Exception as e:
logger.error(f"Enhanced training loop failed: {e}")
# Start enhanced training thread
training_thread = Thread(target=enhanced_training_loop, daemon=True)
training_thread.start()
logger.info("Enhanced training data collection thread started")
def _handle_unified_stream_data(self, data_packet: Dict[str, Any]):
"""Handle data from unified stream for dashboard and training"""
try:
# Extract UI data for dashboard display
if 'ui_data' in data_packet:
self.latest_ui_data = data_packet['ui_data']
if hasattr(self.latest_ui_data, 'current_prices'):
self.current_prices.update(self.latest_ui_data.current_prices)
if hasattr(self.latest_ui_data, 'streaming_status'):
self.is_streaming = self.latest_ui_data.streaming_status == 'LIVE'
if hasattr(self.latest_ui_data, 'training_data_available'):
self.training_data_available = self.latest_ui_data.training_data_available
# Extract training data for enhanced RL
if 'training_data' in data_packet:
self.latest_training_data = data_packet['training_data']
logger.debug("Received comprehensive training data from unified stream")
# Extract tick data for dashboard charts
if 'ticks' in data_packet:
ticks = data_packet['ticks']
for tick in ticks[-100:]: # Keep last 100 ticks
self.tick_cache.append(tick)
# Extract OHLCV data for dashboard charts
if 'one_second_bars' in data_packet:
bars = data_packet['one_second_bars']
for bar in bars[-100:]: # Keep last 100 bars
self.one_second_bars.append(bar)
logger.debug(f"Processed unified stream data packet with keys: {list(data_packet.keys())}")
except Exception as e:
logger.error(f"Error handling unified stream data: {e}")
def _send_comprehensive_training_data_to_enhanced_rl(self, training_data: TrainingDataPacket):
"""Send comprehensive training data to enhanced RL training pipeline"""
try:
if not self.enhanced_rl_training_enabled:
logger.debug("Enhanced RL training not enabled, skipping comprehensive data send")
return
# Extract comprehensive training data components
market_state = training_data.market_state if hasattr(training_data, 'market_state') else None
universal_stream = training_data.universal_stream if hasattr(training_data, 'universal_stream') else None
cnn_features = training_data.cnn_features if hasattr(training_data, 'cnn_features') else None
cnn_predictions = training_data.cnn_predictions if hasattr(training_data, 'cnn_predictions') else None
if market_state and universal_stream:
# Send to enhanced RL trainer if available
if hasattr(self.orchestrator, 'enhanced_rl_trainer'):
try:
# Create comprehensive training step with ~13,400 features
asyncio.run(self.orchestrator.enhanced_rl_trainer.training_step(universal_stream))
self.rl_training_stats['enhanced_rl_episodes'] += 1
logger.debug("Sent comprehensive data to enhanced RL trainer")
except Exception as e:
logger.warning(f"Error in enhanced RL training step: {e}")
# Send to extrema trainer for CNN training with perfect moves
if hasattr(self.orchestrator, 'extrema_trainer'):
try:
extrema_data = self.orchestrator.extrema_trainer.get_extrema_training_data(count=50)
perfect_moves = self.orchestrator.extrema_trainer.get_perfect_moves_for_cnn(count=100)
if extrema_data:
logger.debug(f"Enhanced RL: {len(extrema_data)} extrema training samples available")
if perfect_moves:
logger.debug(f"Enhanced RL: {len(perfect_moves)} perfect moves for CNN training")
except Exception as e:
logger.warning(f"Error getting extrema training data: {e}")
# Send to sensitivity learning DQN for outcome-based learning
if hasattr(self.orchestrator, 'sensitivity_learning_queue'):
try:
if len(self.orchestrator.sensitivity_learning_queue) > 0:
logger.debug("Enhanced RL: Sensitivity learning data available for DQN training")
except Exception as e:
logger.warning(f"Error accessing sensitivity learning queue: {e}")
# Get context features for models with real market data
if hasattr(self.orchestrator, 'extrema_trainer'):
try:
for symbol in self.orchestrator.symbols:
context_features = self.orchestrator.extrema_trainer.get_context_features_for_model(symbol)
if context_features is not None:
logger.debug(f"Enhanced RL: Context features available for {symbol}: {context_features.shape}")
except Exception as e:
logger.warning(f"Error getting context features: {e}")
# Log comprehensive training data statistics
tick_count = len(training_data.tick_cache) if hasattr(training_data, 'tick_cache') else 0
bars_count = len(training_data.one_second_bars) if hasattr(training_data, 'one_second_bars') else 0
timeframe_count = len(training_data.multi_timeframe_data) if hasattr(training_data, 'multi_timeframe_data') else 0
logger.info(f"Enhanced RL Comprehensive Training Data:")
logger.info(f" Tick cache: {tick_count} ticks")
logger.info(f" 1s bars: {bars_count} bars")
logger.info(f" Multi-timeframe data: {timeframe_count} symbols")
logger.info(f" CNN features: {'Available' if cnn_features else 'Not available'}")
logger.info(f" CNN predictions: {'Available' if cnn_predictions else 'Not available'}")
logger.info(f" Market state: {'Available (~13,400 features)' if market_state else 'Not available'}")
logger.info(f" Universal stream: {'Available' if universal_stream else 'Not available'}")
except Exception as e:
logger.error(f"Error sending comprehensive training data to enhanced RL: {e}")
def _collect_basic_training_data(self):
"""Fallback method to collect basic training data when enhanced RL is not available"""
try:
# Get real tick data from data provider subscribers
for symbol in ['ETH/USDT', 'BTC/USDT']:
try:
# Get recent ticks from data provider
if hasattr(self.data_provider, 'get_recent_ticks'):
recent_ticks = self.data_provider.get_recent_ticks(symbol, count=10)
for tick in recent_ticks:
# Create tick data from real market data
tick_data = {
'symbol': tick.symbol,
'price': tick.price,
'timestamp': tick.timestamp,
'volume': tick.volume
}
# Add to tick cache
self.tick_cache.append(tick_data)
# Create 1s bar data from real tick
bar_data = {
'symbol': tick.symbol,
'open': tick.price,
'high': tick.price,
'low': tick.price,
'close': tick.price,
'volume': tick.volume,
'timestamp': tick.timestamp
}
# Add to 1s bars cache
self.one_second_bars.append(bar_data)
except Exception as e:
logger.debug(f"No recent tick data available for {symbol}: {e}")
# Set streaming status based on real data availability
self.is_streaming = len(self.tick_cache) > 0
except Exception as e:
logger.warning(f"Error in basic training data collection: {e}")
def _get_initial_balance(self) -> float:
"""Get initial USDT balance from MEXC or return default"""
@ -2240,12 +2530,12 @@ class TradingDashboard:
logger.warning(f"RL prediction error: {e}")
return np.array([0.33, 0.34, 0.33]), 0.5
def get_memory_usage(self):
return 80 # MB estimate
def to_device(self, device):
self.device = device
return self
def get_memory_usage(self):
return 80 # MB estimate
def to_device(self, device):
self.device = device
return self
rl_wrapper = RLWrapper(rl_path)
@ -2511,19 +2801,20 @@ class TradingDashboard:
return pd.DataFrame()
def _create_training_metrics(self) -> List:
"""Create comprehensive model training metrics display"""
"""Create comprehensive model training metrics display with enhanced RL integration"""
try:
training_items = []
# Training Data Streaming Status
# Enhanced Training Data Streaming Status
tick_cache_size = len(self.tick_cache)
bars_cache_size = len(self.one_second_bars)
enhanced_data_available = self.training_data_available and self.enhanced_rl_training_enabled
training_items.append(
html.Div([
html.H6([
html.I(className="fas fa-database me-2 text-info"),
"Training Data Stream"
"Enhanced Training Data Stream"
], className="mb-2"),
html.Div([
html.Small([
@ -2538,11 +2829,58 @@ class TradingDashboard:
html.Strong("Stream: "),
html.Span("LIVE" if self.is_streaming else "OFFLINE",
className="text-success" if self.is_streaming else "text-danger")
], className="d-block"),
html.Small([
html.Strong("Enhanced RL: "),
html.Span("ENABLED" if self.enhanced_rl_training_enabled else "DISABLED",
className="text-success" if self.enhanced_rl_training_enabled else "text-warning")
], className="d-block"),
html.Small([
html.Strong("Comprehensive Data: "),
html.Span("AVAILABLE" if enhanced_data_available else "WAITING",
className="text-success" if enhanced_data_available else "text-warning")
], className="d-block")
])
], className="mb-3 p-2 border border-info rounded")
)
# Enhanced RL Training Statistics
if self.enhanced_rl_training_enabled:
enhanced_episodes = self.rl_training_stats.get('enhanced_rl_episodes', 0)
comprehensive_packets = self.rl_training_stats.get('comprehensive_data_packets', 0)
training_items.append(
html.Div([
html.H6([
html.I(className="fas fa-brain me-2 text-success"),
"Enhanced RL Training"
], className="mb-2"),
html.Div([
html.Small([
html.Strong("Status: "),
html.Span("ACTIVE" if enhanced_episodes > 0 else "WAITING",
className="text-success" if enhanced_episodes > 0 else "text-warning")
], className="d-block"),
html.Small([
html.Strong("Episodes: "),
html.Span(f"{enhanced_episodes}", className="text-info")
], className="d-block"),
html.Small([
html.Strong("Data Packets: "),
html.Span(f"{comprehensive_packets}", className="text-info")
], className="d-block"),
html.Small([
html.Strong("Features: "),
html.Span("~13,400 (Market State)", className="text-success")
], className="d-block"),
html.Small([
html.Strong("Training Mode: "),
html.Span("Comprehensive", className="text-success")
], className="d-block")
])
], className="mb-3 p-2 border border-success rounded")
)
# Model Training Status
try:
# Try to get real training metrics from orchestrator
@ -2553,7 +2891,7 @@ class TradingDashboard:
html.Div([
html.H6([
html.I(className="fas fa-brain me-2 text-warning"),
"CNN Model"
"CNN Model (Extrema Detection)"
], className="mb-2"),
html.Div([
html.Small([
@ -2570,59 +2908,58 @@ class TradingDashboard:
html.Span(f"{training_status['cnn']['loss']:.4f}", className="text-muted")
], className="d-block"),
html.Small([
html.Strong("Epochs: "),
html.Span(f"{training_status['cnn']['epochs']}", className="text-muted")
], className="d-block"),
html.Small([
html.Strong("Learning Rate: "),
html.Span(f"{training_status['cnn']['learning_rate']:.6f}", className="text-muted")
html.Strong("Perfect Moves: "),
html.Span("Available" if hasattr(self.orchestrator, 'extrema_trainer') else "N/A",
className="text-success" if hasattr(self.orchestrator, 'extrema_trainer') else "text-muted")
], className="d-block")
])
], className="mb-3 p-2 border border-warning rounded")
)
# RL Training Metrics
# RL Training Metrics (Enhanced)
total_episodes = self.rl_training_stats.get('total_training_episodes', 0)
profitable_trades = self.rl_training_stats.get('profitable_trades_trained', 0)
win_rate = (profitable_trades / total_episodes * 100) if total_episodes > 0 else 0
training_items.append(
html.Div([
html.H6([
html.I(className="fas fa-robot me-2 text-success"),
"RL Agent (DQN)"
html.I(className="fas fa-robot me-2 text-primary"),
"RL Agent (DQN + Sensitivity Learning)"
], className="mb-2"),
html.Div([
html.Small([
html.Strong("Status: "),
html.Span(training_status['rl']['status'],
className=f"text-{training_status['rl']['status_color']}")
html.Span("ENHANCED" if self.enhanced_rl_training_enabled else "BASIC",
className="text-success" if self.enhanced_rl_training_enabled else "text-warning")
], className="d-block"),
html.Small([
html.Strong("Win Rate: "),
html.Span(f"{training_status['rl']['win_rate']:.1%}", className="text-info")
html.Span(f"{win_rate:.1f}%", className="text-info")
], className="d-block"),
html.Small([
html.Strong("Avg Reward: "),
html.Span(f"{training_status['rl']['avg_reward']:.2f}", className="text-muted")
html.Strong("Total Episodes: "),
html.Span(f"{total_episodes}", className="text-muted")
], className="d-block"),
html.Small([
html.Strong("Episodes: "),
html.Span(f"{training_status['rl']['episodes']}", className="text-muted")
html.Strong("Enhanced Episodes: "),
html.Span(f"{enhanced_episodes}" if self.enhanced_rl_training_enabled else "N/A",
className="text-success" if self.enhanced_rl_training_enabled else "text-muted")
], className="d-block"),
html.Small([
html.Strong("Epsilon: "),
html.Span(f"{training_status['rl']['epsilon']:.3f}", className="text-muted")
], className="d-block"),
html.Small([
html.Strong("Memory: "),
html.Span(f"{training_status['rl']['memory_size']:,}", className="text-muted")
html.Strong("Sensitivity Learning: "),
html.Span("ACTIVE" if hasattr(self.orchestrator, 'sensitivity_learning_queue') else "N/A",
className="text-success" if hasattr(self.orchestrator, 'sensitivity_learning_queue') else "text-muted")
], className="d-block")
])
], className="mb-3 p-2 border border-success rounded")
], className="mb-3 p-2 border border-primary rounded")
)
# Training Progress Chart (Mini)
training_items.append(
html.Div([
html.H6([
html.I(className="fas fa-chart-line me-2 text-primary"),
html.I(className="fas fa-chart-line me-2 text-secondary"),
"Training Progress"
], className="mb-2"),
dcc.Graph(
@ -2630,7 +2967,7 @@ class TradingDashboard:
style={"height": "150px"},
config={'displayModeBar': False}
)
], className="mb-3 p-2 border border-primary rounded")
], className="mb-3 p-2 border border-secondary rounded")
)
except Exception as e:
@ -3365,7 +3702,7 @@ class TradingDashboard:
logger.error(f"Error stopping continuous training: {e}")
def _trigger_rl_training_on_closed_trade(self, closed_trade):
"""Trigger RL training based on a closed trade's profitability"""
"""Trigger enhanced RL training based on a closed trade's profitability with comprehensive data"""
try:
if not self.rl_training_enabled:
return
@ -3375,7 +3712,7 @@ class TradingDashboard:
is_profitable = net_pnl > 0
trade_duration = closed_trade.get('duration', timedelta(0))
# Create training episode data
# Create enhanced training episode data
training_episode = {
'trade_id': closed_trade.get('trade_id'),
'side': closed_trade.get('side'),
@ -3386,7 +3723,8 @@ class TradingDashboard:
'duration_seconds': trade_duration.total_seconds(),
'symbol': closed_trade.get('symbol', 'ETH/USDT'),
'timestamp': closed_trade.get('exit_time', datetime.now()),
'reward': self._calculate_rl_reward(closed_trade)
'reward': self._calculate_rl_reward(closed_trade),
'enhanced_data_available': self.enhanced_rl_training_enabled
}
# Add to training queue
@ -3402,16 +3740,126 @@ class TradingDashboard:
self.rl_training_stats['last_training_time'] = datetime.now()
self.rl_training_stats['training_rewards'].append(training_episode['reward'])
# Trigger actual RL model training
self._execute_rl_training_step(training_episode)
# Enhanced RL training with comprehensive data
if self.enhanced_rl_training_enabled:
self._execute_enhanced_rl_training_step(training_episode)
else:
# Fallback to basic RL training
self._execute_rl_training_step(training_episode)
logger.info(f"[RL_TRAINING] Trade #{training_episode['trade_id']} added to training: "
logger.info(f"[RL_TRAINING] Trade #{training_episode['trade_id']} added to {'ENHANCED' if self.enhanced_rl_training_enabled else 'BASIC'} training: "
f"{'PROFITABLE' if is_profitable else 'LOSS'} "
f"PnL: ${net_pnl:.2f}, Reward: {training_episode['reward']:.3f}")
except Exception as e:
logger.error(f"Error in RL training trigger: {e}")
def _execute_enhanced_rl_training_step(self, training_episode):
"""Execute enhanced RL training step with comprehensive market data"""
try:
# Get comprehensive training data from unified stream
training_data = self.unified_stream.get_latest_training_data() if ENHANCED_RL_AVAILABLE else None
if training_data and hasattr(training_data, 'market_state') and training_data.market_state:
# Enhanced RL training with ~13,400 features
market_state = training_data.market_state
universal_stream = training_data.universal_stream
# Create comprehensive training context
enhanced_context = {
'trade_outcome': training_episode,
'market_state': market_state,
'universal_stream': universal_stream,
'tick_cache': training_data.tick_cache if hasattr(training_data, 'tick_cache') else [],
'multi_timeframe_data': training_data.multi_timeframe_data if hasattr(training_data, 'multi_timeframe_data') else {},
'cnn_features': training_data.cnn_features if hasattr(training_data, 'cnn_features') else None,
'cnn_predictions': training_data.cnn_predictions if hasattr(training_data, 'cnn_predictions') else None
}
# Send to enhanced RL trainer
if hasattr(self.orchestrator, 'enhanced_rl_trainer'):
try:
# Add trading experience with comprehensive context
symbol = training_episode['symbol']
action = TradingAction(
action=training_episode['side'],
symbol=symbol,
confidence=0.8, # Inferred from executed trade
price=training_episode['exit_price'],
size=0.1, # Default size
timestamp=training_episode['timestamp']
)
# Create initial and final market states for RL learning
initial_state = market_state # State at trade entry
final_state = market_state # State at trade exit (simplified)
reward = training_episode['reward']
# Add comprehensive trading experience
self.orchestrator.enhanced_rl_trainer.add_trading_experience(
symbol=symbol,
action=action,
initial_state=initial_state,
final_state=final_state,
reward=reward
)
logger.info(f"[ENHANCED_RL] Added comprehensive trading experience for trade #{training_episode['trade_id']}")
logger.info(f"[ENHANCED_RL] Market state features: ~13,400, Reward: {reward:.3f}")
# Update enhanced RL statistics
self.rl_training_stats['enhanced_rl_episodes'] += 1
return True
except Exception as e:
logger.error(f"Error in enhanced RL trainer: {e}")
return False
# Send to extrema trainer for CNN learning
if hasattr(self.orchestrator, 'extrema_trainer'):
try:
# Mark this trade outcome for CNN training
trade_context = {
'symbol': training_episode['symbol'],
'entry_price': training_episode['entry_price'],
'exit_price': training_episode['exit_price'],
'is_profitable': training_episode['is_profitable'],
'timestamp': training_episode['timestamp']
}
# Add to extrema training if this was a good/bad move
if abs(training_episode['net_pnl']) > 0.5: # Significant move
self.orchestrator.extrema_trainer.add_trade_outcome_for_learning(trade_context)
logger.debug(f"[EXTREMA_CNN] Added trade outcome for CNN learning")
except Exception as e:
logger.warning(f"Error adding to extrema trainer: {e}")
# Send to sensitivity learning DQN
if hasattr(self.orchestrator, 'sensitivity_learning_queue'):
try:
sensitivity_data = {
'trade_outcome': training_episode,
'market_context': enhanced_context,
'learning_priority': 'high' if abs(training_episode['net_pnl']) > 1.0 else 'normal'
}
self.orchestrator.sensitivity_learning_queue.append(sensitivity_data)
logger.debug(f"[SENSITIVITY_DQN] Added trade outcome for sensitivity learning")
except Exception as e:
logger.warning(f"Error adding to sensitivity learning: {e}")
return True
else:
logger.warning(f"[ENHANCED_RL] No comprehensive training data available, falling back to basic training")
return self._execute_rl_training_step(training_episode)
except Exception as e:
logger.error(f"Error executing enhanced RL training step: {e}")
return False
def _calculate_rl_reward(self, closed_trade):
"""Calculate reward for RL training based on trade performance"""
try:
@ -3658,6 +4106,54 @@ class TradingDashboard:
"""Get current RL training statistics"""
return self.rl_training_stats.copy()
def stop_streaming(self):
"""Stop all streaming and training components"""
try:
logger.info("Stopping dashboard streaming and training components...")
# Stop unified data stream
if ENHANCED_RL_AVAILABLE and hasattr(self, 'unified_stream'):
try:
asyncio.run(self.unified_stream.stop_streaming())
if hasattr(self, 'stream_consumer_id'):
self.unified_stream.unregister_consumer(self.stream_consumer_id)
logger.info("Unified data stream stopped")
except Exception as e:
logger.warning(f"Error stopping unified stream: {e}")
# Stop WebSocket streaming
self.is_streaming = False
if self.ws_connection:
try:
self.ws_connection.close()
logger.info("WebSocket connection closed")
except Exception as e:
logger.warning(f"Error closing WebSocket: {e}")
if self.ws_thread and self.ws_thread.is_alive():
try:
self.ws_thread.join(timeout=5)
logger.info("WebSocket thread stopped")
except Exception as e:
logger.warning(f"Error stopping WebSocket thread: {e}")
# Stop continuous training
self.stop_continuous_training()
# Stop enhanced RL training if available
if self.enhanced_rl_training_enabled and hasattr(self.orchestrator, 'enhanced_rl_trainer'):
try:
if hasattr(self.orchestrator.enhanced_rl_trainer, 'stop_training'):
asyncio.run(self.orchestrator.enhanced_rl_trainer.stop_training())
logger.info("Enhanced RL training stopped")
except Exception as e:
logger.warning(f"Error stopping enhanced RL training: {e}")
logger.info("All streaming and training components stopped")
except Exception as e:
logger.error(f"Error stopping streaming: {e}")
def create_dashboard(data_provider: DataProvider = None, orchestrator: TradingOrchestrator = None, trading_executor: TradingExecutor = None) -> TradingDashboard:
"""Factory function to create a trading dashboard"""