1154 lines
48 KiB
Python
1154 lines
48 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Real-time Reinforcement Learning COB Trader
|
|
|
|
A sophisticated real-time RL system that:
|
|
1. Uses COB (Consolidated Order Book) data for training a 1B parameter RL model
|
|
2. Performs inference every 200ms or when new data comes
|
|
3. Predicts next price moves in real-time
|
|
4. Trains continuously based on prediction success
|
|
5. Accumulates signals based on confidence
|
|
6. Issues trade signals after 3 confident and successful predictions
|
|
7. Trains with higher weight when closing trades
|
|
|
|
Integrates with existing gogo2 trading system architecture.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any, Callable, Tuple
|
|
from collections import deque, defaultdict
|
|
from dataclasses import dataclass, asdict
|
|
import json
|
|
import time
|
|
import threading
|
|
from threading import Lock
|
|
import pickle
|
|
import os
|
|
|
|
# Local imports
|
|
from .cob_integration import COBIntegration
|
|
from .trading_executor import TradingExecutor
|
|
from NN.models.cob_rl_model import MassiveRLNetwork, COBRLModelInterface
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class PredictionResult:
|
|
"""Result of a model prediction"""
|
|
timestamp: datetime
|
|
symbol: str
|
|
predicted_direction: int # 0=DOWN, 1=SIDEWAYS, 2=UP
|
|
confidence: float
|
|
predicted_change: float # Predicted price change %
|
|
features: np.ndarray
|
|
actual_direction: Optional[int] = None # Filled later for training
|
|
actual_change: Optional[float] = None # Filled later for training
|
|
reward: Optional[float] = None # Calculated reward for RL training
|
|
|
|
@dataclass
|
|
class SignalAccumulator:
|
|
"""Accumulates signals for trade decision making"""
|
|
symbol: str
|
|
signals: deque # Recent signals
|
|
confidence_sum: float = 0.0
|
|
successful_predictions: int = 0
|
|
total_predictions: int = 0
|
|
last_reset_time: Optional[datetime] = None
|
|
|
|
def __post_init__(self):
|
|
if self.signals is None:
|
|
self.signals = deque(maxlen=10)
|
|
if self.last_reset_time is None:
|
|
self.last_reset_time = datetime.now()
|
|
|
|
@dataclass
|
|
class TrainingUpdate:
|
|
"""Training update event data"""
|
|
timestamp: datetime
|
|
symbol: str
|
|
epoch: int
|
|
loss: float
|
|
batch_size: int
|
|
learning_rate: float
|
|
accuracy: float
|
|
avg_confidence: float
|
|
|
|
@dataclass
|
|
class TradeSignal:
|
|
"""Trade signal event data"""
|
|
timestamp: datetime
|
|
symbol: str
|
|
action: str # 'BUY', 'SELL', 'HOLD'
|
|
confidence: float
|
|
quantity: float
|
|
price: float
|
|
signals_count: int
|
|
reason: str
|
|
|
|
# MassiveRLNetwork is now imported from NN.models.cob_rl_model
|
|
|
|
class RealtimeRLCOBTrader:
|
|
"""
|
|
Real-time RL trader using COB data with comprehensive subscriber system
|
|
"""
|
|
|
|
def __init__(self,
|
|
symbols: Optional[List[str]] = None,
|
|
trading_executor: Optional[TradingExecutor] = None,
|
|
model_checkpoint_dir: str = "models/realtime_rl_cob",
|
|
inference_interval_ms: int = 200,
|
|
min_confidence_threshold: float = 0.35, # Lowered from 0.7 for more aggressive trading
|
|
required_confident_predictions: int = 3,
|
|
checkpoint_manager: Any = None):
|
|
|
|
self.symbols = symbols or ['BTC/USDT', 'ETH/USDT']
|
|
self.trading_executor = trading_executor
|
|
self.model_checkpoint_dir = model_checkpoint_dir
|
|
self.inference_interval_ms = inference_interval_ms
|
|
self.min_confidence_threshold = min_confidence_threshold
|
|
self.required_confident_predictions = required_confident_predictions
|
|
|
|
# Initialize CheckpointManager (either provided or get global instance)
|
|
if checkpoint_manager is None:
|
|
from utils.checkpoint_manager import get_checkpoint_manager
|
|
self.checkpoint_manager = get_checkpoint_manager()
|
|
else:
|
|
self.checkpoint_manager = checkpoint_manager
|
|
|
|
# Track start time for training duration calculation
|
|
self.start_time = datetime.now() # Initialize start_time
|
|
|
|
# Setup device
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
logger.info(f"Using device: {self.device}")
|
|
|
|
# Initialize models for each symbol
|
|
self.models: Dict[str, MassiveRLNetwork] = {}
|
|
self.optimizers: Dict[str, optim.AdamW] = {}
|
|
self.scalers: Dict[str, torch.cuda.amp.GradScaler] = {}
|
|
|
|
for symbol in self.symbols:
|
|
model = MassiveRLNetwork().to(self.device)
|
|
self.models[symbol] = model
|
|
self.optimizers[symbol] = optim.AdamW(
|
|
model.parameters(),
|
|
lr=1e-5, # Low learning rate for stability
|
|
weight_decay=1e-6,
|
|
betas=(0.9, 0.999)
|
|
)
|
|
self.scalers[symbol] = torch.cuda.amp.GradScaler()
|
|
|
|
# Subscriber system for real-time events
|
|
self.prediction_subscribers: List[Callable[[PredictionResult], None]] = []
|
|
self.training_subscribers: List[Callable[[TrainingUpdate], None]] = []
|
|
self.signal_subscribers: List[Callable[[TradeSignal], None]] = []
|
|
self.async_prediction_subscribers: List[Callable[[PredictionResult], Any]] = []
|
|
self.async_training_subscribers: List[Callable[[TrainingUpdate], Any]] = []
|
|
self.async_signal_subscribers: List[Callable[[TradeSignal], Any]] = []
|
|
|
|
# COB integration
|
|
self.cob_integration = COBIntegration(symbols=self.symbols)
|
|
self.cob_integration.add_dqn_callback(self._on_cob_update_sync)
|
|
|
|
# Data storage for real-time training
|
|
self.prediction_history: Dict[str, deque] = {}
|
|
self.feature_buffers: Dict[str, deque] = {}
|
|
self.price_history: Dict[str, deque] = {}
|
|
|
|
# Signal accumulation
|
|
self.signal_accumulators: Dict[str, SignalAccumulator] = {}
|
|
|
|
# Performance tracking
|
|
self.training_stats: Dict[str, Dict] = {}
|
|
self.inference_stats: Dict[str, Dict] = {}
|
|
|
|
# Initialize per symbol
|
|
for symbol in self.symbols:
|
|
self.prediction_history[symbol] = deque(maxlen=1000)
|
|
self.feature_buffers[symbol] = deque(maxlen=100)
|
|
self.price_history[symbol] = deque(maxlen=1000)
|
|
self.signal_accumulators[symbol] = SignalAccumulator(
|
|
symbol=symbol,
|
|
signals=deque(maxlen=self.required_confident_predictions * 2)
|
|
)
|
|
self.training_stats[symbol] = {
|
|
'total_predictions': 0,
|
|
'successful_predictions': 0,
|
|
'total_training_steps': 0,
|
|
'average_loss': 0.0,
|
|
'last_training_time': None
|
|
}
|
|
self.inference_stats[symbol] = {
|
|
'total_inferences': 0,
|
|
'average_inference_time_ms': 0.0,
|
|
'last_inference_time': None
|
|
}
|
|
|
|
# PnL tracking for loss cutting optimization
|
|
self.pnl_history: Dict[str, deque] = {
|
|
symbol: deque(maxlen=1000) for symbol in self.symbols
|
|
}
|
|
self.position_peak_pnl: Dict[str, float] = {symbol: 0.0 for symbol in self.symbols}
|
|
self.trade_history: Dict[str, List] = {symbol: [] for symbol in self.symbols}
|
|
|
|
# Threading
|
|
self.running = False
|
|
self.inference_lock = Lock()
|
|
self.training_lock = Lock()
|
|
|
|
# Create checkpoint directory
|
|
os.makedirs(self.model_checkpoint_dir, exist_ok=True)
|
|
|
|
logger.info(f"RealtimeRLCOBTrader initialized for symbols: {self.symbols}")
|
|
logger.info(f"Inference interval: {self.inference_interval_ms}ms")
|
|
logger.info(f"Required confident predictions: {self.required_confident_predictions}")
|
|
|
|
# Subscriber system methods
|
|
def add_prediction_subscriber(self, callback: Callable[[PredictionResult], None]):
|
|
"""Add a subscriber for prediction events"""
|
|
self.prediction_subscribers.append(callback)
|
|
logger.info(f"Added prediction subscriber, total: {len(self.prediction_subscribers)}")
|
|
|
|
def add_training_subscriber(self, callback: Callable[[TrainingUpdate], None]):
|
|
"""Add a subscriber for training events"""
|
|
self.training_subscribers.append(callback)
|
|
logger.info(f"Added training subscriber, total: {len(self.training_subscribers)}")
|
|
|
|
def add_signal_subscriber(self, callback: Callable[[TradeSignal], None]):
|
|
"""Add a subscriber for trade signal events"""
|
|
self.signal_subscribers.append(callback)
|
|
logger.info(f"Added signal subscriber, total: {len(self.signal_subscribers)}")
|
|
|
|
def add_async_prediction_subscriber(self, callback: Callable[[PredictionResult], Any]):
|
|
"""Add an async subscriber for prediction events"""
|
|
self.async_prediction_subscribers.append(callback)
|
|
logger.info(f"Added async prediction subscriber, total: {len(self.async_prediction_subscribers)}")
|
|
|
|
def add_async_training_subscriber(self, callback: Callable[[TrainingUpdate], Any]):
|
|
"""Add an async subscriber for training events"""
|
|
self.async_training_subscribers.append(callback)
|
|
logger.info(f"Added async training subscriber, total: {len(self.async_training_subscribers)}")
|
|
|
|
def add_async_signal_subscriber(self, callback: Callable[[TradeSignal], Any]):
|
|
"""Add an async subscriber for trade signal events"""
|
|
self.async_signal_subscribers.append(callback)
|
|
logger.info(f"Added async signal subscriber, total: {len(self.async_signal_subscribers)}")
|
|
|
|
async def _emit_prediction(self, prediction: PredictionResult):
|
|
"""Emit prediction to all subscribers"""
|
|
try:
|
|
# Sync subscribers
|
|
for callback in self.prediction_subscribers:
|
|
try:
|
|
callback(prediction)
|
|
except Exception as e:
|
|
logger.warning(f"Error in prediction subscriber: {e}")
|
|
|
|
# Async subscribers
|
|
for callback in self.async_prediction_subscribers:
|
|
try:
|
|
if asyncio.iscoroutinefunction(callback):
|
|
asyncio.create_task(callback(prediction))
|
|
else:
|
|
callback(prediction)
|
|
except Exception as e:
|
|
logger.warning(f"Error in async prediction subscriber: {e}")
|
|
except Exception as e:
|
|
logger.error(f"Error emitting prediction: {e}")
|
|
|
|
async def _emit_training_update(self, update: TrainingUpdate):
|
|
"""Emit training update to all subscribers"""
|
|
try:
|
|
# Sync subscribers
|
|
for callback in self.training_subscribers:
|
|
try:
|
|
callback(update)
|
|
except Exception as e:
|
|
logger.warning(f"Error in training subscriber: {e}")
|
|
|
|
# Async subscribers
|
|
for callback in self.async_training_subscribers:
|
|
try:
|
|
if asyncio.iscoroutinefunction(callback):
|
|
asyncio.create_task(callback(update))
|
|
else:
|
|
callback(update)
|
|
except Exception as e:
|
|
logger.warning(f"Error in async training subscriber: {e}")
|
|
except Exception as e:
|
|
logger.error(f"Error emitting training update: {e}")
|
|
|
|
async def _emit_trade_signal(self, signal: TradeSignal):
|
|
"""Emit trade signal to all subscribers"""
|
|
try:
|
|
# Sync subscribers
|
|
for callback in self.signal_subscribers:
|
|
try:
|
|
callback(signal)
|
|
except Exception as e:
|
|
logger.warning(f"Error in signal subscriber: {e}")
|
|
|
|
# Async subscribers
|
|
for callback in self.async_signal_subscribers:
|
|
try:
|
|
if asyncio.iscoroutinefunction(callback):
|
|
asyncio.create_task(callback(signal))
|
|
else:
|
|
callback(signal)
|
|
except Exception as e:
|
|
logger.warning(f"Error in async signal subscriber: {e}")
|
|
except Exception as e:
|
|
logger.error(f"Error emitting trade signal: {e}")
|
|
|
|
def _on_cob_update_sync(self, symbol: str, data: Dict):
|
|
"""Sync wrapper for async COB update handler"""
|
|
try:
|
|
# Schedule the async method
|
|
asyncio.create_task(self._on_cob_update(symbol, data))
|
|
except Exception as e:
|
|
logger.error(f"Error scheduling COB update for {symbol}: {e}")
|
|
|
|
async def start(self):
|
|
"""Start the real-time RL trader"""
|
|
logger.info("Starting Real-time RL COB Trader")
|
|
|
|
self.running = True
|
|
|
|
# Load existing models if available
|
|
self._load_models()
|
|
|
|
# Start COB integration
|
|
await self.cob_integration.start()
|
|
|
|
# Start inference loop
|
|
asyncio.create_task(self._inference_loop())
|
|
|
|
# Start training loop
|
|
asyncio.create_task(self._training_loop())
|
|
|
|
# Start signal processing loop
|
|
asyncio.create_task(self._signal_processing_loop())
|
|
|
|
# Start model saving loop
|
|
asyncio.create_task(self._model_saving_loop())
|
|
|
|
logger.info("Real-time RL COB Trader started successfully")
|
|
|
|
async def stop(self):
|
|
"""Stop the real-time RL trader"""
|
|
logger.info("Stopping Real-time RL COB Trader")
|
|
|
|
self.running = False
|
|
|
|
# Save models
|
|
self._save_models()
|
|
|
|
# Stop COB integration
|
|
await self.cob_integration.stop()
|
|
|
|
logger.info("Real-time RL COB Trader stopped")
|
|
|
|
async def _on_cob_update(self, symbol: str, data: Dict):
|
|
"""Handle COB updates for real-time inference"""
|
|
try:
|
|
if symbol not in self.symbols:
|
|
return
|
|
|
|
# Extract features from COB data
|
|
features = self._extract_features(symbol, data)
|
|
if features is None:
|
|
return
|
|
|
|
# Store in buffer
|
|
self.feature_buffers[symbol].append({
|
|
'timestamp': datetime.now(),
|
|
'features': features,
|
|
'raw_data': data
|
|
})
|
|
|
|
# Store price for later reward calculation
|
|
if 'state' in data:
|
|
price = self._extract_price_from_state(data['state'])
|
|
if price > 0:
|
|
self.price_history[symbol].append({
|
|
'timestamp': datetime.now(),
|
|
'price': price
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error handling COB update for {symbol}: {e}")
|
|
|
|
def _extract_features(self, symbol: str, data: Dict) -> Optional[np.ndarray]:
|
|
"""Extract features from COB data for model input"""
|
|
try:
|
|
# Get state from COB data
|
|
if 'state' not in data:
|
|
return None
|
|
|
|
state = data['state']
|
|
|
|
# Ensure we have the right feature size (2000 features)
|
|
if isinstance(state, np.ndarray):
|
|
features = state.flatten()
|
|
else:
|
|
features = np.array(state).flatten()
|
|
|
|
# Pad or truncate to exact size
|
|
target_size = 2000
|
|
if len(features) < target_size:
|
|
# Pad with zeros
|
|
padded = np.zeros(target_size)
|
|
padded[:len(features)] = features
|
|
features = padded
|
|
elif len(features) > target_size:
|
|
# Truncate
|
|
features = features[:target_size]
|
|
|
|
# Normalize features
|
|
features = self._normalize_features(features)
|
|
|
|
return features
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error extracting features for {symbol}: {e}")
|
|
return None
|
|
|
|
def _normalize_features(self, features: np.ndarray) -> np.ndarray:
|
|
"""Normalize features for model input"""
|
|
try:
|
|
# Clip extreme values
|
|
features = np.clip(features, -10.0, 10.0)
|
|
|
|
# Z-score normalization with robust statistics
|
|
median = np.median(features)
|
|
mad = np.median(np.abs(features - median))
|
|
if mad > 1e-6:
|
|
features = (features - median) / (mad * 1.4826)
|
|
|
|
# Final clipping
|
|
features = np.clip(features, -5.0, 5.0)
|
|
|
|
return features.astype(np.float32)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error normalizing features: {e}")
|
|
return features.astype(np.float32)
|
|
|
|
def _extract_price_from_state(self, state) -> float:
|
|
"""Extract current price from state data"""
|
|
try:
|
|
# Try different ways to extract price
|
|
if isinstance(state, np.ndarray) and len(state) > 0:
|
|
# Assume first few elements might contain price info
|
|
return float(state[0])
|
|
elif isinstance(state, (list, tuple)) and len(state) > 0:
|
|
return float(state[0])
|
|
else:
|
|
return 0.0
|
|
except:
|
|
return 0.0
|
|
|
|
async def _inference_loop(self):
|
|
"""Main inference loop - runs every 200ms or when new data arrives"""
|
|
logger.info("Starting inference loop")
|
|
|
|
while self.running:
|
|
try:
|
|
start_time = time.time()
|
|
|
|
# Run inference for all symbols
|
|
for symbol in self.symbols:
|
|
await self._run_inference(symbol)
|
|
|
|
# Calculate sleep time to maintain interval
|
|
elapsed_ms = (time.time() - start_time) * 1000
|
|
sleep_ms = max(0, self.inference_interval_ms - elapsed_ms)
|
|
|
|
if sleep_ms > 0:
|
|
await asyncio.sleep(sleep_ms / 1000)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in inference loop: {e}")
|
|
await asyncio.sleep(0.1)
|
|
|
|
async def _run_inference(self, symbol: str):
|
|
"""Run inference for a specific symbol"""
|
|
try:
|
|
with self.inference_lock:
|
|
# Check if we have recent features
|
|
if not self.feature_buffers[symbol]:
|
|
return
|
|
|
|
# Get latest features
|
|
latest_data = self.feature_buffers[symbol][-1]
|
|
features = latest_data['features']
|
|
timestamp = latest_data['timestamp']
|
|
|
|
# Run model inference
|
|
start_time = time.time()
|
|
prediction = self._predict(symbol, features)
|
|
inference_time_ms = (time.time() - start_time) * 1000
|
|
|
|
# Update inference stats
|
|
stats = self.inference_stats[symbol]
|
|
stats['total_inferences'] += 1
|
|
stats['average_inference_time_ms'] = (
|
|
(stats['average_inference_time_ms'] * (stats['total_inferences'] - 1) + inference_time_ms)
|
|
/ stats['total_inferences']
|
|
)
|
|
stats['last_inference_time'] = timestamp
|
|
|
|
# Create prediction result
|
|
result = PredictionResult(
|
|
timestamp=timestamp,
|
|
symbol=symbol,
|
|
predicted_direction=prediction['direction'],
|
|
confidence=prediction['confidence'],
|
|
predicted_change=prediction['change'],
|
|
features=features
|
|
)
|
|
|
|
# Store prediction for later training
|
|
self.prediction_history[symbol].append(result)
|
|
|
|
# Emit prediction to subscribers
|
|
await self._emit_prediction(result)
|
|
|
|
# Add to signal accumulator if confident enough
|
|
if prediction['confidence'] >= self.min_confidence_threshold:
|
|
self._add_signal(symbol, result)
|
|
|
|
logger.debug(f"Inference {symbol}: direction={prediction['direction']}, "
|
|
f"confidence={prediction['confidence']:.3f}, "
|
|
f"change={prediction['change']:.4f}, "
|
|
f"time={inference_time_ms:.1f}ms")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error running inference for {symbol}: {e}")
|
|
|
|
def _predict(self, symbol: str, features: np.ndarray) -> Dict:
|
|
"""Run model prediction"""
|
|
try:
|
|
model = self.models[symbol]
|
|
model.eval()
|
|
|
|
# Convert to tensor
|
|
features_tensor = torch.from_numpy(features).unsqueeze(0).to(self.device)
|
|
|
|
with torch.no_grad():
|
|
with torch.cuda.amp.autocast():
|
|
outputs = model(features_tensor)
|
|
|
|
# Extract predictions
|
|
price_probs = torch.softmax(outputs['price_logits'], dim=1)
|
|
direction = torch.argmax(price_probs, dim=1).item()
|
|
confidence = outputs['confidence'].item()
|
|
value = outputs['value'].item()
|
|
|
|
# Calculate predicted change based on direction and confidence
|
|
if direction == 2: # UP
|
|
predicted_change = confidence * 0.001 # Max 0.1% up
|
|
elif direction == 0: # DOWN
|
|
predicted_change = -confidence * 0.001 # Max 0.1% down
|
|
else: # SIDEWAYS
|
|
predicted_change = 0.0
|
|
|
|
return {
|
|
'direction': direction,
|
|
'confidence': confidence,
|
|
'change': predicted_change,
|
|
'value': value
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in prediction for {symbol}: {e}")
|
|
return {
|
|
'direction': 1, # SIDEWAYS
|
|
'confidence': 0.0,
|
|
'change': 0.0,
|
|
'value': 0.0
|
|
}
|
|
|
|
def _add_signal(self, symbol: str, prediction: PredictionResult):
|
|
"""Add confident prediction to signal accumulator"""
|
|
try:
|
|
accumulator = self.signal_accumulators[symbol]
|
|
accumulator.signals.append(prediction)
|
|
accumulator.confidence_sum += prediction.confidence
|
|
accumulator.total_predictions += 1
|
|
|
|
logger.debug(f"Added signal for {symbol}: {len(accumulator.signals)} total signals")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding signal for {symbol}: {e}")
|
|
|
|
async def _signal_processing_loop(self):
|
|
"""Process accumulated signals and generate trade decisions"""
|
|
logger.info("Starting signal processing loop")
|
|
|
|
while self.running:
|
|
try:
|
|
for symbol in self.symbols:
|
|
await self._process_signals(symbol)
|
|
|
|
await asyncio.sleep(0.1) # Process signals every 100ms
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in signal processing loop: {e}")
|
|
await asyncio.sleep(1)
|
|
|
|
async def _process_signals(self, symbol: str):
|
|
"""Process signals for a specific symbol and make trade decisions"""
|
|
try:
|
|
accumulator = self.signal_accumulators[symbol]
|
|
|
|
# Check if we have enough confident predictions
|
|
if len(accumulator.signals) < self.required_confident_predictions:
|
|
return
|
|
|
|
# Get recent signals
|
|
recent_signals = list(accumulator.signals)[-self.required_confident_predictions:]
|
|
|
|
# Check if all recent signals are in the same direction
|
|
directions = [signal.predicted_direction for signal in recent_signals]
|
|
confidences = [signal.confidence for signal in recent_signals]
|
|
|
|
# Count direction consensus
|
|
direction_counts = {0: 0, 1: 0, 2: 0} # DOWN, SIDEWAYS, UP
|
|
for direction in directions:
|
|
direction_counts[direction] += 1
|
|
|
|
# Find dominant direction
|
|
dominant_direction = max(direction_counts, key=direction_counts.get)
|
|
consensus_count = direction_counts[dominant_direction]
|
|
|
|
# Check if we have enough consensus
|
|
if consensus_count >= self.required_confident_predictions and dominant_direction != 1:
|
|
# We have consensus for action (not sideways)
|
|
avg_confidence = np.mean(confidences)
|
|
|
|
# Determine action
|
|
if dominant_direction == 2: # UP
|
|
action = 'BUY'
|
|
elif dominant_direction == 0: # DOWN
|
|
action = 'SELL'
|
|
else:
|
|
return # No action for sideways
|
|
|
|
# Execute trade signal
|
|
await self._execute_trade_signal(symbol, action, float(avg_confidence), recent_signals)
|
|
|
|
# Reset accumulator after trade signal
|
|
self._reset_accumulator(symbol)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing signals for {symbol}: {e}")
|
|
|
|
async def _execute_trade_signal(self, symbol: str, action: str, confidence: float, signals: List[PredictionResult]):
|
|
"""Execute a trade signal"""
|
|
try:
|
|
logger.info(f"Executing trade signal: {action} {symbol} with confidence {confidence:.3f}")
|
|
|
|
# Get current price
|
|
current_price = 0.0
|
|
if self.price_history[symbol]:
|
|
current_price = self.price_history[symbol][-1]['price']
|
|
|
|
# Create trade signal for emission
|
|
trade_signal = TradeSignal(
|
|
timestamp=datetime.now(),
|
|
symbol=symbol,
|
|
action=action,
|
|
confidence=confidence,
|
|
quantity=1.0, # Default quantity
|
|
price=current_price,
|
|
signals_count=len(signals),
|
|
reason=f"Consensus of {len(signals)} predictions"
|
|
)
|
|
|
|
# Emit trade signal to subscribers
|
|
await self._emit_trade_signal(trade_signal)
|
|
|
|
# Execute through trading executor if available
|
|
if self.trading_executor and current_price > 0:
|
|
success = self.trading_executor.execute_signal(
|
|
symbol=symbol,
|
|
action=action,
|
|
confidence=confidence,
|
|
current_price=current_price
|
|
)
|
|
|
|
if success:
|
|
logger.info(f"Trade executed successfully: {action} {symbol}")
|
|
|
|
# Schedule training with higher weight for trade closure
|
|
asyncio.create_task(self._train_on_trade_execution(symbol, signals, action, current_price))
|
|
else:
|
|
logger.warning(f"Trade execution failed: {action} {symbol}")
|
|
else:
|
|
logger.info(f"No trading executor available or price unknown for {symbol}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error executing trade signal for {symbol}: {e}")
|
|
|
|
def _reset_accumulator(self, symbol: str):
|
|
"""Reset signal accumulator after trade execution"""
|
|
try:
|
|
accumulator = self.signal_accumulators[symbol]
|
|
accumulator.signals.clear()
|
|
accumulator.confidence_sum = 0.0
|
|
accumulator.last_reset_time = datetime.now()
|
|
|
|
logger.debug(f"Reset signal accumulator for {symbol}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error resetting accumulator for {symbol}: {e}")
|
|
|
|
async def _training_loop(self):
|
|
"""Main training loop for real-time model updates"""
|
|
logger.info("Starting training loop")
|
|
|
|
while self.running:
|
|
try:
|
|
for symbol in self.symbols:
|
|
await self._train_symbol_model(symbol)
|
|
|
|
await asyncio.sleep(1.0) # Train every second
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in training loop: {e}")
|
|
await asyncio.sleep(5)
|
|
|
|
async def _train_symbol_model(self, symbol: str):
|
|
"""Train model for a specific symbol using recent predictions"""
|
|
try:
|
|
with self.training_lock:
|
|
# Check if we have enough data for training
|
|
predictions = list(self.prediction_history[symbol])
|
|
if len(predictions) < 10:
|
|
return
|
|
|
|
# Calculate rewards for recent predictions
|
|
self._calculate_rewards(symbol, predictions)
|
|
|
|
# Filter predictions with calculated rewards
|
|
training_predictions = [p for p in predictions if p.reward is not None]
|
|
if len(training_predictions) < 5:
|
|
return
|
|
|
|
# Prepare training batch
|
|
batch_size = min(32, len(training_predictions))
|
|
batch_predictions = training_predictions[-batch_size:]
|
|
|
|
# Train model
|
|
loss = await self._train_batch(symbol, batch_predictions)
|
|
|
|
# Update training stats
|
|
stats = self.training_stats[symbol]
|
|
stats['total_training_steps'] += 1
|
|
stats['average_loss'] = (
|
|
(stats['average_loss'] * (stats['total_training_steps'] - 1) + loss)
|
|
/ stats['total_training_steps']
|
|
)
|
|
stats['last_training_time'] = datetime.now()
|
|
|
|
# Calculate accuracy and confidence
|
|
accuracy = stats['successful_predictions'] / max(1, stats['total_predictions']) * 100
|
|
avg_confidence = sum(p.confidence for p in batch_predictions) / len(batch_predictions)
|
|
|
|
# Create training update for emission
|
|
training_update = TrainingUpdate(
|
|
timestamp=datetime.now(),
|
|
symbol=symbol,
|
|
epoch=stats['total_training_steps'],
|
|
loss=loss,
|
|
batch_size=batch_size,
|
|
learning_rate=self.optimizers[symbol].param_groups[0]['lr'],
|
|
accuracy=accuracy,
|
|
avg_confidence=avg_confidence
|
|
)
|
|
|
|
# Emit training update to subscribers
|
|
await self._emit_training_update(training_update)
|
|
|
|
logger.debug(f"Training {symbol}: loss={loss:.6f}, batch_size={batch_size}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training model for {symbol}: {e}")
|
|
|
|
def _calculate_rewards(self, symbol: str, predictions: List[PredictionResult]):
|
|
"""Calculate rewards for predictions based on actual price movements"""
|
|
try:
|
|
price_history = list(self.price_history[symbol])
|
|
if len(price_history) < 2:
|
|
return
|
|
|
|
for prediction in predictions:
|
|
if prediction.reward is not None:
|
|
continue # Already calculated
|
|
|
|
# Find actual price change after prediction
|
|
pred_time = prediction.timestamp
|
|
|
|
# Look for price data after prediction (with reasonable timeout)
|
|
future_prices = [
|
|
p for p in price_history
|
|
if p['timestamp'] > pred_time and
|
|
(p['timestamp'] - pred_time).total_seconds() <= 60 # 1 minute timeout
|
|
]
|
|
|
|
if not future_prices:
|
|
continue
|
|
|
|
# Find price at prediction time
|
|
past_prices = [
|
|
p for p in price_history
|
|
if abs((p['timestamp'] - pred_time).total_seconds()) <= 10 # 10 second window
|
|
]
|
|
|
|
if not past_prices:
|
|
continue
|
|
|
|
# Calculate actual price change
|
|
pred_price = past_prices[-1]['price']
|
|
future_price = future_prices[0]['price'] # Use first future price
|
|
|
|
actual_change = (future_price - pred_price) / pred_price
|
|
|
|
# Determine actual direction
|
|
if actual_change > 0.0005: # 0.05% threshold
|
|
actual_direction = 2 # UP
|
|
elif actual_change < -0.0005:
|
|
actual_direction = 0 # DOWN
|
|
else:
|
|
actual_direction = 1 # SIDEWAYS
|
|
|
|
# Calculate reward based on prediction accuracy
|
|
prediction.reward = self._calculate_prediction_reward(
|
|
symbol=symbol,
|
|
predicted_direction=prediction.predicted_direction,
|
|
actual_direction=actual_direction,
|
|
confidence=prediction.confidence,
|
|
predicted_change=prediction.predicted_change,
|
|
actual_change=actual_change
|
|
)
|
|
|
|
# Update training stats
|
|
stats = self.training_stats[symbol]
|
|
stats['total_predictions'] += 1
|
|
if prediction.reward > 0:
|
|
stats['successful_predictions'] += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating rewards for {symbol}: {e}")
|
|
|
|
def _calculate_prediction_reward(self,
|
|
symbol: str,
|
|
predicted_direction: int,
|
|
actual_direction: int,
|
|
confidence: float,
|
|
predicted_change: float,
|
|
actual_change: float,
|
|
current_pnl: float = 0.0,
|
|
position_duration: float = 0.0) -> float:
|
|
"""Calculate reward based on prediction accuracy and actual price movement"""
|
|
reward = 0.0
|
|
|
|
# Base reward for correct direction prediction
|
|
if predicted_direction == actual_direction:
|
|
reward += 1.0 * confidence # Reward scales with confidence
|
|
else:
|
|
reward -= 0.5 # Penalize incorrect predictions
|
|
|
|
# Reward for predicting large changes correctly (proportional to actual change)
|
|
if predicted_direction == actual_direction and abs(predicted_change) > 0.001:
|
|
reward += abs(actual_change) * 5.0 # Amplify reward for significant moves
|
|
|
|
# Penalize for large predicted changes that are wrong
|
|
if predicted_direction != actual_direction and abs(predicted_change) > 0.001:
|
|
reward -= abs(predicted_change) * 2.0
|
|
|
|
# Add reward for PnL (realized or unrealized)
|
|
reward += current_pnl * 0.1 # Small reward for PnL, adjusted by a factor
|
|
|
|
# Dynamic adjustment based on recent PnL (loss cutting incentive)
|
|
if self.pnl_history[symbol]:
|
|
latest_pnl_entry = self.pnl_history[symbol][-1] # Get the latest PnL entry
|
|
# Ensure latest_pnl_entry is a dict and has 'pnl' key, otherwise default to 0.0
|
|
latest_pnl_value = latest_pnl_entry.get('pnl', 0.0) if isinstance(latest_pnl_entry, dict) else 0.0
|
|
|
|
# Incentivize closing losing trades early
|
|
if latest_pnl_value < 0 and position_duration > 60: # If losing position open for > 60s
|
|
# More aggressively penalize holding losing positions, or reward closing them
|
|
reward -= (abs(latest_pnl_value) * 0.2) # Increased penalty for sustained losses
|
|
|
|
# Discourage taking new positions if overall PnL is negative or volatile
|
|
# This requires a more complex calculation of overall PnL, potentially average of last N trades
|
|
# For simplicity, let's use the 'best_pnl' to decide if we are in a good state to trade
|
|
|
|
# Calculate the current best PnL from history, ensuring it's not empty
|
|
pnl_values = [entry.get('pnl', 0.0) for entry in self.pnl_history[symbol] if isinstance(entry, dict)]
|
|
if not pnl_values:
|
|
best_pnl = 0.0
|
|
else:
|
|
best_pnl = max(pnl_values)
|
|
|
|
if best_pnl < 0.0: # If recent best PnL is negative, reduce reward for new trades
|
|
reward -= 0.1 # Small penalty for trading in a losing streak
|
|
|
|
return reward
|
|
|
|
async def _train_batch(self, symbol: str, predictions: List[PredictionResult]) -> float:
|
|
"""Train model on a batch of predictions"""
|
|
try:
|
|
model = self.models[symbol]
|
|
optimizer = self.optimizers[symbol]
|
|
scaler = self.scalers[symbol]
|
|
|
|
model.train()
|
|
optimizer.zero_grad()
|
|
|
|
# Prepare batch data
|
|
features = torch.stack([
|
|
torch.from_numpy(p.features) for p in predictions
|
|
]).to(self.device)
|
|
|
|
# Targets
|
|
direction_targets = torch.tensor([
|
|
p.actual_direction for p in predictions
|
|
], dtype=torch.long).to(self.device)
|
|
|
|
value_targets = torch.tensor([
|
|
p.reward for p in predictions
|
|
], dtype=torch.float32).to(self.device)
|
|
|
|
# Forward pass with mixed precision
|
|
with torch.cuda.amp.autocast():
|
|
outputs = model(features)
|
|
|
|
# Calculate losses
|
|
direction_loss = nn.CrossEntropyLoss()(outputs['price_logits'], direction_targets)
|
|
value_loss = nn.MSELoss()(outputs['value'].squeeze(), value_targets)
|
|
|
|
# Confidence loss (encourage high confidence for correct predictions)
|
|
correct_predictions = (torch.argmax(outputs['price_logits'], dim=1) == direction_targets).float()
|
|
confidence_loss = nn.BCELoss()(outputs['confidence'].squeeze(), correct_predictions)
|
|
|
|
# Combined loss
|
|
total_loss = direction_loss + 0.5 * value_loss + 0.3 * confidence_loss
|
|
|
|
# Backward pass with gradient scaling
|
|
scaler.scale(total_loss).backward()
|
|
scaler.unscale_(optimizer)
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
|
|
return total_loss.item()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training batch for {symbol}: {e}")
|
|
return 0.0
|
|
|
|
async def _train_on_trade_execution(self, symbol: str, signals: List[PredictionResult],
|
|
action: str, price: float):
|
|
"""Train with higher weight when a trade is executed"""
|
|
try:
|
|
logger.info(f"Training on trade execution: {action} {symbol} at ${price:.2f}")
|
|
|
|
# Wait a bit to see trade outcome
|
|
await asyncio.sleep(30) # 30 seconds to see initial outcome
|
|
|
|
# Calculate actual outcome
|
|
current_prices = [p['price'] for p in list(self.price_history[symbol])[-5:]]
|
|
if len(current_prices) >= 2:
|
|
current_price = current_prices[-1]
|
|
entry_price = price
|
|
|
|
# Calculate P&L
|
|
if action == 'BUY':
|
|
pnl_ratio = (current_price - entry_price) / entry_price
|
|
elif action == 'SELL':
|
|
pnl_ratio = (entry_price - current_price) / entry_price
|
|
else:
|
|
pnl_ratio = 0.0
|
|
|
|
# Create enhanced reward for trade execution
|
|
trade_reward = pnl_ratio * 10.0 # Amplify trade outcomes
|
|
|
|
# Apply enhanced training weight to signals that led to trade
|
|
for signal in signals:
|
|
if signal.reward is None:
|
|
signal.reward = trade_reward
|
|
else:
|
|
signal.reward += trade_reward # Add to existing reward
|
|
|
|
logger.info(f"Trade outcome for {symbol}: P&L ratio={pnl_ratio:.4f}, "
|
|
f"enhanced reward={trade_reward:.4f}")
|
|
|
|
# Immediate training step with higher weight
|
|
if len(signals) > 0:
|
|
loss = await self._train_batch(symbol, signals[-3:]) # Train on last 3 signals
|
|
logger.info(f"Enhanced training loss for {symbol}: {loss:.6f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in trade execution training for {symbol}: {e}")
|
|
|
|
async def _model_saving_loop(self):
|
|
"""Periodically save models"""
|
|
logger.info("Starting model saving loop")
|
|
|
|
while self.running:
|
|
try:
|
|
await asyncio.sleep(300) # Save every 5 minutes
|
|
self._save_models()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in model saving loop: {e}")
|
|
await asyncio.sleep(60)
|
|
|
|
def _save_models(self):
|
|
"""Save all models to disk using CheckpointManager"""
|
|
try:
|
|
for symbol in self.symbols:
|
|
model_name = f"cob_rl_{symbol.replace('/', '_').lower()}" # Standardize model name for CheckpointManager
|
|
|
|
# Prepare performance metrics for CheckpointManager
|
|
performance_metrics = {
|
|
'loss': self.training_stats[symbol].get('average_loss', 0.0),
|
|
'reward': self.training_stats[symbol].get('average_reward', 0.0), # Assuming average_reward is tracked
|
|
'accuracy': self.training_stats[symbol].get('average_accuracy', 0.0), # Assuming average_accuracy is tracked
|
|
}
|
|
if self.trading_executor: # Add check for trading_executor
|
|
daily_stats = self.trading_executor.get_daily_stats()
|
|
performance_metrics['pnl'] = daily_stats.get('total_pnl', 0.0) # Example, get actual pnl
|
|
performance_metrics['training_samples'] = self.training_stats[symbol].get('total_training_steps', 0)
|
|
|
|
# Prepare training metadata for CheckpointManager
|
|
training_metadata = {
|
|
'total_parameters': sum(p.numel() for p in self.models[symbol].parameters()),
|
|
'epoch': self.training_stats[symbol].get('total_training_steps', 0), # Using total_training_steps as pseudo-epoch
|
|
'training_time_hours': (datetime.now() - self.start_time).total_seconds() / 3600
|
|
}
|
|
|
|
self.checkpoint_manager.save_checkpoint(
|
|
model=self.models[symbol],
|
|
model_name=model_name,
|
|
model_type='COB_RL', # Specify model type
|
|
performance_metrics=performance_metrics,
|
|
training_metadata=training_metadata
|
|
)
|
|
|
|
logger.debug(f"Saved model for {symbol}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving models: {e}")
|
|
|
|
def _load_models(self):
|
|
"""Load existing models from disk using CheckpointManager"""
|
|
try:
|
|
for symbol in self.symbols:
|
|
model_name = f"cob_rl_{symbol.replace('/', '_').lower()}" # Standardize model name for CheckpointManager
|
|
|
|
loaded_checkpoint = self.checkpoint_manager.load_best_checkpoint(model_name)
|
|
|
|
if loaded_checkpoint:
|
|
model_path, metadata = loaded_checkpoint
|
|
checkpoint = torch.load(model_path, map_location=self.device)
|
|
|
|
self.models[symbol].load_state_dict(checkpoint['model_state_dict'])
|
|
self.optimizers[symbol].load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
|
if 'training_stats' in checkpoint:
|
|
self.training_stats[symbol].update(checkpoint['training_stats'])
|
|
if 'inference_stats' in checkpoint:
|
|
self.inference_stats[symbol].update(checkpoint['inference_stats'])
|
|
|
|
logger.info(f"Loaded existing model for {symbol} from checkpoint: {metadata.checkpoint_id}")
|
|
else:
|
|
logger.info(f"No existing model found for {symbol} via CheckpointManager, starting fresh.")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading models: {e}")
|
|
|
|
def get_performance_stats(self) -> Dict[str, Any]:
|
|
"""Get comprehensive performance statistics"""
|
|
try:
|
|
stats = {
|
|
'symbols': self.symbols,
|
|
'training_stats': self.training_stats.copy(),
|
|
'inference_stats': self.inference_stats.copy(),
|
|
'signal_stats': {},
|
|
'model_info': {}
|
|
}
|
|
|
|
# Add signal accumulator stats
|
|
for symbol in self.symbols:
|
|
accumulator = self.signal_accumulators[symbol]
|
|
stats['signal_stats'][symbol] = {
|
|
'current_signals': len(accumulator.signals),
|
|
'confidence_sum': accumulator.confidence_sum,
|
|
'total_predictions': accumulator.total_predictions,
|
|
'successful_predictions': accumulator.successful_predictions,
|
|
'success_rate': (
|
|
accumulator.successful_predictions / max(1, accumulator.total_predictions)
|
|
)
|
|
}
|
|
|
|
# Add model parameter info
|
|
for symbol in self.symbols:
|
|
model = self.models[symbol]
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
stats['model_info'][symbol] = {
|
|
'total_parameters': total_params,
|
|
'trainable_parameters': sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
}
|
|
|
|
return stats
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting performance stats: {e}")
|
|
return {}
|
|
|
|
# Example usage
|
|
async def main():
|
|
"""Example usage of RealtimeRLCOBTrader"""
|
|
from ..core.trading_executor import TradingExecutor
|
|
|
|
# Initialize trading executor (simulation mode)
|
|
trading_executor = TradingExecutor()
|
|
|
|
# Initialize real-time RL trader
|
|
trader = RealtimeRLCOBTrader(
|
|
symbols=['BTC/USDT', 'ETH/USDT'],
|
|
trading_executor=trading_executor,
|
|
inference_interval_ms=200,
|
|
min_confidence_threshold=0.7,
|
|
required_confident_predictions=3
|
|
)
|
|
|
|
try:
|
|
# Start the trader
|
|
await trader.start()
|
|
|
|
# Run for demonstration
|
|
logger.info("Real-time RL COB Trader running...")
|
|
await asyncio.sleep(300) # Run for 5 minutes
|
|
|
|
# Print performance stats
|
|
stats = trader.get_performance_stats()
|
|
logger.info(f"Performance stats: {json.dumps(stats, indent=2, default=str)}")
|
|
|
|
finally:
|
|
# Stop the trader
|
|
await trader.stop()
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
asyncio.run(main()) |