This commit is contained in:
Dobromir Popov
2025-07-07 23:39:12 +03:00
parent 2d8f763eeb
commit 9cd2d5d8a4
6 changed files with 188 additions and 49 deletions

View File

@ -229,8 +229,8 @@ class COBRLModelInterface(ModelInterface):
Interface for the COB RL model that handles model management, training, and inference Interface for the COB RL model that handles model management, training, and inference
""" """
def __init__(self, model_checkpoint_dir: str = "models/realtime_rl_cob", device: str = None): def __init__(self, model_checkpoint_dir: str = "models/realtime_rl_cob", device: str = None, name=None, **kwargs):
super().__init__(name="cob_rl_model") # Initialize ModelInterface with a name super().__init__(name=name) # Initialize ModelInterface with a name
self.model_checkpoint_dir = model_checkpoint_dir self.model_checkpoint_dir = model_checkpoint_dir
self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu')) self.device = torch.device(device if device else ('cuda' if torch.cuda.is_available() else 'cpu'))

View File

@ -5,7 +5,7 @@ import numpy as np
from collections import deque from collections import deque
import random import random
from typing import Tuple, List from typing import Tuple, List
import osvu import os
import sys import sys
import logging import logging
import torch.nn.functional as F import torch.nn.functional as F

View File

@ -34,7 +34,7 @@ class COBIntegration:
Integration layer for Multi-Exchange COB data with gogo2 trading system Integration layer for Multi-Exchange COB data with gogo2 trading system
""" """
def __init__(self, data_provider: Optional[DataProvider] = None, symbols: Optional[List[str]] = None): def __init__(self, data_provider: Optional[DataProvider] = None, symbols: Optional[List[str]] = None, initial_data_limit=None, **kwargs):
""" """
Initialize COB Integration Initialize COB Integration

View File

@ -1007,6 +1007,17 @@ class TradingOrchestrator:
if enhanced_features is not None: if enhanced_features is not None:
# Get CNN prediction - use the actual underlying model # Get CNN prediction - use the actual underlying model
try: try:
# Ensure features are properly shaped and limited
if isinstance(enhanced_features, np.ndarray):
# Flatten and limit features to prevent shape mismatches
enhanced_features = enhanced_features.flatten()
if len(enhanced_features) > 100: # Limit to 100 features
enhanced_features = enhanced_features[:100]
elif len(enhanced_features) < 100: # Pad with zeros
padded = np.zeros(100)
padded[:len(enhanced_features)] = enhanced_features
enhanced_features = padded
if hasattr(model.model, 'act'): if hasattr(model.model, 'act'):
# Use the CNN's act method # Use the CNN's act method
action_result = model.model.act(enhanced_features, explore=False) action_result = model.model.act(enhanced_features, explore=False)
@ -1138,6 +1149,17 @@ class TradingOrchestrator:
) )
if feature_matrix is not None: if feature_matrix is not None:
# Ensure feature_matrix is properly shaped and limited
if isinstance(feature_matrix, np.ndarray):
# Flatten and limit features to prevent shape mismatches
feature_matrix = feature_matrix.flatten()
if len(feature_matrix) > 2000: # Limit to 2000 features for generic models
feature_matrix = feature_matrix[:2000]
elif len(feature_matrix) < 2000: # Pad with zeros
padded = np.zeros(2000)
padded[:len(feature_matrix)] = feature_matrix
feature_matrix = padded
prediction_result = model.predict(feature_matrix) prediction_result = model.predict(feature_matrix)
# Handle different return formats from model.predict() # Handle different return formats from model.predict()
@ -1834,3 +1856,100 @@ class TradingOrchestrator:
"""Set the trading executor for position tracking""" """Set the trading executor for position tracking"""
self.trading_executor = trading_executor self.trading_executor = trading_executor
logger.info("Trading executor set for position tracking and P&L feedback") logger.info("Trading executor set for position tracking and P&L feedback")
def _get_current_price(self, symbol: str) -> float:
"""Get current price for symbol"""
try:
# Try to get from data provider
if self.data_provider:
try:
# Try different methods to get current price
if hasattr(self.data_provider, 'get_latest_data'):
latest_data = self.data_provider.get_latest_data(symbol)
if latest_data and 'price' in latest_data:
return float(latest_data['price'])
elif latest_data and 'close' in latest_data:
return float(latest_data['close'])
elif hasattr(self.data_provider, 'get_current_price'):
return float(self.data_provider.get_current_price(symbol))
elif hasattr(self.data_provider, 'get_latest_candle'):
latest_candle = self.data_provider.get_latest_candle(symbol, '1m')
if latest_candle and 'close' in latest_candle:
return float(latest_candle['close'])
except Exception as e:
logger.debug(f"Could not get price from data provider: {e}")
# Try to get from universal adapter
if self.universal_adapter:
try:
data_stream = self.universal_adapter.get_latest_data(symbol)
if data_stream and hasattr(data_stream, 'current_price'):
return float(data_stream.current_price)
except Exception as e:
logger.debug(f"Could not get price from universal adapter: {e}")
# Fallback to default prices
default_prices = {
'ETH/USDT': 2500.0,
'BTC/USDT': 108000.0
}
return default_prices.get(symbol, 1000.0)
except Exception as e:
logger.error(f"Error getting current price for {symbol}: {e}")
# Return default price based on symbol
if 'ETH' in symbol:
return 2500.0
elif 'BTC' in symbol:
return 108000.0
else:
return 1000.0
def _generate_fallback_prediction(self, symbol: str) -> Dict[str, Any]:
"""Generate fallback prediction when models fail"""
try:
return {
'action': 'HOLD',
'confidence': 0.5,
'price': self._get_current_price(symbol) or 2500.0,
'timestamp': datetime.now(),
'model': 'fallback'
}
except Exception as e:
logger.debug(f"Error generating fallback prediction: {e}")
return {
'action': 'HOLD',
'confidence': 0.5,
'price': 2500.0,
'timestamp': datetime.now(),
'model': 'fallback'
}
def capture_dqn_prediction(self, symbol: str, action_idx: int, confidence: float, price: float, q_values: List[float] = None):
"""Capture DQN prediction for dashboard visualization"""
try:
if symbol not in self.recent_dqn_predictions:
self.recent_dqn_predictions[symbol] = deque(maxlen=100)
prediction_data = {
'timestamp': datetime.now(),
'action': ['SELL', 'HOLD', 'BUY'][action_idx],
'confidence': confidence,
'price': price,
'q_values': q_values or [0.33, 0.33, 0.34]
}
self.recent_dqn_predictions[symbol].append(prediction_data)
except Exception as e:
logger.debug(f"Error capturing DQN prediction: {e}")
def capture_cnn_prediction(self, symbol: str, direction: int, confidence: float, current_price: float, predicted_price: float):
"""Capture CNN prediction for dashboard visualization"""
try:
if symbol not in self.recent_cnn_predictions:
self.recent_cnn_predictions[symbol] = deque(maxlen=50)
prediction_data = {
'timestamp': datetime.now(),
'direction': ['DOWN', 'SAME', 'UP'][direction],
'confidence': confidence,
'current_price': current_price,
'predicted_price': predicted_price
}
self.recent_cnn_predictions[symbol].append(prediction_data)
except Exception as e:
logger.debug(f"Error capturing CNN prediction: {e}")

View File

@ -1454,9 +1454,10 @@ class EnhancedRealtimeTrainingSystem:
model.train() model.train()
optimizer.zero_grad() optimizer.zero_grad()
# Convert numpy arrays to PyTorch tensors # Convert numpy arrays to PyTorch tensors and move to device
features_tensor = torch.from_numpy(features).float() device = next(model.parameters()).device
targets_tensor = torch.from_numpy(targets).long() features_tensor = torch.from_numpy(features).float().to(device)
targets_tensor = torch.from_numpy(targets).long().to(device)
# Ensure features_tensor has the correct shape for CNN (batch_size, channels, height, width) # Ensure features_tensor has the correct shape for CNN (batch_size, channels, height, width)
# Assuming features are flattened (batch_size, 15*20) and need to be reshaped to (batch_size, 1, 15, 20) # Assuming features are flattened (batch_size, 15*20) and need to be reshaped to (batch_size, 1, 15, 20)
@ -1471,7 +1472,21 @@ class EnhancedRealtimeTrainingSystem:
# If the CNN expects (batch_size, channels, sequence_length) # If the CNN expects (batch_size, channels, sequence_length)
# features_tensor = features_tensor.view(features_tensor.shape[0], 1, 15 * 20) # Example for 1D CNN # features_tensor = features_tensor.view(features_tensor.shape[0], 1, 15 * 20) # Example for 1D CNN
# Let's assume the CNN expects 2D input (batch_size, flattened_features) # Ensure proper shape for CNN input
if len(features_tensor.shape) == 2:
# If it's (batch_size, features), keep as is for 1D CNN
pass
elif len(features_tensor.shape) == 1:
# If it's (features), add batch dimension
features_tensor = features_tensor.unsqueeze(0)
else:
# Reshape to (batch_size, features) if needed
features_tensor = features_tensor.view(features_tensor.shape[0], -1)
# Limit input size to prevent shape mismatches
if features_tensor.shape[1] > 1000: # Limit to 1000 features
features_tensor = features_tensor[:, :1000]
outputs = model(features_tensor) outputs = model(features_tensor)
loss = criterion(outputs, targets_tensor) loss = criterion(outputs, targets_tensor)
@ -1857,12 +1872,17 @@ class EnhancedRealtimeTrainingSystem:
and self.orchestrator.rl_agent): and self.orchestrator.rl_agent):
# Get Q-values from model # Get Q-values from model
q_values = self.orchestrator.rl_agent.act(current_state, return_q_values=True) action = self.orchestrator.rl_agent.act(current_state, explore=False)
if isinstance(q_values, tuple): # Get Q-values separately if available
action, q_vals = q_values if hasattr(self.orchestrator.rl_agent, 'policy_net'):
q_values = q_vals.tolist() if hasattr(q_vals, 'tolist') else [0, 0, 0] with torch.no_grad():
state_tensor = torch.FloatTensor(current_state).unsqueeze(0).to(self.orchestrator.rl_agent.device)
q_values_tensor = self.orchestrator.rl_agent.policy_net(state_tensor)
if isinstance(q_values_tensor, tuple):
q_values = q_values_tensor[0].cpu().numpy()[0].tolist()
else:
q_values = q_values_tensor.cpu().numpy()[0].tolist()
else: else:
action = q_values
q_values = [0.33, 0.33, 0.34] # Default uniform distribution q_values = [0.33, 0.33, 0.34] # Default uniform distribution
confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33 confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33