fixes
This commit is contained in:
@ -229,8 +229,8 @@ class COBRLModelInterface(ModelInterface):
|
|||||||
Interface for the COB RL model that handles model management, training, and inference
|
Interface for the COB RL model that handles model management, training, and inference
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_checkpoint_dir: str = "models/realtime_rl_cob", device: str = None):
|
def __init__(self, model_checkpoint_dir: str = "models/realtime_rl_cob", device: str = None, name=None, **kwargs):
|
||||||
super().__init__(name="cob_rl_model") # Initialize ModelInterface with a name
|
super().__init__(name=name) # Initialize ModelInterface with a name
|
||||||
self.model_checkpoint_dir = model_checkpoint_dir
|
self.model_checkpoint_dir = model_checkpoint_dir
|
||||||
self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu'))
|
self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu'))
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
import random
|
import random
|
||||||
from typing import Tuple, List
|
from typing import Tuple, List
|
||||||
import osvu
|
import os
|
||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -34,7 +34,7 @@ class COBIntegration:
|
|||||||
Integration layer for Multi-Exchange COB data with gogo2 trading system
|
Integration layer for Multi-Exchange COB data with gogo2 trading system
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data_provider: Optional[DataProvider] = None, symbols: Optional[List[str]] = None):
|
def __init__(self, data_provider: Optional[DataProvider] = None, symbols: Optional[List[str]] = None, initial_data_limit=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Initialize COB Integration
|
Initialize COB Integration
|
||||||
|
|
||||||
@ -88,7 +88,7 @@ class COBIntegration:
|
|||||||
# Start COB provider streaming
|
# Start COB provider streaming
|
||||||
try:
|
try:
|
||||||
logger.info("Starting COB provider streaming...")
|
logger.info("Starting COB provider streaming...")
|
||||||
await self.cob_provider.start_streaming()
|
await self.cob_provider.start_streaming()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error starting COB provider streaming: {e}")
|
logger.error(f"Error starting COB provider streaming: {e}")
|
||||||
# Start a background task instead
|
# Start a background task instead
|
||||||
@ -112,7 +112,7 @@ class COBIntegration:
|
|||||||
"""Stop COB integration"""
|
"""Stop COB integration"""
|
||||||
logger.info("Stopping COB Integration")
|
logger.info("Stopping COB Integration")
|
||||||
if self.cob_provider:
|
if self.cob_provider:
|
||||||
await self.cob_provider.stop_streaming()
|
await self.cob_provider.stop_streaming()
|
||||||
logger.info("COB Integration stopped")
|
logger.info("COB Integration stopped")
|
||||||
|
|
||||||
def add_cnn_callback(self, callback: Callable[[str, Dict], None]):
|
def add_cnn_callback(self, callback: Callable[[str, Dict], None]):
|
||||||
@ -313,7 +313,7 @@ class COBIntegration:
|
|||||||
# Get fixed bucket size for the symbol
|
# Get fixed bucket size for the symbol
|
||||||
bucket_size = 1.0 # Default bucket size
|
bucket_size = 1.0 # Default bucket size
|
||||||
if self.cob_provider:
|
if self.cob_provider:
|
||||||
bucket_size = self.cob_provider.fixed_usd_buckets.get(symbol, 1.0)
|
bucket_size = self.cob_provider.fixed_usd_buckets.get(symbol, 1.0)
|
||||||
|
|
||||||
# Calculate price range for buckets
|
# Calculate price range for buckets
|
||||||
mid_price = cob_snapshot.volume_weighted_mid
|
mid_price = cob_snapshot.volume_weighted_mid
|
||||||
@ -359,15 +359,15 @@ class COBIntegration:
|
|||||||
# Get actual Session Volume Profile (SVP) from trade data
|
# Get actual Session Volume Profile (SVP) from trade data
|
||||||
svp_data = []
|
svp_data = []
|
||||||
if self.cob_provider:
|
if self.cob_provider:
|
||||||
try:
|
try:
|
||||||
svp_result = self.cob_provider.get_session_volume_profile(symbol, bucket_size)
|
svp_result = self.cob_provider.get_session_volume_profile(symbol, bucket_size)
|
||||||
if svp_result and 'data' in svp_result:
|
if svp_result and 'data' in svp_result:
|
||||||
svp_data = svp_result['data']
|
svp_data = svp_result['data']
|
||||||
logger.debug(f"Retrieved SVP data for {symbol}: {len(svp_data)} price levels")
|
logger.debug(f"Retrieved SVP data for {symbol}: {len(svp_data)} price levels")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"No SVP data available for {symbol}")
|
logger.warning(f"No SVP data available for {symbol}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting SVP data for {symbol}: {e}")
|
logger.error(f"Error getting SVP data for {symbol}: {e}")
|
||||||
|
|
||||||
# Generate market stats
|
# Generate market stats
|
||||||
stats = {
|
stats = {
|
||||||
@ -405,18 +405,18 @@ class COBIntegration:
|
|||||||
# Get additional real-time stats
|
# Get additional real-time stats
|
||||||
realtime_stats = {}
|
realtime_stats = {}
|
||||||
if self.cob_provider:
|
if self.cob_provider:
|
||||||
try:
|
try:
|
||||||
realtime_stats = self.cob_provider.get_realtime_stats(symbol)
|
realtime_stats = self.cob_provider.get_realtime_stats(symbol)
|
||||||
if realtime_stats:
|
if realtime_stats:
|
||||||
stats['realtime_1s'] = realtime_stats.get('1s_stats', {})
|
stats['realtime_1s'] = realtime_stats.get('1s_stats', {})
|
||||||
stats['realtime_5s'] = realtime_stats.get('5s_stats', {})
|
stats['realtime_5s'] = realtime_stats.get('5s_stats', {})
|
||||||
else:
|
else:
|
||||||
|
stats['realtime_1s'] = {}
|
||||||
|
stats['realtime_5s'] = {}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting real-time stats for {symbol}: {e}")
|
||||||
stats['realtime_1s'] = {}
|
stats['realtime_1s'] = {}
|
||||||
stats['realtime_5s'] = {}
|
stats['realtime_5s'] = {}
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting real-time stats for {symbol}: {e}")
|
|
||||||
stats['realtime_1s'] = {}
|
|
||||||
stats['realtime_5s'] = {}
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'type': 'cob_update',
|
'type': 'cob_update',
|
||||||
@ -487,9 +487,9 @@ class COBIntegration:
|
|||||||
try:
|
try:
|
||||||
for symbol in self.symbols:
|
for symbol in self.symbols:
|
||||||
if self.cob_provider:
|
if self.cob_provider:
|
||||||
cob_snapshot = self.cob_provider.get_consolidated_orderbook(symbol)
|
cob_snapshot = self.cob_provider.get_consolidated_orderbook(symbol)
|
||||||
if cob_snapshot:
|
if cob_snapshot:
|
||||||
await self._analyze_cob_patterns(symbol, cob_snapshot)
|
await self._analyze_cob_patterns(symbol, cob_snapshot)
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
@ -1007,6 +1007,17 @@ class TradingOrchestrator:
|
|||||||
if enhanced_features is not None:
|
if enhanced_features is not None:
|
||||||
# Get CNN prediction - use the actual underlying model
|
# Get CNN prediction - use the actual underlying model
|
||||||
try:
|
try:
|
||||||
|
# Ensure features are properly shaped and limited
|
||||||
|
if isinstance(enhanced_features, np.ndarray):
|
||||||
|
# Flatten and limit features to prevent shape mismatches
|
||||||
|
enhanced_features = enhanced_features.flatten()
|
||||||
|
if len(enhanced_features) > 100: # Limit to 100 features
|
||||||
|
enhanced_features = enhanced_features[:100]
|
||||||
|
elif len(enhanced_features) < 100: # Pad with zeros
|
||||||
|
padded = np.zeros(100)
|
||||||
|
padded[:len(enhanced_features)] = enhanced_features
|
||||||
|
enhanced_features = padded
|
||||||
|
|
||||||
if hasattr(model.model, 'act'):
|
if hasattr(model.model, 'act'):
|
||||||
# Use the CNN's act method
|
# Use the CNN's act method
|
||||||
action_result = model.model.act(enhanced_features, explore=False)
|
action_result = model.model.act(enhanced_features, explore=False)
|
||||||
@ -1020,7 +1031,7 @@ class TradingOrchestrator:
|
|||||||
action_probs = [0.1, 0.1, 0.8] # Default distribution
|
action_probs = [0.1, 0.1, 0.8] # Default distribution
|
||||||
action_probs[action_idx] = confidence
|
action_probs[action_idx] = confidence
|
||||||
else:
|
else:
|
||||||
# Fallback to generic predict method
|
# Fallback to generic predict method
|
||||||
action_probs, confidence = model.predict(enhanced_features)
|
action_probs, confidence = model.predict(enhanced_features)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"CNN prediction failed: {e}")
|
logger.warning(f"CNN prediction failed: {e}")
|
||||||
@ -1138,6 +1149,17 @@ class TradingOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if feature_matrix is not None:
|
if feature_matrix is not None:
|
||||||
|
# Ensure feature_matrix is properly shaped and limited
|
||||||
|
if isinstance(feature_matrix, np.ndarray):
|
||||||
|
# Flatten and limit features to prevent shape mismatches
|
||||||
|
feature_matrix = feature_matrix.flatten()
|
||||||
|
if len(feature_matrix) > 2000: # Limit to 2000 features for generic models
|
||||||
|
feature_matrix = feature_matrix[:2000]
|
||||||
|
elif len(feature_matrix) < 2000: # Pad with zeros
|
||||||
|
padded = np.zeros(2000)
|
||||||
|
padded[:len(feature_matrix)] = feature_matrix
|
||||||
|
feature_matrix = padded
|
||||||
|
|
||||||
prediction_result = model.predict(feature_matrix)
|
prediction_result = model.predict(feature_matrix)
|
||||||
|
|
||||||
# Handle different return formats from model.predict()
|
# Handle different return formats from model.predict()
|
||||||
@ -1834,3 +1856,100 @@ class TradingOrchestrator:
|
|||||||
"""Set the trading executor for position tracking"""
|
"""Set the trading executor for position tracking"""
|
||||||
self.trading_executor = trading_executor
|
self.trading_executor = trading_executor
|
||||||
logger.info("Trading executor set for position tracking and P&L feedback")
|
logger.info("Trading executor set for position tracking and P&L feedback")
|
||||||
|
|
||||||
|
def _get_current_price(self, symbol: str) -> float:
|
||||||
|
"""Get current price for symbol"""
|
||||||
|
try:
|
||||||
|
# Try to get from data provider
|
||||||
|
if self.data_provider:
|
||||||
|
try:
|
||||||
|
# Try different methods to get current price
|
||||||
|
if hasattr(self.data_provider, 'get_latest_data'):
|
||||||
|
latest_data = self.data_provider.get_latest_data(symbol)
|
||||||
|
if latest_data and 'price' in latest_data:
|
||||||
|
return float(latest_data['price'])
|
||||||
|
elif latest_data and 'close' in latest_data:
|
||||||
|
return float(latest_data['close'])
|
||||||
|
elif hasattr(self.data_provider, 'get_current_price'):
|
||||||
|
return float(self.data_provider.get_current_price(symbol))
|
||||||
|
elif hasattr(self.data_provider, 'get_latest_candle'):
|
||||||
|
latest_candle = self.data_provider.get_latest_candle(symbol, '1m')
|
||||||
|
if latest_candle and 'close' in latest_candle:
|
||||||
|
return float(latest_candle['close'])
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not get price from data provider: {e}")
|
||||||
|
# Try to get from universal adapter
|
||||||
|
if self.universal_adapter:
|
||||||
|
try:
|
||||||
|
data_stream = self.universal_adapter.get_latest_data(symbol)
|
||||||
|
if data_stream and hasattr(data_stream, 'current_price'):
|
||||||
|
return float(data_stream.current_price)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not get price from universal adapter: {e}")
|
||||||
|
# Fallback to default prices
|
||||||
|
default_prices = {
|
||||||
|
'ETH/USDT': 2500.0,
|
||||||
|
'BTC/USDT': 108000.0
|
||||||
|
}
|
||||||
|
return default_prices.get(symbol, 1000.0)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting current price for {symbol}: {e}")
|
||||||
|
# Return default price based on symbol
|
||||||
|
if 'ETH' in symbol:
|
||||||
|
return 2500.0
|
||||||
|
elif 'BTC' in symbol:
|
||||||
|
return 108000.0
|
||||||
|
else:
|
||||||
|
return 1000.0
|
||||||
|
|
||||||
|
def _generate_fallback_prediction(self, symbol: str) -> Dict[str, Any]:
|
||||||
|
"""Generate fallback prediction when models fail"""
|
||||||
|
try:
|
||||||
|
return {
|
||||||
|
'action': 'HOLD',
|
||||||
|
'confidence': 0.5,
|
||||||
|
'price': self._get_current_price(symbol) or 2500.0,
|
||||||
|
'timestamp': datetime.now(),
|
||||||
|
'model': 'fallback'
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error generating fallback prediction: {e}")
|
||||||
|
return {
|
||||||
|
'action': 'HOLD',
|
||||||
|
'confidence': 0.5,
|
||||||
|
'price': 2500.0,
|
||||||
|
'timestamp': datetime.now(),
|
||||||
|
'model': 'fallback'
|
||||||
|
}
|
||||||
|
|
||||||
|
def capture_dqn_prediction(self, symbol: str, action_idx: int, confidence: float, price: float, q_values: List[float] = None):
|
||||||
|
"""Capture DQN prediction for dashboard visualization"""
|
||||||
|
try:
|
||||||
|
if symbol not in self.recent_dqn_predictions:
|
||||||
|
self.recent_dqn_predictions[symbol] = deque(maxlen=100)
|
||||||
|
prediction_data = {
|
||||||
|
'timestamp': datetime.now(),
|
||||||
|
'action': ['SELL', 'HOLD', 'BUY'][action_idx],
|
||||||
|
'confidence': confidence,
|
||||||
|
'price': price,
|
||||||
|
'q_values': q_values or [0.33, 0.33, 0.34]
|
||||||
|
}
|
||||||
|
self.recent_dqn_predictions[symbol].append(prediction_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error capturing DQN prediction: {e}")
|
||||||
|
|
||||||
|
def capture_cnn_prediction(self, symbol: str, direction: int, confidence: float, current_price: float, predicted_price: float):
|
||||||
|
"""Capture CNN prediction for dashboard visualization"""
|
||||||
|
try:
|
||||||
|
if symbol not in self.recent_cnn_predictions:
|
||||||
|
self.recent_cnn_predictions[symbol] = deque(maxlen=50)
|
||||||
|
prediction_data = {
|
||||||
|
'timestamp': datetime.now(),
|
||||||
|
'direction': ['DOWN', 'SAME', 'UP'][direction],
|
||||||
|
'confidence': confidence,
|
||||||
|
'current_price': current_price,
|
||||||
|
'predicted_price': predicted_price
|
||||||
|
}
|
||||||
|
self.recent_cnn_predictions[symbol].append(prediction_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error capturing CNN prediction: {e}")
|
@ -1454,9 +1454,10 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
model.train()
|
model.train()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Convert numpy arrays to PyTorch tensors
|
# Convert numpy arrays to PyTorch tensors and move to device
|
||||||
features_tensor = torch.from_numpy(features).float()
|
device = next(model.parameters()).device
|
||||||
targets_tensor = torch.from_numpy(targets).long()
|
features_tensor = torch.from_numpy(features).float().to(device)
|
||||||
|
targets_tensor = torch.from_numpy(targets).long().to(device)
|
||||||
|
|
||||||
# Ensure features_tensor has the correct shape for CNN (batch_size, channels, height, width)
|
# Ensure features_tensor has the correct shape for CNN (batch_size, channels, height, width)
|
||||||
# Assuming features are flattened (batch_size, 15*20) and need to be reshaped to (batch_size, 1, 15, 20)
|
# Assuming features are flattened (batch_size, 15*20) and need to be reshaped to (batch_size, 1, 15, 20)
|
||||||
@ -1471,7 +1472,21 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
# If the CNN expects (batch_size, channels, sequence_length)
|
# If the CNN expects (batch_size, channels, sequence_length)
|
||||||
# features_tensor = features_tensor.view(features_tensor.shape[0], 1, 15 * 20) # Example for 1D CNN
|
# features_tensor = features_tensor.view(features_tensor.shape[0], 1, 15 * 20) # Example for 1D CNN
|
||||||
|
|
||||||
# Let's assume the CNN expects 2D input (batch_size, flattened_features)
|
# Ensure proper shape for CNN input
|
||||||
|
if len(features_tensor.shape) == 2:
|
||||||
|
# If it's (batch_size, features), keep as is for 1D CNN
|
||||||
|
pass
|
||||||
|
elif len(features_tensor.shape) == 1:
|
||||||
|
# If it's (features), add batch dimension
|
||||||
|
features_tensor = features_tensor.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
# Reshape to (batch_size, features) if needed
|
||||||
|
features_tensor = features_tensor.view(features_tensor.shape[0], -1)
|
||||||
|
|
||||||
|
# Limit input size to prevent shape mismatches
|
||||||
|
if features_tensor.shape[1] > 1000: # Limit to 1000 features
|
||||||
|
features_tensor = features_tensor[:, :1000]
|
||||||
|
|
||||||
outputs = model(features_tensor)
|
outputs = model(features_tensor)
|
||||||
|
|
||||||
loss = criterion(outputs, targets_tensor)
|
loss = criterion(outputs, targets_tensor)
|
||||||
@ -1857,12 +1872,17 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
and self.orchestrator.rl_agent):
|
and self.orchestrator.rl_agent):
|
||||||
|
|
||||||
# Get Q-values from model
|
# Get Q-values from model
|
||||||
q_values = self.orchestrator.rl_agent.act(current_state, return_q_values=True)
|
action = self.orchestrator.rl_agent.act(current_state, explore=False)
|
||||||
if isinstance(q_values, tuple):
|
# Get Q-values separately if available
|
||||||
action, q_vals = q_values
|
if hasattr(self.orchestrator.rl_agent, 'policy_net'):
|
||||||
q_values = q_vals.tolist() if hasattr(q_vals, 'tolist') else [0, 0, 0]
|
with torch.no_grad():
|
||||||
|
state_tensor = torch.FloatTensor(current_state).unsqueeze(0).to(self.orchestrator.rl_agent.device)
|
||||||
|
q_values_tensor = self.orchestrator.rl_agent.policy_net(state_tensor)
|
||||||
|
if isinstance(q_values_tensor, tuple):
|
||||||
|
q_values = q_values_tensor[0].cpu().numpy()[0].tolist()
|
||||||
|
else:
|
||||||
|
q_values = q_values_tensor.cpu().numpy()[0].tolist()
|
||||||
else:
|
else:
|
||||||
action = q_values
|
|
||||||
q_values = [0.33, 0.33, 0.34] # Default uniform distribution
|
q_values = [0.33, 0.33, 0.34] # Default uniform distribution
|
||||||
|
|
||||||
confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33
|
confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33
|
||||||
|
@ -286,11 +286,11 @@ class DashboardComponentManager:
|
|||||||
if hasattr(cob_snapshot, 'stats'):
|
if hasattr(cob_snapshot, 'stats'):
|
||||||
# Old format with stats attribute
|
# Old format with stats attribute
|
||||||
stats = cob_snapshot.stats
|
stats = cob_snapshot.stats
|
||||||
mid_price = stats.get('mid_price', 0)
|
mid_price = stats.get('mid_price', 0)
|
||||||
spread_bps = stats.get('spread_bps', 0)
|
spread_bps = stats.get('spread_bps', 0)
|
||||||
imbalance = stats.get('imbalance', 0)
|
imbalance = stats.get('imbalance', 0)
|
||||||
bids = getattr(cob_snapshot, 'consolidated_bids', [])
|
bids = getattr(cob_snapshot, 'consolidated_bids', [])
|
||||||
asks = getattr(cob_snapshot, 'consolidated_asks', [])
|
asks = getattr(cob_snapshot, 'consolidated_asks', [])
|
||||||
else:
|
else:
|
||||||
# New COBSnapshot format with direct attributes
|
# New COBSnapshot format with direct attributes
|
||||||
mid_price = getattr(cob_snapshot, 'volume_weighted_mid', 0)
|
mid_price = getattr(cob_snapshot, 'volume_weighted_mid', 0)
|
||||||
@ -421,10 +421,10 @@ class DashboardComponentManager:
|
|||||||
volume_usd = order.total_volume_usd
|
volume_usd = order.total_volume_usd
|
||||||
else:
|
else:
|
||||||
# Dictionary format (legacy)
|
# Dictionary format (legacy)
|
||||||
price = order.get('price', 0)
|
price = order.get('price', 0)
|
||||||
# Handle both old format (size) and new format (total_size)
|
# Handle both old format (size) and new format (total_size)
|
||||||
size = order.get('total_size', order.get('size', 0))
|
size = order.get('total_size', order.get('size', 0))
|
||||||
volume_usd = order.get('total_volume_usd', size * price)
|
volume_usd = order.get('total_volume_usd', size * price)
|
||||||
|
|
||||||
if price > 0:
|
if price > 0:
|
||||||
bucket_key = round(price / bucket_size) * bucket_size
|
bucket_key = round(price / bucket_size) * bucket_size
|
||||||
|
Reference in New Issue
Block a user