This commit is contained in:
Dobromir Popov
2025-07-07 23:39:12 +03:00
parent 2d8f763eeb
commit 9cd2d5d8a4
6 changed files with 188 additions and 49 deletions

View File

@ -229,8 +229,8 @@ class COBRLModelInterface(ModelInterface):
Interface for the COB RL model that handles model management, training, and inference Interface for the COB RL model that handles model management, training, and inference
""" """
def __init__(self, model_checkpoint_dir: str = "models/realtime_rl_cob", device: str = None): def __init__(self, model_checkpoint_dir: str = "models/realtime_rl_cob", device: str = None, name=None, **kwargs):
super().__init__(name="cob_rl_model") # Initialize ModelInterface with a name super().__init__(name=name) # Initialize ModelInterface with a name
self.model_checkpoint_dir = model_checkpoint_dir self.model_checkpoint_dir = model_checkpoint_dir
self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu')) self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu'))

View File

@ -5,7 +5,7 @@ import numpy as np
from collections import deque from collections import deque
import random import random
from typing import Tuple, List from typing import Tuple, List
import osvu import os
import sys import sys
import logging import logging
import torch.nn.functional as F import torch.nn.functional as F

View File

@ -34,7 +34,7 @@ class COBIntegration:
Integration layer for Multi-Exchange COB data with gogo2 trading system Integration layer for Multi-Exchange COB data with gogo2 trading system
""" """
def __init__(self, data_provider: Optional[DataProvider] = None, symbols: Optional[List[str]] = None): def __init__(self, data_provider: Optional[DataProvider] = None, symbols: Optional[List[str]] = None, initial_data_limit=None, **kwargs):
""" """
Initialize COB Integration Initialize COB Integration
@ -88,7 +88,7 @@ class COBIntegration:
# Start COB provider streaming # Start COB provider streaming
try: try:
logger.info("Starting COB provider streaming...") logger.info("Starting COB provider streaming...")
await self.cob_provider.start_streaming() await self.cob_provider.start_streaming()
except Exception as e: except Exception as e:
logger.error(f"Error starting COB provider streaming: {e}") logger.error(f"Error starting COB provider streaming: {e}")
# Start a background task instead # Start a background task instead
@ -112,7 +112,7 @@ class COBIntegration:
"""Stop COB integration""" """Stop COB integration"""
logger.info("Stopping COB Integration") logger.info("Stopping COB Integration")
if self.cob_provider: if self.cob_provider:
await self.cob_provider.stop_streaming() await self.cob_provider.stop_streaming()
logger.info("COB Integration stopped") logger.info("COB Integration stopped")
def add_cnn_callback(self, callback: Callable[[str, Dict], None]): def add_cnn_callback(self, callback: Callable[[str, Dict], None]):
@ -313,7 +313,7 @@ class COBIntegration:
# Get fixed bucket size for the symbol # Get fixed bucket size for the symbol
bucket_size = 1.0 # Default bucket size bucket_size = 1.0 # Default bucket size
if self.cob_provider: if self.cob_provider:
bucket_size = self.cob_provider.fixed_usd_buckets.get(symbol, 1.0) bucket_size = self.cob_provider.fixed_usd_buckets.get(symbol, 1.0)
# Calculate price range for buckets # Calculate price range for buckets
mid_price = cob_snapshot.volume_weighted_mid mid_price = cob_snapshot.volume_weighted_mid
@ -359,15 +359,15 @@ class COBIntegration:
# Get actual Session Volume Profile (SVP) from trade data # Get actual Session Volume Profile (SVP) from trade data
svp_data = [] svp_data = []
if self.cob_provider: if self.cob_provider:
try: try:
svp_result = self.cob_provider.get_session_volume_profile(symbol, bucket_size) svp_result = self.cob_provider.get_session_volume_profile(symbol, bucket_size)
if svp_result and 'data' in svp_result: if svp_result and 'data' in svp_result:
svp_data = svp_result['data'] svp_data = svp_result['data']
logger.debug(f"Retrieved SVP data for {symbol}: {len(svp_data)} price levels") logger.debug(f"Retrieved SVP data for {symbol}: {len(svp_data)} price levels")
else: else:
logger.warning(f"No SVP data available for {symbol}") logger.warning(f"No SVP data available for {symbol}")
except Exception as e: except Exception as e:
logger.error(f"Error getting SVP data for {symbol}: {e}") logger.error(f"Error getting SVP data for {symbol}: {e}")
# Generate market stats # Generate market stats
stats = { stats = {
@ -405,18 +405,18 @@ class COBIntegration:
# Get additional real-time stats # Get additional real-time stats
realtime_stats = {} realtime_stats = {}
if self.cob_provider: if self.cob_provider:
try: try:
realtime_stats = self.cob_provider.get_realtime_stats(symbol) realtime_stats = self.cob_provider.get_realtime_stats(symbol)
if realtime_stats: if realtime_stats:
stats['realtime_1s'] = realtime_stats.get('1s_stats', {}) stats['realtime_1s'] = realtime_stats.get('1s_stats', {})
stats['realtime_5s'] = realtime_stats.get('5s_stats', {}) stats['realtime_5s'] = realtime_stats.get('5s_stats', {})
else: else:
stats['realtime_1s'] = {}
stats['realtime_5s'] = {}
except Exception as e:
logger.error(f"Error getting real-time stats for {symbol}: {e}")
stats['realtime_1s'] = {} stats['realtime_1s'] = {}
stats['realtime_5s'] = {} stats['realtime_5s'] = {}
except Exception as e:
logger.error(f"Error getting real-time stats for {symbol}: {e}")
stats['realtime_1s'] = {}
stats['realtime_5s'] = {}
return { return {
'type': 'cob_update', 'type': 'cob_update',
@ -487,9 +487,9 @@ class COBIntegration:
try: try:
for symbol in self.symbols: for symbol in self.symbols:
if self.cob_provider: if self.cob_provider:
cob_snapshot = self.cob_provider.get_consolidated_orderbook(symbol) cob_snapshot = self.cob_provider.get_consolidated_orderbook(symbol)
if cob_snapshot: if cob_snapshot:
await self._analyze_cob_patterns(symbol, cob_snapshot) await self._analyze_cob_patterns(symbol, cob_snapshot)
await asyncio.sleep(1) await asyncio.sleep(1)

View File

@ -1007,6 +1007,17 @@ class TradingOrchestrator:
if enhanced_features is not None: if enhanced_features is not None:
# Get CNN prediction - use the actual underlying model # Get CNN prediction - use the actual underlying model
try: try:
# Ensure features are properly shaped and limited
if isinstance(enhanced_features, np.ndarray):
# Flatten and limit features to prevent shape mismatches
enhanced_features = enhanced_features.flatten()
if len(enhanced_features) > 100: # Limit to 100 features
enhanced_features = enhanced_features[:100]
elif len(enhanced_features) < 100: # Pad with zeros
padded = np.zeros(100)
padded[:len(enhanced_features)] = enhanced_features
enhanced_features = padded
if hasattr(model.model, 'act'): if hasattr(model.model, 'act'):
# Use the CNN's act method # Use the CNN's act method
action_result = model.model.act(enhanced_features, explore=False) action_result = model.model.act(enhanced_features, explore=False)
@ -1020,7 +1031,7 @@ class TradingOrchestrator:
action_probs = [0.1, 0.1, 0.8] # Default distribution action_probs = [0.1, 0.1, 0.8] # Default distribution
action_probs[action_idx] = confidence action_probs[action_idx] = confidence
else: else:
# Fallback to generic predict method # Fallback to generic predict method
action_probs, confidence = model.predict(enhanced_features) action_probs, confidence = model.predict(enhanced_features)
except Exception as e: except Exception as e:
logger.warning(f"CNN prediction failed: {e}") logger.warning(f"CNN prediction failed: {e}")
@ -1138,6 +1149,17 @@ class TradingOrchestrator:
) )
if feature_matrix is not None: if feature_matrix is not None:
# Ensure feature_matrix is properly shaped and limited
if isinstance(feature_matrix, np.ndarray):
# Flatten and limit features to prevent shape mismatches
feature_matrix = feature_matrix.flatten()
if len(feature_matrix) > 2000: # Limit to 2000 features for generic models
feature_matrix = feature_matrix[:2000]
elif len(feature_matrix) < 2000: # Pad with zeros
padded = np.zeros(2000)
padded[:len(feature_matrix)] = feature_matrix
feature_matrix = padded
prediction_result = model.predict(feature_matrix) prediction_result = model.predict(feature_matrix)
# Handle different return formats from model.predict() # Handle different return formats from model.predict()
@ -1834,3 +1856,100 @@ class TradingOrchestrator:
"""Set the trading executor for position tracking""" """Set the trading executor for position tracking"""
self.trading_executor = trading_executor self.trading_executor = trading_executor
logger.info("Trading executor set for position tracking and P&L feedback") logger.info("Trading executor set for position tracking and P&L feedback")
def _get_current_price(self, symbol: str) -> float:
"""Get current price for symbol"""
try:
# Try to get from data provider
if self.data_provider:
try:
# Try different methods to get current price
if hasattr(self.data_provider, 'get_latest_data'):
latest_data = self.data_provider.get_latest_data(symbol)
if latest_data and 'price' in latest_data:
return float(latest_data['price'])
elif latest_data and 'close' in latest_data:
return float(latest_data['close'])
elif hasattr(self.data_provider, 'get_current_price'):
return float(self.data_provider.get_current_price(symbol))
elif hasattr(self.data_provider, 'get_latest_candle'):
latest_candle = self.data_provider.get_latest_candle(symbol, '1m')
if latest_candle and 'close' in latest_candle:
return float(latest_candle['close'])
except Exception as e:
logger.debug(f"Could not get price from data provider: {e}")
# Try to get from universal adapter
if self.universal_adapter:
try:
data_stream = self.universal_adapter.get_latest_data(symbol)
if data_stream and hasattr(data_stream, 'current_price'):
return float(data_stream.current_price)
except Exception as e:
logger.debug(f"Could not get price from universal adapter: {e}")
# Fallback to default prices
default_prices = {
'ETH/USDT': 2500.0,
'BTC/USDT': 108000.0
}
return default_prices.get(symbol, 1000.0)
except Exception as e:
logger.error(f"Error getting current price for {symbol}: {e}")
# Return default price based on symbol
if 'ETH' in symbol:
return 2500.0
elif 'BTC' in symbol:
return 108000.0
else:
return 1000.0
def _generate_fallback_prediction(self, symbol: str) -> Dict[str, Any]:
"""Generate fallback prediction when models fail"""
try:
return {
'action': 'HOLD',
'confidence': 0.5,
'price': self._get_current_price(symbol) or 2500.0,
'timestamp': datetime.now(),
'model': 'fallback'
}
except Exception as e:
logger.debug(f"Error generating fallback prediction: {e}")
return {
'action': 'HOLD',
'confidence': 0.5,
'price': 2500.0,
'timestamp': datetime.now(),
'model': 'fallback'
}
def capture_dqn_prediction(self, symbol: str, action_idx: int, confidence: float, price: float, q_values: List[float] = None):
"""Capture DQN prediction for dashboard visualization"""
try:
if symbol not in self.recent_dqn_predictions:
self.recent_dqn_predictions[symbol] = deque(maxlen=100)
prediction_data = {
'timestamp': datetime.now(),
'action': ['SELL', 'HOLD', 'BUY'][action_idx],
'confidence': confidence,
'price': price,
'q_values': q_values or [0.33, 0.33, 0.34]
}
self.recent_dqn_predictions[symbol].append(prediction_data)
except Exception as e:
logger.debug(f"Error capturing DQN prediction: {e}")
def capture_cnn_prediction(self, symbol: str, direction: int, confidence: float, current_price: float, predicted_price: float):
"""Capture CNN prediction for dashboard visualization"""
try:
if symbol not in self.recent_cnn_predictions:
self.recent_cnn_predictions[symbol] = deque(maxlen=50)
prediction_data = {
'timestamp': datetime.now(),
'direction': ['DOWN', 'SAME', 'UP'][direction],
'confidence': confidence,
'current_price': current_price,
'predicted_price': predicted_price
}
self.recent_cnn_predictions[symbol].append(prediction_data)
except Exception as e:
logger.debug(f"Error capturing CNN prediction: {e}")

View File

@ -1454,9 +1454,10 @@ class EnhancedRealtimeTrainingSystem:
model.train() model.train()
optimizer.zero_grad() optimizer.zero_grad()
# Convert numpy arrays to PyTorch tensors # Convert numpy arrays to PyTorch tensors and move to device
features_tensor = torch.from_numpy(features).float() device = next(model.parameters()).device
targets_tensor = torch.from_numpy(targets).long() features_tensor = torch.from_numpy(features).float().to(device)
targets_tensor = torch.from_numpy(targets).long().to(device)
# Ensure features_tensor has the correct shape for CNN (batch_size, channels, height, width) # Ensure features_tensor has the correct shape for CNN (batch_size, channels, height, width)
# Assuming features are flattened (batch_size, 15*20) and need to be reshaped to (batch_size, 1, 15, 20) # Assuming features are flattened (batch_size, 15*20) and need to be reshaped to (batch_size, 1, 15, 20)
@ -1471,7 +1472,21 @@ class EnhancedRealtimeTrainingSystem:
# If the CNN expects (batch_size, channels, sequence_length) # If the CNN expects (batch_size, channels, sequence_length)
# features_tensor = features_tensor.view(features_tensor.shape[0], 1, 15 * 20) # Example for 1D CNN # features_tensor = features_tensor.view(features_tensor.shape[0], 1, 15 * 20) # Example for 1D CNN
# Let's assume the CNN expects 2D input (batch_size, flattened_features) # Ensure proper shape for CNN input
if len(features_tensor.shape) == 2:
# If it's (batch_size, features), keep as is for 1D CNN
pass
elif len(features_tensor.shape) == 1:
# If it's (features), add batch dimension
features_tensor = features_tensor.unsqueeze(0)
else:
# Reshape to (batch_size, features) if needed
features_tensor = features_tensor.view(features_tensor.shape[0], -1)
# Limit input size to prevent shape mismatches
if features_tensor.shape[1] > 1000: # Limit to 1000 features
features_tensor = features_tensor[:, :1000]
outputs = model(features_tensor) outputs = model(features_tensor)
loss = criterion(outputs, targets_tensor) loss = criterion(outputs, targets_tensor)
@ -1857,12 +1872,17 @@ class EnhancedRealtimeTrainingSystem:
and self.orchestrator.rl_agent): and self.orchestrator.rl_agent):
# Get Q-values from model # Get Q-values from model
q_values = self.orchestrator.rl_agent.act(current_state, return_q_values=True) action = self.orchestrator.rl_agent.act(current_state, explore=False)
if isinstance(q_values, tuple): # Get Q-values separately if available
action, q_vals = q_values if hasattr(self.orchestrator.rl_agent, 'policy_net'):
q_values = q_vals.tolist() if hasattr(q_vals, 'tolist') else [0, 0, 0] with torch.no_grad():
state_tensor = torch.FloatTensor(current_state).unsqueeze(0).to(self.orchestrator.rl_agent.device)
q_values_tensor = self.orchestrator.rl_agent.policy_net(state_tensor)
if isinstance(q_values_tensor, tuple):
q_values = q_values_tensor[0].cpu().numpy()[0].tolist()
else:
q_values = q_values_tensor.cpu().numpy()[0].tolist()
else: else:
action = q_values
q_values = [0.33, 0.33, 0.34] # Default uniform distribution q_values = [0.33, 0.33, 0.34] # Default uniform distribution
confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33 confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33

View File

@ -286,11 +286,11 @@ class DashboardComponentManager:
if hasattr(cob_snapshot, 'stats'): if hasattr(cob_snapshot, 'stats'):
# Old format with stats attribute # Old format with stats attribute
stats = cob_snapshot.stats stats = cob_snapshot.stats
mid_price = stats.get('mid_price', 0) mid_price = stats.get('mid_price', 0)
spread_bps = stats.get('spread_bps', 0) spread_bps = stats.get('spread_bps', 0)
imbalance = stats.get('imbalance', 0) imbalance = stats.get('imbalance', 0)
bids = getattr(cob_snapshot, 'consolidated_bids', []) bids = getattr(cob_snapshot, 'consolidated_bids', [])
asks = getattr(cob_snapshot, 'consolidated_asks', []) asks = getattr(cob_snapshot, 'consolidated_asks', [])
else: else:
# New COBSnapshot format with direct attributes # New COBSnapshot format with direct attributes
mid_price = getattr(cob_snapshot, 'volume_weighted_mid', 0) mid_price = getattr(cob_snapshot, 'volume_weighted_mid', 0)
@ -421,10 +421,10 @@ class DashboardComponentManager:
volume_usd = order.total_volume_usd volume_usd = order.total_volume_usd
else: else:
# Dictionary format (legacy) # Dictionary format (legacy)
price = order.get('price', 0) price = order.get('price', 0)
# Handle both old format (size) and new format (total_size) # Handle both old format (size) and new format (total_size)
size = order.get('total_size', order.get('size', 0)) size = order.get('total_size', order.get('size', 0))
volume_usd = order.get('total_volume_usd', size * price) volume_usd = order.get('total_volume_usd', size * price)
if price > 0: if price > 0:
bucket_key = round(price / bucket_size) * bucket_size bucket_key = round(price / bucket_size) * bucket_size