improve training on signals, add save session button to store all progress
This commit is contained in:
@ -16,13 +16,13 @@ import time
|
|||||||
import threading
|
import threading
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, List, Optional, Tuple, Any
|
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
from .config import get_config
|
from .config import get_config
|
||||||
from .data_provider import DataProvider
|
from .data_provider import DataProvider
|
||||||
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface
|
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface, ModelRegistry
|
||||||
|
|
||||||
# Import COB integration for real-time market microstructure data
|
# Import COB integration for real-time market microstructure data
|
||||||
try:
|
try:
|
||||||
@ -45,7 +45,7 @@ class Prediction:
|
|||||||
timeframe: str # Timeframe this prediction is for
|
timeframe: str # Timeframe this prediction is for
|
||||||
timestamp: datetime
|
timestamp: datetime
|
||||||
model_name: str # Name of the model that made this prediction
|
model_name: str # Name of the model that made this prediction
|
||||||
metadata: Dict[str, Any] = None # Additional model-specific data
|
metadata: Optional[Dict[str, Any]] = None # Additional model-specific data
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TradingDecision:
|
class TradingDecision:
|
||||||
@ -62,10 +62,10 @@ class TradingOrchestrator:
|
|||||||
"""
|
"""
|
||||||
Enhanced Trading Orchestrator with full ML and COB integration
|
Enhanced Trading Orchestrator with full ML and COB integration
|
||||||
Coordinates CNN, DQN, and COB models for advanced trading decisions
|
Coordinates CNN, DQN, and COB models for advanced trading decisions
|
||||||
Features real-time COB (Change of Bid) integration for market microstructure data
|
Features real-time COB (Change of Bid) data for market microstructure data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data_provider: DataProvider = None, enhanced_rl_training: bool = True, model_registry: Dict = None):
|
def __init__(self, data_provider: Optional[DataProvider] = None, enhanced_rl_training: bool = True, model_registry: Optional[ModelRegistry] = None):
|
||||||
"""Initialize the enhanced orchestrator with full ML capabilities"""
|
"""Initialize the enhanced orchestrator with full ML capabilities"""
|
||||||
self.config = get_config()
|
self.config = get_config()
|
||||||
self.data_provider = data_provider or DataProvider()
|
self.data_provider = data_provider or DataProvider()
|
||||||
@ -79,18 +79,18 @@ class TradingOrchestrator:
|
|||||||
self.symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) # Enhanced to support multiple symbols
|
self.symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) # Enhanced to support multiple symbols
|
||||||
|
|
||||||
# Dynamic weights (will be adapted based on performance)
|
# Dynamic weights (will be adapted based on performance)
|
||||||
self.model_weights = {} # {model_name: weight}
|
self.model_weights: Dict[str, float] = {} # {model_name: weight}
|
||||||
self._initialize_default_weights()
|
self._initialize_default_weights()
|
||||||
|
|
||||||
# State tracking
|
# State tracking
|
||||||
self.last_decision_time = {} # {symbol: datetime}
|
self.last_decision_time: Dict[str, datetime] = {} # {symbol: datetime}
|
||||||
self.recent_decisions = {} # {symbol: List[TradingDecision]}
|
self.recent_decisions: Dict[str, List[TradingDecision]] = {} # {symbol: List[TradingDecision]}
|
||||||
self.model_performance = {} # {model_name: {'correct': int, 'total': int, 'accuracy': float}}
|
self.model_performance: Dict[str, Dict[str, Any]] = {} # {model_name: {'correct': int, 'total': int, 'accuracy': float}}
|
||||||
|
|
||||||
# Model prediction tracking for dashboard visualization
|
# Model prediction tracking for dashboard visualization
|
||||||
self.recent_dqn_predictions = {} # {symbol: List[Dict]} - Recent DQN predictions
|
self.recent_dqn_predictions: Dict[str, deque] = {} # {symbol: List[Dict]} - Recent DQN predictions
|
||||||
self.recent_cnn_predictions = {} # {symbol: List[Dict]} - Recent CNN predictions
|
self.recent_cnn_predictions: Dict[str, deque] = {} # {symbol: List[Dict]} - Recent CNN predictions
|
||||||
self.prediction_accuracy_history = {} # {symbol: List[Dict]} - Prediction accuracy tracking
|
self.prediction_accuracy_history: Dict[str, deque] = {} # {symbol: List[Dict]} - Prediction accuracy tracking
|
||||||
|
|
||||||
# Initialize prediction tracking for each symbol
|
# Initialize prediction tracking for each symbol
|
||||||
for symbol in self.symbols:
|
for symbol in self.symbols:
|
||||||
@ -99,39 +99,45 @@ class TradingOrchestrator:
|
|||||||
self.prediction_accuracy_history[symbol] = deque(maxlen=200)
|
self.prediction_accuracy_history[symbol] = deque(maxlen=200)
|
||||||
|
|
||||||
# Decision callbacks
|
# Decision callbacks
|
||||||
self.decision_callbacks = []
|
self.decision_callbacks: List[Any] = []
|
||||||
|
|
||||||
# ENHANCED: Decision Fusion System - Built into orchestrator (no separate file needed!)
|
# ENHANCED: Decision Fusion System - Built into orchestrator (no separate file needed!)
|
||||||
self.decision_fusion_enabled = True
|
self.decision_fusion_enabled: bool = True
|
||||||
self.decision_fusion_network = None
|
self.decision_fusion_network: Any = None
|
||||||
self.fusion_training_history = []
|
self.fusion_training_history: List[Any] = []
|
||||||
self.last_fusion_inputs = {}
|
self.last_fusion_inputs: Dict[str, Any] = {} # Fix: Explicitly initialize as dictionary
|
||||||
self.fusion_checkpoint_frequency = 50 # Save every 50 decisions
|
self.fusion_checkpoint_frequency: int = 50 # Save every 50 decisions
|
||||||
self.fusion_decisions_count = 0
|
self.fusion_decisions_count: int = 0
|
||||||
self.fusion_training_data = [] # Store training examples for decision model
|
self.fusion_training_data: List[Any] = [] # Store training examples for decision model
|
||||||
|
|
||||||
# COB Integration - Real-time market microstructure data
|
# COB Integration - Real-time market microstructure data
|
||||||
self.cob_integration = None
|
self.cob_integration: Optional[COBIntegration] = None # Fix: Use Optional for COBIntegration
|
||||||
self.latest_cob_data: Dict[str, Any] = {} # {symbol: COBSnapshot}
|
self.latest_cob_data: Dict[str, Any] = {} # {symbol: COBSnapshot}
|
||||||
self.latest_cob_features: Dict[str, Any] = {} # {symbol: np.ndarray} - CNN features
|
self.latest_cob_features: Dict[str, Any] = {} # {symbol: np.ndarray} - CNN features
|
||||||
self.latest_cob_state: Dict[str, Any] = {} # {symbol: np.ndarray} - DQN state features
|
self.latest_cob_state: Dict[str, Any] = {} # {symbol: np.ndarray} - DQN state features
|
||||||
self.cob_feature_history: Dict[str, List] = {symbol: [] for symbol in self.symbols} # Rolling history for models
|
self.cob_feature_history: Dict[str, List[Any]] = {symbol: [] for symbol in self.symbols} # Rolling history for models
|
||||||
|
|
||||||
# Enhanced ML Models
|
# Enhanced ML Models
|
||||||
self.rl_agent = None # DQN Agent
|
self.rl_agent: Any = None # DQN Agent
|
||||||
self.cnn_model = None # CNN Model for pattern recognition
|
self.cnn_model: Any = None # CNN Model for pattern recognition
|
||||||
self.extrema_trainer = None # Extrema/pivot trainer
|
self.extrema_trainer: Any = None # Extrema/pivot trainer
|
||||||
|
self.primary_transformer: Any = None # Transformer model
|
||||||
|
self.primary_transformer_trainer: Any = None # Transformer model trainer
|
||||||
|
self.transformer_checkpoint_info: Dict[str, Any] = {} # Transformer checkpoint info
|
||||||
|
self.cob_rl_agent: Any = None # COB RL Agent
|
||||||
|
self.decision_model: Any = None # Decision Fusion model
|
||||||
|
|
||||||
self.latest_cnn_features: Dict[str, Any] = {} # CNN hidden features
|
self.latest_cnn_features: Dict[str, Any] = {} # CNN hidden features
|
||||||
self.latest_cnn_predictions: Dict[str, Any] = {} # CNN predictions
|
self.latest_cnn_predictions: Dict[str, Any] = {} # CNN predictions
|
||||||
|
|
||||||
# Enhanced RL features
|
# Enhanced RL features
|
||||||
self.sensitivity_learning_queue = [] # For outcome-based learning
|
self.sensitivity_learning_queue: List[Any] = [] # For outcome-based learning
|
||||||
self.perfect_move_buffer = [] # Buffer for perfect move analysis
|
self.perfect_move_buffer: List[Any] = [] # Buffer for perfect move analysis
|
||||||
self.position_status = {} # Current positions
|
self.position_status: Dict[str, Any] = {} # Current positions
|
||||||
|
|
||||||
# Real-time processing
|
# Real-time processing
|
||||||
self.realtime_processing = False
|
self.realtime_processing: bool = False
|
||||||
self.realtime_tasks = []
|
self.realtime_tasks: List[Any] = []
|
||||||
|
|
||||||
logger.info("Enhanced TradingOrchestrator initialized with full ML capabilities")
|
logger.info("Enhanced TradingOrchestrator initialized with full ML capabilities")
|
||||||
logger.info(f"Enhanced RL training: {enhanced_rl_training}")
|
logger.info(f"Enhanced RL training: {enhanced_rl_training}")
|
||||||
@ -310,6 +316,7 @@ class TradingOrchestrator:
|
|||||||
self.cob_integration = COBIntegration(symbols=self.symbols)
|
self.cob_integration = COBIntegration(symbols=self.symbols)
|
||||||
|
|
||||||
# Register callbacks to receive real-time COB data
|
# Register callbacks to receive real-time COB data
|
||||||
|
if self.cob_integration:
|
||||||
self.cob_integration.add_cnn_callback(self._on_cob_cnn_features)
|
self.cob_integration.add_cnn_callback(self._on_cob_cnn_features)
|
||||||
self.cob_integration.add_dqn_callback(self._on_cob_dqn_features)
|
self.cob_integration.add_dqn_callback(self._on_cob_dqn_features)
|
||||||
self.cob_integration.add_dashboard_callback(self._on_cob_dashboard_data)
|
self.cob_integration.add_dashboard_callback(self._on_cob_dashboard_data)
|
||||||
@ -320,9 +327,9 @@ class TradingOrchestrator:
|
|||||||
self.cob_matrix_size = self.cob_matrix_duration // self.cob_matrix_resolution # 300 samples
|
self.cob_matrix_size = self.cob_matrix_duration // self.cob_matrix_resolution # 300 samples
|
||||||
|
|
||||||
# COB data matrix storage - 5 minutes of 1-second snapshots
|
# COB data matrix storage - 5 minutes of 1-second snapshots
|
||||||
self.cob_data_matrix: Dict[str, deque] = {}
|
self.cob_data_matrix: Dict[str, deque[Any]] = {}
|
||||||
self.cob_feature_matrix: Dict[str, deque] = {}
|
self.cob_feature_matrix: Dict[str, deque[Any]] = {}
|
||||||
self.cob_state_matrix: Dict[str, deque] = {}
|
self.cob_state_matrix: Dict[str, deque[Any]] = {}
|
||||||
|
|
||||||
# Initialize matrix storage for each symbol
|
# Initialize matrix storage for each symbol
|
||||||
for symbol in self.symbols:
|
for symbol in self.symbols:
|
||||||
@ -336,16 +343,16 @@ class TradingOrchestrator:
|
|||||||
self.cob_state_matrix[symbol] = deque(maxlen=self.cob_matrix_size)
|
self.cob_state_matrix[symbol] = deque(maxlen=self.cob_matrix_size)
|
||||||
|
|
||||||
# Initialize COB data storage (legacy support)
|
# Initialize COB data storage (legacy support)
|
||||||
self.latest_cob_snapshots = {}
|
self.latest_cob_snapshots: Dict[str, Any] = {}
|
||||||
self.cob_feature_cache = {}
|
self.cob_feature_cache: Dict[str, Any] = {}
|
||||||
self.cob_state_cache = {}
|
self.cob_state_cache: Dict[str, Any] = {}
|
||||||
|
|
||||||
# COB matrix update tracking
|
# COB matrix update tracking
|
||||||
self.last_cob_matrix_update = {}
|
self.last_cob_matrix_update: Dict[str, float] = {}
|
||||||
self.cob_matrix_update_interval = 1.0 # Update every 1 second
|
self.cob_matrix_update_interval = 1.0 # Update every 1 second
|
||||||
|
|
||||||
# COB matrix statistics
|
# COB matrix statistics
|
||||||
self.cob_matrix_stats = {
|
self.cob_matrix_stats: Dict[str, Any] = {
|
||||||
'total_updates': 0,
|
'total_updates': 0,
|
||||||
'matrix_fills': {symbol: 0 for symbol in self.symbols},
|
'matrix_fills': {symbol: 0 for symbol in self.symbols},
|
||||||
'feature_generations': 0,
|
'feature_generations': 0,
|
||||||
@ -375,6 +382,7 @@ class TradingOrchestrator:
|
|||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
async def cob_main():
|
async def cob_main():
|
||||||
|
if self.cob_integration: # Additional check
|
||||||
await self.cob_integration.start()
|
await self.cob_integration.start()
|
||||||
# Keep running until stopped
|
# Keep running until stopped
|
||||||
while True:
|
while True:
|
||||||
|
@ -571,6 +571,21 @@ class CleanTradingDashboard:
|
|||||||
self._clear_session()
|
self._clear_session()
|
||||||
return [html.I(className="fas fa-trash me-1"), "Clear Session"]
|
return [html.I(className="fas fa-trash me-1"), "Clear Session"]
|
||||||
|
|
||||||
|
@self.app.callback(
|
||||||
|
Output('store-models-btn', 'children'),
|
||||||
|
[Input('store-models-btn', 'n_clicks')],
|
||||||
|
prevent_initial_call=True
|
||||||
|
)
|
||||||
|
def handle_store_models(n_clicks):
|
||||||
|
"""Handle store all models button click"""
|
||||||
|
if n_clicks:
|
||||||
|
success = self._store_all_models()
|
||||||
|
if success:
|
||||||
|
return [html.I(className="fas fa-save me-1"), "Models Stored"]
|
||||||
|
else:
|
||||||
|
return [html.I(className="fas fa-exclamation-triangle me-1"), "Store Failed"]
|
||||||
|
return [html.I(className="fas fa-save me-1"), "Store All Models"]
|
||||||
|
|
||||||
def _get_current_price(self, symbol: str) -> Optional[float]:
|
def _get_current_price(self, symbol: str) -> Optional[float]:
|
||||||
"""Get current price for symbol"""
|
"""Get current price for symbol"""
|
||||||
try:
|
try:
|
||||||
@ -2927,6 +2942,10 @@ class CleanTradingDashboard:
|
|||||||
if len(self.recent_decisions) > 200:
|
if len(self.recent_decisions) > 200:
|
||||||
self.recent_decisions = self.recent_decisions[-200:]
|
self.recent_decisions = self.recent_decisions[-200:]
|
||||||
|
|
||||||
|
# Train ALL models on the signal (if executed)
|
||||||
|
if signal['executed']:
|
||||||
|
self._train_all_models_on_signal(signal)
|
||||||
|
|
||||||
# Log signal processing
|
# Log signal processing
|
||||||
status = "EXECUTED" if signal['executed'] else ("BLOCKED" if signal['blocked'] else "PENDING")
|
status = "EXECUTED" if signal['executed'] else ("BLOCKED" if signal['blocked'] else "PENDING")
|
||||||
logger.info(f"[{status}] {signal['action']} signal for {signal['symbol']} "
|
logger.info(f"[{status}] {signal['action']} signal for {signal['symbol']} "
|
||||||
@ -2935,11 +2954,308 @@ class CleanTradingDashboard:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing dashboard signal: {e}")
|
logger.error(f"Error processing dashboard signal: {e}")
|
||||||
|
|
||||||
def _train_dqn_on_signal(self, signal: Dict):
|
def _train_all_models_on_signal(self, signal: Dict):
|
||||||
"""Train DQN agent on generated signal - NOT AVAILABLE IN BASIC ORCHESTRATOR"""
|
"""Train ALL models on executed trade signal - Comprehensive training system"""
|
||||||
# Basic orchestrator doesn't have DQN features
|
try:
|
||||||
|
# Get trade outcome for training
|
||||||
|
trade_outcome = self._get_trade_outcome_for_training(signal)
|
||||||
|
if not trade_outcome:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 1. Train DQN model
|
||||||
|
self._train_dqn_on_signal(signal, trade_outcome)
|
||||||
|
|
||||||
|
# 2. Train CNN model
|
||||||
|
self._train_cnn_on_signal(signal, trade_outcome)
|
||||||
|
|
||||||
|
# 3. Train Transformer model
|
||||||
|
self._train_transformer_on_signal(signal, trade_outcome)
|
||||||
|
|
||||||
|
# 4. Train COB RL model
|
||||||
|
self._train_cob_rl_on_signal(signal, trade_outcome)
|
||||||
|
|
||||||
|
# 5. Train Decision Fusion model
|
||||||
|
self._train_decision_fusion_on_signal(signal, trade_outcome)
|
||||||
|
|
||||||
|
logger.debug(f"Trained all models on {signal['action']} signal with outcome: {trade_outcome['pnl']:.2f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error training models on signal: {e}")
|
||||||
|
|
||||||
|
def _get_trade_outcome_for_training(self, signal: Dict) -> Optional[Dict]:
|
||||||
|
"""Get trade outcome for training - either from completed trade or position change"""
|
||||||
|
try:
|
||||||
|
# Check if we have a completed trade
|
||||||
|
if self.closed_trades:
|
||||||
|
latest_trade = self.closed_trades[-1]
|
||||||
|
# Verify this trade corresponds to the signal
|
||||||
|
if (latest_trade.get('symbol') == signal.get('symbol') and
|
||||||
|
abs(latest_trade.get('entry_time', 0) - signal.get('timestamp', 0)) < 60): # Within 1 minute
|
||||||
|
return {
|
||||||
|
'pnl': latest_trade.get('pnl', 0),
|
||||||
|
'entry_price': latest_trade.get('entry_price', 0),
|
||||||
|
'exit_price': latest_trade.get('exit_price', 0),
|
||||||
|
'side': latest_trade.get('side', 'UNKNOWN'),
|
||||||
|
'quantity': latest_trade.get('quantity', 0),
|
||||||
|
'duration': latest_trade.get('exit_time', 0) - latest_trade.get('entry_time', 0),
|
||||||
|
'trade_type': 'completed'
|
||||||
|
}
|
||||||
|
|
||||||
|
# If no completed trade, use position change for training
|
||||||
|
if self.current_position:
|
||||||
|
current_price = self._get_current_price(signal.get('symbol', 'ETH/USDT'))
|
||||||
|
if current_price:
|
||||||
|
entry_price = self.current_position.get('price', 0)
|
||||||
|
side = self.current_position.get('side', 'UNKNOWN')
|
||||||
|
size = self.current_position.get('size', 0)
|
||||||
|
|
||||||
|
if entry_price > 0 and size > 0:
|
||||||
|
# Calculate unrealized P&L
|
||||||
|
if side.upper() == 'LONG':
|
||||||
|
pnl = (current_price - entry_price) * size * self.current_leverage
|
||||||
|
else: # SHORT
|
||||||
|
pnl = (entry_price - current_price) * size * self.current_leverage
|
||||||
|
|
||||||
|
return {
|
||||||
|
'pnl': pnl,
|
||||||
|
'entry_price': entry_price,
|
||||||
|
'current_price': current_price,
|
||||||
|
'side': side,
|
||||||
|
'quantity': size,
|
||||||
|
'duration': 0, # Position still open
|
||||||
|
'trade_type': 'position_change'
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error getting trade outcome: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _train_dqn_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||||
|
"""Train DQN agent on executed signal with trade outcome"""
|
||||||
|
try:
|
||||||
|
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create training data for DQN
|
||||||
|
state_features = self._get_dqn_state_features(signal.get('symbol', 'ETH/USDT'), signal.get('price', 0))
|
||||||
|
action = 0 if signal['action'] == 'BUY' else 1 # 0=BUY, 1=SELL
|
||||||
|
|
||||||
|
# Calculate reward based on trade outcome
|
||||||
|
pnl = trade_outcome.get('pnl', 0)
|
||||||
|
reward = pnl * 100 # Scale reward for better learning
|
||||||
|
|
||||||
|
# Create next state (simplified)
|
||||||
|
next_state_features = state_features.copy() # In real implementation, this would be the next market state
|
||||||
|
|
||||||
|
# Store experience in DQN memory
|
||||||
|
if hasattr(self.orchestrator.rl_agent, 'remember'):
|
||||||
|
self.orchestrator.rl_agent.remember(
|
||||||
|
state_features, action, reward, next_state_features, done=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger training if enough samples
|
||||||
|
if hasattr(self.orchestrator.rl_agent, 'memory') and len(self.orchestrator.rl_agent.memory) > 32:
|
||||||
|
if hasattr(self.orchestrator.rl_agent, 'replay'):
|
||||||
|
loss = self.orchestrator.rl_agent.replay(batch_size=32)
|
||||||
|
if loss is not None:
|
||||||
|
logger.debug(f"DQN trained on signal - loss: {loss:.4f}, reward: {reward:.2f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error training DQN on signal: {e}")
|
||||||
|
|
||||||
|
def _train_cnn_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||||
|
"""Train CNN model on executed signal with trade outcome"""
|
||||||
|
try:
|
||||||
|
if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create training data for CNN
|
||||||
|
symbol = signal.get('symbol', 'ETH/USDT')
|
||||||
|
current_price = signal.get('price', 0)
|
||||||
|
|
||||||
|
# Get market features
|
||||||
|
market_features = self._get_cnn_features_and_predictions(symbol)
|
||||||
|
if not market_features:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create target based on trade outcome
|
||||||
|
pnl = trade_outcome.get('pnl', 0)
|
||||||
|
target = 1.0 if pnl > 0 else 0.0 # Binary classification: profitable vs not
|
||||||
|
|
||||||
|
# Prepare training data
|
||||||
|
features = market_features.get('features', [])
|
||||||
|
if features:
|
||||||
|
# Convert to tensor format (simplified)
|
||||||
|
import numpy as np
|
||||||
|
feature_tensor = np.array(features, dtype=np.float32)
|
||||||
|
target_tensor = np.array([target], dtype=np.float32)
|
||||||
|
|
||||||
|
# Train CNN model (if it has training method)
|
||||||
|
if hasattr(self.orchestrator.cnn_model, 'train_on_batch'):
|
||||||
|
loss = self.orchestrator.cnn_model.train_on_batch(feature_tensor, target_tensor)
|
||||||
|
logger.debug(f"CNN trained on signal - loss: {loss:.4f}, target: {target}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error training CNN on signal: {e}")
|
||||||
|
|
||||||
|
def _train_transformer_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||||
|
"""Train Transformer model on executed signal with trade outcome"""
|
||||||
|
try:
|
||||||
|
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer') or not self.orchestrator.primary_transformer:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create training data for Transformer
|
||||||
|
symbol = signal.get('symbol', 'ETH/USDT')
|
||||||
|
current_price = signal.get('price', 0)
|
||||||
|
|
||||||
|
# Get comprehensive market state
|
||||||
|
market_state = self._get_comprehensive_market_state(symbol, current_price)
|
||||||
|
|
||||||
|
# Create target based on trade outcome
|
||||||
|
pnl = trade_outcome.get('pnl', 0)
|
||||||
|
target_action = 0 if signal['action'] == 'BUY' else 1 # 0=BUY, 1=SELL
|
||||||
|
target_confidence = signal.get('confidence', 0.5)
|
||||||
|
|
||||||
|
# Prepare training data
|
||||||
|
features = list(market_state.values())
|
||||||
|
if features:
|
||||||
|
import numpy as np
|
||||||
|
feature_tensor = np.array(features, dtype=np.float32)
|
||||||
|
target_tensor = np.array([target_action, target_confidence], dtype=np.float32)
|
||||||
|
|
||||||
|
# Train Transformer model (if it has training method)
|
||||||
|
if hasattr(self.orchestrator.primary_transformer, 'train_on_batch'):
|
||||||
|
loss = self.orchestrator.primary_transformer.train_on_batch(feature_tensor, target_tensor)
|
||||||
|
logger.debug(f"Transformer trained on signal - loss: {loss:.4f}, action: {target_action}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error training Transformer on signal: {e}")
|
||||||
|
|
||||||
|
def _train_cob_rl_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||||
|
"""Train COB RL model on executed signal with trade outcome"""
|
||||||
|
try:
|
||||||
|
if not self.orchestrator or not hasattr(self.orchestrator, 'cob_rl_agent') or not self.orchestrator.cob_rl_agent:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create training data for COB RL
|
||||||
|
symbol = signal.get('symbol', 'ETH/USDT')
|
||||||
|
|
||||||
|
# Get COB features
|
||||||
|
cob_features = self._get_cob_features_for_training(symbol, signal.get('price', 0))
|
||||||
|
if not cob_features:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create target based on trade outcome
|
||||||
|
pnl = trade_outcome.get('pnl', 0)
|
||||||
|
action = 0 if signal['action'] == 'BUY' else 1
|
||||||
|
reward = pnl * 100 # Scale reward
|
||||||
|
|
||||||
|
# Store experience in COB RL memory
|
||||||
|
if hasattr(self.orchestrator.cob_rl_agent, 'remember'):
|
||||||
|
self.orchestrator.cob_rl_agent.remember(
|
||||||
|
cob_features, action, reward, cob_features, done=True # Simplified next state
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger training if enough samples
|
||||||
|
if hasattr(self.orchestrator.cob_rl_agent, 'memory') and len(self.orchestrator.cob_rl_agent.memory) > 32:
|
||||||
|
if hasattr(self.orchestrator.cob_rl_agent, 'replay'):
|
||||||
|
loss = self.orchestrator.cob_rl_agent.replay(batch_size=32)
|
||||||
|
if loss is not None:
|
||||||
|
logger.debug(f"COB RL trained on signal - loss: {loss:.4f}, reward: {reward:.2f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error training COB RL on signal: {e}")
|
||||||
|
|
||||||
|
def _train_decision_fusion_on_signal(self, signal: Dict, trade_outcome: Dict):
|
||||||
|
"""Train Decision Fusion model on executed signal with trade outcome"""
|
||||||
|
try:
|
||||||
|
# Decision fusion model combines predictions from all models
|
||||||
|
# This would be implemented if there's a decision fusion model available
|
||||||
|
if not self.orchestrator or not hasattr(self.orchestrator, 'decision_model'):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create training data for decision fusion
|
||||||
|
symbol = signal.get('symbol', 'ETH/USDT')
|
||||||
|
current_price = signal.get('price', 0)
|
||||||
|
|
||||||
|
# Get predictions from all models
|
||||||
|
model_predictions = {
|
||||||
|
'dqn': self._get_dqn_prediction(symbol, current_price),
|
||||||
|
'cnn': self._get_cnn_prediction(symbol, current_price),
|
||||||
|
'transformer': self._get_transformer_prediction(symbol, current_price),
|
||||||
|
'cob_rl': self._get_cob_rl_prediction(symbol, current_price)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create target based on trade outcome
|
||||||
|
pnl = trade_outcome.get('pnl', 0)
|
||||||
|
target = 1.0 if pnl > 0 else 0.0
|
||||||
|
|
||||||
|
# Train decision fusion model (if available)
|
||||||
|
if hasattr(self.orchestrator.decision_model, 'train_on_batch'):
|
||||||
|
# Prepare training data
|
||||||
|
import numpy as np
|
||||||
|
prediction_tensor = np.array(list(model_predictions.values()), dtype=np.float32)
|
||||||
|
target_tensor = np.array([target], dtype=np.float32)
|
||||||
|
|
||||||
|
loss = self.orchestrator.decision_model.train_on_batch(prediction_tensor, target_tensor)
|
||||||
|
logger.debug(f"Decision Fusion trained on signal - loss: {loss:.4f}, target: {target}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error training Decision Fusion on signal: {e}")
|
||||||
|
|
||||||
|
def _get_dqn_prediction(self, symbol: str, current_price: float) -> float:
|
||||||
|
"""Get DQN prediction for decision fusion"""
|
||||||
|
try:
|
||||||
|
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||||
|
state_features = self._get_dqn_state_features(symbol, current_price)
|
||||||
|
if hasattr(self.orchestrator.rl_agent, 'predict'):
|
||||||
|
return self.orchestrator.rl_agent.predict(state_features)
|
||||||
|
return 0.5 # Default neutral prediction
|
||||||
|
except:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
def _get_cnn_prediction(self, symbol: str, current_price: float) -> float:
|
||||||
|
"""Get CNN prediction for decision fusion"""
|
||||||
|
try:
|
||||||
|
if self.orchestrator and hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||||
|
market_features = self._get_cnn_features_and_predictions(symbol)
|
||||||
|
if market_features and hasattr(self.orchestrator.cnn_model, 'predict'):
|
||||||
|
features = market_features.get('features', [])
|
||||||
|
if features:
|
||||||
|
import numpy as np
|
||||||
|
return self.orchestrator.cnn_model.predict(np.array([features]))
|
||||||
|
return 0.5 # Default neutral prediction
|
||||||
|
except:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
def _get_transformer_prediction(self, symbol: str, current_price: float) -> float:
|
||||||
|
"""Get Transformer prediction for decision fusion"""
|
||||||
|
try:
|
||||||
|
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer:
|
||||||
|
market_state = self._get_comprehensive_market_state(symbol, current_price)
|
||||||
|
if hasattr(self.orchestrator.primary_transformer, 'predict'):
|
||||||
|
features = list(market_state.values())
|
||||||
|
if features:
|
||||||
|
import numpy as np
|
||||||
|
return self.orchestrator.primary_transformer.predict(np.array([features]))
|
||||||
|
return 0.5 # Default neutral prediction
|
||||||
|
except:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
def _get_cob_rl_prediction(self, symbol: str, current_price: float) -> float:
|
||||||
|
"""Get COB RL prediction for decision fusion"""
|
||||||
|
try:
|
||||||
|
if self.orchestrator and hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
||||||
|
cob_features = self._get_cob_features_for_training(symbol, current_price)
|
||||||
|
if cob_features and hasattr(self.orchestrator.cob_rl_agent, 'predict'):
|
||||||
|
import numpy as np
|
||||||
|
return self.orchestrator.cob_rl_agent.predict(np.array([cob_features]))
|
||||||
|
return 0.5 # Default neutral prediction
|
||||||
|
except:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
def _execute_manual_trade(self, action: str):
|
def _execute_manual_trade(self, action: str):
|
||||||
"""Execute manual trading action - ENHANCED with PERSISTENT SIGNAL STORAGE"""
|
"""Execute manual trading action - ENHANCED with PERSISTENT SIGNAL STORAGE"""
|
||||||
try:
|
try:
|
||||||
@ -3097,6 +3413,18 @@ class CleanTradingDashboard:
|
|||||||
self.closed_trades.append(trade_record)
|
self.closed_trades.append(trade_record)
|
||||||
logger.info(f"Added completed trade to closed_trades: {action} P&L ${leveraged_pnl:.2f} (raw: ${raw_pnl:.2f}, leverage: x{self.current_leverage})")
|
logger.info(f"Added completed trade to closed_trades: {action} P&L ${leveraged_pnl:.2f} (raw: ${raw_pnl:.2f}, leverage: x{self.current_leverage})")
|
||||||
|
|
||||||
|
# TRAIN ALL MODELS ON MANUAL TRADE OUTCOME
|
||||||
|
manual_signal = {
|
||||||
|
'action': action,
|
||||||
|
'price': current_price,
|
||||||
|
'symbol': symbol,
|
||||||
|
'confidence': 1.0,
|
||||||
|
'executed': True,
|
||||||
|
'manual': True,
|
||||||
|
'timestamp': datetime.now().timestamp()
|
||||||
|
}
|
||||||
|
self._train_all_models_on_signal(manual_signal)
|
||||||
|
|
||||||
# MOVE BASE CASE TO POSITIVE/NEGATIVE based on leveraged outcome
|
# MOVE BASE CASE TO POSITIVE/NEGATIVE based on leveraged outcome
|
||||||
if hasattr(self, 'pending_trade_case_id') and self.pending_trade_case_id:
|
if hasattr(self, 'pending_trade_case_id') and self.pending_trade_case_id:
|
||||||
try:
|
try:
|
||||||
@ -3512,6 +3840,100 @@ class CleanTradingDashboard:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error clearing session: {e}")
|
logger.error(f"Error clearing session: {e}")
|
||||||
|
|
||||||
|
def _store_all_models(self) -> bool:
|
||||||
|
"""Store all current models to persistent storage"""
|
||||||
|
try:
|
||||||
|
if not self.orchestrator:
|
||||||
|
logger.warning("No orchestrator available for model storage")
|
||||||
|
return False
|
||||||
|
|
||||||
|
stored_models = []
|
||||||
|
|
||||||
|
# 1. Store DQN model
|
||||||
|
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||||
|
try:
|
||||||
|
if hasattr(self.orchestrator.rl_agent, 'save'):
|
||||||
|
save_path = self.orchestrator.rl_agent.save('models/saved/dqn_agent_session')
|
||||||
|
stored_models.append(('DQN', save_path))
|
||||||
|
logger.info(f"Stored DQN model: {save_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to store DQN model: {e}")
|
||||||
|
|
||||||
|
# 2. Store CNN model
|
||||||
|
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||||
|
try:
|
||||||
|
if hasattr(self.orchestrator.cnn_model, 'save'):
|
||||||
|
save_path = self.orchestrator.cnn_model.save('models/saved/cnn_model_session')
|
||||||
|
stored_models.append(('CNN', save_path))
|
||||||
|
logger.info(f"Stored CNN model: {save_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to store CNN model: {e}")
|
||||||
|
|
||||||
|
# 3. Store Transformer model
|
||||||
|
if hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer:
|
||||||
|
try:
|
||||||
|
if hasattr(self.orchestrator.primary_transformer, 'save'):
|
||||||
|
save_path = self.orchestrator.primary_transformer.save('models/saved/transformer_model_session')
|
||||||
|
stored_models.append(('Transformer', save_path))
|
||||||
|
logger.info(f"Stored Transformer model: {save_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to store Transformer model: {e}")
|
||||||
|
|
||||||
|
# 4. Store COB RL model
|
||||||
|
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
||||||
|
try:
|
||||||
|
if hasattr(self.orchestrator.cob_rl_agent, 'save'):
|
||||||
|
save_path = self.orchestrator.cob_rl_agent.save('models/saved/cob_rl_agent_session')
|
||||||
|
stored_models.append(('COB RL', save_path))
|
||||||
|
logger.info(f"Stored COB RL model: {save_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to store COB RL model: {e}")
|
||||||
|
|
||||||
|
# 5. Store Decision Fusion model
|
||||||
|
if hasattr(self.orchestrator, 'decision_model') and self.orchestrator.decision_model:
|
||||||
|
try:
|
||||||
|
if hasattr(self.orchestrator.decision_model, 'save'):
|
||||||
|
save_path = self.orchestrator.decision_model.save('models/saved/decision_fusion_session')
|
||||||
|
stored_models.append(('Decision Fusion', save_path))
|
||||||
|
logger.info(f"Stored Decision Fusion model: {save_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to store Decision Fusion model: {e}")
|
||||||
|
|
||||||
|
# 6. Store model metadata and training state
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
'timestamp': datetime.now().isoformat(),
|
||||||
|
'session_pnl': self.session_pnl,
|
||||||
|
'trade_count': len(self.closed_trades),
|
||||||
|
'stored_models': stored_models,
|
||||||
|
'training_iterations': getattr(self, 'training_iteration', 0),
|
||||||
|
'model_performance': self.get_model_performance_metrics()
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata_path = 'models/saved/session_metadata.json'
|
||||||
|
with open(metadata_path, 'w') as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
|
||||||
|
logger.info(f"Stored session metadata: {metadata_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to store metadata: {e}")
|
||||||
|
|
||||||
|
# Log summary
|
||||||
|
if stored_models:
|
||||||
|
logger.info(f"Successfully stored {len(stored_models)} models: {[name for name, _ in stored_models]}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning("No models were stored - no models available or save methods not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error storing models: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
def _get_signal_attribute(self, signal, attr_name, default=None):
|
def _get_signal_attribute(self, signal, attr_name, default=None):
|
||||||
"""Safely get attribute from signal (handles both dict and dataclass objects)"""
|
"""Safely get attribute from signal (handles both dict and dataclass objects)"""
|
||||||
try:
|
try:
|
||||||
|
@ -149,6 +149,10 @@ class DashboardLayoutManager:
|
|||||||
html.I(className="fas fa-trash me-1"),
|
html.I(className="fas fa-trash me-1"),
|
||||||
"Clear Session"
|
"Clear Session"
|
||||||
], id="clear-session-btn", className="btn btn-warning btn-sm w-100"),
|
], id="clear-session-btn", className="btn btn-warning btn-sm w-100"),
|
||||||
|
html.Button([
|
||||||
|
html.I(className="fas fa-save me-1"),
|
||||||
|
"Store All Models"
|
||||||
|
], id="store-models-btn", className="btn btn-info btn-sm w-100 mt-2"),
|
||||||
html.Hr(className="my-2"),
|
html.Hr(className="my-2"),
|
||||||
html.Small("System Status", className="text-muted d-block mb-1"),
|
html.Small("System Status", className="text-muted d-block mb-1"),
|
||||||
html.Div([
|
html.Div([
|
||||||
|
Reference in New Issue
Block a user