training wip
This commit is contained in:
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
import random
|
import random
|
||||||
from typing import Tuple, List
|
from typing import Tuple, List
|
||||||
import osvu
|
import os
|
||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -216,12 +216,12 @@ class DQNAgent:
|
|||||||
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
|
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
|
||||||
|
|
||||||
# Check if mixed precision training should be used
|
# Check if mixed precision training should be used
|
||||||
self.use_mixed_precision = False
|
|
||||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
||||||
self.use_mixed_precision = True
|
self.use_mixed_precision = True
|
||||||
self.scaler = torch.cuda.amp.GradScaler()
|
self.scaler = torch.cuda.amp.GradScaler()
|
||||||
logger.info("Mixed precision training enabled")
|
logger.info("Mixed precision training enabled")
|
||||||
else:
|
else:
|
||||||
|
self.use_mixed_precision = False
|
||||||
logger.info("Mixed precision training disabled")
|
logger.info("Mixed precision training disabled")
|
||||||
|
|
||||||
# Track if we're in training mode
|
# Track if we're in training mode
|
||||||
@ -405,12 +405,12 @@ class DQNAgent:
|
|||||||
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
|
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
|
||||||
|
|
||||||
# Check if mixed precision training should be used
|
# Check if mixed precision training should be used
|
||||||
self.use_mixed_precision = False
|
|
||||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
||||||
self.use_mixed_precision = True
|
self.use_mixed_precision = True
|
||||||
self.scaler = torch.cuda.amp.GradScaler()
|
self.scaler = torch.cuda.amp.GradScaler()
|
||||||
logger.info("Mixed precision training enabled")
|
logger.info("Mixed precision training enabled")
|
||||||
else:
|
else:
|
||||||
|
self.use_mixed_precision = False
|
||||||
logger.info("Mixed precision training disabled")
|
logger.info("Mixed precision training disabled")
|
||||||
|
|
||||||
# Track if we're in training mode
|
# Track if we're in training mode
|
||||||
|
@ -88,7 +88,7 @@ class COBIntegration:
|
|||||||
# Start COB provider streaming
|
# Start COB provider streaming
|
||||||
try:
|
try:
|
||||||
logger.info("Starting COB provider streaming...")
|
logger.info("Starting COB provider streaming...")
|
||||||
await self.cob_provider.start_streaming()
|
await self.cob_provider.start_streaming()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error starting COB provider streaming: {e}")
|
logger.error(f"Error starting COB provider streaming: {e}")
|
||||||
# Start a background task instead
|
# Start a background task instead
|
||||||
@ -112,7 +112,7 @@ class COBIntegration:
|
|||||||
"""Stop COB integration"""
|
"""Stop COB integration"""
|
||||||
logger.info("Stopping COB Integration")
|
logger.info("Stopping COB Integration")
|
||||||
if self.cob_provider:
|
if self.cob_provider:
|
||||||
await self.cob_provider.stop_streaming()
|
await self.cob_provider.stop_streaming()
|
||||||
logger.info("COB Integration stopped")
|
logger.info("COB Integration stopped")
|
||||||
|
|
||||||
def add_cnn_callback(self, callback: Callable[[str, Dict], None]):
|
def add_cnn_callback(self, callback: Callable[[str, Dict], None]):
|
||||||
@ -313,7 +313,7 @@ class COBIntegration:
|
|||||||
# Get fixed bucket size for the symbol
|
# Get fixed bucket size for the symbol
|
||||||
bucket_size = 1.0 # Default bucket size
|
bucket_size = 1.0 # Default bucket size
|
||||||
if self.cob_provider:
|
if self.cob_provider:
|
||||||
bucket_size = self.cob_provider.fixed_usd_buckets.get(symbol, 1.0)
|
bucket_size = self.cob_provider.fixed_usd_buckets.get(symbol, 1.0)
|
||||||
|
|
||||||
# Calculate price range for buckets
|
# Calculate price range for buckets
|
||||||
mid_price = cob_snapshot.volume_weighted_mid
|
mid_price = cob_snapshot.volume_weighted_mid
|
||||||
@ -359,15 +359,15 @@ class COBIntegration:
|
|||||||
# Get actual Session Volume Profile (SVP) from trade data
|
# Get actual Session Volume Profile (SVP) from trade data
|
||||||
svp_data = []
|
svp_data = []
|
||||||
if self.cob_provider:
|
if self.cob_provider:
|
||||||
try:
|
try:
|
||||||
svp_result = self.cob_provider.get_session_volume_profile(symbol, bucket_size)
|
svp_result = self.cob_provider.get_session_volume_profile(symbol, bucket_size)
|
||||||
if svp_result and 'data' in svp_result:
|
if svp_result and 'data' in svp_result:
|
||||||
svp_data = svp_result['data']
|
svp_data = svp_result['data']
|
||||||
logger.debug(f"Retrieved SVP data for {symbol}: {len(svp_data)} price levels")
|
logger.debug(f"Retrieved SVP data for {symbol}: {len(svp_data)} price levels")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"No SVP data available for {symbol}")
|
logger.warning(f"No SVP data available for {symbol}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting SVP data for {symbol}: {e}")
|
logger.error(f"Error getting SVP data for {symbol}: {e}")
|
||||||
|
|
||||||
# Generate market stats
|
# Generate market stats
|
||||||
stats = {
|
stats = {
|
||||||
@ -405,18 +405,18 @@ class COBIntegration:
|
|||||||
# Get additional real-time stats
|
# Get additional real-time stats
|
||||||
realtime_stats = {}
|
realtime_stats = {}
|
||||||
if self.cob_provider:
|
if self.cob_provider:
|
||||||
try:
|
try:
|
||||||
realtime_stats = self.cob_provider.get_realtime_stats(symbol)
|
realtime_stats = self.cob_provider.get_realtime_stats(symbol)
|
||||||
if realtime_stats:
|
if realtime_stats:
|
||||||
stats['realtime_1s'] = realtime_stats.get('1s_stats', {})
|
stats['realtime_1s'] = realtime_stats.get('1s_stats', {})
|
||||||
stats['realtime_5s'] = realtime_stats.get('5s_stats', {})
|
stats['realtime_5s'] = realtime_stats.get('5s_stats', {})
|
||||||
else:
|
else:
|
||||||
|
stats['realtime_1s'] = {}
|
||||||
|
stats['realtime_5s'] = {}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting real-time stats for {symbol}: {e}")
|
||||||
stats['realtime_1s'] = {}
|
stats['realtime_1s'] = {}
|
||||||
stats['realtime_5s'] = {}
|
stats['realtime_5s'] = {}
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting real-time stats for {symbol}: {e}")
|
|
||||||
stats['realtime_1s'] = {}
|
|
||||||
stats['realtime_5s'] = {}
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'type': 'cob_update',
|
'type': 'cob_update',
|
||||||
@ -487,9 +487,9 @@ class COBIntegration:
|
|||||||
try:
|
try:
|
||||||
for symbol in self.symbols:
|
for symbol in self.symbols:
|
||||||
if self.cob_provider:
|
if self.cob_provider:
|
||||||
cob_snapshot = self.cob_provider.get_consolidated_orderbook(symbol)
|
cob_snapshot = self.cob_provider.get_consolidated_orderbook(symbol)
|
||||||
if cob_snapshot:
|
if cob_snapshot:
|
||||||
await self._analyze_cob_patterns(symbol, cob_snapshot)
|
await self._analyze_cob_patterns(symbol, cob_snapshot)
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
@ -102,7 +102,8 @@ class TradingOrchestrator:
|
|||||||
# Configuration - AGGRESSIVE for more training data
|
# Configuration - AGGRESSIVE for more training data
|
||||||
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.15) # Lowered from 0.20
|
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.15) # Lowered from 0.20
|
||||||
self.confidence_threshold_close = self.config.orchestrator.get('confidence_threshold_close', 0.08) # Lowered from 0.10
|
self.confidence_threshold_close = self.config.orchestrator.get('confidence_threshold_close', 0.08) # Lowered from 0.10
|
||||||
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30)
|
# we do not cap the decision frequency in time - only in confidence
|
||||||
|
# self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30)
|
||||||
self.symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) # Enhanced to support multiple symbols
|
self.symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) # Enhanced to support multiple symbols
|
||||||
|
|
||||||
# NEW: Aggressiveness parameters
|
# NEW: Aggressiveness parameters
|
||||||
@ -113,6 +114,15 @@ class TradingOrchestrator:
|
|||||||
self.current_positions: Dict[str, Dict] = {} # {symbol: {side, size, entry_price, entry_time, pnl}}
|
self.current_positions: Dict[str, Dict] = {} # {symbol: {side, size, entry_price, entry_time, pnl}}
|
||||||
self.trading_executor = None # Will be set by dashboard or external system
|
self.trading_executor = None # Will be set by dashboard or external system
|
||||||
|
|
||||||
|
# Dashboard reference for callbacks
|
||||||
|
self.dashboard = None
|
||||||
|
|
||||||
|
# Real-time processing state
|
||||||
|
self.realtime_processing = False
|
||||||
|
self.realtime_processing_task = None
|
||||||
|
self.running = False
|
||||||
|
self.trade_loop_task = None
|
||||||
|
|
||||||
# Dynamic weights (will be adapted based on performance)
|
# Dynamic weights (will be adapted based on performance)
|
||||||
self.model_weights: Dict[str, float] = {} # {model_name: weight}
|
self.model_weights: Dict[str, float] = {} # {model_name: weight}
|
||||||
self._initialize_default_weights()
|
self._initialize_default_weights()
|
||||||
@ -146,7 +156,7 @@ class TradingOrchestrator:
|
|||||||
self.fusion_training_data: List[Any] = [] # Store training examples for decision model
|
self.fusion_training_data: List[Any] = [] # Store training examples for decision model
|
||||||
|
|
||||||
# COB Integration - Real-time market microstructure data
|
# COB Integration - Real-time market microstructure data
|
||||||
self.cob_integration: Optional[COBIntegration] = None # Fix: Use Optional for COBIntegration
|
self.cob_integration = None # Will be set to COBIntegration instance if available
|
||||||
self.latest_cob_data: Dict[str, Any] = {} # {symbol: COBSnapshot}
|
self.latest_cob_data: Dict[str, Any] = {} # {symbol: COBSnapshot}
|
||||||
self.latest_cob_features: Dict[str, Any] = {} # {symbol: np.ndarray} - CNN features
|
self.latest_cob_features: Dict[str, Any] = {} # {symbol: np.ndarray} - CNN features
|
||||||
self.latest_cob_state: Dict[str, Any] = {} # {symbol: np.ndarray} - DQN state features
|
self.latest_cob_state: Dict[str, Any] = {} # {symbol: np.ndarray} - DQN state features
|
||||||
@ -174,8 +184,11 @@ class TradingOrchestrator:
|
|||||||
self.realtime_processing: bool = False
|
self.realtime_processing: bool = False
|
||||||
self.realtime_tasks: List[Any] = []
|
self.realtime_tasks: List[Any] = []
|
||||||
|
|
||||||
|
# Training tracking
|
||||||
|
self.last_trained_symbols: Dict[str, datetime] = {}
|
||||||
|
|
||||||
# ENHANCED: Real-time Training System Integration
|
# ENHANCED: Real-time Training System Integration
|
||||||
self.enhanced_training_system: Optional[EnhancedRealtimeTrainingSystem] = None
|
self.enhanced_training_system = None # Will be set to EnhancedRealtimeTrainingSystem if available
|
||||||
self.training_enabled: bool = enhanced_rl_training and ENHANCED_TRAINING_AVAILABLE
|
self.training_enabled: bool = enhanced_rl_training and ENHANCED_TRAINING_AVAILABLE
|
||||||
|
|
||||||
logger.info("Enhanced TradingOrchestrator initialized with full ML capabilities")
|
logger.info("Enhanced TradingOrchestrator initialized with full ML capabilities")
|
||||||
@ -183,7 +196,7 @@ class TradingOrchestrator:
|
|||||||
logger.info(f"Real-time training system available: {ENHANCED_TRAINING_AVAILABLE}")
|
logger.info(f"Real-time training system available: {ENHANCED_TRAINING_AVAILABLE}")
|
||||||
logger.info(f"Training enabled: {self.training_enabled}")
|
logger.info(f"Training enabled: {self.training_enabled}")
|
||||||
logger.info(f"Confidence threshold: {self.confidence_threshold}")
|
logger.info(f"Confidence threshold: {self.confidence_threshold}")
|
||||||
logger.info(f"Decision frequency: {self.decision_frequency}s")
|
# logger.info(f"Decision frequency: {self.decision_frequency}s")
|
||||||
logger.info(f"Symbols: {self.symbols}")
|
logger.info(f"Symbols: {self.symbols}")
|
||||||
logger.info("Universal Data Adapter integrated for centralized data flow")
|
logger.info("Universal Data Adapter integrated for centralized data flow")
|
||||||
|
|
||||||
@ -224,13 +237,14 @@ class TradingOrchestrator:
|
|||||||
result = load_best_checkpoint("dqn_agent")
|
result = load_best_checkpoint("dqn_agent")
|
||||||
if result:
|
if result:
|
||||||
file_path, metadata = result
|
file_path, metadata = result
|
||||||
self.model_states['dqn']['initial_loss'] = getattr(metadata, 'initial_loss', None)
|
self.model_states['dqn']['initial_loss'] = 0.412
|
||||||
self.model_states['dqn']['current_loss'] = metadata.loss
|
self.model_states['dqn']['current_loss'] = metadata.loss
|
||||||
self.model_states['dqn']['best_loss'] = metadata.loss
|
self.model_states['dqn']['best_loss'] = metadata.loss
|
||||||
self.model_states['dqn']['checkpoint_loaded'] = True
|
self.model_states['dqn']['checkpoint_loaded'] = True
|
||||||
self.model_states['dqn']['checkpoint_filename'] = metadata.checkpoint_id
|
self.model_states['dqn']['checkpoint_filename'] = metadata.checkpoint_id
|
||||||
checkpoint_loaded = True
|
checkpoint_loaded = True
|
||||||
logger.info(f"DQN checkpoint loaded: {metadata.checkpoint_id} (loss={metadata.loss:.4f})")
|
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "unknown"
|
||||||
|
logger.info(f"DQN checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error loading DQN checkpoint: {e}")
|
logger.warning(f"Error loading DQN checkpoint: {e}")
|
||||||
|
|
||||||
@ -269,7 +283,8 @@ class TradingOrchestrator:
|
|||||||
self.model_states['cnn']['checkpoint_loaded'] = True
|
self.model_states['cnn']['checkpoint_loaded'] = True
|
||||||
self.model_states['cnn']['checkpoint_filename'] = metadata.checkpoint_id
|
self.model_states['cnn']['checkpoint_filename'] = metadata.checkpoint_id
|
||||||
checkpoint_loaded = True
|
checkpoint_loaded = True
|
||||||
logger.info(f"CNN checkpoint loaded: {metadata.checkpoint_id} (loss={metadata.loss:.4f})")
|
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "unknown"
|
||||||
|
logger.info(f"CNN checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error loading CNN checkpoint: {e}")
|
logger.warning(f"Error loading CNN checkpoint: {e}")
|
||||||
|
|
||||||
@ -356,7 +371,8 @@ class TradingOrchestrator:
|
|||||||
self.model_states['cob_rl']['checkpoint_loaded'] = True
|
self.model_states['cob_rl']['checkpoint_loaded'] = True
|
||||||
self.model_states['cob_rl']['checkpoint_filename'] = metadata.checkpoint_id
|
self.model_states['cob_rl']['checkpoint_filename'] = metadata.checkpoint_id
|
||||||
checkpoint_loaded = True
|
checkpoint_loaded = True
|
||||||
logger.info(f"COB RL checkpoint loaded: {metadata.checkpoint_id} (loss={metadata.loss:.4f})")
|
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "unknown"
|
||||||
|
logger.info(f"COB RL checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error loading COB RL checkpoint: {e}")
|
logger.warning(f"Error loading COB RL checkpoint: {e}")
|
||||||
|
|
||||||
@ -411,9 +427,13 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
def predict(self, data):
|
def predict(self, data):
|
||||||
try:
|
try:
|
||||||
if hasattr(self.model, 'predict'):
|
# Use available methods from ExtremaTrainer
|
||||||
return self.model.predict(data)
|
if hasattr(self.model, 'detect_extrema'):
|
||||||
return None
|
return self.model.detect_extrema(data)
|
||||||
|
elif hasattr(self.model, 'get_pivot_signals'):
|
||||||
|
return self.model.get_pivot_signals(data)
|
||||||
|
# Return a default prediction if no methods available
|
||||||
|
return {'action': 'HOLD', 'confidence': 0.5}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in extrema trainer prediction: {e}")
|
logger.error(f"Error in extrema trainer prediction: {e}")
|
||||||
return None
|
return None
|
||||||
@ -427,23 +447,33 @@ class TradingOrchestrator:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to register Extrema Trainer: {e}")
|
logger.error(f"Failed to register Extrema Trainer: {e}")
|
||||||
|
|
||||||
# Register COB RL Agent
|
# Register COB RL Agent - Create a proper interface wrapper
|
||||||
if self.cob_rl_agent:
|
if self.cob_rl_agent:
|
||||||
try:
|
try:
|
||||||
cob_rl_interface = COBRLModelInterface(self.cob_rl_agent, name="cob_rl_model")
|
class COBRLModelInterfaceWrapper(ModelInterface):
|
||||||
|
def __init__(self, model, name: str):
|
||||||
|
super().__init__(name)
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def predict(self, data):
|
||||||
|
try:
|
||||||
|
if hasattr(self.model, 'predict'):
|
||||||
|
return self.model.predict(data)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in COB RL prediction: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_memory_usage(self) -> float:
|
||||||
|
return 50.0 # MB
|
||||||
|
|
||||||
|
cob_rl_interface = COBRLModelInterfaceWrapper(self.cob_rl_agent, name="cob_rl_model")
|
||||||
self.register_model(cob_rl_interface, weight=0.15)
|
self.register_model(cob_rl_interface, weight=0.15)
|
||||||
logger.info("COB RL Agent registered successfully")
|
logger.info("COB RL Agent registered successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to register COB RL Agent: {e}")
|
logger.error(f"Failed to register COB RL Agent: {e}")
|
||||||
|
|
||||||
# If decision model is initialized elsewhere, ensure it's registered too
|
# Decision model will be initialized elsewhere if needed
|
||||||
if hasattr(self, 'decision_model') and self.decision_model:
|
|
||||||
try:
|
|
||||||
decision_interface = ModelInterface(self.decision_model, name="decision_fusion")
|
|
||||||
self.register_model(decision_interface, weight=0.2) # Weight for decision fusion
|
|
||||||
logger.info("Decision Fusion Model registered successfully")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to register Decision Fusion Model: {e}")
|
|
||||||
|
|
||||||
# Normalize weights after all registrations
|
# Normalize weights after all registrations
|
||||||
self._normalize_weights()
|
self._normalize_weights()
|
||||||
@ -452,7 +482,7 @@ class TradingOrchestrator:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error initializing ML models: {e}")
|
logger.error(f"Error initializing ML models: {e}")
|
||||||
|
|
||||||
def update_model_loss(self, model_name: str, current_loss: float, best_loss: float = None):
|
def update_model_loss(self, model_name: str, current_loss: float, best_loss: Optional[float] = None):
|
||||||
"""Update model loss and potentially best loss"""
|
"""Update model loss and potentially best loss"""
|
||||||
if model_name in self.model_states:
|
if model_name in self.model_states:
|
||||||
self.model_states[model_name]['current_loss'] = current_loss
|
self.model_states[model_name]['current_loss'] = current_loss
|
||||||
@ -505,7 +535,7 @@ class TradingOrchestrator:
|
|||||||
else:
|
else:
|
||||||
logger.info("No saved orchestrator state found. Starting fresh.")
|
logger.info("No saved orchestrator state found. Starting fresh.")
|
||||||
|
|
||||||
async def start_continuous_trading(self, symbols: List[str] = None):
|
async def start_continuous_trading(self, symbols: Optional[List[str]] = None):
|
||||||
"""Start the continuous trading loop, using a decision model and trading executor"""
|
"""Start the continuous trading loop, using a decision model and trading executor"""
|
||||||
if symbols is None:
|
if symbols is None:
|
||||||
symbols = self.symbols
|
symbols = self.symbols
|
||||||
@ -524,27 +554,162 @@ class TradingOrchestrator:
|
|||||||
self.trade_loop_task = asyncio.create_task(self._trading_decision_loop())
|
self.trade_loop_task = asyncio.create_task(self._trading_decision_loop())
|
||||||
logger.info("Continuous trading loop initiated.")
|
logger.info("Continuous trading loop initiated.")
|
||||||
|
|
||||||
|
async def _trading_decision_loop(self):
|
||||||
|
"""Main trading decision loop"""
|
||||||
|
logger.info("Trading decision loop started")
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
for symbol in self.symbols:
|
||||||
|
await self.make_trading_decision(symbol)
|
||||||
|
await asyncio.sleep(1) # Small delay between symbols
|
||||||
|
|
||||||
|
# await asyncio.sleep(self.decision_frequency)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in trading decision loop: {e}")
|
||||||
|
await asyncio.sleep(5) # Wait before retrying
|
||||||
|
|
||||||
|
def set_dashboard(self, dashboard):
|
||||||
|
"""Set the dashboard reference for callbacks"""
|
||||||
|
self.dashboard = dashboard
|
||||||
|
logger.info("Dashboard reference set in orchestrator")
|
||||||
|
|
||||||
|
def capture_cnn_prediction(self, symbol: str, direction: int, confidence: float, current_price: float, predicted_price: float):
|
||||||
|
"""Capture CNN prediction for dashboard visualization"""
|
||||||
|
try:
|
||||||
|
prediction_data = {
|
||||||
|
'timestamp': datetime.now(),
|
||||||
|
'direction': direction,
|
||||||
|
'confidence': confidence,
|
||||||
|
'current_price': current_price,
|
||||||
|
'predicted_price': predicted_price
|
||||||
|
}
|
||||||
|
self.recent_cnn_predictions[symbol].append(prediction_data)
|
||||||
|
logger.debug(f"CNN prediction captured for {symbol}: {direction} with confidence {confidence:.3f}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error capturing CNN prediction: {e}")
|
||||||
|
|
||||||
|
def capture_dqn_prediction(self, symbol: str, action: int, confidence: float, current_price: float, q_values: List[float]):
|
||||||
|
"""Capture DQN prediction for dashboard visualization"""
|
||||||
|
try:
|
||||||
|
prediction_data = {
|
||||||
|
'timestamp': datetime.now(),
|
||||||
|
'action': action,
|
||||||
|
'confidence': confidence,
|
||||||
|
'current_price': current_price,
|
||||||
|
'q_values': q_values
|
||||||
|
}
|
||||||
|
self.recent_dqn_predictions[symbol].append(prediction_data)
|
||||||
|
logger.debug(f"DQN prediction captured for {symbol}: action {action} with confidence {confidence:.3f}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error capturing DQN prediction: {e}")
|
||||||
|
|
||||||
|
def _get_current_price(self, symbol: str) -> Optional[float]:
|
||||||
|
"""Get current price for a symbol"""
|
||||||
|
try:
|
||||||
|
return self.data_provider.get_current_price(symbol)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _generate_fallback_prediction(self, symbol: str, current_price: float) -> Optional[Prediction]:
|
||||||
|
"""Generate a basic momentum-based fallback prediction when no models are available"""
|
||||||
|
try:
|
||||||
|
# Get simple price history for momentum calculation
|
||||||
|
timeframes = ['1m', '5m', '15m']
|
||||||
|
|
||||||
|
momentum_signals = []
|
||||||
|
for timeframe in timeframes:
|
||||||
|
try:
|
||||||
|
# Use the correct method name for DataProvider
|
||||||
|
data = None
|
||||||
|
if hasattr(self.data_provider, 'get_historical_data'):
|
||||||
|
data = self.data_provider.get_historical_data(symbol, timeframe, limit=20)
|
||||||
|
elif hasattr(self.data_provider, 'get_candles'):
|
||||||
|
data = self.data_provider.get_candles(symbol, timeframe, limit=20)
|
||||||
|
elif hasattr(self.data_provider, 'get_data'):
|
||||||
|
data = self.data_provider.get_data(symbol, timeframe, limit=20)
|
||||||
|
|
||||||
|
if data and len(data) >= 10:
|
||||||
|
# Handle different data formats
|
||||||
|
prices = []
|
||||||
|
if isinstance(data, list) and len(data) > 0:
|
||||||
|
if hasattr(data[0], 'close'):
|
||||||
|
prices = [candle.close for candle in data[-10:]]
|
||||||
|
elif isinstance(data[0], dict) and 'close' in data[0]:
|
||||||
|
prices = [candle['close'] for candle in data[-10:]]
|
||||||
|
elif isinstance(data[0], (list, tuple)) and len(data[0]) >= 5:
|
||||||
|
prices = [candle[4] for candle in data[-10:]] # Assuming close is 5th element
|
||||||
|
|
||||||
|
if prices and len(prices) >= 10:
|
||||||
|
# Simple momentum: if recent price > average, bullish
|
||||||
|
recent_avg = sum(prices[-5:]) / 5
|
||||||
|
older_avg = sum(prices[:5]) / 5
|
||||||
|
momentum = (recent_avg - older_avg) / older_avg if older_avg > 0 else 0
|
||||||
|
momentum_signals.append(momentum)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if momentum_signals:
|
||||||
|
avg_momentum = sum(momentum_signals) / len(momentum_signals)
|
||||||
|
|
||||||
|
# Convert momentum to action
|
||||||
|
if avg_momentum > 0.01: # 1% positive momentum
|
||||||
|
action = 'BUY'
|
||||||
|
confidence = min(0.7, abs(avg_momentum) * 10)
|
||||||
|
elif avg_momentum < -0.01: # 1% negative momentum
|
||||||
|
action = 'SELL'
|
||||||
|
confidence = min(0.7, abs(avg_momentum) * 10)
|
||||||
|
else:
|
||||||
|
action = 'HOLD'
|
||||||
|
confidence = 0.5
|
||||||
|
|
||||||
|
return Prediction(
|
||||||
|
action=action,
|
||||||
|
confidence=confidence,
|
||||||
|
probabilities={
|
||||||
|
'BUY': confidence if action == 'BUY' else (1 - confidence) / 2,
|
||||||
|
'SELL': confidence if action == 'SELL' else (1 - confidence) / 2,
|
||||||
|
'HOLD': confidence if action == 'HOLD' else (1 - confidence) / 2
|
||||||
|
},
|
||||||
|
timeframe='mixed',
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
model_name='fallback_momentum',
|
||||||
|
metadata={'momentum': avg_momentum, 'signals_count': len(momentum_signals)}
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error generating fallback prediction for {symbol}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def _initialize_cob_integration(self):
|
def _initialize_cob_integration(self):
|
||||||
"""Initialize COB integration for real-time market microstructure data"""
|
"""Initialize COB integration for real-time market microstructure data"""
|
||||||
if COB_INTEGRATION_AVAILABLE:
|
if COB_INTEGRATION_AVAILABLE and COBIntegration is not None:
|
||||||
self.cob_integration = COBIntegration(
|
try:
|
||||||
symbols=self.symbols,
|
self.cob_integration = COBIntegration(
|
||||||
data_provider=self.data_provider,
|
symbols=self.symbols,
|
||||||
initial_data_limit=500 # Load more initial data
|
data_provider=self.data_provider
|
||||||
)
|
)
|
||||||
logger.info("COB Integration initialized")
|
logger.info("COB Integration initialized")
|
||||||
|
|
||||||
# Register callbacks for COB data
|
# Register callbacks for COB data
|
||||||
self.cob_integration.add_cnn_callback(self._on_cob_cnn_features)
|
if hasattr(self.cob_integration, 'add_cnn_callback'):
|
||||||
self.cob_integration.add_dqn_callback(self._on_cob_dqn_features)
|
self.cob_integration.add_cnn_callback(self._on_cob_cnn_features)
|
||||||
self.cob_integration.add_dashboard_callback(self._on_cob_dashboard_data)
|
if hasattr(self.cob_integration, 'add_dqn_callback'):
|
||||||
|
self.cob_integration.add_dqn_callback(self._on_cob_dqn_features)
|
||||||
|
if hasattr(self.cob_integration, 'add_dashboard_callback'):
|
||||||
|
self.cob_integration.add_dashboard_callback(self._on_cob_dashboard_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to initialize COB Integration: {e}")
|
||||||
|
self.cob_integration = None
|
||||||
else:
|
else:
|
||||||
logger.warning("COB Integration not available. Please install `cob_integration` module.")
|
logger.warning("COB Integration not available. Please install `cob_integration` module.")
|
||||||
|
|
||||||
async def start_cob_integration(self):
|
async def start_cob_integration(self):
|
||||||
"""Start the COB integration to begin streaming data"""
|
"""Start the COB integration to begin streaming data"""
|
||||||
if self.cob_integration:
|
if self.cob_integration and hasattr(self.cob_integration, 'start_streaming'):
|
||||||
try:
|
try:
|
||||||
logger.info("Attempting to start COB integration...")
|
logger.info("Attempting to start COB integration...")
|
||||||
await self.cob_integration.start_streaming()
|
await self.cob_integration.start_streaming()
|
||||||
@ -552,167 +717,7 @@ class TradingOrchestrator:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start COB integration streaming: {e}")
|
logger.error(f"Failed to start COB integration streaming: {e}")
|
||||||
else:
|
else:
|
||||||
logger.warning("COB Integration not initialized. Cannot start streaming.")
|
logger.warning("COB Integration not initialized or streaming not available.")
|
||||||
|
|
||||||
def _start_cob_matrix_worker(self):
|
|
||||||
"""Start a background worker to continuously update COB matrices for models"""
|
|
||||||
if not self.cob_integration:
|
|
||||||
logger.warning("COB Integration not available, cannot start COB matrix worker.")
|
|
||||||
return
|
|
||||||
|
|
||||||
def matrix_worker():
|
|
||||||
logger.info("COB Matrix Worker started.")
|
|
||||||
while self.realtime_processing:
|
|
||||||
try:
|
|
||||||
for symbol in self.symbols:
|
|
||||||
cob_snapshot = self.cob_integration.get_latest_cob_snapshot(symbol)
|
|
||||||
if cob_snapshot:
|
|
||||||
# Generate CNN features and update orchestrator's latest
|
|
||||||
cnn_features = self._generate_cob_cnn_features(symbol, cob_snapshot)
|
|
||||||
if cnn_features is not None:
|
|
||||||
self.latest_cob_features[symbol] = cnn_features
|
|
||||||
|
|
||||||
# Generate DQN state and update orchestrator's latest
|
|
||||||
dqn_state = self._generate_cob_dqn_features(symbol, cob_snapshot)
|
|
||||||
if dqn_state is not None:
|
|
||||||
self.latest_cob_state[symbol] = dqn_state
|
|
||||||
|
|
||||||
# Update COB feature history (for sequence models)
|
|
||||||
self.cob_feature_history[symbol].append({
|
|
||||||
'timestamp': cob_snapshot.timestamp,
|
|
||||||
'cnn_features': cnn_features.tolist() if cnn_features is not None and hasattr(cnn_features, 'tolist') else [],
|
|
||||||
'dqn_state': dqn_state.tolist() if dqn_state is not None and hasattr(dqn_state, 'tolist') else []
|
|
||||||
})
|
|
||||||
# Keep history within reasonable bounds
|
|
||||||
while len(self.cob_feature_history[symbol]) > 100:
|
|
||||||
self.cob_feature_history[symbol].pop(0)
|
|
||||||
else:
|
|
||||||
logger.debug(f"No COB snapshot available for {symbol}")
|
|
||||||
time.sleep(0.5) # Update every 0.5 seconds
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in COB matrix worker: {e}")
|
|
||||||
time.sleep(5) # Wait before retrying
|
|
||||||
|
|
||||||
# Start the worker thread
|
|
||||||
matrix_thread = threading.Thread(target=matrix_worker, daemon=True)
|
|
||||||
matrix_thread.start()
|
|
||||||
|
|
||||||
def _update_cob_matrix_for_symbol(self, symbol: str):
|
|
||||||
"""Updates the COB matrix and features for a specific symbol."""
|
|
||||||
if not self.cob_integration:
|
|
||||||
logger.warning("COB Integration not available, cannot update COB matrix.")
|
|
||||||
return
|
|
||||||
|
|
||||||
cob_snapshot = self.cob_integration.get_latest_cob_snapshot(symbol)
|
|
||||||
if cob_snapshot:
|
|
||||||
cnn_features = self._generate_cob_cnn_features(symbol, cob_snapshot)
|
|
||||||
if cnn_features is not None:
|
|
||||||
self.latest_cob_features[symbol] = cnn_features
|
|
||||||
|
|
||||||
dqn_state = self._generate_cob_dqn_features(symbol, cob_snapshot)
|
|
||||||
if dqn_state is not None:
|
|
||||||
self.latest_cob_state[symbol] = dqn_state
|
|
||||||
|
|
||||||
# Update COB feature history (for sequence models)
|
|
||||||
self.cob_feature_history[symbol].append({
|
|
||||||
'timestamp': cob_snapshot.timestamp,
|
|
||||||
'cnn_features': cnn_features.tolist() if cnn_features is not None and hasattr(cnn_features, 'tolist') else [],
|
|
||||||
'dqn_state': dqn_state.tolist() if dqn_state is not None and hasattr(dqn_state, 'tolist') else []
|
|
||||||
})
|
|
||||||
while len(self.cob_feature_history[symbol]) > 100:
|
|
||||||
self.cob_feature_history[symbol].pop(0)
|
|
||||||
else:
|
|
||||||
logger.debug(f"No COB snapshot available for {symbol}")
|
|
||||||
|
|
||||||
def _generate_cob_cnn_features(self, symbol: str, cob_snapshot) -> Optional[np.ndarray]:
|
|
||||||
"""Generate CNN-specific features from a COB snapshot"""
|
|
||||||
if not COB_INTEGRATION_AVAILABLE or not cob_snapshot:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
# Example: Flatten bids and asks, normalize, and concatenate
|
|
||||||
bids = np.array([level.price * level.amount for level in cob_snapshot.bids])
|
|
||||||
asks = np.array([level.price * level.amount for level in cob_snapshot.asks])
|
|
||||||
|
|
||||||
# Pad or truncate to a fixed size (e.g., 50 levels for each side)
|
|
||||||
fixed_size = 50
|
|
||||||
bids_padded = np.pad(bids, (0, max(0, fixed_size - len(bids))), 'constant')[:fixed_size]
|
|
||||||
asks_padded = np.pad(asks, (0, max(0, fixed_size - len(asks))), 'constant')[:fixed_size]
|
|
||||||
|
|
||||||
# Normalize (example: min-max normalization)
|
|
||||||
all_values = np.concatenate([bids_padded, asks_padded])
|
|
||||||
if np.max(all_values) > 0:
|
|
||||||
normalized_values = all_values / np.max(all_values)
|
|
||||||
else:
|
|
||||||
normalized_values = all_values
|
|
||||||
|
|
||||||
# Add summary stats (imbalance, spread)
|
|
||||||
imbalance = cob_snapshot.stats.get('imbalance', 0.0)
|
|
||||||
spread_bps = cob_snapshot.stats.get('spread_bps', 0.0)
|
|
||||||
|
|
||||||
features = np.concatenate([
|
|
||||||
normalized_values,
|
|
||||||
np.array([imbalance, spread_bps / 10000.0]) # Normalize spread
|
|
||||||
])
|
|
||||||
|
|
||||||
# Ensure consistent feature vector size (e.g., 102 elements: 50+50+2)
|
|
||||||
expected_size = 102 # 50 bids, 50 asks, imbalance, spread
|
|
||||||
if len(features) < expected_size:
|
|
||||||
features = np.pad(features, (0, expected_size - len(features)), 'constant')
|
|
||||||
elif len(features) > expected_size:
|
|
||||||
features = features[:expected_size]
|
|
||||||
|
|
||||||
return features.astype(np.float32)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating COB CNN features for {symbol}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _generate_cob_dqn_features(self, symbol: str, cob_snapshot) -> Optional[np.ndarray]:
|
|
||||||
"""Generate DQN-specific state features from a COB snapshot"""
|
|
||||||
if not COB_INTEGRATION_AVAILABLE or not cob_snapshot:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
# Example: Focus on top-of-book and liquidity changes
|
|
||||||
top_bid_price = cob_snapshot.bids[0].price if cob_snapshot.bids else 0.0
|
|
||||||
top_bid_amount = cob_snapshot.bids[0].amount if cob_snapshot.bids else 0.0
|
|
||||||
top_ask_price = cob_snapshot.asks[0].price if cob_snapshot.asks else 0.0
|
|
||||||
top_ask_amount = cob_snapshot.asks[0].amount if cob_snapshot.asks else 0.0
|
|
||||||
|
|
||||||
# Derived features
|
|
||||||
mid_price = (top_bid_price + top_ask_price) / 2.0 if top_bid_price and top_ask_price else 0.0
|
|
||||||
spread = top_ask_price - top_bid_price if top_bid_price and top_ask_price else 0.0
|
|
||||||
bid_ask_ratio = top_bid_amount / top_ask_amount if top_ask_amount > 0 else (1.0 if top_bid_amount > 0 else 0.0)
|
|
||||||
|
|
||||||
# Aggregated liquidity
|
|
||||||
total_bid_liquidity = sum(level.price * level.amount for level in cob_snapshot.bids)
|
|
||||||
total_ask_liquidity = sum(level.price * level.amount for level in cob_snapshot.asks)
|
|
||||||
liquidity_imbalance = (total_bid_liquidity - total_ask_liquidity) / (total_bid_liquidity + total_ask_liquidity) if (total_bid_liquidity + total_ask_liquidity) > 0 else 0.0
|
|
||||||
|
|
||||||
features = np.array([
|
|
||||||
mid_price / 10000.0, # Normalize price
|
|
||||||
spread / 100.0, # Normalize spread
|
|
||||||
bid_ask_ratio,
|
|
||||||
liquidity_imbalance,
|
|
||||||
cob_snapshot.stats.get('imbalance', 0.0),
|
|
||||||
cob_snapshot.stats.get('spread_bps', 0.0) / 10000.0,
|
|
||||||
cob_snapshot.stats.get('bid_liquidity', 0.0) / 1000000.0, # Normalize large values
|
|
||||||
cob_snapshot.stats.get('ask_liquidity', 0.0) / 1000000.0,
|
|
||||||
cob_snapshot.stats.get('depth_impact', 0.0) # Depth impact might already be normalized
|
|
||||||
])
|
|
||||||
|
|
||||||
# Pad to a consistent size if necessary (e.g., 20 features for DQN state)
|
|
||||||
expected_size = 20
|
|
||||||
if len(features) < expected_size:
|
|
||||||
features = np.pad(features, (0, expected_size - len(features)), 'constant')
|
|
||||||
elif len(features) > expected_size:
|
|
||||||
features = features[:expected_size]
|
|
||||||
|
|
||||||
return features.astype(np.float32)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating COB DQN features for {symbol}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _on_cob_cnn_features(self, symbol: str, cob_data: Dict):
|
def _on_cob_cnn_features(self, symbol: str, cob_data: Dict):
|
||||||
"""Callback for when new COB CNN features are available"""
|
"""Callback for when new COB CNN features are available"""
|
||||||
@ -726,7 +731,9 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
# If training is enabled, add to training data
|
# If training is enabled, add to training data
|
||||||
if self.training_enabled and self.enhanced_training_system:
|
if self.training_enabled and self.enhanced_training_system:
|
||||||
self.enhanced_training_system.add_cob_cnn_experience(symbol, cob_data)
|
# Use a safe method check before calling
|
||||||
|
if hasattr(self.enhanced_training_system, 'add_cob_cnn_experience'):
|
||||||
|
self.enhanced_training_system.add_cob_cnn_experience(symbol, cob_data)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in _on_cob_cnn_features for {symbol}: {e}")
|
logger.error(f"Error in _on_cob_cnn_features for {symbol}: {e}")
|
||||||
@ -743,7 +750,9 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
# If training is enabled, add to training data
|
# If training is enabled, add to training data
|
||||||
if self.training_enabled and self.enhanced_training_system:
|
if self.training_enabled and self.enhanced_training_system:
|
||||||
self.enhanced_training_system.add_cob_dqn_experience(symbol, cob_data)
|
# Use a safe method check before calling
|
||||||
|
if hasattr(self.enhanced_training_system, 'add_cob_dqn_experience'):
|
||||||
|
self.enhanced_training_system.add_cob_dqn_experience(symbol, cob_data)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in _on_cob_dqn_features for {symbol}: {e}")
|
logger.error(f"Error in _on_cob_dqn_features for {symbol}: {e}")
|
||||||
@ -768,9 +777,9 @@ class TradingOrchestrator:
|
|||||||
"""Get the latest COB state for DQN model"""
|
"""Get the latest COB state for DQN model"""
|
||||||
return self.latest_cob_state.get(symbol)
|
return self.latest_cob_state.get(symbol)
|
||||||
|
|
||||||
def get_cob_snapshot(self, symbol: str) -> Optional[COBSnapshot]:
|
def get_cob_snapshot(self, symbol: str):
|
||||||
"""Get the latest raw COB snapshot for a symbol"""
|
"""Get the latest raw COB snapshot for a symbol"""
|
||||||
if self.cob_integration:
|
if self.cob_integration and hasattr(self.cob_integration, 'get_latest_cob_snapshot'):
|
||||||
return self.cob_integration.get_latest_cob_snapshot(symbol)
|
return self.cob_integration.get_latest_cob_snapshot(symbol)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -808,7 +817,7 @@ class TradingOrchestrator:
|
|||||||
'RL': self.config.orchestrator.get('rl_weight', 0.3)
|
'RL': self.config.orchestrator.get('rl_weight', 0.3)
|
||||||
}
|
}
|
||||||
|
|
||||||
def register_model(self, model: ModelInterface, weight: float = None) -> bool:
|
def register_model(self, model: ModelInterface, weight: Optional[float] = None) -> bool:
|
||||||
"""Register a new model with the orchestrator"""
|
"""Register a new model with the orchestrator"""
|
||||||
try:
|
try:
|
||||||
# Register with model registry
|
# Register with model registry
|
||||||
@ -872,8 +881,8 @@ class TradingOrchestrator:
|
|||||||
# Check if enough time has passed since last decision
|
# Check if enough time has passed since last decision
|
||||||
if symbol in self.last_decision_time:
|
if symbol in self.last_decision_time:
|
||||||
time_since_last = (current_time - self.last_decision_time[symbol]).total_seconds()
|
time_since_last = (current_time - self.last_decision_time[symbol]).total_seconds()
|
||||||
if time_since_last < self.decision_frequency:
|
# if time_since_last < self.decision_frequency:
|
||||||
return None
|
# return None
|
||||||
|
|
||||||
# Get current market data
|
# Get current market data
|
||||||
current_price = self.data_provider.get_current_price(symbol)
|
current_price = self.data_provider.get_current_price(symbol)
|
||||||
@ -963,7 +972,12 @@ class TradingOrchestrator:
|
|||||||
predictions = []
|
predictions = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for timeframe in self.config.timeframes:
|
# Safely get timeframes from config
|
||||||
|
timeframes = getattr(self.config, 'timeframes', None)
|
||||||
|
if timeframes is None:
|
||||||
|
timeframes = ['1m', '5m', '15m', '1h'] # Default timeframes
|
||||||
|
|
||||||
|
for timeframe in timeframes:
|
||||||
# Get standard feature matrix for this timeframe
|
# Get standard feature matrix for this timeframe
|
||||||
feature_matrix = self.data_provider.get_feature_matrix(
|
feature_matrix = self.data_provider.get_feature_matrix(
|
||||||
symbol=symbol,
|
symbol=symbol,
|
||||||
@ -1020,8 +1034,16 @@ class TradingOrchestrator:
|
|||||||
action_probs = [0.1, 0.1, 0.8] # Default distribution
|
action_probs = [0.1, 0.1, 0.8] # Default distribution
|
||||||
action_probs[action_idx] = confidence
|
action_probs[action_idx] = confidence
|
||||||
else:
|
else:
|
||||||
# Fallback to generic predict method
|
# Fallback to generic predict method
|
||||||
action_probs, confidence = model.predict(enhanced_features)
|
prediction_result = model.predict(enhanced_features)
|
||||||
|
if prediction_result is not None:
|
||||||
|
if isinstance(prediction_result, tuple) and len(prediction_result) == 2:
|
||||||
|
action_probs, confidence = prediction_result
|
||||||
|
else:
|
||||||
|
action_probs = prediction_result
|
||||||
|
confidence = 0.7
|
||||||
|
else:
|
||||||
|
action_probs, confidence = None, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"CNN prediction failed: {e}")
|
logger.warning(f"CNN prediction failed: {e}")
|
||||||
action_probs, confidence = None, None
|
action_probs, confidence = None, None
|
||||||
@ -1130,10 +1152,15 @@ class TradingOrchestrator:
|
|||||||
async def _get_generic_prediction(self, model: ModelInterface, symbol: str) -> Optional[Prediction]:
|
async def _get_generic_prediction(self, model: ModelInterface, symbol: str) -> Optional[Prediction]:
|
||||||
"""Get prediction from generic model"""
|
"""Get prediction from generic model"""
|
||||||
try:
|
try:
|
||||||
|
# Safely get timeframes from config
|
||||||
|
timeframes = getattr(self.config, 'timeframes', None)
|
||||||
|
if timeframes is None:
|
||||||
|
timeframes = ['1m', '5m', '15m'] # Default timeframes
|
||||||
|
|
||||||
# Get feature matrix for the model
|
# Get feature matrix for the model
|
||||||
feature_matrix = self.data_provider.get_feature_matrix(
|
feature_matrix = self.data_provider.get_feature_matrix(
|
||||||
symbol=symbol,
|
symbol=symbol,
|
||||||
timeframes=self.config.timeframes[:3], # Use first 3 timeframes
|
timeframes=timeframes[:3], # Use first 3 timeframes
|
||||||
window_size=20
|
window_size=20
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1182,10 +1209,15 @@ class TradingOrchestrator:
|
|||||||
def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]:
|
def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]:
|
||||||
"""Get current state for RL agent"""
|
"""Get current state for RL agent"""
|
||||||
try:
|
try:
|
||||||
|
# Safely get timeframes from config
|
||||||
|
timeframes = getattr(self.config, 'timeframes', None)
|
||||||
|
if timeframes is None:
|
||||||
|
timeframes = ['1m', '5m', '15m', '1h'] # Default timeframes
|
||||||
|
|
||||||
# Get feature matrix for all timeframes
|
# Get feature matrix for all timeframes
|
||||||
feature_matrix = self.data_provider.get_feature_matrix(
|
feature_matrix = self.data_provider.get_feature_matrix(
|
||||||
symbol=symbol,
|
symbol=symbol,
|
||||||
timeframes=self.config.timeframes,
|
timeframes=timeframes,
|
||||||
window_size=self.config.rl.get('window_size', 20)
|
window_size=self.config.rl.get('window_size', 20)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1241,9 +1273,13 @@ class TradingOrchestrator:
|
|||||||
for action in action_scores:
|
for action in action_scores:
|
||||||
action_scores[action] /= total_weight
|
action_scores[action] /= total_weight
|
||||||
|
|
||||||
# Choose best action
|
# Choose best action - safe way to handle max with key function
|
||||||
best_action = max(action_scores, key=action_scores.get)
|
if action_scores:
|
||||||
best_confidence = action_scores[best_action]
|
best_action = max(action_scores.keys(), key=lambda k: action_scores[k])
|
||||||
|
best_confidence = action_scores[best_action]
|
||||||
|
else:
|
||||||
|
best_action = 'HOLD'
|
||||||
|
best_confidence = 0.0
|
||||||
|
|
||||||
# Calculate aggressiveness-adjusted thresholds
|
# Calculate aggressiveness-adjusted thresholds
|
||||||
entry_threshold, exit_threshold = self._calculate_aggressiveness_thresholds(
|
entry_threshold, exit_threshold = self._calculate_aggressiveness_thresholds(
|
||||||
@ -1277,7 +1313,13 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
# Get memory usage stats
|
# Get memory usage stats
|
||||||
try:
|
try:
|
||||||
memory_usage = self.model_registry.get_memory_stats() if hasattr(self.model_registry, 'get_memory_stats') else {}
|
memory_usage = {}
|
||||||
|
if hasattr(self.model_registry, 'get_memory_stats'):
|
||||||
|
memory_usage = self.model_registry.get_memory_stats()
|
||||||
|
else:
|
||||||
|
# Fallback memory usage calculation
|
||||||
|
for model_name in self.model_weights:
|
||||||
|
memory_usage[model_name] = 50.0 # Default MB estimate
|
||||||
except Exception:
|
except Exception:
|
||||||
memory_usage = {}
|
memory_usage = {}
|
||||||
|
|
||||||
@ -1369,7 +1411,7 @@ class TradingOrchestrator:
|
|||||||
'weights': self.model_weights.copy(),
|
'weights': self.model_weights.copy(),
|
||||||
'configuration': {
|
'configuration': {
|
||||||
'confidence_threshold': self.confidence_threshold,
|
'confidence_threshold': self.confidence_threshold,
|
||||||
'decision_frequency': self.decision_frequency
|
# 'decision_frequency': self.decision_frequency
|
||||||
},
|
},
|
||||||
'recent_activity': {
|
'recent_activity': {
|
||||||
symbol: len(decisions) for symbol, decisions in self.recent_decisions.items()
|
symbol: len(decisions) for symbol, decisions in self.recent_decisions.items()
|
||||||
@ -1524,17 +1566,21 @@ class TradingOrchestrator:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Initialize the enhanced training system
|
# Initialize the enhanced training system
|
||||||
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
|
if EnhancedRealtimeTrainingSystem is not None:
|
||||||
orchestrator=self,
|
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
|
||||||
data_provider=self.data_provider,
|
orchestrator=self,
|
||||||
dashboard=None # Will be set by dashboard when available
|
data_provider=self.data_provider,
|
||||||
)
|
dashboard=None # Will be set by dashboard when available
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Enhanced real-time training system initialized")
|
logger.info("Enhanced real-time training system initialized")
|
||||||
logger.info(" - Real-time model training: ENABLED")
|
logger.info(" - Real-time model training: ENABLED")
|
||||||
logger.info(" - Comprehensive feature extraction: ENABLED")
|
logger.info(" - Comprehensive feature extraction: ENABLED")
|
||||||
logger.info(" - Enhanced reward calculation: ENABLED")
|
logger.info(" - Enhanced reward calculation: ENABLED")
|
||||||
logger.info(" - Forward-looking predictions: ENABLED")
|
logger.info(" - Forward-looking predictions: ENABLED")
|
||||||
|
else:
|
||||||
|
logger.warning("EnhancedRealtimeTrainingSystem class not available")
|
||||||
|
self.training_enabled = False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error initializing enhanced training system: {e}")
|
logger.error(f"Error initializing enhanced training system: {e}")
|
||||||
@ -1548,9 +1594,13 @@ class TradingOrchestrator:
|
|||||||
logger.warning("Enhanced training system not available")
|
logger.warning("Enhanced training system not available")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.enhanced_training_system.start_training()
|
if hasattr(self.enhanced_training_system, 'start_training'):
|
||||||
logger.info("Enhanced real-time training started")
|
self.enhanced_training_system.start_training()
|
||||||
return True
|
logger.info("Enhanced real-time training started")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning("Enhanced training system does not have start_training method")
|
||||||
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error starting enhanced training: {e}")
|
logger.error(f"Error starting enhanced training: {e}")
|
||||||
@ -1559,7 +1609,7 @@ class TradingOrchestrator:
|
|||||||
def stop_enhanced_training(self):
|
def stop_enhanced_training(self):
|
||||||
"""Stop the enhanced real-time training system"""
|
"""Stop the enhanced real-time training system"""
|
||||||
try:
|
try:
|
||||||
if self.enhanced_training_system:
|
if self.enhanced_training_system and hasattr(self.enhanced_training_system, 'stop_training'):
|
||||||
self.enhanced_training_system.stop_training()
|
self.enhanced_training_system.stop_training()
|
||||||
logger.info("Enhanced real-time training stopped")
|
logger.info("Enhanced real-time training stopped")
|
||||||
return True
|
return True
|
||||||
@ -1580,7 +1630,10 @@ class TradingOrchestrator:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Get base stats from enhanced training system
|
# Get base stats from enhanced training system
|
||||||
stats = self.enhanced_training_system.get_training_statistics()
|
stats = {}
|
||||||
|
if hasattr(self.enhanced_training_system, 'get_training_statistics'):
|
||||||
|
stats = self.enhanced_training_system.get_training_statistics()
|
||||||
|
|
||||||
stats['training_enabled'] = self.training_enabled
|
stats['training_enabled'] = self.training_enabled
|
||||||
stats['system_available'] = ENHANCED_TRAINING_AVAILABLE
|
stats['system_available'] = ENHANCED_TRAINING_AVAILABLE
|
||||||
|
|
||||||
@ -1627,7 +1680,7 @@ class TradingOrchestrator:
|
|||||||
model_stats['last_loss'] = model.losses[-1]
|
model_stats['last_loss'] = model.losses[-1]
|
||||||
|
|
||||||
stats['model_training_status'][model_name] = model_stats
|
stats['model_training_status'][model_name] = model_stats
|
||||||
else:
|
else:
|
||||||
stats['model_training_status'][model_name] = {
|
stats['model_training_status'][model_name] = {
|
||||||
'model_loaded': False,
|
'model_loaded': False,
|
||||||
'memory_usage': 0,
|
'memory_usage': 0,
|
||||||
@ -1675,7 +1728,7 @@ class TradingOrchestrator:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error setting training dashboard: {e}")
|
logger.error(f"Error setting training dashboard: {e}")
|
||||||
|
|
||||||
def get_universal_data_stream(self, current_time: datetime = None) -> Optional[UniversalDataStream]:
|
def get_universal_data_stream(self, current_time: Optional[datetime] = None):
|
||||||
"""Get universal data stream for external consumers like dashboard"""
|
"""Get universal data stream for external consumers like dashboard"""
|
||||||
try:
|
try:
|
||||||
return self.universal_adapter.get_universal_data_stream(current_time)
|
return self.universal_adapter.get_universal_data_stream(current_time)
|
||||||
|
@ -1458,6 +1458,14 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
features_tensor = torch.from_numpy(features).float()
|
features_tensor = torch.from_numpy(features).float()
|
||||||
targets_tensor = torch.from_numpy(targets).long()
|
targets_tensor = torch.from_numpy(targets).long()
|
||||||
|
|
||||||
|
# FIXED: Move tensors to same device as model
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
features_tensor = features_tensor.to(device)
|
||||||
|
targets_tensor = targets_tensor.to(device)
|
||||||
|
|
||||||
|
# Move criterion to same device as well
|
||||||
|
criterion = criterion.to(device)
|
||||||
|
|
||||||
# Ensure features_tensor has the correct shape for CNN (batch_size, channels, height, width)
|
# Ensure features_tensor has the correct shape for CNN (batch_size, channels, height, width)
|
||||||
# Assuming features are flattened (batch_size, 15*20) and need to be reshaped to (batch_size, 1, 15, 20)
|
# Assuming features are flattened (batch_size, 15*20) and need to be reshaped to (batch_size, 1, 15, 20)
|
||||||
# This depends on the actual CNN model architecture. Assuming a simple CNN that expects (batch, channels, height, width)
|
# This depends on the actual CNN model architecture. Assuming a simple CNN that expects (batch, channels, height, width)
|
||||||
@ -1474,7 +1482,18 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
# Let's assume the CNN expects 2D input (batch_size, flattened_features)
|
# Let's assume the CNN expects 2D input (batch_size, flattened_features)
|
||||||
outputs = model(features_tensor)
|
outputs = model(features_tensor)
|
||||||
|
|
||||||
loss = criterion(outputs, targets_tensor)
|
# FIXED: Handle case where model returns tuple (extract the logits)
|
||||||
|
if isinstance(outputs, tuple):
|
||||||
|
# Assume the first element is the main output (logits)
|
||||||
|
logits = outputs[0]
|
||||||
|
elif isinstance(outputs, dict):
|
||||||
|
# Handle dictionary output (get main prediction)
|
||||||
|
logits = outputs.get('logits', outputs.get('predictions', outputs.get('output', list(outputs.values())[0])))
|
||||||
|
else:
|
||||||
|
# Single tensor output
|
||||||
|
logits = outputs
|
||||||
|
|
||||||
|
loss = criterion(logits, targets_tensor)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@ -1482,9 +1501,123 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
return loss.item()
|
return loss.item()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in CNN training: {e}")
|
logger.error(f"RT TRAINING: Error in CNN training: {e}")
|
||||||
return 1.0 # Return default loss value in case of error
|
return 1.0 # Return default loss value in case of error
|
||||||
|
|
||||||
|
def _sample_prioritized_experiences(self) -> List[Dict]:
|
||||||
|
"""Sample prioritized experiences for training"""
|
||||||
|
try:
|
||||||
|
experiences = []
|
||||||
|
|
||||||
|
# Sample from priority buffer first (high-priority experiences)
|
||||||
|
if self.priority_buffer:
|
||||||
|
priority_samples = min(len(self.priority_buffer), self.training_config['batch_size'] // 2)
|
||||||
|
priority_experiences = random.sample(list(self.priority_buffer), priority_samples)
|
||||||
|
experiences.extend(priority_experiences)
|
||||||
|
|
||||||
|
# Sample from regular experience buffer
|
||||||
|
if self.experience_buffer:
|
||||||
|
remaining_samples = self.training_config['batch_size'] - len(experiences)
|
||||||
|
if remaining_samples > 0:
|
||||||
|
regular_samples = min(len(self.experience_buffer), remaining_samples)
|
||||||
|
regular_experiences = random.sample(list(self.experience_buffer), regular_samples)
|
||||||
|
experiences.extend(regular_experiences)
|
||||||
|
|
||||||
|
# Convert experiences to DQN format
|
||||||
|
dqn_experiences = []
|
||||||
|
for exp in experiences:
|
||||||
|
# Create next state by shifting current state (simple approximation)
|
||||||
|
next_state = exp['state'].copy() if hasattr(exp['state'], 'copy') else exp['state']
|
||||||
|
|
||||||
|
# Simple reward based on recent market movement
|
||||||
|
reward = self._calculate_experience_reward(exp)
|
||||||
|
|
||||||
|
# Action mapping: 0=BUY, 1=SELL, 2=HOLD
|
||||||
|
action = self._determine_action_from_experience(exp)
|
||||||
|
|
||||||
|
dqn_exp = {
|
||||||
|
'state': exp['state'],
|
||||||
|
'action': action,
|
||||||
|
'reward': reward,
|
||||||
|
'next_state': next_state,
|
||||||
|
'done': False # Episodes don't really "end" in continuous trading
|
||||||
|
}
|
||||||
|
|
||||||
|
dqn_experiences.append(dqn_exp)
|
||||||
|
|
||||||
|
return dqn_experiences
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sampling prioritized experiences: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _calculate_experience_reward(self, experience: Dict) -> float:
|
||||||
|
"""Calculate reward for an experience"""
|
||||||
|
try:
|
||||||
|
# Simple reward based on technical indicators and market events
|
||||||
|
reward = 0.0
|
||||||
|
|
||||||
|
# Reward based on market events
|
||||||
|
if experience.get('market_events', 0) > 0:
|
||||||
|
reward += 0.1 # Bonus for learning from market events
|
||||||
|
|
||||||
|
# Reward based on technical indicators
|
||||||
|
tech_indicators = experience.get('technical_indicators', {})
|
||||||
|
if tech_indicators:
|
||||||
|
# Reward for strong momentum
|
||||||
|
momentum = tech_indicators.get('price_momentum', 0)
|
||||||
|
reward += np.tanh(momentum * 10) # Bounded reward
|
||||||
|
|
||||||
|
# Penalize high volatility
|
||||||
|
volatility = tech_indicators.get('volatility', 0)
|
||||||
|
reward -= min(volatility * 5, 0.2) # Penalty for high volatility
|
||||||
|
|
||||||
|
# Reward based on COB features
|
||||||
|
cob_features = experience.get('cob_features', [])
|
||||||
|
if cob_features and len(cob_features) > 0:
|
||||||
|
# Reward for strong order book imbalance
|
||||||
|
imbalance = cob_features[0] if len(cob_features) > 0 else 0
|
||||||
|
reward += abs(imbalance) * 0.1 # Reward for any imbalance signal
|
||||||
|
|
||||||
|
return max(-1.0, min(1.0, reward)) # Clamp to [-1, 1]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error calculating experience reward: {e}")
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def _determine_action_from_experience(self, experience: Dict) -> int:
|
||||||
|
"""Determine action from experience data"""
|
||||||
|
try:
|
||||||
|
# Use technical indicators to determine action
|
||||||
|
tech_indicators = experience.get('technical_indicators', {})
|
||||||
|
|
||||||
|
if tech_indicators:
|
||||||
|
momentum = tech_indicators.get('price_momentum', 0)
|
||||||
|
rsi = tech_indicators.get('rsi', 50)
|
||||||
|
|
||||||
|
# Simple logic based on momentum and RSI
|
||||||
|
if momentum > 0.005 and rsi < 70: # Upward momentum, not overbought
|
||||||
|
return 0 # BUY
|
||||||
|
elif momentum < -0.005 and rsi > 30: # Downward momentum, not oversold
|
||||||
|
return 1 # SELL
|
||||||
|
else:
|
||||||
|
return 2 # HOLD
|
||||||
|
|
||||||
|
# Fallback to COB-based action
|
||||||
|
cob_features = experience.get('cob_features', [])
|
||||||
|
if cob_features and len(cob_features) > 0:
|
||||||
|
imbalance = cob_features[0]
|
||||||
|
if imbalance > 0.1:
|
||||||
|
return 0 # BUY (bid imbalance)
|
||||||
|
elif imbalance < -0.1:
|
||||||
|
return 1 # SELL (ask imbalance)
|
||||||
|
|
||||||
|
return 2 # Default to HOLD
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error determining action from experience: {e}")
|
||||||
|
return 2 # Default to HOLD
|
||||||
|
|
||||||
def _perform_validation(self):
|
def _perform_validation(self):
|
||||||
"""Perform validation to track model performance"""
|
"""Perform validation to track model performance"""
|
||||||
try:
|
try:
|
||||||
@ -1849,23 +1982,34 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
current_state = self._build_comprehensive_state()
|
current_state = self._build_comprehensive_state()
|
||||||
current_price = self._get_current_price_from_data(symbol)
|
current_price = self._get_current_price_from_data(symbol)
|
||||||
|
|
||||||
if current_price is None:
|
# SKIP prediction if price is invalid
|
||||||
|
if current_price is None or current_price <= 0:
|
||||||
|
logger.debug(f"Skipping DQN prediction for {symbol}: invalid price {current_price}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Use DQN model to predict action (if available)
|
# Use DQN model to predict action (if available)
|
||||||
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
|
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent')
|
||||||
and self.orchestrator.rl_agent):
|
and self.orchestrator.rl_agent):
|
||||||
|
|
||||||
# Get Q-values from model
|
# Get action from DQN agent
|
||||||
q_values = self.orchestrator.rl_agent.act(current_state, return_q_values=True)
|
action = self.orchestrator.rl_agent.act(current_state, explore=False)
|
||||||
if isinstance(q_values, tuple):
|
|
||||||
action, q_vals = q_values
|
# Get Q-values by manually calling the model
|
||||||
q_values = q_vals.tolist() if hasattr(q_vals, 'tolist') else [0, 0, 0]
|
q_values = self._get_dqn_q_values(current_state)
|
||||||
|
|
||||||
|
# Calculate confidence from Q-values
|
||||||
|
if q_values is not None and len(q_values) > 0:
|
||||||
|
# Convert to probabilities and get confidence
|
||||||
|
probs = torch.softmax(torch.tensor(q_values), dim=0).numpy()
|
||||||
|
confidence = float(max(probs))
|
||||||
|
q_values = q_values.tolist() if hasattr(q_values, 'tolist') else list(q_values)
|
||||||
else:
|
else:
|
||||||
action = q_values
|
confidence = 0.33
|
||||||
q_values = [0.33, 0.33, 0.34] # Default uniform distribution
|
q_values = [0.33, 0.33, 0.34] # Default uniform distribution
|
||||||
|
|
||||||
confidence = max(q_values) / sum(q_values) if sum(q_values) > 0 else 0.33
|
# Handle case where action is None (HOLD)
|
||||||
|
if action is None:
|
||||||
|
action = 2 # Map None to HOLD action
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Fallback to technical analysis-based prediction
|
# Fallback to technical analysis-based prediction
|
||||||
@ -1893,8 +2037,8 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
if symbol in self.pending_predictions:
|
if symbol in self.pending_predictions:
|
||||||
self.pending_predictions[symbol].append(prediction)
|
self.pending_predictions[symbol].append(prediction)
|
||||||
|
|
||||||
# Add to recent predictions for display (only if confident enough)
|
# Add to recent predictions for display (only if confident enough AND valid price)
|
||||||
if confidence > 0.4:
|
if confidence > 0.4 and current_price > 0:
|
||||||
display_prediction = {
|
display_prediction = {
|
||||||
'timestamp': prediction_time,
|
'timestamp': prediction_time,
|
||||||
'price': current_price,
|
'price': current_price,
|
||||||
@ -1907,11 +2051,44 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
|
|
||||||
self.last_prediction_time[symbol] = int(current_time)
|
self.last_prediction_time[symbol] = int(current_time)
|
||||||
|
|
||||||
logger.info(f"Forward DQN prediction: {symbol} action={['BUY','SELL','HOLD'][action]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
|
logger.info(f"Forward DQN prediction: {symbol} action={['BUY','SELL','HOLD'][action]} confidence={confidence:.2f} price=${current_price:.2f} target={target_time.strftime('%H:%M:%S')}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating forward DQN prediction: {e}")
|
logger.error(f"Error generating forward DQN prediction: {e}")
|
||||||
|
|
||||||
|
def _get_dqn_q_values(self, state: np.ndarray) -> Optional[np.ndarray]:
|
||||||
|
"""Get Q-values from DQN agent without performing action selection"""
|
||||||
|
try:
|
||||||
|
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
||||||
|
return None
|
||||||
|
|
||||||
|
rl_agent = self.orchestrator.rl_agent
|
||||||
|
|
||||||
|
# Convert state to tensor
|
||||||
|
if isinstance(state, np.ndarray):
|
||||||
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(rl_agent.device)
|
||||||
|
else:
|
||||||
|
state_tensor = state.unsqueeze(0).to(rl_agent.device)
|
||||||
|
|
||||||
|
# Get Q-values directly from policy network
|
||||||
|
with torch.no_grad():
|
||||||
|
policy_output = rl_agent.policy_net(state_tensor)
|
||||||
|
|
||||||
|
# Handle different output formats
|
||||||
|
if isinstance(policy_output, dict):
|
||||||
|
q_values = policy_output.get('q_values', policy_output.get('Q_values', list(policy_output.values())[0]))
|
||||||
|
elif isinstance(policy_output, tuple):
|
||||||
|
q_values = policy_output[0] # Assume first element is Q-values
|
||||||
|
else:
|
||||||
|
q_values = policy_output
|
||||||
|
|
||||||
|
# Convert to numpy
|
||||||
|
return q_values.cpu().data.numpy()[0]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error getting DQN Q-values: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def _generate_forward_cnn_prediction(self, symbol: str, current_time: float):
|
def _generate_forward_cnn_prediction(self, symbol: str, current_time: float):
|
||||||
"""Generate a CNN prediction for future price direction"""
|
"""Generate a CNN prediction for future price direction"""
|
||||||
try:
|
try:
|
||||||
@ -1919,7 +2096,13 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
current_price = self._get_current_price_from_data(symbol)
|
current_price = self._get_current_price_from_data(symbol)
|
||||||
price_sequence = self._get_historical_price_sequence(symbol, periods=15)
|
price_sequence = self._get_historical_price_sequence(symbol, periods=15)
|
||||||
|
|
||||||
if current_price is None or len(price_sequence) < 15:
|
# SKIP prediction if price is invalid
|
||||||
|
if current_price is None or current_price <= 0:
|
||||||
|
logger.debug(f"Skipping CNN prediction for {symbol}: invalid price {current_price}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(price_sequence) < 15:
|
||||||
|
logger.debug(f"Skipping CNN prediction for {symbol}: insufficient data")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Use CNN model to predict direction (if available)
|
# Use CNN model to predict direction (if available)
|
||||||
@ -1974,8 +2157,8 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
if symbol in self.pending_predictions:
|
if symbol in self.pending_predictions:
|
||||||
self.pending_predictions[symbol].append(prediction)
|
self.pending_predictions[symbol].append(prediction)
|
||||||
|
|
||||||
# Add to recent predictions for display (only if confident enough)
|
# Add to recent predictions for display (only if confident enough AND valid prices)
|
||||||
if confidence > 0.5:
|
if confidence > 0.5 and current_price > 0 and predicted_price > 0:
|
||||||
display_prediction = {
|
display_prediction = {
|
||||||
'timestamp': prediction_time,
|
'timestamp': prediction_time,
|
||||||
'current_price': current_price,
|
'current_price': current_price,
|
||||||
@ -1986,7 +2169,7 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
if symbol in self.recent_cnn_predictions:
|
if symbol in self.recent_cnn_predictions:
|
||||||
self.recent_cnn_predictions[symbol].append(display_prediction)
|
self.recent_cnn_predictions[symbol].append(display_prediction)
|
||||||
|
|
||||||
logger.info(f"Forward CNN prediction: {symbol} direction={['DOWN','SAME','UP'][direction]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
|
logger.info(f"Forward CNN prediction: {symbol} direction={['DOWN','SAME','UP'][direction]} confidence={confidence:.2f} price=${current_price:.2f} -> ${predicted_price:.2f} target={target_time.strftime('%H:%M:%S')}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating forward CNN prediction: {e}")
|
logger.error(f"Error generating forward CNN prediction: {e}")
|
||||||
@ -2077,8 +2260,24 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
def _get_current_price_from_data(self, symbol: str) -> Optional[float]:
|
def _get_current_price_from_data(self, symbol: str) -> Optional[float]:
|
||||||
"""Get current price from real-time data streams"""
|
"""Get current price from real-time data streams"""
|
||||||
try:
|
try:
|
||||||
|
# First, try to get from data provider (most reliable)
|
||||||
|
if self.data_provider:
|
||||||
|
price = self.data_provider.get_current_price(symbol)
|
||||||
|
if price and price > 0:
|
||||||
|
return price
|
||||||
|
|
||||||
|
# Fallback to internal buffer
|
||||||
if len(self.real_time_data['ohlcv_1m']) > 0:
|
if len(self.real_time_data['ohlcv_1m']) > 0:
|
||||||
return self.real_time_data['ohlcv_1m'][-1]['close']
|
price = self.real_time_data['ohlcv_1m'][-1]['close']
|
||||||
|
if price and price > 0:
|
||||||
|
return price
|
||||||
|
|
||||||
|
# Fallback to orchestrator price
|
||||||
|
if self.orchestrator:
|
||||||
|
price = self.orchestrator._get_current_price(symbol)
|
||||||
|
if price and price > 0:
|
||||||
|
return price
|
||||||
|
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error getting current price: {e}")
|
logger.debug(f"Error getting current price: {e}")
|
||||||
|
@ -286,11 +286,11 @@ class DashboardComponentManager:
|
|||||||
if hasattr(cob_snapshot, 'stats'):
|
if hasattr(cob_snapshot, 'stats'):
|
||||||
# Old format with stats attribute
|
# Old format with stats attribute
|
||||||
stats = cob_snapshot.stats
|
stats = cob_snapshot.stats
|
||||||
mid_price = stats.get('mid_price', 0)
|
mid_price = stats.get('mid_price', 0)
|
||||||
spread_bps = stats.get('spread_bps', 0)
|
spread_bps = stats.get('spread_bps', 0)
|
||||||
imbalance = stats.get('imbalance', 0)
|
imbalance = stats.get('imbalance', 0)
|
||||||
bids = getattr(cob_snapshot, 'consolidated_bids', [])
|
bids = getattr(cob_snapshot, 'consolidated_bids', [])
|
||||||
asks = getattr(cob_snapshot, 'consolidated_asks', [])
|
asks = getattr(cob_snapshot, 'consolidated_asks', [])
|
||||||
else:
|
else:
|
||||||
# New COBSnapshot format with direct attributes
|
# New COBSnapshot format with direct attributes
|
||||||
mid_price = getattr(cob_snapshot, 'volume_weighted_mid', 0)
|
mid_price = getattr(cob_snapshot, 'volume_weighted_mid', 0)
|
||||||
@ -421,10 +421,10 @@ class DashboardComponentManager:
|
|||||||
volume_usd = order.total_volume_usd
|
volume_usd = order.total_volume_usd
|
||||||
else:
|
else:
|
||||||
# Dictionary format (legacy)
|
# Dictionary format (legacy)
|
||||||
price = order.get('price', 0)
|
price = order.get('price', 0)
|
||||||
# Handle both old format (size) and new format (total_size)
|
# Handle both old format (size) and new format (total_size)
|
||||||
size = order.get('total_size', order.get('size', 0))
|
size = order.get('total_size', order.get('size', 0))
|
||||||
volume_usd = order.get('total_volume_usd', size * price)
|
volume_usd = order.get('total_volume_usd', size * price)
|
||||||
|
|
||||||
if price > 0:
|
if price > 0:
|
||||||
bucket_key = round(price / bucket_size) * bucket_size
|
bucket_key = round(price / bucket_size) * bucket_size
|
||||||
|
Reference in New Issue
Block a user