rl cob subscription model

This commit is contained in:
Dobromir Popov
2025-06-24 19:07:42 +03:00
parent 6702a490dd
commit 97f7f54c30
6 changed files with 405 additions and 43 deletions

View File

@ -23,7 +23,7 @@ import torch.optim as optim
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Callable, Tuple
from collections import deque, defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, asdict
import json
import time
import threading
@ -66,6 +66,30 @@ class SignalAccumulator:
if self.last_reset_time is None:
self.last_reset_time = datetime.now()
@dataclass
class TrainingUpdate:
"""Training update event data"""
timestamp: datetime
symbol: str
epoch: int
loss: float
batch_size: int
learning_rate: float
accuracy: float
avg_confidence: float
@dataclass
class TradeSignal:
"""Trade signal event data"""
timestamp: datetime
symbol: str
action: str # 'BUY', 'SELL', 'HOLD'
confidence: float
quantity: float
price: float
signals_count: int
reason: str
class MassiveRLNetwork(nn.Module):
"""
Massive 1B+ parameter RL network optimized for real-time COB trading
@ -193,7 +217,7 @@ class MassiveRLNetwork(nn.Module):
class RealtimeRLCOBTrader:
"""
Real-time RL trader using COB data
Real-time RL trader using COB data with comprehensive subscriber system
"""
def __init__(self,
@ -231,9 +255,17 @@ class RealtimeRLCOBTrader:
)
self.scalers[symbol] = torch.cuda.amp.GradScaler()
# Subscriber system for real-time events
self.prediction_subscribers: List[Callable[[PredictionResult], None]] = []
self.training_subscribers: List[Callable[[TrainingUpdate], None]] = []
self.signal_subscribers: List[Callable[[TradeSignal], None]] = []
self.async_prediction_subscribers: List[Callable[[PredictionResult], Any]] = []
self.async_training_subscribers: List[Callable[[TrainingUpdate], Any]] = []
self.async_signal_subscribers: List[Callable[[TradeSignal], Any]] = []
# COB integration
self.cob_integration = COBIntegration(symbols=self.symbols)
self.cob_integration.add_dqn_callback(self._on_cob_update)
self.cob_integration.add_dqn_callback(self._on_cob_update_sync)
# Data storage for real-time training
self.prediction_history: Dict[str, deque] = {}
@ -280,6 +312,111 @@ class RealtimeRLCOBTrader:
logger.info(f"RealtimeRLCOBTrader initialized for symbols: {self.symbols}")
logger.info(f"Inference interval: {self.inference_interval_ms}ms")
logger.info(f"Required confident predictions: {self.required_confident_predictions}")
# Subscriber system methods
def add_prediction_subscriber(self, callback: Callable[[PredictionResult], None]):
"""Add a subscriber for prediction events"""
self.prediction_subscribers.append(callback)
logger.info(f"Added prediction subscriber, total: {len(self.prediction_subscribers)}")
def add_training_subscriber(self, callback: Callable[[TrainingUpdate], None]):
"""Add a subscriber for training events"""
self.training_subscribers.append(callback)
logger.info(f"Added training subscriber, total: {len(self.training_subscribers)}")
def add_signal_subscriber(self, callback: Callable[[TradeSignal], None]):
"""Add a subscriber for trade signal events"""
self.signal_subscribers.append(callback)
logger.info(f"Added signal subscriber, total: {len(self.signal_subscribers)}")
def add_async_prediction_subscriber(self, callback: Callable[[PredictionResult], Any]):
"""Add an async subscriber for prediction events"""
self.async_prediction_subscribers.append(callback)
logger.info(f"Added async prediction subscriber, total: {len(self.async_prediction_subscribers)}")
def add_async_training_subscriber(self, callback: Callable[[TrainingUpdate], Any]):
"""Add an async subscriber for training events"""
self.async_training_subscribers.append(callback)
logger.info(f"Added async training subscriber, total: {len(self.async_training_subscribers)}")
def add_async_signal_subscriber(self, callback: Callable[[TradeSignal], Any]):
"""Add an async subscriber for trade signal events"""
self.async_signal_subscribers.append(callback)
logger.info(f"Added async signal subscriber, total: {len(self.async_signal_subscribers)}")
async def _emit_prediction(self, prediction: PredictionResult):
"""Emit prediction to all subscribers"""
try:
# Sync subscribers
for callback in self.prediction_subscribers:
try:
callback(prediction)
except Exception as e:
logger.warning(f"Error in prediction subscriber: {e}")
# Async subscribers
for callback in self.async_prediction_subscribers:
try:
if asyncio.iscoroutinefunction(callback):
asyncio.create_task(callback(prediction))
else:
callback(prediction)
except Exception as e:
logger.warning(f"Error in async prediction subscriber: {e}")
except Exception as e:
logger.error(f"Error emitting prediction: {e}")
async def _emit_training_update(self, update: TrainingUpdate):
"""Emit training update to all subscribers"""
try:
# Sync subscribers
for callback in self.training_subscribers:
try:
callback(update)
except Exception as e:
logger.warning(f"Error in training subscriber: {e}")
# Async subscribers
for callback in self.async_training_subscribers:
try:
if asyncio.iscoroutinefunction(callback):
asyncio.create_task(callback(update))
else:
callback(update)
except Exception as e:
logger.warning(f"Error in async training subscriber: {e}")
except Exception as e:
logger.error(f"Error emitting training update: {e}")
async def _emit_trade_signal(self, signal: TradeSignal):
"""Emit trade signal to all subscribers"""
try:
# Sync subscribers
for callback in self.signal_subscribers:
try:
callback(signal)
except Exception as e:
logger.warning(f"Error in signal subscriber: {e}")
# Async subscribers
for callback in self.async_signal_subscribers:
try:
if asyncio.iscoroutinefunction(callback):
asyncio.create_task(callback(signal))
else:
callback(signal)
except Exception as e:
logger.warning(f"Error in async signal subscriber: {e}")
except Exception as e:
logger.error(f"Error emitting trade signal: {e}")
def _on_cob_update_sync(self, symbol: str, data: Dict):
"""Sync wrapper for async COB update handler"""
try:
# Schedule the async method
asyncio.create_task(self._on_cob_update(symbol, data))
except Exception as e:
logger.error(f"Error scheduling COB update for {symbol}: {e}")
async def start(self):
"""Start the real-time RL trader"""
@ -484,6 +621,9 @@ class RealtimeRLCOBTrader:
# Store prediction for later training
self.prediction_history[symbol].append(result)
# Emit prediction to subscribers
await self._emit_prediction(result)
# Add to signal accumulator if confident enough
if prediction['confidence'] >= self.min_confidence_threshold:
self._add_signal(symbol, result)
@ -606,7 +746,7 @@ class RealtimeRLCOBTrader:
return # No action for sideways
# Execute trade signal
await self._execute_trade_signal(symbol, action, avg_confidence, recent_signals)
await self._execute_trade_signal(symbol, action, float(avg_confidence), recent_signals)
# Reset accumulator after trade signal
self._reset_accumulator(symbol)
@ -624,6 +764,21 @@ class RealtimeRLCOBTrader:
if self.price_history[symbol]:
current_price = self.price_history[symbol][-1]['price']
# Create trade signal for emission
trade_signal = TradeSignal(
timestamp=datetime.now(),
symbol=symbol,
action=action,
confidence=confidence,
quantity=1.0, # Default quantity
price=current_price,
signals_count=len(signals),
reason=f"Consensus of {len(signals)} predictions"
)
# Emit trade signal to subscribers
await self._emit_trade_signal(trade_signal)
# Execute through trading executor if available
if self.trading_executor and current_price > 0:
success = self.trading_executor.execute_signal(
@ -707,6 +862,25 @@ class RealtimeRLCOBTrader:
)
stats['last_training_time'] = datetime.now()
# Calculate accuracy and confidence
accuracy = stats['successful_predictions'] / max(1, stats['total_predictions']) * 100
avg_confidence = sum(p.confidence for p in batch_predictions) / len(batch_predictions)
# Create training update for emission
training_update = TrainingUpdate(
timestamp=datetime.now(),
symbol=symbol,
epoch=stats['total_training_steps'],
loss=loss,
batch_size=batch_size,
learning_rate=self.optimizers[symbol].param_groups[0]['lr'],
accuracy=accuracy,
avg_confidence=avg_confidence
)
# Emit training update to subscribers
await self._emit_training_update(training_update)
logger.debug(f"Training {symbol}: loss={loss:.6f}, batch_size={batch_size}")
except Exception as e: