rl cob subscription model
This commit is contained in:
@ -23,7 +23,7 @@ import torch.optim as optim
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Callable, Tuple
|
||||
from collections import deque, defaultdict
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, asdict
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
@ -66,6 +66,30 @@ class SignalAccumulator:
|
||||
if self.last_reset_time is None:
|
||||
self.last_reset_time = datetime.now()
|
||||
|
||||
@dataclass
|
||||
class TrainingUpdate:
|
||||
"""Training update event data"""
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
epoch: int
|
||||
loss: float
|
||||
batch_size: int
|
||||
learning_rate: float
|
||||
accuracy: float
|
||||
avg_confidence: float
|
||||
|
||||
@dataclass
|
||||
class TradeSignal:
|
||||
"""Trade signal event data"""
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
action: str # 'BUY', 'SELL', 'HOLD'
|
||||
confidence: float
|
||||
quantity: float
|
||||
price: float
|
||||
signals_count: int
|
||||
reason: str
|
||||
|
||||
class MassiveRLNetwork(nn.Module):
|
||||
"""
|
||||
Massive 1B+ parameter RL network optimized for real-time COB trading
|
||||
@ -193,7 +217,7 @@ class MassiveRLNetwork(nn.Module):
|
||||
|
||||
class RealtimeRLCOBTrader:
|
||||
"""
|
||||
Real-time RL trader using COB data
|
||||
Real-time RL trader using COB data with comprehensive subscriber system
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -231,9 +255,17 @@ class RealtimeRLCOBTrader:
|
||||
)
|
||||
self.scalers[symbol] = torch.cuda.amp.GradScaler()
|
||||
|
||||
# Subscriber system for real-time events
|
||||
self.prediction_subscribers: List[Callable[[PredictionResult], None]] = []
|
||||
self.training_subscribers: List[Callable[[TrainingUpdate], None]] = []
|
||||
self.signal_subscribers: List[Callable[[TradeSignal], None]] = []
|
||||
self.async_prediction_subscribers: List[Callable[[PredictionResult], Any]] = []
|
||||
self.async_training_subscribers: List[Callable[[TrainingUpdate], Any]] = []
|
||||
self.async_signal_subscribers: List[Callable[[TradeSignal], Any]] = []
|
||||
|
||||
# COB integration
|
||||
self.cob_integration = COBIntegration(symbols=self.symbols)
|
||||
self.cob_integration.add_dqn_callback(self._on_cob_update)
|
||||
self.cob_integration.add_dqn_callback(self._on_cob_update_sync)
|
||||
|
||||
# Data storage for real-time training
|
||||
self.prediction_history: Dict[str, deque] = {}
|
||||
@ -280,6 +312,111 @@ class RealtimeRLCOBTrader:
|
||||
logger.info(f"RealtimeRLCOBTrader initialized for symbols: {self.symbols}")
|
||||
logger.info(f"Inference interval: {self.inference_interval_ms}ms")
|
||||
logger.info(f"Required confident predictions: {self.required_confident_predictions}")
|
||||
|
||||
# Subscriber system methods
|
||||
def add_prediction_subscriber(self, callback: Callable[[PredictionResult], None]):
|
||||
"""Add a subscriber for prediction events"""
|
||||
self.prediction_subscribers.append(callback)
|
||||
logger.info(f"Added prediction subscriber, total: {len(self.prediction_subscribers)}")
|
||||
|
||||
def add_training_subscriber(self, callback: Callable[[TrainingUpdate], None]):
|
||||
"""Add a subscriber for training events"""
|
||||
self.training_subscribers.append(callback)
|
||||
logger.info(f"Added training subscriber, total: {len(self.training_subscribers)}")
|
||||
|
||||
def add_signal_subscriber(self, callback: Callable[[TradeSignal], None]):
|
||||
"""Add a subscriber for trade signal events"""
|
||||
self.signal_subscribers.append(callback)
|
||||
logger.info(f"Added signal subscriber, total: {len(self.signal_subscribers)}")
|
||||
|
||||
def add_async_prediction_subscriber(self, callback: Callable[[PredictionResult], Any]):
|
||||
"""Add an async subscriber for prediction events"""
|
||||
self.async_prediction_subscribers.append(callback)
|
||||
logger.info(f"Added async prediction subscriber, total: {len(self.async_prediction_subscribers)}")
|
||||
|
||||
def add_async_training_subscriber(self, callback: Callable[[TrainingUpdate], Any]):
|
||||
"""Add an async subscriber for training events"""
|
||||
self.async_training_subscribers.append(callback)
|
||||
logger.info(f"Added async training subscriber, total: {len(self.async_training_subscribers)}")
|
||||
|
||||
def add_async_signal_subscriber(self, callback: Callable[[TradeSignal], Any]):
|
||||
"""Add an async subscriber for trade signal events"""
|
||||
self.async_signal_subscribers.append(callback)
|
||||
logger.info(f"Added async signal subscriber, total: {len(self.async_signal_subscribers)}")
|
||||
|
||||
async def _emit_prediction(self, prediction: PredictionResult):
|
||||
"""Emit prediction to all subscribers"""
|
||||
try:
|
||||
# Sync subscribers
|
||||
for callback in self.prediction_subscribers:
|
||||
try:
|
||||
callback(prediction)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in prediction subscriber: {e}")
|
||||
|
||||
# Async subscribers
|
||||
for callback in self.async_prediction_subscribers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
asyncio.create_task(callback(prediction))
|
||||
else:
|
||||
callback(prediction)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in async prediction subscriber: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting prediction: {e}")
|
||||
|
||||
async def _emit_training_update(self, update: TrainingUpdate):
|
||||
"""Emit training update to all subscribers"""
|
||||
try:
|
||||
# Sync subscribers
|
||||
for callback in self.training_subscribers:
|
||||
try:
|
||||
callback(update)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in training subscriber: {e}")
|
||||
|
||||
# Async subscribers
|
||||
for callback in self.async_training_subscribers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
asyncio.create_task(callback(update))
|
||||
else:
|
||||
callback(update)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in async training subscriber: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting training update: {e}")
|
||||
|
||||
async def _emit_trade_signal(self, signal: TradeSignal):
|
||||
"""Emit trade signal to all subscribers"""
|
||||
try:
|
||||
# Sync subscribers
|
||||
for callback in self.signal_subscribers:
|
||||
try:
|
||||
callback(signal)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in signal subscriber: {e}")
|
||||
|
||||
# Async subscribers
|
||||
for callback in self.async_signal_subscribers:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
asyncio.create_task(callback(signal))
|
||||
else:
|
||||
callback(signal)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in async signal subscriber: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting trade signal: {e}")
|
||||
|
||||
def _on_cob_update_sync(self, symbol: str, data: Dict):
|
||||
"""Sync wrapper for async COB update handler"""
|
||||
try:
|
||||
# Schedule the async method
|
||||
asyncio.create_task(self._on_cob_update(symbol, data))
|
||||
except Exception as e:
|
||||
logger.error(f"Error scheduling COB update for {symbol}: {e}")
|
||||
|
||||
async def start(self):
|
||||
"""Start the real-time RL trader"""
|
||||
@ -484,6 +621,9 @@ class RealtimeRLCOBTrader:
|
||||
# Store prediction for later training
|
||||
self.prediction_history[symbol].append(result)
|
||||
|
||||
# Emit prediction to subscribers
|
||||
await self._emit_prediction(result)
|
||||
|
||||
# Add to signal accumulator if confident enough
|
||||
if prediction['confidence'] >= self.min_confidence_threshold:
|
||||
self._add_signal(symbol, result)
|
||||
@ -606,7 +746,7 @@ class RealtimeRLCOBTrader:
|
||||
return # No action for sideways
|
||||
|
||||
# Execute trade signal
|
||||
await self._execute_trade_signal(symbol, action, avg_confidence, recent_signals)
|
||||
await self._execute_trade_signal(symbol, action, float(avg_confidence), recent_signals)
|
||||
|
||||
# Reset accumulator after trade signal
|
||||
self._reset_accumulator(symbol)
|
||||
@ -624,6 +764,21 @@ class RealtimeRLCOBTrader:
|
||||
if self.price_history[symbol]:
|
||||
current_price = self.price_history[symbol][-1]['price']
|
||||
|
||||
# Create trade signal for emission
|
||||
trade_signal = TradeSignal(
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
quantity=1.0, # Default quantity
|
||||
price=current_price,
|
||||
signals_count=len(signals),
|
||||
reason=f"Consensus of {len(signals)} predictions"
|
||||
)
|
||||
|
||||
# Emit trade signal to subscribers
|
||||
await self._emit_trade_signal(trade_signal)
|
||||
|
||||
# Execute through trading executor if available
|
||||
if self.trading_executor and current_price > 0:
|
||||
success = self.trading_executor.execute_signal(
|
||||
@ -707,6 +862,25 @@ class RealtimeRLCOBTrader:
|
||||
)
|
||||
stats['last_training_time'] = datetime.now()
|
||||
|
||||
# Calculate accuracy and confidence
|
||||
accuracy = stats['successful_predictions'] / max(1, stats['total_predictions']) * 100
|
||||
avg_confidence = sum(p.confidence for p in batch_predictions) / len(batch_predictions)
|
||||
|
||||
# Create training update for emission
|
||||
training_update = TrainingUpdate(
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol,
|
||||
epoch=stats['total_training_steps'],
|
||||
loss=loss,
|
||||
batch_size=batch_size,
|
||||
learning_rate=self.optimizers[symbol].param_groups[0]['lr'],
|
||||
accuracy=accuracy,
|
||||
avg_confidence=avg_confidence
|
||||
)
|
||||
|
||||
# Emit training update to subscribers
|
||||
await self._emit_training_update(training_update)
|
||||
|
||||
logger.debug(f"Training {symbol}: loss={loss:.6f}, batch_size={batch_size}")
|
||||
|
||||
except Exception as e:
|
||||
|
Reference in New Issue
Block a user