#!/usr/bin/env python3 """ Real-time Reinforcement Learning COB Trader A sophisticated real-time RL system that: 1. Uses COB (Consolidated Order Book) data for training a 1B parameter RL model 2. Performs inference every 200ms or when new data comes 3. Predicts next price moves in real-time 4. Trains continuously based on prediction success 5. Accumulates signals based on confidence 6. Issues trade signals after 3 confident and successful predictions 7. Trains with higher weight when closing trades Integrates with existing gogo2 trading system architecture. """ import asyncio import logging import numpy as np import torch import torch.nn as nn import torch.optim as optim from datetime import datetime, timedelta from typing import Dict, List, Optional, Any, Callable, Tuple from collections import deque, defaultdict from dataclasses import dataclass, asdict import json import time import threading from threading import Lock import pickle import os # Local imports from .cob_integration import COBIntegration from .trading_executor import TradingExecutor from NN.models.cob_rl_model import MassiveRLNetwork, COBRLModelInterface logger = logging.getLogger(__name__) @dataclass class PredictionResult: """Result of a model prediction""" timestamp: datetime symbol: str predicted_direction: int # 0=DOWN, 1=SIDEWAYS, 2=UP confidence: float predicted_change: float # Predicted price change % features: np.ndarray actual_direction: Optional[int] = None # Filled later for training actual_change: Optional[float] = None # Filled later for training reward: Optional[float] = None # Calculated reward for RL training @dataclass class SignalAccumulator: """Accumulates signals for trade decision making""" symbol: str signals: deque # Recent signals confidence_sum: float = 0.0 successful_predictions: int = 0 total_predictions: int = 0 last_reset_time: Optional[datetime] = None def __post_init__(self): if self.signals is None: self.signals = deque(maxlen=10) if self.last_reset_time is None: self.last_reset_time = datetime.now() @dataclass class TrainingUpdate: """Training update event data""" timestamp: datetime symbol: str epoch: int loss: float batch_size: int learning_rate: float accuracy: float avg_confidence: float @dataclass class TradeSignal: """Trade signal event data""" timestamp: datetime symbol: str action: str # 'BUY', 'SELL', 'HOLD' confidence: float quantity: float price: float signals_count: int reason: str # MassiveRLNetwork is now imported from NN.models.cob_rl_model class RealtimeRLCOBTrader: """ Real-time RL trader using COB data with comprehensive subscriber system """ def __init__(self, symbols: Optional[List[str]] = None, trading_executor: Optional[TradingExecutor] = None, model_checkpoint_dir: str = "models/realtime_rl_cob", inference_interval_ms: int = 200, min_confidence_threshold: float = 0.35, # Lowered from 0.7 for more aggressive trading required_confident_predictions: int = 3, checkpoint_manager: Any = None): self.symbols = symbols or ['BTC/USDT', 'ETH/USDT'] self.trading_executor = trading_executor self.model_checkpoint_dir = model_checkpoint_dir self.inference_interval_ms = inference_interval_ms self.min_confidence_threshold = min_confidence_threshold self.required_confident_predictions = required_confident_predictions # Initialize CheckpointManager (either provided or get global instance) if checkpoint_manager is None: from utils.checkpoint_manager import get_checkpoint_manager self.checkpoint_manager = get_checkpoint_manager() else: self.checkpoint_manager = checkpoint_manager # Track start time for training duration calculation self.start_time = datetime.now() # Initialize start_time # Setup device self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Using device: {self.device}") # Initialize models for each symbol self.models: Dict[str, MassiveRLNetwork] = {} self.optimizers: Dict[str, optim.AdamW] = {} self.scalers: Dict[str, torch.cuda.amp.GradScaler] = {} for symbol in self.symbols: model = MassiveRLNetwork().to(self.device) self.models[symbol] = model self.optimizers[symbol] = optim.AdamW( model.parameters(), lr=1e-5, # Low learning rate for stability weight_decay=1e-6, betas=(0.9, 0.999) ) self.scalers[symbol] = torch.cuda.amp.GradScaler() # Subscriber system for real-time events self.prediction_subscribers: List[Callable[[PredictionResult], None]] = [] self.training_subscribers: List[Callable[[TrainingUpdate], None]] = [] self.signal_subscribers: List[Callable[[TradeSignal], None]] = [] self.async_prediction_subscribers: List[Callable[[PredictionResult], Any]] = [] self.async_training_subscribers: List[Callable[[TrainingUpdate], Any]] = [] self.async_signal_subscribers: List[Callable[[TradeSignal], Any]] = [] # COB integration self.cob_integration = COBIntegration(symbols=self.symbols) self.cob_integration.add_dqn_callback(self._on_cob_update_sync) # Data storage for real-time training self.prediction_history: Dict[str, deque] = {} self.feature_buffers: Dict[str, deque] = {} self.price_history: Dict[str, deque] = {} # Signal accumulation self.signal_accumulators: Dict[str, SignalAccumulator] = {} # Performance tracking self.training_stats: Dict[str, Dict] = {} self.inference_stats: Dict[str, Dict] = {} # Initialize per symbol for symbol in self.symbols: self.prediction_history[symbol] = deque(maxlen=1000) self.feature_buffers[symbol] = deque(maxlen=100) self.price_history[symbol] = deque(maxlen=1000) self.signal_accumulators[symbol] = SignalAccumulator( symbol=symbol, signals=deque(maxlen=self.required_confident_predictions * 2) ) self.training_stats[symbol] = { 'total_predictions': 0, 'successful_predictions': 0, 'total_training_steps': 0, 'average_loss': 0.0, 'last_training_time': None } self.inference_stats[symbol] = { 'total_inferences': 0, 'average_inference_time_ms': 0.0, 'last_inference_time': None } # PnL tracking for loss cutting optimization self.pnl_history: Dict[str, deque] = { symbol: deque(maxlen=1000) for symbol in self.symbols } self.position_peak_pnl: Dict[str, float] = {symbol: 0.0 for symbol in self.symbols} self.trade_history: Dict[str, List] = {symbol: [] for symbol in self.symbols} # Threading self.running = False self.inference_lock = Lock() self.training_lock = Lock() # Create checkpoint directory os.makedirs(self.model_checkpoint_dir, exist_ok=True) logger.info(f"RealtimeRLCOBTrader initialized for symbols: {self.symbols}") logger.info(f"Inference interval: {self.inference_interval_ms}ms") logger.info(f"Required confident predictions: {self.required_confident_predictions}") # Subscriber system methods def add_prediction_subscriber(self, callback: Callable[[PredictionResult], None]): """Add a subscriber for prediction events""" self.prediction_subscribers.append(callback) logger.info(f"Added prediction subscriber, total: {len(self.prediction_subscribers)}") def add_training_subscriber(self, callback: Callable[[TrainingUpdate], None]): """Add a subscriber for training events""" self.training_subscribers.append(callback) logger.info(f"Added training subscriber, total: {len(self.training_subscribers)}") def add_signal_subscriber(self, callback: Callable[[TradeSignal], None]): """Add a subscriber for trade signal events""" self.signal_subscribers.append(callback) logger.info(f"Added signal subscriber, total: {len(self.signal_subscribers)}") def add_async_prediction_subscriber(self, callback: Callable[[PredictionResult], Any]): """Add an async subscriber for prediction events""" self.async_prediction_subscribers.append(callback) logger.info(f"Added async prediction subscriber, total: {len(self.async_prediction_subscribers)}") def add_async_training_subscriber(self, callback: Callable[[TrainingUpdate], Any]): """Add an async subscriber for training events""" self.async_training_subscribers.append(callback) logger.info(f"Added async training subscriber, total: {len(self.async_training_subscribers)}") def add_async_signal_subscriber(self, callback: Callable[[TradeSignal], Any]): """Add an async subscriber for trade signal events""" self.async_signal_subscribers.append(callback) logger.info(f"Added async signal subscriber, total: {len(self.async_signal_subscribers)}") async def _emit_prediction(self, prediction: PredictionResult): """Emit prediction to all subscribers""" try: # Sync subscribers for callback in self.prediction_subscribers: try: callback(prediction) except Exception as e: logger.warning(f"Error in prediction subscriber: {e}") # Async subscribers for callback in self.async_prediction_subscribers: try: if asyncio.iscoroutinefunction(callback): asyncio.create_task(callback(prediction)) else: callback(prediction) except Exception as e: logger.warning(f"Error in async prediction subscriber: {e}") except Exception as e: logger.error(f"Error emitting prediction: {e}") async def _emit_training_update(self, update: TrainingUpdate): """Emit training update to all subscribers""" try: # Sync subscribers for callback in self.training_subscribers: try: callback(update) except Exception as e: logger.warning(f"Error in training subscriber: {e}") # Async subscribers for callback in self.async_training_subscribers: try: if asyncio.iscoroutinefunction(callback): asyncio.create_task(callback(update)) else: callback(update) except Exception as e: logger.warning(f"Error in async training subscriber: {e}") except Exception as e: logger.error(f"Error emitting training update: {e}") async def _emit_trade_signal(self, signal: TradeSignal): """Emit trade signal to all subscribers""" try: # Sync subscribers for callback in self.signal_subscribers: try: callback(signal) except Exception as e: logger.warning(f"Error in signal subscriber: {e}") # Async subscribers for callback in self.async_signal_subscribers: try: if asyncio.iscoroutinefunction(callback): asyncio.create_task(callback(signal)) else: callback(signal) except Exception as e: logger.warning(f"Error in async signal subscriber: {e}") except Exception as e: logger.error(f"Error emitting trade signal: {e}") def _on_cob_update_sync(self, symbol: str, data: Dict): """Sync wrapper for async COB update handler""" try: # Schedule the async method asyncio.create_task(self._on_cob_update(symbol, data)) except Exception as e: logger.error(f"Error scheduling COB update for {symbol}: {e}") async def start(self): """Start the real-time RL trader""" logger.info("Starting Real-time RL COB Trader") self.running = True # Load existing models if available self._load_models() # Start COB integration await self.cob_integration.start() # Start inference loop asyncio.create_task(self._inference_loop()) # Start training loop asyncio.create_task(self._training_loop()) # Start signal processing loop asyncio.create_task(self._signal_processing_loop()) # Start model saving loop asyncio.create_task(self._model_saving_loop()) logger.info("Real-time RL COB Trader started successfully") async def stop(self): """Stop the real-time RL trader""" logger.info("Stopping Real-time RL COB Trader") self.running = False # Save models self._save_models() # Stop COB integration await self.cob_integration.stop() logger.info("Real-time RL COB Trader stopped") async def _on_cob_update(self, symbol: str, data: Dict): """Handle COB updates for real-time inference""" try: if symbol not in self.symbols: return # Extract features from COB data features = self._extract_features(symbol, data) if features is None: return # Store in buffer self.feature_buffers[symbol].append({ 'timestamp': datetime.now(), 'features': features, 'raw_data': data }) # Store price for later reward calculation if 'state' in data: price = self._extract_price_from_state(data['state']) if price > 0: self.price_history[symbol].append({ 'timestamp': datetime.now(), 'price': price }) except Exception as e: logger.error(f"Error handling COB update for {symbol}: {e}") def _extract_features(self, symbol: str, data: Dict) -> Optional[np.ndarray]: """Extract features from COB data for model input""" try: # Get state from COB data if 'state' not in data: return None state = data['state'] # Ensure we have the right feature size (2000 features) if isinstance(state, np.ndarray): features = state.flatten() else: features = np.array(state).flatten() # Pad or truncate to exact size target_size = 2000 if len(features) < target_size: # Pad with zeros padded = np.zeros(target_size) padded[:len(features)] = features features = padded elif len(features) > target_size: # Truncate features = features[:target_size] # Normalize features features = self._normalize_features(features) return features except Exception as e: logger.error(f"Error extracting features for {symbol}: {e}") return None def _normalize_features(self, features: np.ndarray) -> np.ndarray: """Normalize features for model input""" try: # Clip extreme values features = np.clip(features, -10.0, 10.0) # Z-score normalization with robust statistics median = np.median(features) mad = np.median(np.abs(features - median)) if mad > 1e-6: features = (features - median) / (mad * 1.4826) # Final clipping features = np.clip(features, -5.0, 5.0) return features.astype(np.float32) except Exception as e: logger.warning(f"Error normalizing features: {e}") return features.astype(np.float32) def _extract_price_from_state(self, state) -> float: """Extract current price from state data""" try: # Try different ways to extract price if isinstance(state, np.ndarray) and len(state) > 0: # Assume first few elements might contain price info return float(state[0]) elif isinstance(state, (list, tuple)) and len(state) > 0: return float(state[0]) else: return 0.0 except: return 0.0 async def _inference_loop(self): """Main inference loop - runs every 200ms or when new data arrives""" logger.info("Starting inference loop") while self.running: try: start_time = time.time() # Run inference for all symbols for symbol in self.symbols: await self._run_inference(symbol) # Calculate sleep time to maintain interval elapsed_ms = (time.time() - start_time) * 1000 sleep_ms = max(0, self.inference_interval_ms - elapsed_ms) if sleep_ms > 0: await asyncio.sleep(sleep_ms / 1000) except Exception as e: logger.error(f"Error in inference loop: {e}") await asyncio.sleep(0.1) async def _run_inference(self, symbol: str): """Run inference for a specific symbol""" try: with self.inference_lock: # Check if we have recent features if not self.feature_buffers[symbol]: return # Get latest features latest_data = self.feature_buffers[symbol][-1] features = latest_data['features'] timestamp = latest_data['timestamp'] # Run model inference start_time = time.time() prediction = self._predict(symbol, features) inference_time_ms = (time.time() - start_time) * 1000 # Update inference stats stats = self.inference_stats[symbol] stats['total_inferences'] += 1 stats['average_inference_time_ms'] = ( (stats['average_inference_time_ms'] * (stats['total_inferences'] - 1) + inference_time_ms) / stats['total_inferences'] ) stats['last_inference_time'] = timestamp # Create prediction result result = PredictionResult( timestamp=timestamp, symbol=symbol, predicted_direction=prediction['direction'], confidence=prediction['confidence'], predicted_change=prediction['change'], features=features ) # Store prediction for later training self.prediction_history[symbol].append(result) # Emit prediction to subscribers await self._emit_prediction(result) # Add to signal accumulator if confident enough if prediction['confidence'] >= self.min_confidence_threshold: self._add_signal(symbol, result) logger.debug(f"Inference {symbol}: direction={prediction['direction']}, " f"confidence={prediction['confidence']:.3f}, " f"change={prediction['change']:.4f}, " f"time={inference_time_ms:.1f}ms") except Exception as e: logger.error(f"Error running inference for {symbol}: {e}") def _predict(self, symbol: str, features: np.ndarray) -> Dict: """Run model prediction""" try: model = self.models[symbol] model.eval() # Convert to tensor features_tensor = torch.from_numpy(features).unsqueeze(0).to(self.device) with torch.no_grad(): with torch.cuda.amp.autocast(): outputs = model(features_tensor) # Extract predictions price_probs = torch.softmax(outputs['price_logits'], dim=1) direction = torch.argmax(price_probs, dim=1).item() confidence = outputs['confidence'].item() value = outputs['value'].item() # Calculate predicted change based on direction and confidence if direction == 2: # UP predicted_change = confidence * 0.001 # Max 0.1% up elif direction == 0: # DOWN predicted_change = -confidence * 0.001 # Max 0.1% down else: # SIDEWAYS predicted_change = 0.0 return { 'direction': direction, 'confidence': confidence, 'change': predicted_change, 'value': value } except Exception as e: logger.error(f"Error in prediction for {symbol}: {e}") return { 'direction': 1, # SIDEWAYS 'confidence': 0.0, 'change': 0.0, 'value': 0.0 } def _add_signal(self, symbol: str, prediction: PredictionResult): """Add confident prediction to signal accumulator""" try: accumulator = self.signal_accumulators[symbol] accumulator.signals.append(prediction) accumulator.confidence_sum += prediction.confidence accumulator.total_predictions += 1 logger.debug(f"Added signal for {symbol}: {len(accumulator.signals)} total signals") except Exception as e: logger.error(f"Error adding signal for {symbol}: {e}") async def _signal_processing_loop(self): """Process accumulated signals and generate trade decisions""" logger.info("Starting signal processing loop") while self.running: try: for symbol in self.symbols: await self._process_signals(symbol) await asyncio.sleep(0.5) # Process signals every 500ms to reduce load except Exception as e: logger.error(f"Error in signal processing loop: {e}") await asyncio.sleep(1) async def _process_signals(self, symbol: str): """Process signals for a specific symbol and make trade decisions""" try: accumulator = self.signal_accumulators[symbol] # Check if we have enough confident predictions if len(accumulator.signals) < self.required_confident_predictions: return # Get recent signals recent_signals = list(accumulator.signals)[-self.required_confident_predictions:] # Check if all recent signals are in the same direction directions = [signal.predicted_direction for signal in recent_signals] confidences = [signal.confidence for signal in recent_signals] # Count direction consensus direction_counts = {0: 0, 1: 0, 2: 0} # DOWN, SIDEWAYS, UP for direction in directions: direction_counts[direction] += 1 # Find dominant direction dominant_direction = max(direction_counts, key=direction_counts.get) consensus_count = direction_counts[dominant_direction] # Check if we have enough consensus if consensus_count >= self.required_confident_predictions and dominant_direction != 1: # We have consensus for action (not sideways) avg_confidence = np.mean(confidences) # Determine action if dominant_direction == 2: # UP action = 'BUY' elif dominant_direction == 0: # DOWN action = 'SELL' else: return # No action for sideways # Execute trade signal await self._execute_trade_signal(symbol, action, float(avg_confidence), recent_signals) # Reset accumulator after trade signal self._reset_accumulator(symbol) except Exception as e: logger.error(f"Error processing signals for {symbol}: {e}") async def _execute_trade_signal(self, symbol: str, action: str, confidence: float, signals: List[PredictionResult]): """Execute a trade signal""" try: logger.info(f"Executing trade signal: {action} {symbol} with confidence {confidence:.3f}") # Get current price current_price = 0.0 if self.price_history[symbol]: current_price = self.price_history[symbol][-1]['price'] # Create trade signal for emission trade_signal = TradeSignal( timestamp=datetime.now(), symbol=symbol, action=action, confidence=confidence, quantity=1.0, # Default quantity price=current_price, signals_count=len(signals), reason=f"Consensus of {len(signals)} predictions" ) # Emit trade signal to subscribers await self._emit_trade_signal(trade_signal) # Execute through trading executor if available if self.trading_executor and current_price > 0: success = self.trading_executor.execute_signal( symbol=symbol, action=action, confidence=confidence, current_price=current_price ) if success: logger.info(f"Trade executed successfully: {action} {symbol}") # Schedule training with higher weight for trade closure asyncio.create_task(self._train_on_trade_execution(symbol, signals, action, current_price)) else: logger.warning(f"Trade execution failed: {action} {symbol}") else: logger.info(f"No trading executor available or price unknown for {symbol}") except Exception as e: logger.error(f"Error executing trade signal for {symbol}: {e}") def _reset_accumulator(self, symbol: str): """Reset signal accumulator after trade execution""" try: accumulator = self.signal_accumulators[symbol] accumulator.signals.clear() accumulator.confidence_sum = 0.0 accumulator.last_reset_time = datetime.now() logger.debug(f"Reset signal accumulator for {symbol}") except Exception as e: logger.error(f"Error resetting accumulator for {symbol}: {e}") async def _training_loop(self): """Main training loop for real-time model updates""" logger.info("Starting training loop") while self.running: try: for symbol in self.symbols: await self._train_symbol_model(symbol) await asyncio.sleep(1.0) # Train every second except Exception as e: logger.error(f"Error in training loop: {e}") await asyncio.sleep(5) async def _train_symbol_model(self, symbol: str): """Train model for a specific symbol using recent predictions""" try: with self.training_lock: # Check if we have enough data for training predictions = list(self.prediction_history[symbol]) if len(predictions) < 10: return # Calculate rewards for recent predictions self._calculate_rewards(symbol, predictions) # Filter predictions with calculated rewards training_predictions = [p for p in predictions if p.reward is not None] if len(training_predictions) < 5: return # Prepare training batch batch_size = min(32, len(training_predictions)) batch_predictions = training_predictions[-batch_size:] # Train model loss = await self._train_batch(symbol, batch_predictions) # Update training stats stats = self.training_stats[symbol] stats['total_training_steps'] += 1 stats['average_loss'] = ( (stats['average_loss'] * (stats['total_training_steps'] - 1) + loss) / stats['total_training_steps'] ) stats['last_training_time'] = datetime.now() # Calculate accuracy and confidence accuracy = stats['successful_predictions'] / max(1, stats['total_predictions']) * 100 avg_confidence = sum(p.confidence for p in batch_predictions) / len(batch_predictions) # Create training update for emission training_update = TrainingUpdate( timestamp=datetime.now(), symbol=symbol, epoch=stats['total_training_steps'], loss=loss, batch_size=batch_size, learning_rate=self.optimizers[symbol].param_groups[0]['lr'], accuracy=accuracy, avg_confidence=avg_confidence ) # Emit training update to subscribers await self._emit_training_update(training_update) logger.debug(f"Training {symbol}: loss={loss:.6f}, batch_size={batch_size}") except Exception as e: logger.error(f"Error training model for {symbol}: {e}") def _calculate_rewards(self, symbol: str, predictions: List[PredictionResult]): """Calculate rewards for predictions based on actual price movements""" try: price_history = list(self.price_history[symbol]) if len(price_history) < 2: return for prediction in predictions: if prediction.reward is not None: continue # Already calculated # Find actual price change after prediction pred_time = prediction.timestamp # Look for price data after prediction (with reasonable timeout) future_prices = [ p for p in price_history if p['timestamp'] > pred_time and (p['timestamp'] - pred_time).total_seconds() <= 60 # 1 minute timeout ] if not future_prices: continue # Find price at prediction time past_prices = [ p for p in price_history if abs((p['timestamp'] - pred_time).total_seconds()) <= 10 # 10 second window ] if not past_prices: continue # Calculate actual price change pred_price = past_prices[-1]['price'] future_price = future_prices[0]['price'] # Use first future price actual_change = (future_price - pred_price) / pred_price # Determine actual direction if actual_change > 0.0005: # 0.05% threshold actual_direction = 2 # UP elif actual_change < -0.0005: actual_direction = 0 # DOWN else: actual_direction = 1 # SIDEWAYS # Calculate reward based on prediction accuracy prediction.reward = self._calculate_prediction_reward( symbol=symbol, predicted_direction=prediction.predicted_direction, actual_direction=actual_direction, confidence=prediction.confidence, predicted_change=prediction.predicted_change, actual_change=actual_change ) # Update training stats stats = self.training_stats[symbol] stats['total_predictions'] += 1 if prediction.reward > 0: stats['successful_predictions'] += 1 except Exception as e: logger.error(f"Error calculating rewards for {symbol}: {e}") def _calculate_prediction_reward(self, symbol: str, predicted_direction: int, actual_direction: int, confidence: float, predicted_change: float, actual_change: float, current_pnl: float = 0.0, position_duration: float = 0.0) -> float: """Calculate reward based on prediction accuracy and actual price movement""" reward = 0.0 # Base reward for correct direction prediction if predicted_direction == actual_direction: reward += 1.0 * confidence # Reward scales with confidence else: reward -= 0.5 # Penalize incorrect predictions # Reward for predicting large changes correctly (proportional to actual change) if predicted_direction == actual_direction and abs(predicted_change) > 0.001: reward += abs(actual_change) * 5.0 # Amplify reward for significant moves # Penalize for large predicted changes that are wrong if predicted_direction != actual_direction and abs(predicted_change) > 0.001: reward -= abs(predicted_change) * 2.0 # Add reward for PnL (realized or unrealized) reward += current_pnl * 0.1 # Small reward for PnL, adjusted by a factor # Dynamic adjustment based on recent PnL (loss cutting incentive) if self.pnl_history[symbol]: latest_pnl_entry = self.pnl_history[symbol][-1] # Get the latest PnL entry # Ensure latest_pnl_entry is a dict and has 'pnl' key, otherwise default to 0.0 latest_pnl_value = latest_pnl_entry.get('pnl', 0.0) if isinstance(latest_pnl_entry, dict) else 0.0 # Incentivize closing losing trades early if latest_pnl_value < 0 and position_duration > 60: # If losing position open for > 60s # More aggressively penalize holding losing positions, or reward closing them reward -= (abs(latest_pnl_value) * 0.2) # Increased penalty for sustained losses # Discourage taking new positions if overall PnL is negative or volatile # This requires a more complex calculation of overall PnL, potentially average of last N trades # For simplicity, let's use the 'best_pnl' to decide if we are in a good state to trade # Calculate the current best PnL from history, ensuring it's not empty pnl_values = [entry.get('pnl', 0.0) for entry in self.pnl_history[symbol] if isinstance(entry, dict)] if not pnl_values: best_pnl = 0.0 else: best_pnl = max(pnl_values) if best_pnl < 0.0: # If recent best PnL is negative, reduce reward for new trades reward -= 0.1 # Small penalty for trading in a losing streak return reward async def _train_batch(self, symbol: str, predictions: List[PredictionResult]) -> float: """Train model on a batch of predictions""" try: model = self.models[symbol] optimizer = self.optimizers[symbol] scaler = self.scalers[symbol] model.train() optimizer.zero_grad() # Prepare batch data features = torch.stack([ torch.from_numpy(p.features) for p in predictions ]).to(self.device) # Targets direction_targets = torch.tensor([ p.actual_direction for p in predictions ], dtype=torch.long).to(self.device) value_targets = torch.tensor([ p.reward for p in predictions ], dtype=torch.float32).to(self.device) # Forward pass with mixed precision with torch.cuda.amp.autocast(): outputs = model(features) # Calculate losses direction_loss = nn.CrossEntropyLoss()(outputs['price_logits'], direction_targets) value_loss = nn.MSELoss()(outputs['value'].squeeze(), value_targets) # Confidence loss (encourage high confidence for correct predictions) correct_predictions = (torch.argmax(outputs['price_logits'], dim=1) == direction_targets).float() confidence_loss = nn.BCELoss()(outputs['confidence'].squeeze(), correct_predictions) # Combined loss total_loss = direction_loss + 0.5 * value_loss + 0.3 * confidence_loss # Backward pass with gradient scaling scaler.scale(total_loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) scaler.step(optimizer) scaler.update() return total_loss.item() except Exception as e: logger.error(f"Error training batch for {symbol}: {e}") return 0.0 async def _train_on_trade_execution(self, symbol: str, signals: List[PredictionResult], action: str, price: float): """Train with higher weight when a trade is executed""" try: logger.info(f"Training on trade execution: {action} {symbol} at ${price:.2f}") # Wait a bit to see trade outcome await asyncio.sleep(30) # 30 seconds to see initial outcome # Calculate actual outcome current_prices = [p['price'] for p in list(self.price_history[symbol])[-5:]] if len(current_prices) >= 2: current_price = current_prices[-1] entry_price = price # Calculate P&L if action == 'BUY': pnl_ratio = (current_price - entry_price) / entry_price elif action == 'SELL': pnl_ratio = (entry_price - current_price) / entry_price else: pnl_ratio = 0.0 # Create enhanced reward for trade execution trade_reward = pnl_ratio * 10.0 # Amplify trade outcomes # Apply enhanced training weight to signals that led to trade for signal in signals: if signal.reward is None: signal.reward = trade_reward else: signal.reward += trade_reward # Add to existing reward logger.info(f"Trade outcome for {symbol}: P&L ratio={pnl_ratio:.4f}, " f"enhanced reward={trade_reward:.4f}") # Immediate training step with higher weight if len(signals) > 0: loss = await self._train_batch(symbol, signals[-3:]) # Train on last 3 signals logger.info(f"Enhanced training loss for {symbol}: {loss:.6f}") except Exception as e: logger.error(f"Error in trade execution training for {symbol}: {e}") async def _model_saving_loop(self): """Periodically save models""" logger.info("Starting model saving loop") while self.running: try: await asyncio.sleep(300) # Save every 5 minutes self._save_models() except Exception as e: logger.error(f"Error in model saving loop: {e}") await asyncio.sleep(60) def _save_models(self): """Save all models to disk using CheckpointManager""" try: for symbol in self.symbols: model_name = f"cob_rl_{symbol.replace('/', '_').lower()}" # Standardize model name for CheckpointManager # Prepare performance metrics for CheckpointManager performance_metrics = { 'loss': self.training_stats[symbol].get('average_loss', 0.0), 'reward': self.training_stats[symbol].get('average_reward', 0.0), # Assuming average_reward is tracked 'accuracy': self.training_stats[symbol].get('average_accuracy', 0.0), # Assuming average_accuracy is tracked } if self.trading_executor: # Add check for trading_executor daily_stats = self.trading_executor.get_daily_stats() performance_metrics['pnl'] = daily_stats.get('total_pnl', 0.0) # Example, get actual pnl performance_metrics['training_samples'] = self.training_stats[symbol].get('total_training_steps', 0) # Prepare training metadata for CheckpointManager training_metadata = { 'total_parameters': sum(p.numel() for p in self.models[symbol].parameters()), 'epoch': self.training_stats[symbol].get('total_training_steps', 0), # Using total_training_steps as pseudo-epoch 'training_time_hours': (datetime.now() - self.start_time).total_seconds() / 3600 } self.checkpoint_manager.save_checkpoint( model=self.models[symbol], model_name=model_name, model_type='COB_RL', # Specify model type performance_metrics=performance_metrics, training_metadata=training_metadata ) logger.debug(f"Saved model for {symbol}") except Exception as e: logger.error(f"Error saving models: {e}") def _load_models(self): """Load existing models from disk using CheckpointManager""" try: for symbol in self.symbols: model_name = f"cob_rl_{symbol.replace('/', '_').lower()}" # Standardize model name for CheckpointManager loaded_checkpoint = self.checkpoint_manager.load_best_checkpoint(model_name) if loaded_checkpoint: model_path, metadata = loaded_checkpoint checkpoint = torch.load(model_path, map_location=self.device) self.models[symbol].load_state_dict(checkpoint['model_state_dict']) self.optimizers[symbol].load_state_dict(checkpoint['optimizer_state_dict']) if 'training_stats' in checkpoint: self.training_stats[symbol].update(checkpoint['training_stats']) if 'inference_stats' in checkpoint: self.inference_stats[symbol].update(checkpoint['inference_stats']) logger.info(f"Loaded existing model for {symbol} from checkpoint: {metadata.checkpoint_id}") else: logger.info(f"No existing model found for {symbol} via CheckpointManager, starting fresh.") except Exception as e: logger.error(f"Error loading models: {e}") def get_performance_stats(self) -> Dict[str, Any]: """Get comprehensive performance statistics""" try: stats = { 'symbols': self.symbols, 'training_stats': self.training_stats.copy(), 'inference_stats': self.inference_stats.copy(), 'signal_stats': {}, 'model_info': {} } # Add signal accumulator stats for symbol in self.symbols: accumulator = self.signal_accumulators[symbol] stats['signal_stats'][symbol] = { 'current_signals': len(accumulator.signals), 'confidence_sum': accumulator.confidence_sum, 'total_predictions': accumulator.total_predictions, 'successful_predictions': accumulator.successful_predictions, 'success_rate': ( accumulator.successful_predictions / max(1, accumulator.total_predictions) ) } # Add model parameter info for symbol in self.symbols: model = self.models[symbol] total_params = sum(p.numel() for p in model.parameters()) stats['model_info'][symbol] = { 'total_parameters': total_params, 'trainable_parameters': sum(p.numel() for p in model.parameters() if p.requires_grad) } return stats except Exception as e: logger.error(f"Error getting performance stats: {e}") return {} # Example usage async def main(): """Example usage of RealtimeRLCOBTrader""" from ..core.trading_executor import TradingExecutor # Initialize trading executor (simulation mode) trading_executor = TradingExecutor() # Initialize real-time RL trader trader = RealtimeRLCOBTrader( symbols=['BTC/USDT', 'ETH/USDT'], trading_executor=trading_executor, inference_interval_ms=200, min_confidence_threshold=0.7, required_confident_predictions=3 ) try: # Start the trader await trader.start() # Run for demonstration logger.info("Real-time RL COB Trader running...") await asyncio.sleep(300) # Run for 5 minutes # Print performance stats stats = trader.get_performance_stats() logger.info(f"Performance stats: {json.dumps(stats, indent=2, default=str)}") finally: # Stop the trader await trader.stop() if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.run(main())