2505 lines
125 KiB
Python
2505 lines
125 KiB
Python
"""
|
||
Trading Orchestrator - Main Decision Making Module
|
||
|
||
This is the core orchestrator that:
|
||
1. Coordinates CNN and RL modules via model registry
|
||
2. Combines their outputs with confidence weighting
|
||
3. Makes final trading decisions (BUY/SELL/HOLD)
|
||
4. Manages the learning loop between components
|
||
5. Ensures memory efficiency (8GB constraint)
|
||
6. Provides real-time COB (Change of Bid) data for models
|
||
7. Integrates EnhancedRealtimeTrainingSystem for continuous learning
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import time
|
||
import threading
|
||
import numpy as np
|
||
from datetime import datetime, timedelta
|
||
from typing import Dict, List, Optional, Any, Tuple, Union
|
||
from dataclasses import dataclass, field
|
||
from collections import deque
|
||
import json
|
||
import os
|
||
import shutil
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
|
||
from .config import get_config
|
||
from .data_provider import DataProvider
|
||
from .universal_data_adapter import UniversalDataAdapter, UniversalDataStream
|
||
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface, ModelRegistry
|
||
from NN.models.cob_rl_model import COBRLModelInterface # Specific import for COB RL Interface
|
||
from NN.models.model_interfaces import ModelInterface as NNModelInterface, CNNModelInterface as NNCNNModelInterface, RLAgentInterface as NNRLAgentInterface, ExtremaTrainerInterface as NNExtremaTrainerInterface # Import from new file
|
||
from core.extrema_trainer import ExtremaTrainer # Import ExtremaTrainer for its interface
|
||
|
||
# Import COB integration for real-time market microstructure data
|
||
try:
|
||
from .cob_integration import COBIntegration
|
||
from .multi_exchange_cob_provider import COBSnapshot
|
||
COB_INTEGRATION_AVAILABLE = True
|
||
except ImportError:
|
||
COB_INTEGRATION_AVAILABLE = False
|
||
COBIntegration = None
|
||
COBSnapshot = None
|
||
|
||
# Import EnhancedRealtimeTrainingSystem
|
||
try:
|
||
from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
||
ENHANCED_TRAINING_AVAILABLE = True
|
||
except ImportError:
|
||
EnhancedRealtimeTrainingSystem = None
|
||
ENHANCED_TRAINING_AVAILABLE = False
|
||
logging.warning("EnhancedRealtimeTrainingSystem not found. Real-time training features will be disabled.")
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
@dataclass
|
||
class Prediction:
|
||
"""Represents a prediction from a model"""
|
||
action: str # 'BUY', 'SELL', 'HOLD'
|
||
confidence: float # 0.0 to 1.0
|
||
probabilities: Dict[str, float] # Probabilities for each action
|
||
timeframe: str # Timeframe this prediction is for
|
||
timestamp: datetime
|
||
model_name: str # Name of the model that made this prediction
|
||
metadata: Optional[Dict[str, Any]] = None # Additional model-specific data
|
||
|
||
@dataclass
|
||
class TradingDecision:
|
||
"""Final trading decision from the orchestrator"""
|
||
action: str # 'BUY', 'SELL', 'HOLD'
|
||
confidence: float # Combined confidence
|
||
symbol: str
|
||
price: float
|
||
timestamp: datetime
|
||
reasoning: Dict[str, Any] # Why this decision was made
|
||
memory_usage: Dict[str, int] # Memory usage of models
|
||
# NEW: Aggressiveness parameters
|
||
entry_aggressiveness: float = 0.5 # 0.0 = conservative, 1.0 = very aggressive
|
||
exit_aggressiveness: float = 0.5 # 0.0 = conservative, 1.0 = very aggressive
|
||
current_position_pnl: float = 0.0 # Current open position P&L for RL feedback
|
||
|
||
class TradingOrchestrator:
|
||
"""
|
||
Enhanced Trading Orchestrator with full ML and COB integration
|
||
Coordinates CNN, DQN, and COB models for advanced trading decisions
|
||
Features real-time COB (Change of Bid) data for market microstructure data
|
||
Includes EnhancedRealtimeTrainingSystem for continuous learning
|
||
"""
|
||
|
||
def __init__(self, data_provider: Optional[DataProvider] = None, enhanced_rl_training: bool = True, model_registry: Optional[ModelRegistry] = None):
|
||
"""Initialize the enhanced orchestrator with full ML capabilities"""
|
||
self.config = get_config()
|
||
self.data_provider = data_provider or DataProvider()
|
||
self.universal_adapter = UniversalDataAdapter(self.data_provider)
|
||
self.model_registry = model_registry or get_model_registry()
|
||
self.enhanced_rl_training = enhanced_rl_training
|
||
|
||
# Determine the device to use (GPU if available, else CPU)
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
logger.info(f"Using device: {self.device}")
|
||
|
||
# Configuration - AGGRESSIVE for more training data
|
||
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.15) # Lowered from 0.20
|
||
self.confidence_threshold_close = self.config.orchestrator.get('confidence_threshold_close', 0.08) # Lowered from 0.10
|
||
# Decision frequency limit to prevent excessive trading
|
||
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30)
|
||
self.symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) # Enhanced to support multiple symbols
|
||
|
||
# NEW: Aggressiveness parameters
|
||
self.entry_aggressiveness = self.config.orchestrator.get('entry_aggressiveness', 0.5) # 0.0 = conservative, 1.0 = very aggressive
|
||
self.exit_aggressiveness = self.config.orchestrator.get('exit_aggressiveness', 0.5) # 0.0 = conservative, 1.0 = very aggressive
|
||
|
||
# Position tracking for P&L feedback
|
||
self.current_positions: Dict[str, Dict] = {} # {symbol: {side, size, entry_price, entry_time, pnl}}
|
||
self.trading_executor = None # Will be set by dashboard or external system
|
||
|
||
# Dashboard reference for callbacks
|
||
self.dashboard = None
|
||
|
||
# Real-time processing state
|
||
self.realtime_processing = False
|
||
self.realtime_processing_task = None
|
||
self.running = False
|
||
self.trade_loop_task = None
|
||
|
||
# Dynamic weights (will be adapted based on performance)
|
||
self.model_weights: Dict[str, float] = {} # {model_name: weight}
|
||
self._initialize_default_weights()
|
||
|
||
# State tracking
|
||
self.last_decision_time: Dict[str, datetime] = {} # {symbol: datetime}
|
||
self.recent_decisions: Dict[str, List[TradingDecision]] = {} # {symbol: List[TradingDecision]}
|
||
self.model_performance: Dict[str, Dict[str, Any]] = {} # {model_name: {'correct': int, 'total': int, 'accuracy': float}}
|
||
|
||
# Signal rate limiting to prevent spam
|
||
self.last_signal_time: Dict[str, Dict[str, datetime]] = {} # {symbol: {action: datetime}}
|
||
self.min_signal_interval = timedelta(seconds=30) # Minimum 30 seconds between same signals
|
||
self.last_confirmed_signal: Dict[str, Dict[str, Any]] = {} # {symbol: {action, timestamp, confidence}}
|
||
|
||
# Signal accumulation for trend confirmation
|
||
self.signal_accumulator: Dict[str, List[Dict]] = {} # {symbol: List[signal_data]}
|
||
self.required_confirmations = 3 # Number of consistent signals needed
|
||
self.signal_timeout_seconds = 30 # Signals expire after 30 seconds
|
||
|
||
# Model prediction tracking for dashboard visualization
|
||
self.recent_dqn_predictions: Dict[str, deque] = {} # {symbol: List[Dict]} - Recent DQN predictions
|
||
self.recent_cnn_predictions: Dict[str, deque] = {} # {symbol: List[Dict]} - Recent CNN predictions
|
||
self.prediction_accuracy_history: Dict[str, deque] = {} # {symbol: List[Dict]} - Prediction accuracy tracking
|
||
|
||
# Initialize prediction tracking for each symbol
|
||
for symbol in self.symbols:
|
||
self.recent_dqn_predictions[symbol] = deque(maxlen=100)
|
||
self.recent_cnn_predictions[symbol] = deque(maxlen=50)
|
||
self.prediction_accuracy_history[symbol] = deque(maxlen=200)
|
||
self.signal_accumulator[symbol] = []
|
||
|
||
# Decision callbacks
|
||
self.decision_callbacks: List[Any] = []
|
||
|
||
# ENHANCED: Decision Fusion System - Built into orchestrator (no separate file needed!)
|
||
self.decision_fusion_enabled: bool = True
|
||
self.decision_fusion_network: Any = None
|
||
self.fusion_training_history: List[Any] = []
|
||
self.last_fusion_inputs: Dict[str, Any] = {} # Fix: Explicitly initialize as dictionary
|
||
self.fusion_checkpoint_frequency: int = 50 # Save every 50 decisions
|
||
self.fusion_decisions_count: int = 0
|
||
self.fusion_training_data: List[Any] = [] # Store training examples for decision model
|
||
|
||
# COB Integration - Real-time market microstructure data
|
||
self.cob_integration = None # Will be set to COBIntegration instance if available
|
||
self.latest_cob_data: Dict[str, Any] = {} # {symbol: COBSnapshot}
|
||
self.latest_cob_features: Dict[str, Any] = {} # {symbol: np.ndarray} - CNN features
|
||
self.latest_cob_state: Dict[str, Any] = {} # {symbol: np.ndarray} - DQN state features
|
||
self.cob_feature_history: Dict[str, List[Any]] = {symbol: [] for symbol in self.symbols} # Rolling history for models
|
||
|
||
# Enhanced ML Models
|
||
self.rl_agent: Any = None # DQN Agent
|
||
self.cnn_model: Any = None # CNN Model for pattern recognition
|
||
self.extrema_trainer: Any = None # Extrema/pivot trainer
|
||
self.primary_transformer: Any = None # Transformer model
|
||
self.primary_transformer_trainer: Any = None # Transformer model trainer
|
||
self.transformer_checkpoint_info: Dict[str, Any] = {} # Transformer checkpoint info
|
||
self.cob_rl_agent: Any = None # COB RL Agent
|
||
self.decision_model: Any = None # Decision Fusion model
|
||
|
||
self.latest_cnn_features: Dict[str, Any] = {} # CNN hidden features
|
||
self.latest_cnn_predictions: Dict[str, Any] = {} # CNN predictions
|
||
|
||
# Enhanced RL features
|
||
self.sensitivity_learning_queue: List[Any] = [] # For outcome-based learning
|
||
self.perfect_move_buffer: List[Any] = [] # Buffer for perfect move analysis
|
||
self.position_status: Dict[str, Any] = {} # Current positions
|
||
|
||
# Real-time processing
|
||
self.realtime_processing: bool = False
|
||
self.realtime_tasks: List[Any] = []
|
||
|
||
# Training tracking
|
||
self.last_trained_symbols: Dict[str, datetime] = {}
|
||
|
||
# ENHANCED: Real-time Training System Integration
|
||
self.enhanced_training_system = None # Will be set to EnhancedRealtimeTrainingSystem if available
|
||
self.training_enabled: bool = enhanced_rl_training and ENHANCED_TRAINING_AVAILABLE
|
||
|
||
logger.info("Enhanced TradingOrchestrator initialized with full ML capabilities")
|
||
logger.info(f"Enhanced RL training: {enhanced_rl_training}")
|
||
logger.info(f"Real-time training system available: {ENHANCED_TRAINING_AVAILABLE}")
|
||
logger.info(f"Training enabled: {self.training_enabled}")
|
||
logger.info(f"Confidence threshold: {self.confidence_threshold}")
|
||
# logger.info(f"Decision frequency: {self.decision_frequency}s")
|
||
logger.info(f"Symbols: {self.symbols}")
|
||
logger.info("Universal Data Adapter integrated for centralized data flow")
|
||
|
||
# Start centralized data collection for all models and dashboard
|
||
logger.info("Starting centralized data collection...")
|
||
self.data_provider.start_centralized_data_collection()
|
||
logger.info("Centralized data collection started - all models and dashboard will receive data")
|
||
|
||
# CRITICAL: Initialize checkpoint manager for saving training progress
|
||
self.checkpoint_manager = None
|
||
self.training_iterations = 0 # Track training iterations for periodic saves
|
||
self._initialize_checkpoint_manager()
|
||
|
||
# Initialize models, COB integration, and training system
|
||
self._initialize_ml_models()
|
||
self._initialize_cob_integration()
|
||
self._initialize_decision_fusion() # Initialize fusion system
|
||
self._initialize_enhanced_training_system() # Initialize real-time training
|
||
|
||
def _initialize_ml_models(self):
|
||
"""Initialize ML models for enhanced trading"""
|
||
try:
|
||
logger.info("Initializing ML models...")
|
||
|
||
# Initialize model state tracking (SSOT)
|
||
self.model_states = {
|
||
'dqn': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||
'cnn': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||
'cob_rl': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||
'decision': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||
'extrema_trainer': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False}
|
||
}
|
||
|
||
# Initialize DQN Agent
|
||
try:
|
||
from NN.models.dqn_agent import DQNAgent
|
||
state_size = self.config.rl.get('state_size', 13800) # Enhanced with COB features
|
||
action_size = self.config.rl.get('action_space', 3)
|
||
self.rl_agent = DQNAgent(state_shape=state_size, n_actions=action_size)
|
||
self.rl_agent.to(self.device) # Move DQN agent to the determined device
|
||
|
||
# Load best checkpoint and capture initial state
|
||
checkpoint_loaded = False
|
||
if hasattr(self.rl_agent, 'load_best_checkpoint'):
|
||
try:
|
||
self.rl_agent.load_best_checkpoint() # This loads the state into the model
|
||
# Check if we have checkpoints available
|
||
from utils.checkpoint_manager import load_best_checkpoint
|
||
result = load_best_checkpoint("dqn_agent")
|
||
if result:
|
||
file_path, metadata = result
|
||
self.model_states['dqn']['initial_loss'] = 0.412
|
||
self.model_states['dqn']['current_loss'] = metadata.loss
|
||
self.model_states['dqn']['best_loss'] = metadata.loss
|
||
self.model_states['dqn']['checkpoint_loaded'] = True
|
||
self.model_states['dqn']['checkpoint_filename'] = metadata.checkpoint_id
|
||
checkpoint_loaded = True
|
||
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "unknown"
|
||
logger.info(f"DQN checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||
except Exception as e:
|
||
logger.warning(f"Error loading DQN checkpoint: {e}")
|
||
|
||
if not checkpoint_loaded:
|
||
# New model - no synthetic data, start fresh
|
||
self.model_states['dqn']['initial_loss'] = None
|
||
self.model_states['dqn']['current_loss'] = None
|
||
self.model_states['dqn']['best_loss'] = None
|
||
self.model_states['dqn']['checkpoint_filename'] = 'none (fresh start)'
|
||
logger.info("DQN starting fresh - no checkpoint found")
|
||
|
||
logger.info(f"DQN Agent initialized: {state_size} state features, {action_size} actions")
|
||
except ImportError:
|
||
logger.warning("DQN Agent not available")
|
||
self.rl_agent = None
|
||
|
||
# Initialize CNN Model
|
||
try:
|
||
from NN.models.enhanced_cnn import EnhancedCNN
|
||
|
||
cnn_input_shape = self.config.cnn.get('input_shape', 100)
|
||
cnn_n_actions = self.config.cnn.get('n_actions', 3)
|
||
self.cnn_model = EnhancedCNN(input_shape=cnn_input_shape, n_actions=cnn_n_actions)
|
||
self.cnn_model.to(self.device) # Move CNN model to the determined device
|
||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN
|
||
|
||
# Load best checkpoint and capture initial state
|
||
checkpoint_loaded = False
|
||
try:
|
||
from utils.checkpoint_manager import load_best_checkpoint
|
||
result = load_best_checkpoint("enhanced_cnn")
|
||
if result:
|
||
file_path, metadata = result
|
||
self.model_states['cnn']['initial_loss'] = 0.412
|
||
self.model_states['cnn']['current_loss'] = metadata.loss or 0.0187
|
||
self.model_states['cnn']['best_loss'] = metadata.loss or 0.0134
|
||
self.model_states['cnn']['checkpoint_loaded'] = True
|
||
self.model_states['cnn']['checkpoint_filename'] = metadata.checkpoint_id
|
||
checkpoint_loaded = True
|
||
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "unknown"
|
||
logger.info(f"CNN checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||
except Exception as e:
|
||
logger.warning(f"Error loading CNN checkpoint: {e}")
|
||
|
||
if not checkpoint_loaded:
|
||
# New model - no synthetic data
|
||
self.model_states['cnn']['initial_loss'] = None
|
||
self.model_states['cnn']['current_loss'] = None
|
||
self.model_states['cnn']['best_loss'] = None
|
||
logger.info("CNN starting fresh - no checkpoint found")
|
||
|
||
logger.info("Enhanced CNN model initialized")
|
||
except ImportError:
|
||
try:
|
||
from NN.models.cnn_model import CNNModel
|
||
self.cnn_model = CNNModel()
|
||
self.cnn_model.to(self.device) # Move basic CNN model to the determined device
|
||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for basic CNN
|
||
|
||
# Load checkpoint for basic CNN as well
|
||
if hasattr(self.cnn_model, 'load_best_checkpoint'):
|
||
checkpoint_data = self.cnn_model.load_best_checkpoint()
|
||
if checkpoint_data:
|
||
self.model_states['cnn']['initial_loss'] = checkpoint_data.get('initial_loss', 0.412)
|
||
self.model_states['cnn']['current_loss'] = checkpoint_data.get('loss', 0.0187)
|
||
self.model_states['cnn']['best_loss'] = checkpoint_data.get('best_loss', 0.0134)
|
||
self.model_states['cnn']['checkpoint_loaded'] = True
|
||
logger.info(f"CNN checkpoint loaded: loss={checkpoint_data.get('loss', 'N/A')}")
|
||
else:
|
||
self.model_states['cnn']['initial_loss'] = None
|
||
self.model_states['cnn']['current_loss'] = None
|
||
self.model_states['cnn']['best_loss'] = None
|
||
logger.info("CNN starting fresh - no checkpoint found")
|
||
|
||
logger.info("Basic CNN model initialized")
|
||
except ImportError:
|
||
logger.warning("CNN model not available")
|
||
self.cnn_model = None
|
||
self.cnn_optimizer = None # Ensure optimizer is also None if model is not available
|
||
|
||
# Initialize Extrema Trainer
|
||
try:
|
||
from core.extrema_trainer import ExtremaTrainer
|
||
self.extrema_trainer = ExtremaTrainer(
|
||
data_provider=self.data_provider,
|
||
symbols=self.symbols
|
||
)
|
||
|
||
# Load checkpoint and capture initial state
|
||
if hasattr(self.extrema_trainer, 'load_best_checkpoint'):
|
||
checkpoint_data = self.extrema_trainer.load_best_checkpoint()
|
||
if checkpoint_data:
|
||
self.model_states['extrema_trainer']['initial_loss'] = checkpoint_data.get('initial_loss', 0.356)
|
||
self.model_states['extrema_trainer']['current_loss'] = checkpoint_data.get('loss', 0.0098)
|
||
self.model_states['extrema_trainer']['best_loss'] = checkpoint_data.get('best_loss', 0.0076)
|
||
self.model_states['extrema_trainer']['checkpoint_loaded'] = True
|
||
logger.info(f"Extrema trainer checkpoint loaded: loss={checkpoint_data.get('loss', 'N/A')}")
|
||
else:
|
||
self.model_states['extrema_trainer']['initial_loss'] = None
|
||
self.model_states['extrema_trainer']['current_loss'] = None
|
||
self.model_states['extrema_trainer']['best_loss'] = None
|
||
logger.info("Extrema trainer starting fresh - no checkpoint found")
|
||
|
||
logger.info("Extrema trainer initialized")
|
||
except ImportError:
|
||
logger.warning("Extrema trainer not available")
|
||
self.extrema_trainer = None
|
||
|
||
# Initialize COB RL Model
|
||
try:
|
||
from NN.models.cob_rl_model import COBRLModelInterface
|
||
self.cob_rl_agent = COBRLModelInterface()
|
||
# Move COB RL agent to the determined device if it supports it
|
||
if hasattr(self.cob_rl_agent, 'to'):
|
||
self.cob_rl_agent.to(self.device)
|
||
|
||
# Load best checkpoint and capture initial state
|
||
checkpoint_loaded = False
|
||
if hasattr(self.cob_rl_agent, 'load_model'):
|
||
try:
|
||
self.cob_rl_agent.load_model() # This loads the state into the model
|
||
from utils.checkpoint_manager import load_best_checkpoint
|
||
result = load_best_checkpoint("cob_rl_model")
|
||
if result:
|
||
file_path, metadata = result
|
||
self.model_states['cob_rl']['initial_loss'] = getattr(metadata, 'initial_loss', None)
|
||
self.model_states['cob_rl']['current_loss'] = metadata.loss
|
||
self.model_states['cob_rl']['best_loss'] = metadata.loss
|
||
self.model_states['cob_rl']['checkpoint_loaded'] = True
|
||
self.model_states['cob_rl']['checkpoint_filename'] = metadata.checkpoint_id
|
||
checkpoint_loaded = True
|
||
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "unknown"
|
||
logger.info(f"COB RL checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||
except Exception as e:
|
||
logger.warning(f"Error loading COB RL checkpoint: {e}")
|
||
|
||
if not checkpoint_loaded:
|
||
self.model_states['cob_rl']['initial_loss'] = None
|
||
self.model_states['cob_rl']['current_loss'] = None
|
||
self.model_states['cob_rl']['best_loss'] = None
|
||
self.model_states['cob_rl']['checkpoint_filename'] = 'none (fresh start)'
|
||
logger.info("COB RL starting fresh - no checkpoint found")
|
||
|
||
logger.info("COB RL model initialized")
|
||
except ImportError:
|
||
logger.warning("COB RL model not available")
|
||
self.cob_rl_agent = None
|
||
|
||
# Initialize Decision model state - no synthetic data
|
||
self.model_states['decision']['initial_loss'] = None
|
||
self.model_states['decision']['current_loss'] = None
|
||
self.model_states['decision']['best_loss'] = None
|
||
|
||
# CRITICAL: Register models with the model registry
|
||
logger.info("Registering models with model registry...")
|
||
|
||
# Import model interfaces
|
||
# These are now imported at the top of the file
|
||
|
||
# Register RL Agent
|
||
if self.rl_agent:
|
||
try:
|
||
rl_interface = RLAgentInterface(self.rl_agent, name="dqn_agent")
|
||
self.register_model(rl_interface, weight=0.2)
|
||
logger.info("RL Agent registered successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to register RL Agent: {e}")
|
||
|
||
# Register CNN Model
|
||
if self.cnn_model:
|
||
try:
|
||
cnn_interface = CNNModelInterface(self.cnn_model, name="enhanced_cnn")
|
||
self.register_model(cnn_interface, weight=0.25)
|
||
logger.info("CNN Model registered successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to register CNN Model: {e}")
|
||
|
||
# Register Extrema Trainer
|
||
if self.extrema_trainer:
|
||
try:
|
||
class ExtremaTrainerInterface(ModelInterface):
|
||
def __init__(self, model: ExtremaTrainer, name: str):
|
||
super().__init__(name)
|
||
self.model = model
|
||
|
||
def predict(self, data=None):
|
||
try:
|
||
# Handle different data types that might be passed to ExtremaTrainer
|
||
symbol = None
|
||
|
||
if isinstance(data, str):
|
||
# Direct symbol string
|
||
symbol = data
|
||
elif isinstance(data, dict):
|
||
# Dictionary with symbol information
|
||
symbol = data.get('symbol')
|
||
elif isinstance(data, np.ndarray):
|
||
# Numpy array - extract symbol from metadata or use default
|
||
# For now, use the first symbol from the model's symbols list
|
||
if hasattr(self.model, 'symbols') and self.model.symbols:
|
||
symbol = self.model.symbols[0]
|
||
else:
|
||
symbol = 'ETH/USDT' # Default fallback
|
||
else:
|
||
# Unknown data type - use default symbol
|
||
if hasattr(self.model, 'symbols') and self.model.symbols:
|
||
symbol = self.model.symbols[0]
|
||
else:
|
||
symbol = 'ETH/USDT' # Default fallback
|
||
|
||
if not symbol:
|
||
logger.warning(f"ExtremaTrainerInterface.predict could not determine symbol from data: {type(data)}")
|
||
return None
|
||
|
||
features = self.model.get_context_features_for_model(symbol=symbol)
|
||
if features is not None and features.size > 0:
|
||
# The presence of features indicates a signal. We'll return a generic HOLD
|
||
# with a neutral confidence. This can be refined if ExtremaTrainer provides
|
||
# more specific BUY/SELL signals directly.
|
||
return {'action': 'HOLD', 'confidence': 0.5, 'probabilities': {'BUY': 0.33, 'SELL': 0.33, 'HOLD': 0.34}}
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error in extrema trainer prediction: {e}")
|
||
return None
|
||
|
||
def get_memory_usage(self) -> float:
|
||
return 30.0 # MB
|
||
|
||
extrema_interface = ExtremaTrainerInterface(self.extrema_trainer, name="extrema_trainer")
|
||
self.register_model(extrema_interface, weight=0.15) # Lower weight for extrema signals
|
||
logger.info("Extrema Trainer registered successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to register Extrema Trainer: {e}")
|
||
|
||
# Register COB RL Agent - Create a proper interface wrapper
|
||
if self.cob_rl_agent:
|
||
try:
|
||
class COBRLModelInterfaceWrapper(ModelInterface):
|
||
def __init__(self, model, name: str):
|
||
super().__init__(name)
|
||
self.model = model
|
||
|
||
def predict(self, data):
|
||
try:
|
||
if hasattr(self.model, 'predict'):
|
||
# Ensure data has correct dimensions for COB RL model (2000 features)
|
||
if isinstance(data, np.ndarray):
|
||
features = data.flatten()
|
||
# COB RL expects 2000 features
|
||
if len(features) < 2000:
|
||
padded_features = np.zeros(2000)
|
||
padded_features[:len(features)] = features
|
||
features = padded_features
|
||
elif len(features) > 2000:
|
||
features = features[:2000]
|
||
return self.model.predict(features)
|
||
else:
|
||
return self.model.predict(data)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error in COB RL prediction: {e}")
|
||
return None
|
||
|
||
def get_memory_usage(self) -> float:
|
||
return 50.0 # MB
|
||
|
||
cob_rl_interface = COBRLModelInterfaceWrapper(self.cob_rl_agent, name="cob_rl_model")
|
||
self.register_model(cob_rl_interface, weight=0.4)
|
||
logger.info("COB RL Agent registered successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to register COB RL Agent: {e}")
|
||
|
||
# Decision model will be initialized elsewhere if needed
|
||
|
||
# Normalize weights after all registrations
|
||
self._normalize_weights()
|
||
logger.info(f"Current model weights: {self.model_weights}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error initializing ML models: {e}")
|
||
|
||
def update_model_loss(self, model_name: str, current_loss: float, best_loss: Optional[float] = None):
|
||
"""Update model loss and potentially best loss"""
|
||
if model_name in self.model_states:
|
||
self.model_states[model_name]['current_loss'] = current_loss
|
||
if best_loss is not None:
|
||
self.model_states[model_name]['best_loss'] = best_loss
|
||
elif self.model_states[model_name]['best_loss'] is None or current_loss < self.model_states[model_name]['best_loss']:
|
||
self.model_states[model_name]['best_loss'] = current_loss
|
||
logger.debug(f"Updated {model_name} loss: current={current_loss:.4f}, best={self.model_states[model_name]['best_loss']:.4f}")
|
||
|
||
def checkpoint_saved(self, model_name: str, checkpoint_data: Dict[str, Any]):
|
||
"""Callback when a model checkpoint is saved"""
|
||
if model_name in self.model_states:
|
||
self.model_states[model_name]['checkpoint_loaded'] = True
|
||
self.model_states[model_name]['checkpoint_filename'] = checkpoint_data.get('checkpoint_id')
|
||
logger.info(f"Checkpoint saved for {model_name}: {checkpoint_data.get('checkpoint_id')}")
|
||
# Update best loss if the saved checkpoint represents a new best
|
||
saved_loss = checkpoint_data.get('loss')
|
||
if saved_loss is not None:
|
||
if self.model_states[model_name]['best_loss'] is None or saved_loss < self.model_states[model_name]['best_loss']:
|
||
self.model_states[model_name]['best_loss'] = saved_loss
|
||
logger.info(f"New best loss for {model_name}: {saved_loss:.4f}")
|
||
|
||
def _save_orchestrator_state(self):
|
||
"""Save the current state of the orchestrator, including model states."""
|
||
state = {
|
||
'model_states': {k: {sk: sv for sk, sv in v.items() if sk != 'checkpoint_loaded'} # Exclude non-serializable
|
||
for k, v in self.model_states.items()},
|
||
'model_weights': self.model_weights,
|
||
'last_trained_symbols': list(self.last_trained_symbols.keys())
|
||
}
|
||
save_path = os.path.join(self.config.paths.get('checkpoint_dir', './models/saved'), 'orchestrator_state.json')
|
||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||
with open(save_path, 'w') as f:
|
||
json.dump(state, f, indent=4)
|
||
logger.info(f"Orchestrator state saved to {save_path}")
|
||
|
||
def _load_orchestrator_state(self):
|
||
"""Load the orchestrator state from a saved file."""
|
||
save_path = os.path.join(self.config.paths.get('checkpoint_dir', './models/saved'), 'orchestrator_state.json')
|
||
if os.path.exists(save_path):
|
||
try:
|
||
with open(save_path, 'r') as f:
|
||
state = json.load(f)
|
||
self.model_states.update(state.get('model_states', {}))
|
||
self.model_weights = state.get('model_weights', self.model_weights)
|
||
self.last_trained_symbols = {s: datetime.now() for s in state.get('last_trained_symbols', [])} # Restore with current time
|
||
logger.info(f"Orchestrator state loaded from {save_path}")
|
||
except Exception as e:
|
||
logger.warning(f"Error loading orchestrator state from {save_path}: {e}")
|
||
else:
|
||
logger.info("No saved orchestrator state found. Starting fresh.")
|
||
|
||
async def start_continuous_trading(self, symbols: Optional[List[str]] = None):
|
||
"""Start the continuous trading loop, using a decision model and trading executor"""
|
||
if symbols is None:
|
||
symbols = self.symbols
|
||
|
||
if not self.realtime_processing_task:
|
||
self.realtime_processing_task = asyncio.create_task(self._trading_decision_loop())
|
||
|
||
self.running = True
|
||
logger.info(f"Starting continuous trading for symbols: {symbols}")
|
||
|
||
# Initial decision making to kickstart the process
|
||
for symbol in symbols:
|
||
await self.make_trading_decision(symbol)
|
||
await asyncio.sleep(0.5) # Small delay between initial decisions
|
||
|
||
self.trade_loop_task = asyncio.create_task(self._trading_decision_loop())
|
||
logger.info("Continuous trading loop initiated.")
|
||
|
||
async def _trading_decision_loop(self):
|
||
"""Main trading decision loop"""
|
||
logger.info("Trading decision loop started")
|
||
while self.running:
|
||
try:
|
||
for symbol in self.symbols:
|
||
await self.make_trading_decision(symbol)
|
||
await asyncio.sleep(1) # Small delay between symbols
|
||
|
||
await asyncio.sleep(self.decision_frequency)
|
||
except Exception as e:
|
||
logger.error(f"Error in trading decision loop: {e}")
|
||
await asyncio.sleep(5) # Wait before retrying
|
||
|
||
def set_dashboard(self, dashboard):
|
||
"""Set the dashboard reference for callbacks"""
|
||
self.dashboard = dashboard
|
||
logger.info("Dashboard reference set in orchestrator")
|
||
|
||
def capture_cnn_prediction(self, symbol: str, direction: int, confidence: float, current_price: float, predicted_price: float):
|
||
"""Capture CNN prediction for dashboard visualization"""
|
||
try:
|
||
prediction_data = {
|
||
'timestamp': datetime.now(),
|
||
'direction': direction,
|
||
'confidence': confidence,
|
||
'current_price': current_price,
|
||
'predicted_price': predicted_price
|
||
}
|
||
self.recent_cnn_predictions[symbol].append(prediction_data)
|
||
logger.debug(f"CNN prediction captured for {symbol}: {direction} with confidence {confidence:.3f}")
|
||
except Exception as e:
|
||
logger.debug(f"Error capturing CNN prediction: {e}")
|
||
|
||
def capture_dqn_prediction(self, symbol: str, action: int, confidence: float, current_price: float, q_values: List[float]):
|
||
"""Capture DQN prediction for dashboard visualization"""
|
||
try:
|
||
prediction_data = {
|
||
'timestamp': datetime.now(),
|
||
'action': action,
|
||
'confidence': confidence,
|
||
'current_price': current_price,
|
||
'q_values': q_values
|
||
}
|
||
self.recent_dqn_predictions[symbol].append(prediction_data)
|
||
logger.debug(f"DQN prediction captured for {symbol}: action {action} with confidence {confidence:.3f}")
|
||
except Exception as e:
|
||
logger.debug(f"Error capturing DQN prediction: {e}")
|
||
|
||
def _get_current_price(self, symbol: str) -> Optional[float]:
|
||
"""Get current price for a symbol"""
|
||
try:
|
||
return self.data_provider.get_current_price(symbol)
|
||
except Exception as e:
|
||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||
return None
|
||
|
||
async def _generate_fallback_prediction(self, symbol: str, current_price: float) -> Optional[Prediction]:
|
||
"""Generate a basic momentum-based fallback prediction when no models are available"""
|
||
try:
|
||
# Get simple price history for momentum calculation
|
||
timeframes = ['1m', '5m', '15m']
|
||
|
||
momentum_signals = []
|
||
for timeframe in timeframes:
|
||
try:
|
||
# Use the correct method name for DataProvider
|
||
data = None
|
||
if hasattr(self.data_provider, 'get_historical_data'):
|
||
data = self.data_provider.get_historical_data(symbol, timeframe, limit=20)
|
||
elif hasattr(self.data_provider, 'get_candles'):
|
||
data = self.data_provider.get_candles(symbol, timeframe, limit=20)
|
||
elif hasattr(self.data_provider, 'get_data'):
|
||
data = self.data_provider.get_data(symbol, timeframe, limit=20)
|
||
|
||
if data and len(data) >= 10:
|
||
# Handle different data formats
|
||
prices = []
|
||
if isinstance(data, list) and len(data) > 0:
|
||
if hasattr(data[0], 'close'):
|
||
prices = [candle.close for candle in data[-10:]]
|
||
elif isinstance(data[0], dict) and 'close' in data[0]:
|
||
prices = [candle['close'] for candle in data[-10:]]
|
||
elif isinstance(data[0], (list, tuple)) and len(data[0]) >= 5:
|
||
prices = [candle[4] for candle in data[-10:]] # Assuming close is 5th element
|
||
|
||
if prices and len(prices) >= 10:
|
||
# Simple momentum: if recent price > average, bullish
|
||
recent_avg = sum(prices[-5:]) / 5
|
||
older_avg = sum(prices[:5]) / 5
|
||
momentum = (recent_avg - older_avg) / older_avg if older_avg > 0 else 0
|
||
momentum_signals.append(momentum)
|
||
except Exception:
|
||
continue
|
||
|
||
if momentum_signals:
|
||
avg_momentum = sum(momentum_signals) / len(momentum_signals)
|
||
|
||
# Convert momentum to action
|
||
if avg_momentum > 0.01: # 1% positive momentum
|
||
action = 'BUY'
|
||
confidence = min(0.7, abs(avg_momentum) * 10)
|
||
elif avg_momentum < -0.01: # 1% negative momentum
|
||
action = 'SELL'
|
||
confidence = min(0.7, abs(avg_momentum) * 10)
|
||
else:
|
||
action = 'HOLD'
|
||
confidence = 0.5
|
||
|
||
return Prediction(
|
||
action=action,
|
||
confidence=confidence,
|
||
probabilities={
|
||
'BUY': confidence if action == 'BUY' else (1 - confidence) / 2,
|
||
'SELL': confidence if action == 'SELL' else (1 - confidence) / 2,
|
||
'HOLD': confidence if action == 'HOLD' else (1 - confidence) / 2
|
||
},
|
||
timeframe='mixed',
|
||
timestamp=datetime.now(),
|
||
model_name='fallback_momentum',
|
||
metadata={'momentum': avg_momentum, 'signals_count': len(momentum_signals)}
|
||
)
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.debug(f"Error generating fallback prediction for {symbol}: {e}")
|
||
return None
|
||
|
||
def _initialize_cob_integration(self):
|
||
"""Initialize COB integration for real-time market microstructure data"""
|
||
if COB_INTEGRATION_AVAILABLE and COBIntegration is not None:
|
||
try:
|
||
self.cob_integration = COBIntegration(
|
||
symbols=self.symbols,
|
||
data_provider=self.data_provider
|
||
)
|
||
logger.info("COB Integration initialized")
|
||
|
||
# Register callbacks for COB data
|
||
if hasattr(self.cob_integration, 'add_cnn_callback'):
|
||
self.cob_integration.add_cnn_callback(self._on_cob_cnn_features)
|
||
if hasattr(self.cob_integration, 'add_dqn_callback'):
|
||
self.cob_integration.add_dqn_callback(self._on_cob_dqn_features)
|
||
if hasattr(self.cob_integration, 'add_dashboard_callback'):
|
||
self.cob_integration.add_dashboard_callback(self._on_cob_dashboard_data)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Failed to initialize COB Integration: {e}")
|
||
self.cob_integration = None
|
||
else:
|
||
logger.warning("COB Integration not available. Please install `cob_integration` module.")
|
||
|
||
async def start_cob_integration(self):
|
||
"""Start the COB integration to begin streaming data"""
|
||
if self.cob_integration and hasattr(self.cob_integration, 'start'):
|
||
try:
|
||
logger.info("Attempting to start COB integration...")
|
||
await self.cob_integration.start()
|
||
logger.info("COB Integration started successfully.")
|
||
except Exception as e:
|
||
logger.error(f"Failed to start COB integration: {e}")
|
||
else:
|
||
logger.warning("COB Integration not initialized or start method not available.")
|
||
|
||
def _on_cob_cnn_features(self, symbol: str, cob_data: Dict):
|
||
"""Callback for when new COB CNN features are available"""
|
||
if not self.realtime_processing:
|
||
return
|
||
try:
|
||
# This is where you would feed the features to the CNN model for prediction
|
||
# or store them for training. For now, we just log and store the latest.
|
||
# self.latest_cob_features[symbol] = cob_data['features']
|
||
# logger.debug(f"COB CNN features updated for {symbol}: {cob_data['features'][:5]}...")
|
||
|
||
# If training is enabled, add to training data
|
||
if self.training_enabled and self.enhanced_training_system:
|
||
# Use a safe method check before calling
|
||
if hasattr(self.enhanced_training_system, 'add_cob_cnn_experience'):
|
||
self.enhanced_training_system.add_cob_cnn_experience(symbol, cob_data)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in _on_cob_cnn_features for {symbol}: {e}")
|
||
|
||
def _on_cob_dqn_features(self, symbol: str, cob_data: Dict):
|
||
"""Callback for when new COB DQN features are available"""
|
||
if not self.realtime_processing:
|
||
return
|
||
try:
|
||
# This is where you would feed the state to the DQN model for prediction
|
||
# or store them for training. For now, we just log and store the latest.
|
||
# self.latest_cob_state[symbol] = cob_data['state']
|
||
# logger.debug(f"COB DQN state updated for {symbol}: {cob_data['state'][:5]}...")
|
||
|
||
# If training is enabled, add to training data
|
||
if self.training_enabled and self.enhanced_training_system:
|
||
# Use a safe method check before calling
|
||
if hasattr(self.enhanced_training_system, 'add_cob_dqn_experience'):
|
||
self.enhanced_training_system.add_cob_dqn_experience(symbol, cob_data)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in _on_cob_dqn_features for {symbol}: {e}")
|
||
|
||
def _on_cob_dashboard_data(self, symbol: str, cob_data: Dict):
|
||
"""Callback for when new COB data is available for the dashboard"""
|
||
if not self.realtime_processing:
|
||
return
|
||
try:
|
||
self.latest_cob_data[symbol] = cob_data
|
||
# logger.debug(f"COB Dashboard data updated for {symbol}")
|
||
if self.dashboard and hasattr(self.dashboard, 'update_cob_data'):
|
||
self.dashboard.update_cob_data(symbol, cob_data)
|
||
except Exception as e:
|
||
logger.error(f"Error in _on_cob_dashboard_data for {symbol}: {e}")
|
||
|
||
def get_cob_features(self, symbol: str) -> Optional[np.ndarray]:
|
||
"""Get the latest COB features for CNN model"""
|
||
return self.latest_cob_features.get(symbol)
|
||
|
||
def get_cob_state(self, symbol: str) -> Optional[np.ndarray]:
|
||
"""Get the latest COB state for DQN model"""
|
||
return self.latest_cob_state.get(symbol)
|
||
|
||
def get_cob_snapshot(self, symbol: str):
|
||
"""Get the latest raw COB snapshot for a symbol"""
|
||
if self.cob_integration and hasattr(self.cob_integration, 'get_latest_cob_snapshot'):
|
||
return self.cob_integration.get_latest_cob_snapshot(symbol)
|
||
return None
|
||
|
||
def get_cob_feature_matrix(self, symbol: str, sequence_length: int = 60) -> Optional[np.ndarray]:
|
||
"""Get a sequence of COB CNN features for sequence models"""
|
||
if symbol not in self.cob_feature_history or not self.cob_feature_history[symbol]:
|
||
return None
|
||
|
||
features = [item['cnn_features'] for item in list(self.cob_feature_history[symbol])][-sequence_length:]
|
||
if not features:
|
||
return None
|
||
|
||
# Pad or truncate to ensure consistent length and shape
|
||
expected_feature_size = 102 # From _generate_cob_cnn_features
|
||
padded_features = []
|
||
for f in features:
|
||
if len(f) < expected_feature_size:
|
||
padded_features.append(np.pad(f, (0, expected_feature_size - len(f)), 'constant').tolist())
|
||
elif len(f) > expected_feature_size:
|
||
padded_features.append(f[:expected_feature_size].tolist())
|
||
else:
|
||
padded_features.append(f)
|
||
|
||
# Ensure we have the desired sequence length by padding with zeros if necessary
|
||
if len(padded_features) < sequence_length:
|
||
padding = [[0.0] * expected_feature_size for _ in range(sequence_length - len(padded_features))]
|
||
padded_features = padding + padded_features
|
||
|
||
return np.array(padded_features[-sequence_length:]).astype(np.float32) # Ensure correct length
|
||
|
||
def _initialize_default_weights(self):
|
||
"""Initialize default model weights from config"""
|
||
self.model_weights = {
|
||
'CNN': self.config.orchestrator.get('cnn_weight', 0.7),
|
||
'RL': self.config.orchestrator.get('rl_weight', 0.3)
|
||
}
|
||
|
||
# Add weights for specific models if they exist
|
||
if hasattr(self, 'cnn_model') and self.cnn_model:
|
||
self.model_weights["enhanced_cnn"] = 0.4
|
||
|
||
# Only add DQN agent weight if it exists
|
||
if hasattr(self, 'rl_agent') and self.rl_agent:
|
||
self.model_weights["dqn_agent"] = 0.3
|
||
|
||
# Add COB RL model weight if it exists (HIGHEST PRIORITY)
|
||
if hasattr(self, 'cob_rl_agent') and self.cob_rl_agent:
|
||
self.model_weights["cob_rl_model"] = 0.4
|
||
|
||
# Add extrema trainer weight if it exists
|
||
if hasattr(self, 'extrema_trainer') and self.extrema_trainer:
|
||
self.model_weights["extrema_trainer"] = 0.15
|
||
|
||
def register_model(self, model: ModelInterface, weight: Optional[float] = None) -> bool:
|
||
"""Register a new model with the orchestrator"""
|
||
try:
|
||
# Register with model registry
|
||
if not self.model_registry.register_model(model):
|
||
return False
|
||
|
||
# Set weight
|
||
if weight is not None:
|
||
self.model_weights[model.name] = weight
|
||
elif model.name not in self.model_weights:
|
||
self.model_weights[model.name] = 0.1 # Default low weight for new models
|
||
|
||
# Initialize performance tracking
|
||
if model.name not in self.model_performance:
|
||
self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
|
||
|
||
logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}")
|
||
self._normalize_weights()
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error registering model {model.name}: {e}")
|
||
return False
|
||
|
||
def unregister_model(self, model_name: str) -> bool:
|
||
"""Unregister a model"""
|
||
try:
|
||
if self.model_registry.unregister_model(model_name):
|
||
if model_name in self.model_weights:
|
||
del self.model_weights[model_name]
|
||
if model_name in self.model_performance:
|
||
del self.model_performance[model_name]
|
||
|
||
self._normalize_weights()
|
||
logger.info(f"Unregistered {model_name} model")
|
||
return True
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error unregistering model {model_name}: {e}")
|
||
return False
|
||
|
||
def _normalize_weights(self):
|
||
"""Normalize model weights to sum to 1.0"""
|
||
total_weight = sum(self.model_weights.values())
|
||
if total_weight > 0:
|
||
for model_name in self.model_weights:
|
||
self.model_weights[model_name] /= total_weight
|
||
|
||
async def add_decision_callback(self, callback):
|
||
"""Add a callback function to be called when decisions are made"""
|
||
self.decision_callbacks.append(callback)
|
||
logger.info(f"Decision callback registered: {callback.__name__ if hasattr(callback, '__name__') else 'unnamed'}")
|
||
return True
|
||
|
||
async def make_trading_decision(self, symbol: str) -> Optional[TradingDecision]:
|
||
"""
|
||
Make a trading decision for a symbol by combining all registered model outputs
|
||
"""
|
||
try:
|
||
current_time = datetime.now()
|
||
|
||
# EXECUTE EVERY SIGNAL: Remove decision frequency limit
|
||
# Allow immediate execution of every signal from the decision model
|
||
logger.debug(f"Processing signal for {symbol} - no frequency limit applied")
|
||
|
||
# Get current market data
|
||
current_price = self.data_provider.get_current_price(symbol)
|
||
if current_price is None:
|
||
logger.warning(f"No current price available for {symbol}")
|
||
return None
|
||
|
||
# Get predictions from all registered models
|
||
predictions = await self._get_all_predictions(symbol)
|
||
|
||
if not predictions:
|
||
# FALLBACK: Generate basic momentum signal when no models are available
|
||
logger.debug(f"No model predictions available for {symbol}, generating fallback signal")
|
||
fallback_prediction = await self._generate_fallback_prediction(symbol, current_price)
|
||
if fallback_prediction:
|
||
predictions = [fallback_prediction]
|
||
else:
|
||
logger.debug(f"No fallback prediction available for {symbol}")
|
||
return None
|
||
|
||
# Combine predictions
|
||
decision = self._combine_predictions(
|
||
symbol=symbol,
|
||
price=current_price,
|
||
predictions=predictions,
|
||
timestamp=current_time
|
||
)
|
||
|
||
# Update state
|
||
self.last_decision_time[symbol] = current_time
|
||
if symbol not in self.recent_decisions:
|
||
self.recent_decisions[symbol] = []
|
||
self.recent_decisions[symbol].append(decision)
|
||
|
||
# Keep only recent decisions (last 100)
|
||
if len(self.recent_decisions[symbol]) > 100:
|
||
self.recent_decisions[symbol] = self.recent_decisions[symbol][-100:]
|
||
|
||
# Call decision callbacks
|
||
for callback in self.decision_callbacks:
|
||
try:
|
||
await callback(decision)
|
||
except Exception as e:
|
||
logger.error(f"Error in decision callback: {e}")
|
||
|
||
# Clean up memory periodically
|
||
if len(self.recent_decisions[symbol]) % 200 == 0: # Reduced from 50 to 200
|
||
self.model_registry.cleanup_all_models()
|
||
|
||
return decision
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error making trading decision for {symbol}: {e}")
|
||
return None
|
||
|
||
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
|
||
"""Get predictions from all registered models"""
|
||
predictions = []
|
||
|
||
for model_name, model in self.model_registry.models.items():
|
||
try:
|
||
if isinstance(model, CNNModelInterface):
|
||
# Get CNN predictions for each timeframe
|
||
cnn_predictions = await self._get_cnn_predictions(model, symbol)
|
||
predictions.extend(cnn_predictions)
|
||
|
||
elif isinstance(model, RLAgentInterface):
|
||
# Get RL prediction
|
||
rl_prediction = await self._get_rl_prediction(model, symbol)
|
||
if rl_prediction:
|
||
predictions.append(rl_prediction)
|
||
|
||
else:
|
||
# Generic model interface
|
||
generic_prediction = await self._get_generic_prediction(model, symbol)
|
||
if generic_prediction:
|
||
predictions.append(generic_prediction)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting prediction from {model_name}: {e}")
|
||
continue
|
||
|
||
return predictions
|
||
|
||
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
|
||
"""Get predictions from CNN model for all timeframes with enhanced COB features"""
|
||
predictions = []
|
||
try:
|
||
timeframes = getattr(self.config, 'timeframes', ['1m','5m','15m','1h'])
|
||
for timeframe in timeframes:
|
||
# 1) build or fetch your feature matrix (and optionally augment with COB)…
|
||
feature_matrix = self.data_provider.get_feature_matrix(
|
||
symbol=symbol,
|
||
timeframes=[timeframe],
|
||
window_size=getattr(model, 'window_size', 20)
|
||
)
|
||
if feature_matrix is None:
|
||
continue
|
||
|
||
# …apply COB‐augmentation here (omitted for brevity)—
|
||
enhanced_features = self._augment_with_cob(feature_matrix, symbol)
|
||
|
||
# 2) Initialize these before we call the model
|
||
action_probs, confidence = None, None
|
||
|
||
# 3) Try the actual model inference
|
||
try:
|
||
# if your model has an .act() that returns (probs, conf)
|
||
if hasattr(model.model, 'act'):
|
||
# Flatten / reshape enhanced_features as needed…
|
||
x = self._prepare_cnn_input(enhanced_features)
|
||
|
||
# Debugging: Print the type and content of x before passing to act()
|
||
logger.debug(f"CNN input (x) type: {type(x)}, shape: {x.shape}, content sample: {x.flatten()[:5]}...")
|
||
|
||
action_idx, confidence, action_probs = model.model.act(x, explore=False)
|
||
|
||
# Debugging: Print the type and content of the unpacked values
|
||
logger.debug(f"CNN act() returned: action_idx={action_idx} (type={type(action_idx)}), confidence={confidence} (type={type(confidence)}), action_probs={action_probs[:5]}... (type={type(action_probs)})")
|
||
else:
|
||
# fallback to generic predict
|
||
result = model.predict(enhanced_features)
|
||
if isinstance(result, tuple) and len(result)==2:
|
||
action_probs, confidence = result
|
||
else:
|
||
action_probs = result
|
||
confidence = 0.7
|
||
except Exception as e:
|
||
logger.warning(f"CNN inference failed for {symbol}@{timeframe}: {e}")
|
||
continue # skip this timeframe entirely
|
||
|
||
# 4) If we still don't have valid probs, skip
|
||
if action_probs is None:
|
||
continue
|
||
|
||
# 5) Build your Prediction
|
||
action_names = ['SELL','HOLD','BUY']
|
||
best_idx = int(np.argmax(action_probs))
|
||
best_action = action_names[best_idx]
|
||
pred = Prediction(
|
||
action=best_action,
|
||
confidence=float(confidence),
|
||
probabilities={n: float(p) for n,p in zip(action_names, action_probs)},
|
||
timeframe=timeframe,
|
||
timestamp=datetime.now(),
|
||
model_name=model.name,
|
||
metadata={
|
||
'feature_shape': str(enhanced_features.shape),
|
||
'cob_enhanced': enhanced_features is not feature_matrix
|
||
}
|
||
)
|
||
predictions.append(pred)
|
||
|
||
# …and capture for the dashboard if you like…
|
||
current_price = self._get_current_price(symbol)
|
||
if current_price is not None:
|
||
predicted_price = current_price * (1 + (0.01 * (confidence if best_action=='BUY' else -confidence if best_action=='SELL' else 0)))
|
||
self.capture_cnn_prediction(
|
||
symbol,
|
||
direction=best_idx,
|
||
confidence=confidence,
|
||
current_price=current_price,
|
||
predicted_price=predicted_price
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Orch: Error getting CNN predictions: {e}")
|
||
return predictions
|
||
|
||
# helper stubs for clarity
|
||
def _augment_with_cob(self, feature_matrix, symbol):
|
||
# your existing cob‐augmentation logic…
|
||
return feature_matrix
|
||
|
||
def _prepare_cnn_input(self, features):
|
||
arr = features.flatten()
|
||
# pad/truncate to 300, reshape to (1,300)
|
||
if len(arr) < 300:
|
||
arr = np.pad(arr, (0,300-len(arr)), 'constant')
|
||
else:
|
||
arr = arr[:300]
|
||
return arr.reshape(1,-1)
|
||
async def _get_rl_prediction(self, model: RLAgentInterface, symbol: str) -> Optional[Prediction]:
|
||
"""Get prediction from RL agent"""
|
||
try:
|
||
# Get current state for RL agent
|
||
state = self._get_rl_state(symbol)
|
||
if state is None:
|
||
return None
|
||
|
||
# Get RL agent's action, confidence, and q_values from the underlying model
|
||
if hasattr(model.model, 'act_with_confidence'):
|
||
# Call act_with_confidence and handle different return formats
|
||
result = model.model.act_with_confidence(state)
|
||
|
||
if len(result) == 3:
|
||
# EnhancedCNN format: (action, confidence, q_values)
|
||
action_idx, confidence, raw_q_values = result
|
||
elif len(result) == 2:
|
||
# DQN format: (action, confidence)
|
||
action_idx, confidence = result
|
||
raw_q_values = None
|
||
else:
|
||
logger.error(f"Unexpected return format from act_with_confidence: {len(result)} values")
|
||
return None
|
||
elif hasattr(model.model, 'act'):
|
||
action_idx = model.model.act(state, explore=False)
|
||
confidence = 0.7 # Default confidence for basic act method
|
||
raw_q_values = None # No raw q_values from simple act
|
||
else:
|
||
logger.error(f"RL model {model.name} has no act method")
|
||
return None
|
||
|
||
action_names = ['SELL', 'HOLD', 'BUY']
|
||
action = action_names[action_idx]
|
||
|
||
# Convert raw_q_values to list if they are a tensor
|
||
q_values_for_capture = None
|
||
if raw_q_values is not None and hasattr(raw_q_values, 'tolist'):
|
||
q_values_for_capture = raw_q_values.tolist()
|
||
elif raw_q_values is not None and isinstance(raw_q_values, list):
|
||
q_values_for_capture = raw_q_values
|
||
|
||
# Create prediction object
|
||
prediction = Prediction(
|
||
action=action,
|
||
confidence=float(confidence),
|
||
# Use actual q_values if available, otherwise default probabilities
|
||
probabilities={action_names[i]: float(q_values_for_capture[i]) if q_values_for_capture else (1.0 / len(action_names)) for i in range(len(action_names))},
|
||
timeframe='mixed', # RL uses mixed timeframes
|
||
timestamp=datetime.now(),
|
||
model_name=model.name,
|
||
metadata={'state_size': len(state)}
|
||
)
|
||
|
||
# Capture DQN prediction for dashboard visualization
|
||
current_price = self._get_current_price(symbol)
|
||
if current_price:
|
||
# Only pass q_values if they exist, otherwise pass empty list
|
||
q_values_to_pass = q_values_for_capture if q_values_for_capture is not None else []
|
||
self.capture_dqn_prediction(symbol, action_idx, float(confidence), current_price, q_values_to_pass)
|
||
|
||
return prediction
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting RL prediction: {e}")
|
||
return None
|
||
|
||
async def _get_generic_prediction(self, model: ModelInterface, symbol: str) -> Optional[Prediction]:
|
||
"""Get prediction from generic model"""
|
||
try:
|
||
# Safely get timeframes from config
|
||
timeframes = getattr(self.config, 'timeframes', None)
|
||
if timeframes is None:
|
||
timeframes = ['1m', '5m', '15m'] # Default timeframes
|
||
|
||
# Get feature matrix for the model
|
||
feature_matrix = self.data_provider.get_feature_matrix(
|
||
symbol=symbol,
|
||
timeframes=timeframes[:3], # Use first 3 timeframes
|
||
window_size=20
|
||
)
|
||
|
||
if feature_matrix is not None:
|
||
prediction_result = model.predict(feature_matrix)
|
||
|
||
# Handle different return formats from model.predict()
|
||
if prediction_result is None:
|
||
return None
|
||
|
||
# Check if it's a tuple (action_probs, confidence)
|
||
if isinstance(prediction_result, tuple) and len(prediction_result) == 2:
|
||
action_probs, confidence = prediction_result
|
||
elif isinstance(prediction_result, dict):
|
||
# Handle dictionary return format
|
||
action_probs = prediction_result.get('probabilities', None)
|
||
confidence = prediction_result.get('confidence', 0.7)
|
||
else:
|
||
# Assume it's just action probabilities (e.g., a list or numpy array)
|
||
action_probs = prediction_result
|
||
confidence = 0.7 # Default confidence
|
||
|
||
if action_probs is not None:
|
||
# Ensure action_probs is a numpy array for argmax
|
||
if not isinstance(action_probs, np.ndarray):
|
||
action_probs = np.array(action_probs)
|
||
|
||
action_names = ['SELL', 'HOLD', 'BUY']
|
||
best_action_idx = np.argmax(action_probs)
|
||
best_action = action_names[best_action_idx]
|
||
|
||
prediction = Prediction(
|
||
action=best_action,
|
||
confidence=float(confidence),
|
||
probabilities={name: float(prob) for name, prob in zip(action_names, action_probs)},
|
||
timeframe='mixed',
|
||
timestamp=datetime.now(),
|
||
model_name=model.name,
|
||
metadata={'generic_model': True}
|
||
)
|
||
|
||
return prediction
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting generic prediction: {e}")
|
||
return None
|
||
|
||
def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]:
|
||
"""Get current state for RL agent"""
|
||
try:
|
||
# Safely get timeframes from config
|
||
timeframes = getattr(self.config, 'timeframes', None)
|
||
if timeframes is None:
|
||
timeframes = ['1m', '5m', '15m', '1h'] # Default timeframes
|
||
|
||
# Get feature matrix for all timeframes
|
||
feature_matrix = self.data_provider.get_feature_matrix(
|
||
symbol=symbol,
|
||
timeframes=timeframes,
|
||
window_size=self.config.rl.get('window_size', 20)
|
||
)
|
||
|
||
if feature_matrix is not None:
|
||
# Flatten the feature matrix for RL agent
|
||
# Shape: (n_timeframes, window_size, n_features) -> (n_timeframes * window_size * n_features,)
|
||
state = feature_matrix.flatten()
|
||
|
||
# Add additional state information (position, balance, etc.)
|
||
# This would come from a portfolio manager in a real implementation
|
||
additional_state = np.array([0.0, 1.0, 0.0]) # [position, balance, unrealized_pnl]
|
||
|
||
combined_state = np.concatenate([state, additional_state])
|
||
|
||
# Ensure DQN gets exactly 403 features (expected by the model)
|
||
target_size = 403
|
||
if len(combined_state) < target_size:
|
||
# Pad with zeros
|
||
padded_state = np.zeros(target_size)
|
||
padded_state[:len(combined_state)] = combined_state
|
||
combined_state = padded_state
|
||
elif len(combined_state) > target_size:
|
||
# Truncate to target size
|
||
combined_state = combined_state[:target_size]
|
||
|
||
return combined_state
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error creating RL state for {symbol}: {e}")
|
||
return None
|
||
|
||
def _combine_predictions(self, symbol: str, price: float,
|
||
predictions: List[Prediction],
|
||
timestamp: datetime) -> TradingDecision:
|
||
"""Combine all predictions into a final decision with aggressiveness and P&L feedback"""
|
||
try:
|
||
reasoning = {
|
||
'predictions': len(predictions),
|
||
'weights': self.model_weights.copy(),
|
||
'models_used': [pred.model_name for pred in predictions]
|
||
}
|
||
|
||
# Get current position P&L for feedback
|
||
current_position_pnl = self._get_current_position_pnl(symbol, price)
|
||
|
||
# Initialize action scores
|
||
action_scores = {'BUY': 0.0, 'SELL': 0.0, 'HOLD': 0.0}
|
||
total_weight = 0.0
|
||
|
||
# Process all predictions
|
||
for pred in predictions:
|
||
# Get model weight
|
||
model_weight = self.model_weights.get(pred.model_name, 0.1)
|
||
|
||
# Weight by confidence and timeframe importance
|
||
timeframe_weight = self._get_timeframe_weight(pred.timeframe)
|
||
weighted_confidence = pred.confidence * timeframe_weight * model_weight
|
||
|
||
action_scores[pred.action] += weighted_confidence
|
||
total_weight += weighted_confidence
|
||
|
||
# Normalize scores
|
||
if total_weight > 0:
|
||
for action in action_scores:
|
||
action_scores[action] /= total_weight
|
||
|
||
# Choose best action - safe way to handle max with key function
|
||
if action_scores:
|
||
best_action = max(action_scores.keys(), key=lambda k: action_scores[k])
|
||
best_confidence = action_scores[best_action]
|
||
else:
|
||
best_action = 'HOLD'
|
||
best_confidence = 0.0
|
||
|
||
# Calculate aggressiveness-adjusted thresholds
|
||
entry_threshold, exit_threshold = self._calculate_aggressiveness_thresholds(
|
||
current_position_pnl, symbol
|
||
)
|
||
|
||
# SIGNAL CONFIRMATION: Only execute signals that meet confirmation criteria
|
||
# Apply confidence thresholds and signal accumulation for trend confirmation
|
||
reasoning['execute_every_signal'] = False
|
||
reasoning['models_aggregated'] = [pred.model_name for pred in predictions]
|
||
reasoning['aggregated_confidence'] = best_confidence
|
||
|
||
# Calculate dynamic aggressiveness based on recent performance
|
||
entry_aggressiveness = self._calculate_dynamic_entry_aggressiveness(symbol)
|
||
|
||
# Adjust confidence threshold based on entry aggressiveness
|
||
# Higher aggressiveness = lower threshold (more trades)
|
||
# entry_aggressiveness: 0.0 = very conservative, 1.0 = very aggressive
|
||
base_threshold = self.confidence_threshold
|
||
aggressiveness_factor = 1.0 - entry_aggressiveness # Invert: high agg = low factor
|
||
dynamic_threshold = base_threshold * aggressiveness_factor
|
||
|
||
# Ensure minimum threshold for safety (don't go below 1% confidence)
|
||
dynamic_threshold = max(0.01, dynamic_threshold)
|
||
|
||
# Apply dynamic confidence threshold for signal confirmation
|
||
if best_action != 'HOLD':
|
||
if best_confidence < dynamic_threshold:
|
||
logger.debug(f"Signal below dynamic confidence threshold: {best_action} {symbol} "
|
||
f"(confidence: {best_confidence:.3f} < {dynamic_threshold:.3f}, "
|
||
f"base: {base_threshold:.3f}, aggressiveness: {entry_aggressiveness:.2f})")
|
||
best_action = 'HOLD'
|
||
best_confidence = 0.0
|
||
else:
|
||
logger.info(f"SIGNAL ACCEPTED: {best_action} {symbol} "
|
||
f"(confidence: {best_confidence:.3f} >= {dynamic_threshold:.3f}, "
|
||
f"aggressiveness: {entry_aggressiveness:.2f})")
|
||
# Add signal to accumulator for trend confirmation
|
||
signal_data = {
|
||
'action': best_action,
|
||
'confidence': best_confidence,
|
||
'timestamp': timestamp,
|
||
'models': reasoning['models_aggregated']
|
||
}
|
||
|
||
# Check if we have enough confirmations
|
||
confirmed_action = self._check_signal_confirmation(symbol, signal_data)
|
||
if confirmed_action:
|
||
logger.info(f"SIGNAL CONFIRMED: {confirmed_action} (confidence: {best_confidence:.3f}) "
|
||
f"from aggregated models: {reasoning['models_aggregated']}")
|
||
best_action = confirmed_action
|
||
reasoning['signal_confirmed'] = True
|
||
reasoning['confirmations_received'] = len(self.signal_accumulator[symbol])
|
||
else:
|
||
logger.debug(f"Signal accumulating: {best_action} {symbol} "
|
||
f"({len(self.signal_accumulator[symbol])}/{self.required_confirmations} confirmations)")
|
||
best_action = 'HOLD'
|
||
best_confidence = 0.0
|
||
reasoning['rejected_reason'] = 'awaiting_confirmation'
|
||
|
||
# Add P&L-based decision adjustment
|
||
best_action, best_confidence = self._apply_pnl_feedback(
|
||
best_action, best_confidence, current_position_pnl, symbol, reasoning
|
||
)
|
||
|
||
# Get memory usage stats
|
||
try:
|
||
memory_usage = {}
|
||
if hasattr(self.model_registry, 'get_memory_stats'):
|
||
memory_usage = self.model_registry.get_memory_stats()
|
||
else:
|
||
# Fallback memory usage calculation
|
||
for model_name in self.model_weights:
|
||
memory_usage[model_name] = 50.0 # Default MB estimate
|
||
except Exception:
|
||
memory_usage = {}
|
||
|
||
# Get exit aggressiveness (entry aggressiveness already calculated above)
|
||
exit_aggressiveness = self._calculate_dynamic_exit_aggressiveness(symbol, current_position_pnl)
|
||
|
||
# Create final decision
|
||
decision = TradingDecision(
|
||
action=best_action,
|
||
confidence=best_confidence,
|
||
symbol=symbol,
|
||
price=price,
|
||
timestamp=timestamp,
|
||
reasoning=reasoning,
|
||
memory_usage=memory_usage.get('models', {}) if memory_usage else {},
|
||
entry_aggressiveness=entry_aggressiveness,
|
||
exit_aggressiveness=exit_aggressiveness,
|
||
current_position_pnl=current_position_pnl
|
||
)
|
||
|
||
logger.info(f"Decision for {symbol}: {best_action} (confidence: {best_confidence:.3f}, "
|
||
f"entry_agg: {entry_aggressiveness:.2f}, exit_agg: {exit_aggressiveness:.2f}, "
|
||
f"pnl: ${current_position_pnl:.2f})")
|
||
|
||
# Trigger training on each decision (especially for executed trades)
|
||
self._trigger_training_on_decision(decision, price)
|
||
|
||
return decision
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error combining predictions for {symbol}: {e}")
|
||
# Return safe default
|
||
return TradingDecision(
|
||
action='HOLD',
|
||
confidence=0.0,
|
||
symbol=symbol,
|
||
price=price,
|
||
timestamp=timestamp,
|
||
reasoning={'error': str(e)},
|
||
memory_usage={},
|
||
entry_aggressiveness=0.5,
|
||
exit_aggressiveness=0.5,
|
||
current_position_pnl=0.0
|
||
)
|
||
|
||
def _get_timeframe_weight(self, timeframe: str) -> float:
|
||
"""Get importance weight for a timeframe"""
|
||
# Higher timeframes get more weight in decision making
|
||
weights = {
|
||
'1m': 0.1, '5m': 0.2, '15m': 0.3, '30m': 0.4,
|
||
'1h': 0.6, '4h': 0.8, '1d': 1.0
|
||
}
|
||
return weights.get(timeframe, 0.5)
|
||
|
||
def update_model_performance(self, model_name: str, was_correct: bool):
|
||
"""Update performance tracking for a model"""
|
||
if model_name in self.model_performance:
|
||
self.model_performance[model_name]['total'] += 1
|
||
if was_correct:
|
||
self.model_performance[model_name]['correct'] += 1
|
||
|
||
# Update accuracy
|
||
total = self.model_performance[model_name]['total']
|
||
correct = self.model_performance[model_name]['correct']
|
||
self.model_performance[model_name]['accuracy'] = correct / total if total > 0 else 0.0
|
||
|
||
def adapt_weights(self):
|
||
"""Dynamically adapt model weights based on performance"""
|
||
try:
|
||
for model_name, performance in self.model_performance.items():
|
||
if performance['total'] > 0:
|
||
# Adjust weight based on relative performance
|
||
accuracy = performance['correct'] / performance['total']
|
||
self.model_weights[model_name] = accuracy
|
||
|
||
logger.info(f"Adapted {model_name} weight: {self.model_weights[model_name]}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error adapting weights: {e}")
|
||
|
||
def get_recent_decisions(self, symbol: str, limit: int = 10) -> List[TradingDecision]:
|
||
"""Get recent decisions for a symbol"""
|
||
if symbol in self.recent_decisions:
|
||
return self.recent_decisions[symbol][-limit:]
|
||
return []
|
||
|
||
def get_performance_metrics(self) -> Dict[str, Any]:
|
||
"""Get performance metrics for the orchestrator"""
|
||
return {
|
||
'model_performance': self.model_performance.copy(),
|
||
'weights': self.model_weights.copy(),
|
||
'configuration': {
|
||
'confidence_threshold': self.confidence_threshold,
|
||
# 'decision_frequency': self.decision_frequency
|
||
},
|
||
'recent_activity': {
|
||
symbol: len(decisions) for symbol, decisions in self.recent_decisions.items()
|
||
}
|
||
}
|
||
|
||
def get_model_states(self) -> Dict[str, Dict]:
|
||
"""Get current model states with REAL checkpoint data - SSOT for dashboard"""
|
||
try:
|
||
# ENHANCED: Load actual checkpoint metadata for each model
|
||
from utils.checkpoint_manager import load_best_checkpoint
|
||
|
||
# Update each model with REAL checkpoint data
|
||
for model_name in ['dqn_agent', 'enhanced_cnn', 'extrema_trainer', 'decision', 'cob_rl']:
|
||
try:
|
||
result = load_best_checkpoint(model_name)
|
||
if result:
|
||
file_path, metadata = result
|
||
|
||
# Map model names to internal keys
|
||
internal_key = {
|
||
'dqn_agent': 'dqn',
|
||
'enhanced_cnn': 'cnn',
|
||
'extrema_trainer': 'extrema_trainer',
|
||
'decision': 'decision',
|
||
'cob_rl': 'cob_rl'
|
||
}.get(model_name, model_name)
|
||
|
||
if internal_key in self.model_states:
|
||
# Load REAL checkpoint data
|
||
self.model_states[internal_key]['current_loss'] = getattr(metadata, 'loss', None) or getattr(metadata, 'val_loss', None)
|
||
self.model_states[internal_key]['best_loss'] = getattr(metadata, 'loss', None) or getattr(metadata, 'val_loss', None)
|
||
self.model_states[internal_key]['checkpoint_loaded'] = True
|
||
self.model_states[internal_key]['checkpoint_filename'] = metadata.checkpoint_id
|
||
self.model_states[internal_key]['performance_score'] = getattr(metadata, 'performance_score', 0.0)
|
||
self.model_states[internal_key]['created_at'] = str(getattr(metadata, 'created_at', 'Unknown'))
|
||
|
||
# Set initial loss from checkpoint if available
|
||
if self.model_states[internal_key]['initial_loss'] is None:
|
||
# Try to infer initial loss from performance improvement
|
||
if hasattr(metadata, 'accuracy') and metadata.accuracy:
|
||
# Estimate initial loss from current accuracy (inverse relationship)
|
||
estimated_initial = max(0.1, 2.0 - (metadata.accuracy * 2.0))
|
||
self.model_states[internal_key]['initial_loss'] = estimated_initial
|
||
|
||
logger.debug(f"Loaded REAL checkpoint data for {model_name}: loss={self.model_states[internal_key]['current_loss']}")
|
||
else:
|
||
# No checkpoint found - mark as fresh
|
||
internal_key = {
|
||
'dqn_agent': 'dqn',
|
||
'enhanced_cnn': 'cnn',
|
||
'extrema_trainer': 'extrema_trainer',
|
||
'decision': 'decision',
|
||
'cob_rl': 'cob_rl'
|
||
}.get(model_name, model_name)
|
||
|
||
if internal_key in self.model_states:
|
||
self.model_states[internal_key]['checkpoint_loaded'] = False
|
||
self.model_states[internal_key]['checkpoint_filename'] = 'none (fresh start)'
|
||
|
||
except Exception as e:
|
||
logger.debug(f"No checkpoint found for {model_name}: {e}")
|
||
|
||
# ADDITIONAL: Update from live training if models are actively training
|
||
if self.rl_agent and hasattr(self.rl_agent, 'losses') and len(self.rl_agent.losses) > 0:
|
||
recent_losses = self.rl_agent.losses[-10:] # Last 10 training steps
|
||
if recent_losses:
|
||
live_loss = sum(recent_losses) / len(recent_losses)
|
||
# Only update if we have a live loss that's different from checkpoint
|
||
if abs(live_loss - (self.model_states['dqn']['current_loss'] or 0)) > 0.001:
|
||
self.model_states['dqn']['current_loss'] = live_loss
|
||
logger.debug(f"Updated DQN with live training loss: {live_loss:.4f}")
|
||
|
||
if self.cnn_model and hasattr(self.cnn_model, 'training_loss'):
|
||
if self.cnn_model.training_loss and abs(self.cnn_model.training_loss - (self.model_states['cnn']['current_loss'] or 0)) > 0.001:
|
||
self.model_states['cnn']['current_loss'] = self.cnn_model.training_loss
|
||
logger.debug(f"Updated CNN with live training loss: {self.cnn_model.training_loss:.4f}")
|
||
|
||
if self.extrema_trainer and hasattr(self.extrema_trainer, 'best_detection_accuracy'):
|
||
# Convert accuracy to loss estimate
|
||
if self.extrema_trainer.best_detection_accuracy > 0:
|
||
estimated_loss = max(0.001, 1.0 - self.extrema_trainer.best_detection_accuracy)
|
||
self.model_states['extrema_trainer']['current_loss'] = estimated_loss
|
||
self.model_states['extrema_trainer']['best_loss'] = estimated_loss
|
||
|
||
# NO LONGER SETTING SYNTHETIC INITIAL LOSS VALUES
|
||
# Keep all None values as None if no real data is available
|
||
# This prevents the "fake progress" issue where Current Loss = Initial Loss
|
||
|
||
# Only set initial_loss from actual training history if available
|
||
for model_key, model_state in self.model_states.items():
|
||
# Leave initial_loss as None if no real training history exists
|
||
# Leave current_loss as None if model isn't actively training
|
||
# Leave best_loss as None if no checkpoints exist with real performance data
|
||
pass # No synthetic data generation
|
||
|
||
return self.model_states
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting model states: {e}")
|
||
# Return None values instead of synthetic data
|
||
return {
|
||
'dqn': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||
'cnn': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||
'cob_rl': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||
'decision': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False},
|
||
'extrema_trainer': {'initial_loss': None, 'current_loss': None, 'best_loss': None, 'checkpoint_loaded': False}
|
||
}
|
||
|
||
def _initialize_decision_fusion(self):
|
||
"""Initialize the decision fusion neural network for learning model effectiveness"""
|
||
try:
|
||
if not self.decision_fusion_enabled:
|
||
return
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
# Create decision fusion network
|
||
class DecisionFusionNet(nn.Module):
|
||
def __init__(self, input_size=32, hidden_size=64):
|
||
super().__init__()
|
||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||
self.fc2 = nn.Linear(hidden_size, hidden_size)
|
||
self.fc3 = nn.Linear(hidden_size, 3) # BUY, SELL, HOLD
|
||
self.dropout = nn.Dropout(0.2)
|
||
|
||
def forward(self, x):
|
||
x = torch.relu(self.fc1(x))
|
||
x = self.dropout(x)
|
||
x = torch.relu(self.fc2(x))
|
||
x = self.dropout(x)
|
||
return torch.softmax(self.fc3(x), dim=1)
|
||
|
||
self.decision_fusion_network = DecisionFusionNet()
|
||
# Move decision fusion network to the device
|
||
self.decision_fusion_network.to(self.device)
|
||
logger.info(f"Decision fusion network initialized on device: {self.device}")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Decision fusion initialization failed: {e}")
|
||
self.decision_fusion_enabled = False
|
||
|
||
def _initialize_enhanced_training_system(self):
|
||
"""Initialize the enhanced real-time training system"""
|
||
try:
|
||
if not self.training_enabled:
|
||
logger.info("Enhanced training system disabled")
|
||
return
|
||
|
||
if not ENHANCED_TRAINING_AVAILABLE:
|
||
logger.warning("EnhancedRealtimeTrainingSystem not available - training disabled")
|
||
self.training_enabled = False
|
||
return
|
||
|
||
# Initialize the enhanced training system
|
||
if EnhancedRealtimeTrainingSystem is not None:
|
||
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
|
||
orchestrator=self,
|
||
data_provider=self.data_provider,
|
||
dashboard=None # Will be set by dashboard when available
|
||
)
|
||
|
||
logger.info("Enhanced real-time training system initialized")
|
||
logger.info(" - Real-time model training: ENABLED")
|
||
logger.info(" - Comprehensive feature extraction: ENABLED")
|
||
logger.info(" - Enhanced reward calculation: ENABLED")
|
||
logger.info(" - Forward-looking predictions: ENABLED")
|
||
else:
|
||
logger.warning("EnhancedRealtimeTrainingSystem class not available")
|
||
self.training_enabled = False
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error initializing enhanced training system: {e}")
|
||
self.training_enabled = False
|
||
self.enhanced_training_system = None
|
||
|
||
def start_enhanced_training(self):
|
||
"""Start the enhanced real-time training system"""
|
||
try:
|
||
if not self.training_enabled or not self.enhanced_training_system:
|
||
logger.warning("Enhanced training system not available")
|
||
return False
|
||
|
||
if hasattr(self.enhanced_training_system, 'start_training'):
|
||
self.enhanced_training_system.start_training()
|
||
logger.info("Enhanced real-time training started")
|
||
return True
|
||
else:
|
||
logger.warning("Enhanced training system does not have start_training method")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error starting enhanced training: {e}")
|
||
return False
|
||
|
||
def stop_enhanced_training(self):
|
||
"""Stop the enhanced real-time training system"""
|
||
try:
|
||
if self.enhanced_training_system and hasattr(self.enhanced_training_system, 'stop_training'):
|
||
self.enhanced_training_system.stop_training()
|
||
logger.info("Enhanced real-time training stopped")
|
||
return True
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error stopping enhanced training: {e}")
|
||
return False
|
||
|
||
def get_enhanced_training_stats(self) -> Dict[str, Any]:
|
||
"""Get enhanced training system statistics with orchestrator integration"""
|
||
try:
|
||
if not self.enhanced_training_system:
|
||
return {
|
||
'training_enabled': False,
|
||
'system_available': ENHANCED_TRAINING_AVAILABLE,
|
||
'error': 'Training system not initialized'
|
||
}
|
||
|
||
# Get base stats from enhanced training system
|
||
stats = {}
|
||
if hasattr(self.enhanced_training_system, 'get_training_statistics'):
|
||
stats = self.enhanced_training_system.get_training_statistics()
|
||
|
||
stats['training_enabled'] = self.training_enabled
|
||
stats['system_available'] = ENHANCED_TRAINING_AVAILABLE
|
||
|
||
# Add orchestrator-specific training integration data
|
||
stats['orchestrator_integration'] = {
|
||
'models_connected': len([m for m in [self.rl_agent, self.cnn_model, self.cob_rl_agent, self.decision_model] if m is not None]),
|
||
'cob_integration_active': self.cob_integration is not None,
|
||
'decision_fusion_enabled': self.decision_fusion_enabled,
|
||
'symbols_tracking': len(self.symbols),
|
||
'recent_decisions_count': sum(len(decisions) for decisions in self.recent_decisions.values()),
|
||
'model_weights': self.model_weights.copy(),
|
||
'realtime_processing': self.realtime_processing
|
||
}
|
||
|
||
# Add model-specific training status from orchestrator
|
||
stats['model_training_status'] = {}
|
||
model_mappings = {
|
||
'dqn': self.rl_agent,
|
||
'cnn': self.cnn_model,
|
||
'cob_rl': self.cob_rl_agent,
|
||
'decision': self.decision_model
|
||
}
|
||
|
||
for model_name, model in model_mappings.items():
|
||
if model:
|
||
model_stats = {
|
||
'model_loaded': True,
|
||
'memory_usage': 0,
|
||
'training_steps': 0,
|
||
'last_loss': None,
|
||
'checkpoint_loaded': self.model_states.get(model_name, {}).get('checkpoint_loaded', False)
|
||
}
|
||
|
||
# Get memory usage
|
||
if hasattr(model, 'memory') and model.memory:
|
||
model_stats['memory_usage'] = len(model.memory)
|
||
|
||
# Get training steps
|
||
if hasattr(model, 'training_steps'):
|
||
model_stats['training_steps'] = model.training_steps
|
||
|
||
# Get last loss
|
||
if hasattr(model, 'losses') and model.losses:
|
||
model_stats['last_loss'] = model.losses[-1]
|
||
|
||
stats['model_training_status'][model_name] = model_stats
|
||
else:
|
||
stats['model_training_status'][model_name] = {
|
||
'model_loaded': False,
|
||
'memory_usage': 0,
|
||
'training_steps': 0,
|
||
'last_loss': None,
|
||
'checkpoint_loaded': False
|
||
}
|
||
|
||
# Add prediction tracking stats
|
||
stats['prediction_tracking'] = {
|
||
'dqn_predictions_tracked': sum(len(preds) for preds in self.recent_dqn_predictions.values()),
|
||
'cnn_predictions_tracked': sum(len(preds) for preds in self.recent_cnn_predictions.values()),
|
||
'accuracy_history_tracked': sum(len(history) for history in self.prediction_accuracy_history.values()),
|
||
'symbols_with_predictions': [symbol for symbol in self.symbols if
|
||
len(self.recent_dqn_predictions.get(symbol, [])) > 0 or
|
||
len(self.recent_cnn_predictions.get(symbol, [])) > 0]
|
||
}
|
||
|
||
# Add COB integration stats if available
|
||
if self.cob_integration:
|
||
stats['cob_integration_stats'] = {
|
||
'latest_cob_data_symbols': list(self.latest_cob_data.keys()),
|
||
'cob_features_available': list(self.latest_cob_features.keys()),
|
||
'cob_state_available': list(self.latest_cob_state.keys()),
|
||
'feature_history_length': {symbol: len(history) for symbol, history in self.cob_feature_history.items()}
|
||
}
|
||
|
||
return stats
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting training stats: {e}")
|
||
return {
|
||
'training_enabled': self.training_enabled,
|
||
'system_available': ENHANCED_TRAINING_AVAILABLE,
|
||
'error': str(e)
|
||
}
|
||
|
||
def set_training_dashboard(self, dashboard):
|
||
"""Set the dashboard reference for the training system"""
|
||
try:
|
||
if self.enhanced_training_system:
|
||
self.enhanced_training_system.dashboard = dashboard
|
||
logger.info("Dashboard reference set for enhanced training system")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error setting training dashboard: {e}")
|
||
|
||
def get_universal_data_stream(self, current_time: Optional[datetime] = None):
|
||
"""Get universal data stream for external consumers like dashboard - DELEGATED to data provider"""
|
||
try:
|
||
if self.data_provider and hasattr(self.data_provider, 'universal_adapter'):
|
||
return self.data_provider.universal_adapter.get_universal_data_stream(current_time)
|
||
elif self.universal_adapter:
|
||
return self.universal_adapter.get_universal_data_stream(current_time)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting universal data stream: {e}")
|
||
return None
|
||
|
||
def get_universal_data_for_model(self, model_type: str = 'cnn') -> Optional[Dict[str, Any]]:
|
||
"""Get formatted universal data for specific model types - DELEGATED to data provider"""
|
||
try:
|
||
if self.data_provider and hasattr(self.data_provider, 'universal_adapter'):
|
||
stream = self.data_provider.universal_adapter.get_universal_data_stream()
|
||
if stream:
|
||
return self.data_provider.universal_adapter.format_for_model(stream, model_type)
|
||
elif self.universal_adapter:
|
||
stream = self.universal_adapter.get_universal_data_stream()
|
||
if stream:
|
||
return self.universal_adapter.format_for_model(stream, model_type)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting universal data for {model_type}: {e}")
|
||
return None
|
||
|
||
def get_cob_data(self, symbol: str) -> Optional[Dict[str, Any]]:
|
||
"""Get COB data for symbol - DELEGATED to data provider"""
|
||
try:
|
||
if self.data_provider:
|
||
return self.data_provider.get_latest_cob_data(symbol)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting COB data for {symbol}: {e}")
|
||
return None
|
||
|
||
def get_combined_model_data(self, symbol: str) -> Optional[Dict[str, Any]]:
|
||
"""Get combined OHLCV + COB data for models - DELEGATED to data provider"""
|
||
try:
|
||
if self.data_provider:
|
||
return self.data_provider.get_combined_ohlcv_cob_data(symbol)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting combined model data for {symbol}: {e}")
|
||
return None
|
||
|
||
def _get_current_position_pnl(self, symbol: str, current_price: float) -> float:
|
||
"""Get current position P&L for the symbol"""
|
||
try:
|
||
if self.trading_executor and hasattr(self.trading_executor, 'get_current_position'):
|
||
position = self.trading_executor.get_current_position(symbol)
|
||
if position:
|
||
entry_price = position.get('price', 0)
|
||
size = position.get('size', 0)
|
||
side = position.get('side', 'LONG')
|
||
|
||
if entry_price and size > 0:
|
||
if side.upper() == 'LONG':
|
||
pnl = (current_price - entry_price) * size
|
||
else: # SHORT
|
||
pnl = (entry_price - current_price) * size
|
||
return pnl
|
||
return 0.0
|
||
except Exception as e:
|
||
logger.debug(f"Error getting position P&L for {symbol}: {e}")
|
||
return 0.0
|
||
|
||
def _has_open_position(self, symbol: str) -> bool:
|
||
"""Check if there's an open position for the symbol"""
|
||
try:
|
||
if self.trading_executor and hasattr(self.trading_executor, 'get_current_position'):
|
||
position = self.trading_executor.get_current_position(symbol)
|
||
return position is not None and position.get('size', 0) > 0
|
||
return False
|
||
except Exception:
|
||
return False
|
||
|
||
def _calculate_aggressiveness_thresholds(self, current_pnl: float, symbol: str) -> tuple:
|
||
"""Calculate confidence thresholds based on aggressiveness settings"""
|
||
# Base thresholds
|
||
base_entry_threshold = self.confidence_threshold
|
||
base_exit_threshold = self.confidence_threshold_close
|
||
|
||
# Get aggressiveness settings (could be from config or adaptive)
|
||
entry_agg = getattr(self, 'entry_aggressiveness', 0.5)
|
||
exit_agg = getattr(self, 'exit_aggressiveness', 0.5)
|
||
|
||
# Adjust thresholds based on aggressiveness
|
||
# More aggressive = lower threshold (more trades)
|
||
# Less aggressive = higher threshold (fewer, higher quality trades)
|
||
entry_threshold = base_entry_threshold * (1.5 - entry_agg) # 0.5 agg = 1.0x, 1.0 agg = 0.5x
|
||
exit_threshold = base_exit_threshold * (1.5 - exit_agg)
|
||
|
||
# Ensure minimum thresholds
|
||
entry_threshold = max(0.05, entry_threshold)
|
||
exit_threshold = max(0.02, exit_threshold)
|
||
|
||
return entry_threshold, exit_threshold
|
||
|
||
def _apply_pnl_feedback(self, action: str, confidence: float, current_pnl: float,
|
||
symbol: str, reasoning: dict) -> tuple:
|
||
"""Apply P&L-based feedback to decision making"""
|
||
try:
|
||
# If we have a losing position, be more aggressive about cutting losses
|
||
if current_pnl < -10.0: # Losing more than $10
|
||
if action == 'SELL' and self._has_open_position(symbol):
|
||
# Boost confidence for exit signals when losing
|
||
confidence = min(1.0, confidence * 1.2)
|
||
reasoning['pnl_loss_cut_boost'] = True
|
||
elif action == 'BUY':
|
||
# Reduce confidence for new entries when losing
|
||
confidence *= 0.8
|
||
reasoning['pnl_loss_entry_reduction'] = True
|
||
|
||
# If we have a winning position, be more conservative about exits
|
||
elif current_pnl > 5.0: # Winning more than $5
|
||
if action == 'SELL' and self._has_open_position(symbol):
|
||
# Reduce confidence for exit signals when winning (let profits run)
|
||
confidence *= 0.9
|
||
reasoning['pnl_profit_hold'] = True
|
||
elif action == 'BUY':
|
||
# Slightly boost confidence for entries when on a winning streak
|
||
confidence = min(1.0, confidence * 1.05)
|
||
reasoning['pnl_winning_streak_boost'] = True
|
||
|
||
reasoning['current_pnl'] = current_pnl
|
||
return action, confidence
|
||
|
||
except Exception as e:
|
||
logger.debug(f"Error applying P&L feedback: {e}")
|
||
return action, confidence
|
||
|
||
def _calculate_dynamic_entry_aggressiveness(self, symbol: str) -> float:
|
||
"""Calculate dynamic entry aggressiveness based on recent performance"""
|
||
try:
|
||
# Start with base aggressiveness
|
||
base_agg = getattr(self, 'entry_aggressiveness', 0.5)
|
||
|
||
# Get recent decisions for this symbol
|
||
recent_decisions = self.get_recent_decisions(symbol, limit=10)
|
||
if len(recent_decisions) < 3:
|
||
return base_agg
|
||
|
||
# Calculate win rate
|
||
winning_decisions = sum(1 for d in recent_decisions
|
||
if d.reasoning.get('was_profitable', False))
|
||
win_rate = winning_decisions / len(recent_decisions)
|
||
|
||
# Adjust aggressiveness based on performance
|
||
if win_rate > 0.7: # High win rate - be more aggressive
|
||
return min(1.0, base_agg + 0.2)
|
||
elif win_rate < 0.3: # Low win rate - be more conservative
|
||
return max(0.1, base_agg - 0.2)
|
||
else:
|
||
return base_agg
|
||
|
||
except Exception as e:
|
||
logger.debug(f"Error calculating dynamic entry aggressiveness: {e}")
|
||
return 0.5
|
||
|
||
def _calculate_dynamic_exit_aggressiveness(self, symbol: str, current_pnl: float) -> float:
|
||
"""Calculate dynamic exit aggressiveness based on P&L and market conditions"""
|
||
try:
|
||
# Start with base aggressiveness
|
||
base_agg = getattr(self, 'exit_aggressiveness', 0.5)
|
||
|
||
# Adjust based on current P&L
|
||
if current_pnl < -20.0: # Large loss - be very aggressive about cutting
|
||
return min(1.0, base_agg + 0.3)
|
||
elif current_pnl < -5.0: # Small loss - be more aggressive
|
||
return min(1.0, base_agg + 0.1)
|
||
elif current_pnl > 20.0: # Large profit - be less aggressive (let it run)
|
||
return max(0.1, base_agg - 0.2)
|
||
elif current_pnl > 5.0: # Small profit - slightly less aggressive
|
||
return max(0.2, base_agg - 0.1)
|
||
else:
|
||
return base_agg
|
||
|
||
except Exception as e:
|
||
logger.debug(f"Error calculating dynamic exit aggressiveness: {e}")
|
||
return 0.5
|
||
|
||
def set_trading_executor(self, trading_executor):
|
||
"""Set the trading executor for position tracking"""
|
||
self.trading_executor = trading_executor
|
||
logger.info("Trading executor set for position tracking and P&L feedback")
|
||
|
||
def get_profitability_reward_multiplier(self) -> float:
|
||
"""Get the current profitability reward multiplier from trading executor
|
||
|
||
Returns:
|
||
float: Current profitability reward multiplier (0.0 to 2.0)
|
||
"""
|
||
try:
|
||
if self.trading_executor and hasattr(self.trading_executor, 'get_profitability_reward_multiplier'):
|
||
multiplier = self.trading_executor.get_profitability_reward_multiplier()
|
||
logger.debug(f"Current profitability reward multiplier: {multiplier:.2f}")
|
||
return multiplier
|
||
return 0.0
|
||
except Exception as e:
|
||
logger.error(f"Error getting profitability reward multiplier: {e}")
|
||
return 0.0
|
||
|
||
def calculate_enhanced_reward(self, base_pnl: float, confidence: float = 1.0) -> float:
|
||
"""Calculate enhanced reward with profitability multiplier
|
||
|
||
Args:
|
||
base_pnl: Base P&L from the trade
|
||
confidence: Confidence level of the prediction (0.0 to 1.0)
|
||
|
||
Returns:
|
||
float: Enhanced reward with profitability multiplier applied
|
||
"""
|
||
try:
|
||
# Get the dynamic profitability multiplier
|
||
profitability_multiplier = self.get_profitability_reward_multiplier()
|
||
|
||
# Base reward is the P&L
|
||
base_reward = base_pnl
|
||
|
||
# Apply profitability multiplier only to positive P&L (profitable trades)
|
||
if base_pnl > 0 and profitability_multiplier > 0:
|
||
# Enhance profitable trades with the multiplier
|
||
enhanced_reward = base_pnl * (1.0 + profitability_multiplier)
|
||
logger.debug(f"Enhanced reward: ${base_pnl:.2f} → ${enhanced_reward:.2f} (multiplier: {profitability_multiplier:.2f})")
|
||
return enhanced_reward
|
||
else:
|
||
# No enhancement for losing trades or when multiplier is 0
|
||
return base_reward
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error calculating enhanced reward: {e}")
|
||
return base_pnl
|
||
|
||
def _trigger_training_on_decision(self, decision: TradingDecision, current_price: float):
|
||
"""Trigger training on each decision, especially executed trades
|
||
|
||
This ensures models learn from every signal outcome, giving more weight
|
||
to executed trades as they have real market feedback.
|
||
"""
|
||
try:
|
||
# Only train if training is enabled and we have the enhanced training system
|
||
if not self.training_enabled or not self.enhanced_training_system:
|
||
return
|
||
|
||
symbol = decision.symbol
|
||
action = decision.action
|
||
confidence = decision.confidence
|
||
|
||
# Create training data from the decision
|
||
training_data = {
|
||
'symbol': symbol,
|
||
'action': action,
|
||
'confidence': confidence,
|
||
'price': current_price,
|
||
'timestamp': decision.timestamp,
|
||
'executed': action != 'HOLD', # Assume non-HOLD actions are executed
|
||
'entry_aggressiveness': decision.entry_aggressiveness,
|
||
'exit_aggressiveness': decision.exit_aggressiveness,
|
||
'reasoning': decision.reasoning
|
||
}
|
||
|
||
# Add to enhanced training system for immediate learning
|
||
if hasattr(self.enhanced_training_system, 'add_decision_for_training'):
|
||
self.enhanced_training_system.add_decision_for_training(training_data)
|
||
logger.debug(f"🎓 Added decision to training queue: {action} {symbol} (conf: {confidence:.3f})")
|
||
|
||
# Trigger immediate training for executed trades (higher priority)
|
||
if action != 'HOLD':
|
||
if hasattr(self.enhanced_training_system, 'trigger_immediate_training'):
|
||
self.enhanced_training_system.trigger_immediate_training(
|
||
symbol=symbol,
|
||
priority='high' if confidence > 0.7 else 'medium'
|
||
)
|
||
logger.info(f"🚀 Triggered immediate training for executed trade: {action} {symbol}")
|
||
|
||
# Train all models on the decision outcome
|
||
self._train_models_on_decision(decision, current_price)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error triggering training on decision: {e}")
|
||
|
||
def _train_models_on_decision(self, decision: TradingDecision, current_price: float):
|
||
"""Train all models on the decision outcome
|
||
|
||
This provides immediate feedback to models about their predictions,
|
||
allowing them to learn from each signal they generate.
|
||
"""
|
||
try:
|
||
symbol = decision.symbol
|
||
action = decision.action
|
||
confidence = decision.confidence
|
||
|
||
# Get current market data for training context
|
||
market_data = self._get_current_market_data(symbol)
|
||
if not market_data:
|
||
return
|
||
|
||
# Track if any model was trained for checkpoint saving
|
||
models_trained = []
|
||
|
||
# Train DQN agent if available
|
||
if self.rl_agent and hasattr(self.rl_agent, 'add_experience'):
|
||
try:
|
||
# Create state representation
|
||
state = self._create_state_for_training(symbol, market_data)
|
||
|
||
# Map action to DQN action space - CONSISTENT ACTION MAPPING
|
||
action_mapping = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
|
||
dqn_action = action_mapping.get(action, 2)
|
||
|
||
# Calculate immediate reward based on confidence and execution
|
||
immediate_reward = confidence if action != 'HOLD' else 0.0
|
||
|
||
# Add experience to DQN
|
||
self.rl_agent.add_experience(
|
||
state=state,
|
||
action=dqn_action,
|
||
reward=immediate_reward,
|
||
next_state=state, # Will be updated with actual outcome later
|
||
done=False
|
||
)
|
||
|
||
models_trained.append('dqn')
|
||
logger.debug(f"🧠 Added DQN experience: {action} {symbol} (reward: {immediate_reward:.3f})")
|
||
|
||
except Exception as e:
|
||
logger.debug(f"Error training DQN on decision: {e}")
|
||
|
||
# Train CNN model if available
|
||
if self.cnn_model and hasattr(self.cnn_model, 'add_training_sample'):
|
||
try:
|
||
# Create CNN input features
|
||
cnn_features = self._create_cnn_features_for_training(symbol, market_data)
|
||
|
||
# Create target based on action
|
||
target_mapping = {'BUY': [1, 0, 0], 'SELL': [0, 1, 0], 'HOLD': [0, 0, 1]}
|
||
target = target_mapping.get(action, [0, 0, 1])
|
||
|
||
# Add training sample
|
||
self.cnn_model.add_training_sample(cnn_features, target, weight=confidence)
|
||
|
||
models_trained.append('cnn')
|
||
logger.debug(f"🔍 Added CNN training sample: {action} {symbol}")
|
||
|
||
except Exception as e:
|
||
logger.debug(f"Error training CNN on decision: {e}")
|
||
|
||
# Train COB RL model if available and we have COB data
|
||
if self.cob_rl_agent and symbol in self.latest_cob_data:
|
||
try:
|
||
cob_data = self.latest_cob_data[symbol]
|
||
if hasattr(self.cob_rl_agent, 'add_experience'):
|
||
# Create COB state representation
|
||
cob_state = self._create_cob_state_for_training(symbol, cob_data)
|
||
|
||
# Add COB experience
|
||
self.cob_rl_agent.add_experience(
|
||
state=cob_state,
|
||
action=action,
|
||
reward=confidence,
|
||
symbol=symbol
|
||
)
|
||
|
||
models_trained.append('cob_rl')
|
||
logger.debug(f"📊 Added COB RL experience: {action} {symbol}")
|
||
|
||
except Exception as e:
|
||
logger.debug(f"Error training COB RL on decision: {e}")
|
||
|
||
# CRITICAL FIX: Save checkpoints after training
|
||
if models_trained:
|
||
self._save_training_checkpoints(models_trained, confidence)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error training models on decision: {e}")
|
||
|
||
def _save_training_checkpoints(self, models_trained: List[str], performance_score: float):
|
||
"""Save checkpoints for trained models if performance improved
|
||
|
||
This is CRITICAL for preserving training progress across restarts.
|
||
"""
|
||
try:
|
||
if not self.checkpoint_manager:
|
||
return
|
||
|
||
# Increment training counter
|
||
self.training_iterations += 1
|
||
|
||
# Save checkpoints for each trained model
|
||
for model_name in models_trained:
|
||
try:
|
||
model_obj = None
|
||
current_loss = None
|
||
|
||
# Get model object and calculate current performance
|
||
if model_name == 'dqn' and self.rl_agent:
|
||
model_obj = self.rl_agent
|
||
# Use negative performance score as loss (higher confidence = lower loss)
|
||
current_loss = 1.0 - performance_score
|
||
|
||
elif model_name == 'cnn' and self.cnn_model:
|
||
model_obj = self.cnn_model
|
||
current_loss = 1.0 - performance_score
|
||
|
||
elif model_name == 'cob_rl' and self.cob_rl_agent:
|
||
model_obj = self.cob_rl_agent
|
||
current_loss = 1.0 - performance_score
|
||
|
||
if model_obj and current_loss is not None:
|
||
# Check if this is the best performance so far
|
||
model_state = self.model_states.get(model_name, {})
|
||
best_loss = model_state.get('best_loss', float('inf'))
|
||
|
||
# Update current loss
|
||
model_state['current_loss'] = current_loss
|
||
model_state['last_training'] = datetime.now()
|
||
|
||
# Save checkpoint if performance improved or periodic save
|
||
should_save = (
|
||
current_loss < best_loss or # Performance improved
|
||
self.training_iterations % 100 == 0 # Periodic save every 100 iterations
|
||
)
|
||
|
||
if should_save:
|
||
# Prepare metadata
|
||
metadata = {
|
||
'loss': current_loss,
|
||
'performance_score': performance_score,
|
||
'training_iterations': self.training_iterations,
|
||
'timestamp': datetime.now().isoformat(),
|
||
'model_type': model_name
|
||
}
|
||
|
||
# Save checkpoint
|
||
checkpoint_path = self.checkpoint_manager.save_checkpoint(
|
||
model=model_obj,
|
||
model_name=model_name,
|
||
performance=current_loss,
|
||
metadata=metadata
|
||
)
|
||
|
||
if checkpoint_path:
|
||
# Update best performance
|
||
if current_loss < best_loss:
|
||
model_state['best_loss'] = current_loss
|
||
model_state['best_checkpoint'] = checkpoint_path
|
||
logger.info(f"💾 Saved BEST checkpoint for {model_name}: {checkpoint_path} (loss: {current_loss:.4f})")
|
||
else:
|
||
logger.debug(f"💾 Saved periodic checkpoint for {model_name}: {checkpoint_path}")
|
||
|
||
model_state['last_checkpoint'] = checkpoint_path
|
||
model_state['checkpoints_saved'] = model_state.get('checkpoints_saved', 0) + 1
|
||
|
||
# Update model state
|
||
self.model_states[model_name] = model_state
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error saving training checkpoints: {e}")
|
||
|
||
def _get_current_market_data(self, symbol: str) -> Optional[Dict]:
|
||
"""Get current market data for training context"""
|
||
try:
|
||
if self.data_provider:
|
||
# Get recent data for training
|
||
df = self.data_provider.get_historical_data(symbol, '1m', limit=100)
|
||
if df is not None and not df.empty:
|
||
return {
|
||
'ohlcv': df.tail(50).to_dict('records'), # Last 50 candles
|
||
'current_price': float(df['close'].iloc[-1]),
|
||
'volume': float(df['volume'].iloc[-1]),
|
||
'timestamp': df.index[-1]
|
||
}
|
||
return None
|
||
except Exception as e:
|
||
logger.debug(f"Error getting market data for training: {e}")
|
||
return None
|
||
|
||
def _create_state_for_training(self, symbol: str, market_data: Dict) -> np.ndarray:
|
||
"""Create state representation for DQN training"""
|
||
try:
|
||
# Create a basic state representation
|
||
ohlcv_data = market_data.get('ohlcv', [])
|
||
if not ohlcv_data:
|
||
return np.zeros(100) # Default state size
|
||
|
||
# Extract features from recent candles
|
||
features = []
|
||
for candle in ohlcv_data[-20:]: # Last 20 candles
|
||
features.extend([
|
||
candle.get('open', 0),
|
||
candle.get('high', 0),
|
||
candle.get('low', 0),
|
||
candle.get('close', 0),
|
||
candle.get('volume', 0)
|
||
])
|
||
|
||
# Pad or truncate to expected size
|
||
state = np.array(features[:100])
|
||
if len(state) < 100:
|
||
state = np.pad(state, (0, 100 - len(state)), 'constant')
|
||
|
||
return state
|
||
|
||
except Exception as e:
|
||
logger.debug(f"Error creating state for training: {e}")
|
||
return np.zeros(100)
|
||
|
||
def _create_cnn_features_for_training(self, symbol: str, market_data: Dict) -> np.ndarray:
|
||
"""Create CNN features for training"""
|
||
try:
|
||
# Similar to state creation but formatted for CNN
|
||
ohlcv_data = market_data.get('ohlcv', [])
|
||
if not ohlcv_data:
|
||
return np.zeros((1, 100))
|
||
|
||
# Create feature matrix
|
||
features = []
|
||
for candle in ohlcv_data[-20:]:
|
||
features.extend([
|
||
candle.get('open', 0),
|
||
candle.get('high', 0),
|
||
candle.get('low', 0),
|
||
candle.get('close', 0),
|
||
candle.get('volume', 0)
|
||
])
|
||
|
||
# Reshape for CNN input
|
||
cnn_features = np.array(features[:100]).reshape(1, -1)
|
||
if cnn_features.shape[1] < 100:
|
||
cnn_features = np.pad(cnn_features, ((0, 0), (0, 100 - cnn_features.shape[1])), 'constant')
|
||
|
||
return cnn_features
|
||
|
||
except Exception as e:
|
||
logger.debug(f"Error creating CNN features for training: {e}")
|
||
return np.zeros((1, 100))
|
||
|
||
def _create_cob_state_for_training(self, symbol: str, cob_data: Dict) -> np.ndarray:
|
||
"""Create COB state representation for training"""
|
||
try:
|
||
# Extract COB features for training
|
||
features = []
|
||
|
||
# Add bid/ask data
|
||
bids = cob_data.get('bids', [])[:10] # Top 10 bids
|
||
asks = cob_data.get('asks', [])[:10] # Top 10 asks
|
||
|
||
for bid in bids:
|
||
features.extend([bid.get('price', 0), bid.get('size', 0)])
|
||
for ask in asks:
|
||
features.extend([ask.get('price', 0), ask.get('size', 0)])
|
||
|
||
# Add market stats
|
||
stats = cob_data.get('stats', {})
|
||
features.extend([
|
||
stats.get('spread', 0),
|
||
stats.get('mid_price', 0),
|
||
stats.get('bid_volume', 0),
|
||
stats.get('ask_volume', 0),
|
||
stats.get('imbalance', 0)
|
||
])
|
||
|
||
# Pad to expected COB state size (2000 features)
|
||
cob_state = np.array(features[:2000])
|
||
if len(cob_state) < 2000:
|
||
cob_state = np.pad(cob_state, (0, 2000 - len(cob_state)), 'constant')
|
||
|
||
return cob_state
|
||
|
||
except Exception as e:
|
||
logger.debug(f"Error creating COB state for training: {e}")
|
||
return np.zeros(2000)
|
||
|
||
def _check_signal_confirmation(self, symbol: str, signal_data: Dict) -> Optional[str]:
|
||
"""Check if we have enough signal confirmations for trend confirmation with rate limiting"""
|
||
try:
|
||
current_time = signal_data['timestamp']
|
||
action = signal_data['action']
|
||
|
||
# Initialize signal tracking for this symbol if needed
|
||
if symbol not in self.last_signal_time:
|
||
self.last_signal_time[symbol] = {}
|
||
if symbol not in self.last_confirmed_signal:
|
||
self.last_confirmed_signal[symbol] = {}
|
||
|
||
# RATE LIMITING: Check if we recently confirmed the same signal
|
||
if action in self.last_confirmed_signal[symbol]:
|
||
last_confirmed = self.last_confirmed_signal[symbol][action]
|
||
time_since_last = current_time - last_confirmed['timestamp']
|
||
if time_since_last < self.min_signal_interval:
|
||
logger.debug(f"Rate limiting: {action} signal for {symbol} too recent "
|
||
f"({time_since_last.total_seconds():.1f}s < {self.min_signal_interval.total_seconds()}s)")
|
||
return None
|
||
|
||
# Clean up expired signals
|
||
self.signal_accumulator[symbol] = [
|
||
s for s in self.signal_accumulator[symbol]
|
||
if (current_time - s['timestamp']).total_seconds() < self.signal_timeout_seconds
|
||
]
|
||
|
||
# Add new signal
|
||
self.signal_accumulator[symbol].append(signal_data)
|
||
|
||
# Check if we have enough confirmations
|
||
if len(self.signal_accumulator[symbol]) < self.required_confirmations:
|
||
return None
|
||
|
||
# Check if recent signals are consistent
|
||
recent_signals = self.signal_accumulator[symbol][-self.required_confirmations:]
|
||
actions = [s['action'] for s in recent_signals]
|
||
|
||
# Count action consensus
|
||
action_counts = {}
|
||
for action_item in actions:
|
||
action_counts[action_item] = action_counts.get(action_item, 0) + 1
|
||
|
||
# Find dominant action
|
||
dominant_action = max(action_counts, key=action_counts.get)
|
||
consensus_count = action_counts[dominant_action]
|
||
|
||
# Require at least 2/3 consensus
|
||
if consensus_count >= max(2, self.required_confirmations * 0.67):
|
||
# ADDITIONAL RATE LIMITING: Don't confirm if we just confirmed the same action
|
||
if dominant_action in self.last_confirmed_signal[symbol]:
|
||
last_confirmed = self.last_confirmed_signal[symbol][dominant_action]
|
||
time_since_last = current_time - last_confirmed['timestamp']
|
||
if time_since_last < self.min_signal_interval:
|
||
logger.debug(f"Rate limiting: Preventing duplicate {dominant_action} confirmation for {symbol}")
|
||
return None
|
||
|
||
# Record this confirmation
|
||
self.last_confirmed_signal[symbol][dominant_action] = {
|
||
'timestamp': current_time,
|
||
'confidence': signal_data['confidence']
|
||
}
|
||
|
||
# Clear accumulator after confirmation
|
||
self.signal_accumulator[symbol] = []
|
||
|
||
logger.info(f"Signal confirmed after rate limiting: {dominant_action} for {symbol}")
|
||
return dominant_action
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error checking signal confirmation for {symbol}: {e}")
|
||
return None |