Files
gogo2/core/realtime_rl_cob_trader.py
Dobromir Popov 06fbbeb81e fixes
2025-06-24 20:07:44 +03:00

1265 lines
50 KiB
Python

#!/usr/bin/env python3
"""
Real-time Reinforcement Learning COB Trader
A sophisticated real-time RL system that:
1. Uses COB (Consolidated Order Book) data for training a 1B parameter RL model
2. Performs inference every 200ms or when new data comes
3. Predicts next price moves in real-time
4. Trains continuously based on prediction success
5. Accumulates signals based on confidence
6. Issues trade signals after 3 confident and successful predictions
7. Trains with higher weight when closing trades
Integrates with existing gogo2 trading system architecture.
"""
import asyncio
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Callable, Tuple
from collections import deque, defaultdict
from dataclasses import dataclass, asdict
import json
import time
import threading
from threading import Lock
import pickle
import os
# Local imports
from .cob_integration import COBIntegration
from .trading_executor import TradingExecutor
logger = logging.getLogger(__name__)
@dataclass
class PredictionResult:
"""Result of a model prediction"""
timestamp: datetime
symbol: str
predicted_direction: int # 0=DOWN, 1=SIDEWAYS, 2=UP
confidence: float
predicted_change: float # Predicted price change %
features: np.ndarray
actual_direction: Optional[int] = None # Filled later for training
actual_change: Optional[float] = None # Filled later for training
reward: Optional[float] = None # Calculated reward for RL training
@dataclass
class SignalAccumulator:
"""Accumulates signals for trade decision making"""
symbol: str
signals: deque # Recent signals
confidence_sum: float = 0.0
successful_predictions: int = 0
total_predictions: int = 0
last_reset_time: datetime = None
def __post_init__(self):
if self.signals is None:
self.signals = deque(maxlen=10)
if self.last_reset_time is None:
self.last_reset_time = datetime.now()
@dataclass
class TrainingUpdate:
"""Training update event data"""
timestamp: datetime
symbol: str
epoch: int
loss: float
batch_size: int
learning_rate: float
accuracy: float
avg_confidence: float
@dataclass
class TradeSignal:
"""Trade signal event data"""
timestamp: datetime
symbol: str
action: str # 'BUY', 'SELL', 'HOLD'
confidence: float
quantity: float
price: float
signals_count: int
reason: str
class MassiveRLNetwork(nn.Module):
"""
Massive 1B+ parameter RL network optimized for real-time COB trading
"""
def __init__(self, input_size: int = 2000, hidden_size: int = 4096, num_layers: int = 12):
super(MassiveRLNetwork, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
# Massive input processing layers
self.input_projection = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.LayerNorm(hidden_size),
nn.GELU(),
nn.Dropout(0.1)
)
# Massive transformer-style encoder layers
self.encoder_layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=32, # Large number of attention heads
dim_feedforward=hidden_size * 4, # 16K feedforward
dropout=0.1,
activation='gelu',
batch_first=True
) for _ in range(num_layers)
])
# Market regime understanding layers
self.regime_encoder = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 2),
nn.LayerNorm(hidden_size * 2),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_size * 2, hidden_size),
nn.LayerNorm(hidden_size),
nn.GELU()
)
# Price prediction head (main RL objective)
self.price_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.LayerNorm(hidden_size // 2),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(hidden_size // 2, hidden_size // 4),
nn.LayerNorm(hidden_size // 4),
nn.GELU(),
nn.Linear(hidden_size // 4, 3) # DOWN, SIDEWAYS, UP
)
# Value estimation head for RL
self.value_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.LayerNorm(hidden_size // 2),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(hidden_size // 2, hidden_size // 4),
nn.LayerNorm(hidden_size // 4),
nn.GELU(),
nn.Linear(hidden_size // 4, 1)
)
# Confidence head
self.confidence_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 4),
nn.LayerNorm(hidden_size // 4),
nn.GELU(),
nn.Linear(hidden_size // 4, 1),
nn.Sigmoid()
)
# Initialize weights
self.apply(self._init_weights)
# Calculate total parameters
total_params = sum(p.numel() for p in self.parameters())
logger.info(f"Massive RL Network initialized with {total_params:,} parameters")
def _init_weights(self, module):
"""Initialize weights with proper scaling for large models"""
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.ones_(module.weight)
torch.nn.init.zeros_(module.bias)
def forward(self, x):
"""Forward pass through massive network"""
batch_size = x.size(0)
# Project input
x = self.input_projection(x) # [batch, hidden_size]
# Add sequence dimension for transformer
x = x.unsqueeze(1) # [batch, 1, hidden_size]
# Pass through transformer layers
for layer in self.encoder_layers:
x = layer(x)
# Remove sequence dimension
x = x.squeeze(1) # [batch, hidden_size]
# Apply regime encoding
x = self.regime_encoder(x)
# Generate predictions
price_logits = self.price_head(x)
value = self.value_head(x)
confidence = self.confidence_head(x)
return {
'price_logits': price_logits,
'value': value,
'confidence': confidence,
'features': x # Hidden features for analysis
}
class RealtimeRLCOBTrader:
"""
Real-time RL trader using COB data with comprehensive subscriber system
"""
def __init__(self,
symbols: List[str] = None,
trading_executor: TradingExecutor = None,
model_checkpoint_dir: str = "models/realtime_rl_cob",
inference_interval_ms: int = 200,
min_confidence_threshold: float = 0.7,
required_confident_predictions: int = 3):
self.symbols = symbols or ['BTC/USDT', 'ETH/USDT']
self.trading_executor = trading_executor
self.model_checkpoint_dir = model_checkpoint_dir
self.inference_interval_ms = inference_interval_ms
self.min_confidence_threshold = min_confidence_threshold
self.required_confident_predictions = required_confident_predictions
# Setup device
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {self.device}")
# Initialize models for each symbol
self.models: Dict[str, MassiveRLNetwork] = {}
self.optimizers: Dict[str, optim.AdamW] = {}
self.scalers: Dict[str, torch.cuda.amp.GradScaler] = {}
for symbol in self.symbols:
model = MassiveRLNetwork().to(self.device)
self.models[symbol] = model
self.optimizers[symbol] = optim.AdamW(
model.parameters(),
lr=1e-5, # Low learning rate for stability
weight_decay=1e-6,
betas=(0.9, 0.999)
)
self.scalers[symbol] = torch.cuda.amp.GradScaler()
# Subscriber system for real-time events
self.prediction_subscribers: List[Callable[[PredictionResult], None]] = []
self.training_subscribers: List[Callable[[TrainingUpdate], None]] = []
self.signal_subscribers: List[Callable[[TradeSignal], None]] = []
self.async_prediction_subscribers: List[Callable[[PredictionResult], Any]] = []
self.async_training_subscribers: List[Callable[[TrainingUpdate], Any]] = []
self.async_signal_subscribers: List[Callable[[TradeSignal], Any]] = []
# COB integration
self.cob_integration = COBIntegration(symbols=self.symbols)
self.cob_integration.add_dqn_callback(self._on_cob_update_sync)
# Data storage for real-time training
self.prediction_history: Dict[str, deque] = {}
self.feature_buffers: Dict[str, deque] = {}
self.price_history: Dict[str, deque] = {}
# Signal accumulation
self.signal_accumulators: Dict[str, SignalAccumulator] = {}
# Performance tracking
self.training_stats: Dict[str, Dict] = {}
self.inference_stats: Dict[str, Dict] = {}
# Initialize per symbol
for symbol in self.symbols:
self.prediction_history[symbol] = deque(maxlen=1000)
self.feature_buffers[symbol] = deque(maxlen=100)
self.price_history[symbol] = deque(maxlen=1000)
self.signal_accumulators[symbol] = SignalAccumulator(
symbol=symbol,
signals=deque(maxlen=self.required_confident_predictions * 2)
)
self.training_stats[symbol] = {
'total_predictions': 0,
'successful_predictions': 0,
'total_training_steps': 0,
'average_loss': 0.0,
'last_training_time': None
}
self.inference_stats[symbol] = {
'total_inferences': 0,
'average_inference_time_ms': 0.0,
'last_inference_time': None
}
# PnL tracking for loss cutting optimization
self.pnl_history: Dict[str, deque] = {
symbol: deque(maxlen=1000) for symbol in self.symbols
}
self.position_peak_pnl: Dict[str, float] = {symbol: 0.0 for symbol in self.symbols}
self.trade_history: Dict[str, List] = {symbol: [] for symbol in self.symbols}
# Threading
self.running = False
self.inference_lock = Lock()
self.training_lock = Lock()
# Create checkpoint directory
os.makedirs(self.model_checkpoint_dir, exist_ok=True)
logger.info(f"RealtimeRLCOBTrader initialized for symbols: {self.symbols}")
logger.info(f"Inference interval: {self.inference_interval_ms}ms")
logger.info(f"Required confident predictions: {self.required_confident_predictions}")
# Subscriber system methods
def add_prediction_subscriber(self, callback: Callable[[PredictionResult], None]):
"""Add a subscriber for prediction events"""
self.prediction_subscribers.append(callback)
logger.info(f"Added prediction subscriber, total: {len(self.prediction_subscribers)}")
def add_training_subscriber(self, callback: Callable[[TrainingUpdate], None]):
"""Add a subscriber for training events"""
self.training_subscribers.append(callback)
logger.info(f"Added training subscriber, total: {len(self.training_subscribers)}")
def add_signal_subscriber(self, callback: Callable[[TradeSignal], None]):
"""Add a subscriber for trade signal events"""
self.signal_subscribers.append(callback)
logger.info(f"Added signal subscriber, total: {len(self.signal_subscribers)}")
def add_async_prediction_subscriber(self, callback: Callable[[PredictionResult], Any]):
"""Add an async subscriber for prediction events"""
self.async_prediction_subscribers.append(callback)
logger.info(f"Added async prediction subscriber, total: {len(self.async_prediction_subscribers)}")
def add_async_training_subscriber(self, callback: Callable[[TrainingUpdate], Any]):
"""Add an async subscriber for training events"""
self.async_training_subscribers.append(callback)
logger.info(f"Added async training subscriber, total: {len(self.async_training_subscribers)}")
def add_async_signal_subscriber(self, callback: Callable[[TradeSignal], Any]):
"""Add an async subscriber for trade signal events"""
self.async_signal_subscribers.append(callback)
logger.info(f"Added async signal subscriber, total: {len(self.async_signal_subscribers)}")
async def _emit_prediction(self, prediction: PredictionResult):
"""Emit prediction to all subscribers"""
try:
# Sync subscribers
for callback in self.prediction_subscribers:
try:
callback(prediction)
except Exception as e:
logger.warning(f"Error in prediction subscriber: {e}")
# Async subscribers
for callback in self.async_prediction_subscribers:
try:
if asyncio.iscoroutinefunction(callback):
asyncio.create_task(callback(prediction))
else:
callback(prediction)
except Exception as e:
logger.warning(f"Error in async prediction subscriber: {e}")
except Exception as e:
logger.error(f"Error emitting prediction: {e}")
async def _emit_training_update(self, update: TrainingUpdate):
"""Emit training update to all subscribers"""
try:
# Sync subscribers
for callback in self.training_subscribers:
try:
callback(update)
except Exception as e:
logger.warning(f"Error in training subscriber: {e}")
# Async subscribers
for callback in self.async_training_subscribers:
try:
if asyncio.iscoroutinefunction(callback):
asyncio.create_task(callback(update))
else:
callback(update)
except Exception as e:
logger.warning(f"Error in async training subscriber: {e}")
except Exception as e:
logger.error(f"Error emitting training update: {e}")
async def _emit_trade_signal(self, signal: TradeSignal):
"""Emit trade signal to all subscribers"""
try:
# Sync subscribers
for callback in self.signal_subscribers:
try:
callback(signal)
except Exception as e:
logger.warning(f"Error in signal subscriber: {e}")
# Async subscribers
for callback in self.async_signal_subscribers:
try:
if asyncio.iscoroutinefunction(callback):
asyncio.create_task(callback(signal))
else:
callback(signal)
except Exception as e:
logger.warning(f"Error in async signal subscriber: {e}")
except Exception as e:
logger.error(f"Error emitting trade signal: {e}")
def _on_cob_update_sync(self, symbol: str, data: Dict):
"""Sync wrapper for async COB update handler"""
try:
# Schedule the async method
asyncio.create_task(self._on_cob_update(symbol, data))
except Exception as e:
logger.error(f"Error scheduling COB update for {symbol}: {e}")
async def start(self):
"""Start the real-time RL trader"""
logger.info("Starting Real-time RL COB Trader")
self.running = True
# Load existing models if available
self._load_models()
# Start COB integration
await self.cob_integration.start()
# Start inference loop
asyncio.create_task(self._inference_loop())
# Start training loop
asyncio.create_task(self._training_loop())
# Start signal processing loop
asyncio.create_task(self._signal_processing_loop())
# Start model saving loop
asyncio.create_task(self._model_saving_loop())
logger.info("Real-time RL COB Trader started successfully")
async def stop(self):
"""Stop the real-time RL trader"""
logger.info("Stopping Real-time RL COB Trader")
self.running = False
# Save models
self._save_models()
# Stop COB integration
await self.cob_integration.stop()
logger.info("Real-time RL COB Trader stopped")
async def _on_cob_update(self, symbol: str, data: Dict):
"""Handle COB updates for real-time inference"""
try:
if symbol not in self.symbols:
return
# Extract features from COB data
features = self._extract_features(symbol, data)
if features is None:
return
# Store in buffer
self.feature_buffers[symbol].append({
'timestamp': datetime.now(),
'features': features,
'raw_data': data
})
# Store price for later reward calculation
if 'state' in data:
price = self._extract_price_from_state(data['state'])
if price > 0:
self.price_history[symbol].append({
'timestamp': datetime.now(),
'price': price
})
except Exception as e:
logger.error(f"Error handling COB update for {symbol}: {e}")
def _extract_features(self, symbol: str, data: Dict) -> Optional[np.ndarray]:
"""Extract features from COB data for model input"""
try:
# Get state from COB data
if 'state' not in data:
return None
state = data['state']
# Ensure we have the right feature size (2000 features)
if isinstance(state, np.ndarray):
features = state.flatten()
else:
features = np.array(state).flatten()
# Pad or truncate to exact size
target_size = 2000
if len(features) < target_size:
# Pad with zeros
padded = np.zeros(target_size)
padded[:len(features)] = features
features = padded
elif len(features) > target_size:
# Truncate
features = features[:target_size]
# Normalize features
features = self._normalize_features(features)
return features
except Exception as e:
logger.error(f"Error extracting features for {symbol}: {e}")
return None
def _normalize_features(self, features: np.ndarray) -> np.ndarray:
"""Normalize features for model input"""
try:
# Clip extreme values
features = np.clip(features, -10.0, 10.0)
# Z-score normalization with robust statistics
median = np.median(features)
mad = np.median(np.abs(features - median))
if mad > 1e-6:
features = (features - median) / (mad * 1.4826)
# Final clipping
features = np.clip(features, -5.0, 5.0)
return features.astype(np.float32)
except Exception as e:
logger.warning(f"Error normalizing features: {e}")
return features.astype(np.float32)
def _extract_price_from_state(self, state) -> float:
"""Extract current price from state data"""
try:
# Try different ways to extract price
if isinstance(state, np.ndarray) and len(state) > 0:
# Assume first few elements might contain price info
return float(state[0])
elif isinstance(state, (list, tuple)) and len(state) > 0:
return float(state[0])
else:
return 0.0
except:
return 0.0
async def _inference_loop(self):
"""Main inference loop - runs every 200ms or when new data arrives"""
logger.info("Starting inference loop")
while self.running:
try:
start_time = time.time()
# Run inference for all symbols
for symbol in self.symbols:
await self._run_inference(symbol)
# Calculate sleep time to maintain interval
elapsed_ms = (time.time() - start_time) * 1000
sleep_ms = max(0, self.inference_interval_ms - elapsed_ms)
if sleep_ms > 0:
await asyncio.sleep(sleep_ms / 1000)
except Exception as e:
logger.error(f"Error in inference loop: {e}")
await asyncio.sleep(0.1)
async def _run_inference(self, symbol: str):
"""Run inference for a specific symbol"""
try:
with self.inference_lock:
# Check if we have recent features
if not self.feature_buffers[symbol]:
return
# Get latest features
latest_data = self.feature_buffers[symbol][-1]
features = latest_data['features']
timestamp = latest_data['timestamp']
# Run model inference
start_time = time.time()
prediction = self._predict(symbol, features)
inference_time_ms = (time.time() - start_time) * 1000
# Update inference stats
stats = self.inference_stats[symbol]
stats['total_inferences'] += 1
stats['average_inference_time_ms'] = (
(stats['average_inference_time_ms'] * (stats['total_inferences'] - 1) + inference_time_ms)
/ stats['total_inferences']
)
stats['last_inference_time'] = timestamp
# Create prediction result
result = PredictionResult(
timestamp=timestamp,
symbol=symbol,
predicted_direction=prediction['direction'],
confidence=prediction['confidence'],
predicted_change=prediction['change'],
features=features
)
# Store prediction for later training
self.prediction_history[symbol].append(result)
# Emit prediction to subscribers
await self._emit_prediction(result)
# Add to signal accumulator if confident enough
if prediction['confidence'] >= self.min_confidence_threshold:
self._add_signal(symbol, result)
logger.debug(f"Inference {symbol}: direction={prediction['direction']}, "
f"confidence={prediction['confidence']:.3f}, "
f"change={prediction['change']:.4f}, "
f"time={inference_time_ms:.1f}ms")
except Exception as e:
logger.error(f"Error running inference for {symbol}: {e}")
def _predict(self, symbol: str, features: np.ndarray) -> Dict:
"""Run model prediction"""
try:
model = self.models[symbol]
model.eval()
# Convert to tensor
features_tensor = torch.from_numpy(features).unsqueeze(0).to(self.device)
with torch.no_grad():
with torch.cuda.amp.autocast():
outputs = model(features_tensor)
# Extract predictions
price_probs = torch.softmax(outputs['price_logits'], dim=1)
direction = torch.argmax(price_probs, dim=1).item()
confidence = outputs['confidence'].item()
value = outputs['value'].item()
# Calculate predicted change based on direction and confidence
if direction == 2: # UP
predicted_change = confidence * 0.001 # Max 0.1% up
elif direction == 0: # DOWN
predicted_change = -confidence * 0.001 # Max 0.1% down
else: # SIDEWAYS
predicted_change = 0.0
return {
'direction': direction,
'confidence': confidence,
'change': predicted_change,
'value': value
}
except Exception as e:
logger.error(f"Error in prediction for {symbol}: {e}")
return {
'direction': 1, # SIDEWAYS
'confidence': 0.0,
'change': 0.0,
'value': 0.0
}
def _add_signal(self, symbol: str, prediction: PredictionResult):
"""Add confident prediction to signal accumulator"""
try:
accumulator = self.signal_accumulators[symbol]
accumulator.signals.append(prediction)
accumulator.confidence_sum += prediction.confidence
accumulator.total_predictions += 1
logger.debug(f"Added signal for {symbol}: {len(accumulator.signals)} total signals")
except Exception as e:
logger.error(f"Error adding signal for {symbol}: {e}")
async def _signal_processing_loop(self):
"""Process accumulated signals and generate trade decisions"""
logger.info("Starting signal processing loop")
while self.running:
try:
for symbol in self.symbols:
await self._process_signals(symbol)
await asyncio.sleep(0.1) # Process signals every 100ms
except Exception as e:
logger.error(f"Error in signal processing loop: {e}")
await asyncio.sleep(1)
async def _process_signals(self, symbol: str):
"""Process signals for a specific symbol and make trade decisions"""
try:
accumulator = self.signal_accumulators[symbol]
# Check if we have enough confident predictions
if len(accumulator.signals) < self.required_confident_predictions:
return
# Get recent signals
recent_signals = list(accumulator.signals)[-self.required_confident_predictions:]
# Check if all recent signals are in the same direction
directions = [signal.predicted_direction for signal in recent_signals]
confidences = [signal.confidence for signal in recent_signals]
# Count direction consensus
direction_counts = {0: 0, 1: 0, 2: 0} # DOWN, SIDEWAYS, UP
for direction in directions:
direction_counts[direction] += 1
# Find dominant direction
dominant_direction = max(direction_counts, key=direction_counts.get)
consensus_count = direction_counts[dominant_direction]
# Check if we have enough consensus
if consensus_count >= self.required_confident_predictions and dominant_direction != 1:
# We have consensus for action (not sideways)
avg_confidence = np.mean(confidences)
# Determine action
if dominant_direction == 2: # UP
action = 'BUY'
elif dominant_direction == 0: # DOWN
action = 'SELL'
else:
return # No action for sideways
# Execute trade signal
await self._execute_trade_signal(symbol, action, float(avg_confidence), recent_signals)
# Reset accumulator after trade signal
self._reset_accumulator(symbol)
except Exception as e:
logger.error(f"Error processing signals for {symbol}: {e}")
async def _execute_trade_signal(self, symbol: str, action: str, confidence: float, signals: List[PredictionResult]):
"""Execute a trade signal"""
try:
logger.info(f"Executing trade signal: {action} {symbol} with confidence {confidence:.3f}")
# Get current price
current_price = 0.0
if self.price_history[symbol]:
current_price = self.price_history[symbol][-1]['price']
# Create trade signal for emission
trade_signal = TradeSignal(
timestamp=datetime.now(),
symbol=symbol,
action=action,
confidence=confidence,
quantity=1.0, # Default quantity
price=current_price,
signals_count=len(signals),
reason=f"Consensus of {len(signals)} predictions"
)
# Emit trade signal to subscribers
await self._emit_trade_signal(trade_signal)
# Execute through trading executor if available
if self.trading_executor and current_price > 0:
success = self.trading_executor.execute_signal(
symbol=symbol,
action=action,
confidence=confidence,
current_price=current_price
)
if success:
logger.info(f"Trade executed successfully: {action} {symbol}")
# Schedule training with higher weight for trade closure
asyncio.create_task(self._train_on_trade_execution(symbol, signals, action, current_price))
else:
logger.warning(f"Trade execution failed: {action} {symbol}")
else:
logger.info(f"No trading executor available or price unknown for {symbol}")
except Exception as e:
logger.error(f"Error executing trade signal for {symbol}: {e}")
def _reset_accumulator(self, symbol: str):
"""Reset signal accumulator after trade execution"""
try:
accumulator = self.signal_accumulators[symbol]
accumulator.signals.clear()
accumulator.confidence_sum = 0.0
accumulator.last_reset_time = datetime.now()
logger.debug(f"Reset signal accumulator for {symbol}")
except Exception as e:
logger.error(f"Error resetting accumulator for {symbol}: {e}")
async def _training_loop(self):
"""Main training loop for real-time model updates"""
logger.info("Starting training loop")
while self.running:
try:
for symbol in self.symbols:
await self._train_symbol_model(symbol)
await asyncio.sleep(1.0) # Train every second
except Exception as e:
logger.error(f"Error in training loop: {e}")
await asyncio.sleep(5)
async def _train_symbol_model(self, symbol: str):
"""Train model for a specific symbol using recent predictions"""
try:
with self.training_lock:
# Check if we have enough data for training
predictions = list(self.prediction_history[symbol])
if len(predictions) < 10:
return
# Calculate rewards for recent predictions
self._calculate_rewards(symbol, predictions)
# Filter predictions with calculated rewards
training_predictions = [p for p in predictions if p.reward is not None]
if len(training_predictions) < 5:
return
# Prepare training batch
batch_size = min(32, len(training_predictions))
batch_predictions = training_predictions[-batch_size:]
# Train model
loss = await self._train_batch(symbol, batch_predictions)
# Update training stats
stats = self.training_stats[symbol]
stats['total_training_steps'] += 1
stats['average_loss'] = (
(stats['average_loss'] * (stats['total_training_steps'] - 1) + loss)
/ stats['total_training_steps']
)
stats['last_training_time'] = datetime.now()
# Calculate accuracy and confidence
accuracy = stats['successful_predictions'] / max(1, stats['total_predictions']) * 100
avg_confidence = sum(p.confidence for p in batch_predictions) / len(batch_predictions)
# Create training update for emission
training_update = TrainingUpdate(
timestamp=datetime.now(),
symbol=symbol,
epoch=stats['total_training_steps'],
loss=loss,
batch_size=batch_size,
learning_rate=self.optimizers[symbol].param_groups[0]['lr'],
accuracy=accuracy,
avg_confidence=avg_confidence
)
# Emit training update to subscribers
await self._emit_training_update(training_update)
logger.debug(f"Training {symbol}: loss={loss:.6f}, batch_size={batch_size}")
except Exception as e:
logger.error(f"Error training model for {symbol}: {e}")
def _calculate_rewards(self, symbol: str, predictions: List[PredictionResult]):
"""Calculate rewards for predictions based on actual price movements"""
try:
price_history = list(self.price_history[symbol])
if len(price_history) < 2:
return
for prediction in predictions:
if prediction.reward is not None:
continue # Already calculated
# Find actual price change after prediction
pred_time = prediction.timestamp
# Look for price data after prediction (with reasonable timeout)
future_prices = [
p for p in price_history
if p['timestamp'] > pred_time and
(p['timestamp'] - pred_time).total_seconds() <= 60 # 1 minute timeout
]
if not future_prices:
continue
# Find price at prediction time
past_prices = [
p for p in price_history
if abs((p['timestamp'] - pred_time).total_seconds()) <= 10 # 10 second window
]
if not past_prices:
continue
# Calculate actual price change
pred_price = past_prices[-1]['price']
future_price = future_prices[0]['price'] # Use first future price
actual_change = (future_price - pred_price) / pred_price
# Determine actual direction
if actual_change > 0.0005: # 0.05% threshold
actual_direction = 2 # UP
elif actual_change < -0.0005:
actual_direction = 0 # DOWN
else:
actual_direction = 1 # SIDEWAYS
# Calculate reward based on prediction accuracy
reward = self._calculate_prediction_reward(
prediction.predicted_direction,
actual_direction,
prediction.confidence,
prediction.predicted_change,
actual_change
)
# Update prediction
prediction.actual_direction = actual_direction
prediction.actual_change = actual_change
prediction.reward = reward
# Update training stats
stats = self.training_stats[symbol]
stats['total_predictions'] += 1
if reward > 0:
stats['successful_predictions'] += 1
except Exception as e:
logger.error(f"Error calculating rewards for {symbol}: {e}")
def _calculate_prediction_reward(self,
predicted_direction: int,
actual_direction: int,
confidence: float,
predicted_change: float,
actual_change: float,
current_pnl: float = 0.0,
position_duration: float = 0.0) -> float:
"""Calculate reward for a prediction with PnL-aware loss cutting optimization"""
try:
# Base reward for correct direction
if predicted_direction == actual_direction:
base_reward = 1.0
else:
base_reward = -1.0
# Scale by confidence
confidence_scaled_reward = base_reward * confidence
# Additional reward for magnitude accuracy
if predicted_direction != 1: # Not sideways
magnitude_accuracy = 1.0 - abs(predicted_change - actual_change) / max(abs(actual_change), 0.001)
magnitude_accuracy = max(0.0, magnitude_accuracy)
confidence_scaled_reward += magnitude_accuracy * 0.5
# Penalty for overconfident wrong predictions
if base_reward < 0 and confidence > 0.8:
confidence_scaled_reward *= 1.5 # Increase penalty
# === PnL-AWARE LOSS CUTTING REWARDS ===
pnl_reward = 0.0
# Reward cutting losses early (SIDEWAYS when losing)
if current_pnl < -10.0: # In significant loss
if predicted_direction == 1: # SIDEWAYS (exit signal)
# Reward cutting losses before they get worse
loss_cutting_bonus = min(1.0, abs(current_pnl) / 100.0) * confidence
pnl_reward += loss_cutting_bonus
elif predicted_direction != 1: # Continuing to trade while in loss
# Penalty for not cutting losses
pnl_reward -= 0.5 * confidence
# Reward protecting profits (SIDEWAYS when in profit and market turning)
elif current_pnl > 10.0: # In profit
if predicted_direction == 1 and base_reward > 0: # Correct SIDEWAYS prediction
# Reward protecting profits from reversal
profit_protection_bonus = min(0.5, current_pnl / 200.0) * confidence
pnl_reward += profit_protection_bonus
# Duration penalty for holding losing positions
if current_pnl < 0 and position_duration > 3600: # Losing for > 1 hour
duration_penalty = min(1.0, position_duration / 7200.0) * 0.3 # Up to 30% penalty
confidence_scaled_reward -= duration_penalty
# Severe penalty for letting small losses become big losses
if current_pnl < -50.0: # Large loss
drawdown_penalty = min(2.0, abs(current_pnl) / 100.0) * confidence
confidence_scaled_reward -= drawdown_penalty
# Total reward
total_reward = confidence_scaled_reward + pnl_reward
# Clamp final reward
return max(-5.0, min(5.0, float(total_reward)))
except Exception as e:
logger.error(f"Error calculating reward: {e}")
return 0.0
async def _train_batch(self, symbol: str, predictions: List[PredictionResult]) -> float:
"""Train model on a batch of predictions"""
try:
model = self.models[symbol]
optimizer = self.optimizers[symbol]
scaler = self.scalers[symbol]
model.train()
optimizer.zero_grad()
# Prepare batch data
features = torch.stack([
torch.from_numpy(p.features) for p in predictions
]).to(self.device)
# Targets
direction_targets = torch.tensor([
p.actual_direction for p in predictions
], dtype=torch.long).to(self.device)
value_targets = torch.tensor([
p.reward for p in predictions
], dtype=torch.float32).to(self.device)
# Forward pass with mixed precision
with torch.cuda.amp.autocast():
outputs = model(features)
# Calculate losses
direction_loss = nn.CrossEntropyLoss()(outputs['price_logits'], direction_targets)
value_loss = nn.MSELoss()(outputs['value'].squeeze(), value_targets)
# Confidence loss (encourage high confidence for correct predictions)
correct_predictions = (torch.argmax(outputs['price_logits'], dim=1) == direction_targets).float()
confidence_loss = nn.BCELoss()(outputs['confidence'].squeeze(), correct_predictions)
# Combined loss
total_loss = direction_loss + 0.5 * value_loss + 0.3 * confidence_loss
# Backward pass with gradient scaling
scaler.scale(total_loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
return total_loss.item()
except Exception as e:
logger.error(f"Error training batch for {symbol}: {e}")
return 0.0
async def _train_on_trade_execution(self, symbol: str, signals: List[PredictionResult],
action: str, price: float):
"""Train with higher weight when a trade is executed"""
try:
logger.info(f"Training on trade execution: {action} {symbol} at ${price:.2f}")
# Wait a bit to see trade outcome
await asyncio.sleep(30) # 30 seconds to see initial outcome
# Calculate actual outcome
current_prices = [p['price'] for p in list(self.price_history[symbol])[-5:]]
if len(current_prices) >= 2:
current_price = current_prices[-1]
entry_price = price
# Calculate P&L
if action == 'BUY':
pnl_ratio = (current_price - entry_price) / entry_price
elif action == 'SELL':
pnl_ratio = (entry_price - current_price) / entry_price
else:
pnl_ratio = 0.0
# Create enhanced reward for trade execution
trade_reward = pnl_ratio * 10.0 # Amplify trade outcomes
# Apply enhanced training weight to signals that led to trade
for signal in signals:
if signal.reward is None:
signal.reward = trade_reward
else:
signal.reward += trade_reward # Add to existing reward
logger.info(f"Trade outcome for {symbol}: P&L ratio={pnl_ratio:.4f}, "
f"enhanced reward={trade_reward:.4f}")
# Immediate training step with higher weight
if len(signals) > 0:
loss = await self._train_batch(symbol, signals[-3:]) # Train on last 3 signals
logger.info(f"Enhanced training loss for {symbol}: {loss:.6f}")
except Exception as e:
logger.error(f"Error in trade execution training for {symbol}: {e}")
async def _model_saving_loop(self):
"""Periodically save models"""
logger.info("Starting model saving loop")
while self.running:
try:
await asyncio.sleep(300) # Save every 5 minutes
self._save_models()
except Exception as e:
logger.error(f"Error in model saving loop: {e}")
await asyncio.sleep(60)
def _save_models(self):
"""Save all models to disk"""
try:
for symbol in self.symbols:
symbol_safe = symbol.replace('/', '_')
model_path = os.path.join(self.model_checkpoint_dir, f"{symbol_safe}_model.pt")
# Save model state
torch.save({
'model_state_dict': self.models[symbol].state_dict(),
'optimizer_state_dict': self.optimizers[symbol].state_dict(),
'training_stats': self.training_stats[symbol],
'inference_stats': self.inference_stats[symbol],
'timestamp': datetime.now().isoformat()
}, model_path)
logger.debug(f"Saved model for {symbol}")
except Exception as e:
logger.error(f"Error saving models: {e}")
def _load_models(self):
"""Load existing models from disk"""
try:
for symbol in self.symbols:
symbol_safe = symbol.replace('/', '_')
model_path = os.path.join(self.model_checkpoint_dir, f"{symbol_safe}_model.pt")
if os.path.exists(model_path):
checkpoint = torch.load(model_path, map_location=self.device)
self.models[symbol].load_state_dict(checkpoint['model_state_dict'])
self.optimizers[symbol].load_state_dict(checkpoint['optimizer_state_dict'])
if 'training_stats' in checkpoint:
self.training_stats[symbol].update(checkpoint['training_stats'])
if 'inference_stats' in checkpoint:
self.inference_stats[symbol].update(checkpoint['inference_stats'])
logger.info(f"Loaded existing model for {symbol}")
else:
logger.info(f"No existing model found for {symbol}, starting fresh")
except Exception as e:
logger.error(f"Error loading models: {e}")
def get_performance_stats(self) -> Dict[str, Any]:
"""Get comprehensive performance statistics"""
try:
stats = {
'symbols': self.symbols,
'training_stats': self.training_stats.copy(),
'inference_stats': self.inference_stats.copy(),
'signal_stats': {},
'model_info': {}
}
# Add signal accumulator stats
for symbol in self.symbols:
accumulator = self.signal_accumulators[symbol]
stats['signal_stats'][symbol] = {
'current_signals': len(accumulator.signals),
'confidence_sum': accumulator.confidence_sum,
'total_predictions': accumulator.total_predictions,
'successful_predictions': accumulator.successful_predictions,
'success_rate': (
accumulator.successful_predictions / max(1, accumulator.total_predictions)
)
}
# Add model parameter info
for symbol in self.symbols:
model = self.models[symbol]
total_params = sum(p.numel() for p in model.parameters())
stats['model_info'][symbol] = {
'total_parameters': total_params,
'trainable_parameters': sum(p.numel() for p in model.parameters() if p.requires_grad)
}
return stats
except Exception as e:
logger.error(f"Error getting performance stats: {e}")
return {}
# Example usage
async def main():
"""Example usage of RealtimeRLCOBTrader"""
from ..core.trading_executor import TradingExecutor
# Initialize trading executor (simulation mode)
trading_executor = TradingExecutor(simulation_mode=True)
# Initialize real-time RL trader
trader = RealtimeRLCOBTrader(
symbols=['BTC/USDT', 'ETH/USDT'],
trading_executor=trading_executor,
inference_interval_ms=200,
min_confidence_threshold=0.7,
required_confident_predictions=3
)
try:
# Start the trader
await trader.start()
# Run for demonstration
logger.info("Real-time RL COB Trader running...")
await asyncio.sleep(300) # Run for 5 minutes
# Print performance stats
stats = trader.get_performance_stats()
logger.info(f"Performance stats: {json.dumps(stats, indent=2, default=str)}")
finally:
# Stop the trader
await trader.stop()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(main())