integration of (legacy) training systems, initialize, train, show on the UI

This commit is contained in:
Dobromir Popov
2025-07-05 00:33:03 +03:00
parent 5ca7493708
commit d260e73f9a
10 changed files with 647 additions and 151 deletions

View File

@ -8,6 +8,7 @@ This is the core orchestrator that:
4. Manages the learning loop between components
5. Ensures memory efficiency (8GB constraint)
6. Provides real-time COB (Change of Bid) data for models
7. Integrates EnhancedRealtimeTrainingSystem for continuous learning
"""
import asyncio
@ -35,6 +36,14 @@ except ImportError:
COBIntegration = None
COBSnapshot = None
# Import EnhancedRealtimeTrainingSystem
try:
from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
ENHANCED_TRAINING_AVAILABLE = True
except ImportError:
ENHANCED_TRAINING_AVAILABLE = False
EnhancedRealtimeTrainingSystem = None
logger = logging.getLogger(__name__)
@dataclass
@ -64,6 +73,7 @@ class TradingOrchestrator:
Enhanced Trading Orchestrator with full ML and COB integration
Coordinates CNN, DQN, and COB models for advanced trading decisions
Features real-time COB (Change of Bid) data for market microstructure data
Includes EnhancedRealtimeTrainingSystem for continuous learning
"""
def __init__(self, data_provider: Optional[DataProvider] = None, enhanced_rl_training: bool = True, model_registry: Optional[ModelRegistry] = None):
@ -141,17 +151,24 @@ class TradingOrchestrator:
self.realtime_processing: bool = False
self.realtime_tasks: List[Any] = []
# ENHANCED: Real-time Training System Integration
self.enhanced_training_system: Optional[EnhancedRealtimeTrainingSystem] = None
self.training_enabled: bool = enhanced_rl_training and ENHANCED_TRAINING_AVAILABLE
logger.info("Enhanced TradingOrchestrator initialized with full ML capabilities")
logger.info(f"Enhanced RL training: {enhanced_rl_training}")
logger.info(f"Real-time training system available: {ENHANCED_TRAINING_AVAILABLE}")
logger.info(f"Training enabled: {self.training_enabled}")
logger.info(f"Confidence threshold: {self.confidence_threshold}")
logger.info(f"Decision frequency: {self.decision_frequency}s")
logger.info(f"Symbols: {self.symbols}")
logger.info("Universal Data Adapter integrated for centralized data flow")
# Initialize models and COB integration
# Initialize models, COB integration, and training system
self._initialize_ml_models()
self._initialize_cob_integration()
self._initialize_decision_fusion() # Initialize fusion system
self._initialize_enhanced_training_system() # Initialize real-time training
def _initialize_ml_models(self):
"""Initialize ML models for enhanced trading"""
@ -2391,7 +2408,7 @@ class TradingOrchestrator:
# ENHANCED: Decision Fusion Methods - Built into orchestrator (NO SEPARATE FILE NEEDED!)
def _initialize_decision_fusion(self):
"""Initialize the decision fusion neural network"""
"""Initialize the decision fusion neural network for learning model effectiveness"""
try:
if not self.decision_fusion_enabled:
return
@ -2399,168 +2416,121 @@ class TradingOrchestrator:
import torch
import torch.nn as nn
# Simple decision fusion network
# Create decision fusion network
class DecisionFusionNet(nn.Module):
def __init__(self, input_size=32, hidden_size=64):
super().__init__()
self.fusion_layers = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_size, hidden_size // 2),
nn.ReLU(),
nn.Linear(hidden_size // 2, 16)
)
self.action_head = nn.Linear(16, 3) # BUY, SELL, HOLD
self.confidence_head = nn.Linear(16, 1)
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, 3) # BUY, SELL, HOLD
self.dropout = nn.Dropout(0.2)
def forward(self, x):
features = self.fusion_layers(x)
action_logits = self.action_head(features)
confidence = torch.sigmoid(self.confidence_head(features))
return action_logits, confidence.squeeze()
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = torch.relu(self.fc2(x))
x = self.dropout(x)
return torch.softmax(self.fc3(x), dim=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.decision_fusion_network = DecisionFusionNet().to(device)
self.fusion_optimizer = torch.optim.Adam(self.decision_fusion_network.parameters(), lr=0.001)
self.fusion_device = device
# Try to load existing checkpoint
try:
from utils.checkpoint_manager import load_best_checkpoint
result = load_best_checkpoint("decision")
if result:
file_path, metadata = result
checkpoint = torch.load(file_path, map_location=device)
if 'model_state_dict' in checkpoint:
self.decision_fusion_network.load_state_dict(checkpoint['model_state_dict'])
self.model_states['decision']['checkpoint_loaded'] = True
self.model_states['decision']['checkpoint_filename'] = metadata.checkpoint_id
self.model_states['decision']['current_loss'] = metadata.loss or 0.0089
self.model_states['decision']['best_loss'] = metadata.loss or 0.0065
logger.info(f"Decision fusion checkpoint loaded: {metadata.checkpoint_id} (loss={metadata.loss:.4f})")
except Exception as e:
logger.debug(f"No decision fusion checkpoint found: {e}")
logger.info("Decision fusion network initialized in orchestrator - TRAINING ON EVERY SIGNAL!")
self.decision_fusion_network = DecisionFusionNet()
logger.info("Decision fusion network initialized")
except Exception as e:
logger.error(f"Error initializing decision fusion: {e}")
logger.warning(f"Decision fusion initialization failed: {e}")
self.decision_fusion_enabled = False
def train_fusion_on_every_signal(self, decision: TradingDecision, market_outcome: Dict):
"""Train the decision fusion network on EVERY signal/action - COMPREHENSIVE TRAINING"""
def _initialize_enhanced_training_system(self):
"""Initialize the enhanced real-time training system"""
try:
if not self.decision_fusion_enabled or not self.decision_fusion_network:
if not self.training_enabled:
logger.info("Enhanced training system disabled")
return
symbol = decision.symbol
if symbol not in self.last_fusion_inputs:
if not ENHANCED_TRAINING_AVAILABLE:
logger.warning("EnhancedRealtimeTrainingSystem not available - training disabled")
self.training_enabled = False
return
import torch
import torch.nn as nn
# Get the features used for this decision
fusion_input = self.last_fusion_inputs[symbol]
features = fusion_input['features'].to(self.fusion_device)
# Create training target based on outcome
actual_outcome = market_outcome.get('price_change', 0)
pnl = market_outcome.get('pnl', 0)
# Convert decision and outcome to training labels
action_target = {'BUY': 0, 'SELL': 1, 'HOLD': 2}[decision.action]
# Enhanced reward based on actual market movement
if decision.action == 'BUY' and actual_outcome > 0:
confidence_target = min(0.95, 0.5 + abs(actual_outcome) * 10) # Higher confidence for good predictions
elif decision.action == 'SELL' and actual_outcome < 0:
confidence_target = min(0.95, 0.5 + abs(actual_outcome) * 10)
elif decision.action == 'HOLD':
confidence_target = 0.5 # Neutral confidence for hold
else:
confidence_target = max(0.05, 0.5 - abs(actual_outcome) * 10) # Lower confidence for bad predictions
# Train the network
self.decision_fusion_network.train()
self.fusion_optimizer.zero_grad()
action_logits, predicted_confidence = self.decision_fusion_network(features)
# Calculate losses
action_loss = nn.CrossEntropyLoss()(action_logits, torch.tensor([action_target], device=self.fusion_device))
confidence_loss = nn.MSELoss()(predicted_confidence, torch.tensor([confidence_target], device=self.fusion_device))
total_loss = action_loss + confidence_loss
total_loss.backward()
self.fusion_optimizer.step()
# Update model state with REAL loss values
self.model_states['decision']['current_loss'] = total_loss.item()
if self.model_states['decision']['best_loss'] is None or total_loss.item() < self.model_states['decision']['best_loss']:
self.model_states['decision']['best_loss'] = total_loss.item()
# Store training example
self.fusion_training_data.append({
'features': features.cpu().numpy(),
'action_target': action_target,
'confidence_target': confidence_target,
'loss': total_loss.item(),
'timestamp': datetime.now()
})
# Save checkpoint periodically
if self.fusion_decisions_count % self.fusion_checkpoint_frequency == 0:
self._save_fusion_checkpoint()
logger.debug(f"🧠 Fusion training: action_loss={action_loss.item():.4f}, conf_loss={confidence_loss.item():.4f}, total={total_loss.item():.4f}")
except Exception as e:
logger.error(f"Error training fusion network: {e}")
def _save_fusion_checkpoint(self):
"""Save decision fusion checkpoint with real performance data"""
try:
if not self.decision_fusion_network:
return
from utils.checkpoint_manager import save_checkpoint
# Prepare checkpoint data
checkpoint_data = {
'model_state_dict': self.decision_fusion_network.state_dict(),
'optimizer_state_dict': self.fusion_optimizer.state_dict(),
'fusion_decisions_count': self.fusion_decisions_count,
'training_history': self.fusion_training_history[-100:], # Last 100 entries
}
# Calculate REAL performance metrics from actual training
recent_losses = [entry['loss'] for entry in self.fusion_training_data[-50:]]
avg_loss = sum(recent_losses) / len(recent_losses) if recent_losses else self.model_states['decision']['current_loss']
performance_metrics = {
'loss': avg_loss,
'decisions_count': self.fusion_decisions_count,
'model_parameters': sum(p.numel() for p in self.decision_fusion_network.parameters())
}
metadata = save_checkpoint(
model=checkpoint_data,
model_name="decision",
model_type="decision_fusion",
performance_metrics=performance_metrics,
training_metadata={'decisions_trained': self.fusion_decisions_count}
# Initialize the enhanced training system
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
orchestrator=self,
data_provider=self.data_provider,
dashboard=None # Will be set by dashboard when available
)
if metadata:
self.model_states['decision']['checkpoint_filename'] = metadata.checkpoint_id
logger.info(f"🧠 Decision fusion checkpoint saved: {metadata.checkpoint_id} (loss={avg_loss:.4f})")
logger.info("Enhanced real-time training system initialized")
logger.info(" - Real-time model training: ENABLED")
logger.info(" - Comprehensive feature extraction: ENABLED")
logger.info(" - Enhanced reward calculation: ENABLED")
logger.info(" - Forward-looking predictions: ENABLED")
except Exception as e:
logger.error(f"Error saving fusion checkpoint: {e}")
logger.error(f"Error initializing enhanced training system: {e}")
self.training_enabled = False
self.enhanced_training_system = None
def start_enhanced_training(self):
"""Start the enhanced real-time training system"""
try:
if not self.training_enabled or not self.enhanced_training_system:
logger.warning("Enhanced training system not available")
return False
self.enhanced_training_system.start_training()
logger.info("Enhanced real-time training started")
return True
except Exception as e:
logger.error(f"Error starting enhanced training: {e}")
return False
def stop_enhanced_training(self):
"""Stop the enhanced real-time training system"""
try:
if self.enhanced_training_system:
self.enhanced_training_system.stop_training()
logger.info("Enhanced real-time training stopped")
return True
return False
except Exception as e:
logger.error(f"Error stopping enhanced training: {e}")
return False
def get_enhanced_training_stats(self) -> Dict[str, Any]:
"""Get enhanced training system statistics"""
try:
if not self.enhanced_training_system:
return {
'training_enabled': False,
'system_available': ENHANCED_TRAINING_AVAILABLE,
'error': 'Training system not initialized'
}
stats = self.enhanced_training_system.get_training_statistics()
stats['training_enabled'] = self.training_enabled
stats['system_available'] = ENHANCED_TRAINING_AVAILABLE
return stats
except Exception as e:
logger.error(f"Error getting training stats: {e}")
return {
'training_enabled': self.training_enabled,
'system_available': ENHANCED_TRAINING_AVAILABLE,
'error': str(e)
}
def set_training_dashboard(self, dashboard):
"""Set the dashboard reference for the training system"""
try:
if self.enhanced_training_system:
self.enhanced_training_system.dashboard = dashboard
logger.info("Dashboard reference set for enhanced training system")
except Exception as e:
logger.error(f"Error setting training dashboard: {e}")
def get_universal_data_stream(self, current_time: datetime = None) -> Optional[UniversalDataStream]:
"""Get universal data stream for external consumers like dashboard"""
try: