integration of (legacy) training systems, initialize, train, show on the UI
This commit is contained in:
@ -8,6 +8,7 @@ This is the core orchestrator that:
|
||||
4. Manages the learning loop between components
|
||||
5. Ensures memory efficiency (8GB constraint)
|
||||
6. Provides real-time COB (Change of Bid) data for models
|
||||
7. Integrates EnhancedRealtimeTrainingSystem for continuous learning
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@ -35,6 +36,14 @@ except ImportError:
|
||||
COBIntegration = None
|
||||
COBSnapshot = None
|
||||
|
||||
# Import EnhancedRealtimeTrainingSystem
|
||||
try:
|
||||
from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
||||
ENHANCED_TRAINING_AVAILABLE = True
|
||||
except ImportError:
|
||||
ENHANCED_TRAINING_AVAILABLE = False
|
||||
EnhancedRealtimeTrainingSystem = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
@ -64,6 +73,7 @@ class TradingOrchestrator:
|
||||
Enhanced Trading Orchestrator with full ML and COB integration
|
||||
Coordinates CNN, DQN, and COB models for advanced trading decisions
|
||||
Features real-time COB (Change of Bid) data for market microstructure data
|
||||
Includes EnhancedRealtimeTrainingSystem for continuous learning
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider: Optional[DataProvider] = None, enhanced_rl_training: bool = True, model_registry: Optional[ModelRegistry] = None):
|
||||
@ -141,17 +151,24 @@ class TradingOrchestrator:
|
||||
self.realtime_processing: bool = False
|
||||
self.realtime_tasks: List[Any] = []
|
||||
|
||||
# ENHANCED: Real-time Training System Integration
|
||||
self.enhanced_training_system: Optional[EnhancedRealtimeTrainingSystem] = None
|
||||
self.training_enabled: bool = enhanced_rl_training and ENHANCED_TRAINING_AVAILABLE
|
||||
|
||||
logger.info("Enhanced TradingOrchestrator initialized with full ML capabilities")
|
||||
logger.info(f"Enhanced RL training: {enhanced_rl_training}")
|
||||
logger.info(f"Real-time training system available: {ENHANCED_TRAINING_AVAILABLE}")
|
||||
logger.info(f"Training enabled: {self.training_enabled}")
|
||||
logger.info(f"Confidence threshold: {self.confidence_threshold}")
|
||||
logger.info(f"Decision frequency: {self.decision_frequency}s")
|
||||
logger.info(f"Symbols: {self.symbols}")
|
||||
logger.info("Universal Data Adapter integrated for centralized data flow")
|
||||
|
||||
# Initialize models and COB integration
|
||||
# Initialize models, COB integration, and training system
|
||||
self._initialize_ml_models()
|
||||
self._initialize_cob_integration()
|
||||
self._initialize_decision_fusion() # Initialize fusion system
|
||||
self._initialize_enhanced_training_system() # Initialize real-time training
|
||||
|
||||
def _initialize_ml_models(self):
|
||||
"""Initialize ML models for enhanced trading"""
|
||||
@ -2391,7 +2408,7 @@ class TradingOrchestrator:
|
||||
|
||||
# ENHANCED: Decision Fusion Methods - Built into orchestrator (NO SEPARATE FILE NEEDED!)
|
||||
def _initialize_decision_fusion(self):
|
||||
"""Initialize the decision fusion neural network"""
|
||||
"""Initialize the decision fusion neural network for learning model effectiveness"""
|
||||
try:
|
||||
if not self.decision_fusion_enabled:
|
||||
return
|
||||
@ -2399,168 +2416,121 @@ class TradingOrchestrator:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Simple decision fusion network
|
||||
# Create decision fusion network
|
||||
class DecisionFusionNet(nn.Module):
|
||||
def __init__(self, input_size=32, hidden_size=64):
|
||||
super().__init__()
|
||||
self.fusion_layers = nn.Sequential(
|
||||
nn.Linear(input_size, hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(hidden_size, hidden_size // 2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size // 2, 16)
|
||||
)
|
||||
self.action_head = nn.Linear(16, 3) # BUY, SELL, HOLD
|
||||
self.confidence_head = nn.Linear(16, 1)
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.fc2 = nn.Linear(hidden_size, hidden_size)
|
||||
self.fc3 = nn.Linear(hidden_size, 3) # BUY, SELL, HOLD
|
||||
self.dropout = nn.Dropout(0.2)
|
||||
|
||||
def forward(self, x):
|
||||
features = self.fusion_layers(x)
|
||||
action_logits = self.action_head(features)
|
||||
confidence = torch.sigmoid(self.confidence_head(features))
|
||||
return action_logits, confidence.squeeze()
|
||||
x = torch.relu(self.fc1(x))
|
||||
x = self.dropout(x)
|
||||
x = torch.relu(self.fc2(x))
|
||||
x = self.dropout(x)
|
||||
return torch.softmax(self.fc3(x), dim=1)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.decision_fusion_network = DecisionFusionNet().to(device)
|
||||
self.fusion_optimizer = torch.optim.Adam(self.decision_fusion_network.parameters(), lr=0.001)
|
||||
self.fusion_device = device
|
||||
|
||||
# Try to load existing checkpoint
|
||||
try:
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("decision")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
checkpoint = torch.load(file_path, map_location=device)
|
||||
if 'model_state_dict' in checkpoint:
|
||||
self.decision_fusion_network.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model_states['decision']['checkpoint_loaded'] = True
|
||||
self.model_states['decision']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
self.model_states['decision']['current_loss'] = metadata.loss or 0.0089
|
||||
self.model_states['decision']['best_loss'] = metadata.loss or 0.0065
|
||||
logger.info(f"Decision fusion checkpoint loaded: {metadata.checkpoint_id} (loss={metadata.loss:.4f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"No decision fusion checkpoint found: {e}")
|
||||
|
||||
logger.info("Decision fusion network initialized in orchestrator - TRAINING ON EVERY SIGNAL!")
|
||||
self.decision_fusion_network = DecisionFusionNet()
|
||||
logger.info("Decision fusion network initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing decision fusion: {e}")
|
||||
logger.warning(f"Decision fusion initialization failed: {e}")
|
||||
self.decision_fusion_enabled = False
|
||||
|
||||
def train_fusion_on_every_signal(self, decision: TradingDecision, market_outcome: Dict):
|
||||
"""Train the decision fusion network on EVERY signal/action - COMPREHENSIVE TRAINING"""
|
||||
|
||||
def _initialize_enhanced_training_system(self):
|
||||
"""Initialize the enhanced real-time training system"""
|
||||
try:
|
||||
if not self.decision_fusion_enabled or not self.decision_fusion_network:
|
||||
if not self.training_enabled:
|
||||
logger.info("Enhanced training system disabled")
|
||||
return
|
||||
|
||||
symbol = decision.symbol
|
||||
if symbol not in self.last_fusion_inputs:
|
||||
if not ENHANCED_TRAINING_AVAILABLE:
|
||||
logger.warning("EnhancedRealtimeTrainingSystem not available - training disabled")
|
||||
self.training_enabled = False
|
||||
return
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Get the features used for this decision
|
||||
fusion_input = self.last_fusion_inputs[symbol]
|
||||
features = fusion_input['features'].to(self.fusion_device)
|
||||
|
||||
# Create training target based on outcome
|
||||
actual_outcome = market_outcome.get('price_change', 0)
|
||||
pnl = market_outcome.get('pnl', 0)
|
||||
|
||||
# Convert decision and outcome to training labels
|
||||
action_target = {'BUY': 0, 'SELL': 1, 'HOLD': 2}[decision.action]
|
||||
|
||||
# Enhanced reward based on actual market movement
|
||||
if decision.action == 'BUY' and actual_outcome > 0:
|
||||
confidence_target = min(0.95, 0.5 + abs(actual_outcome) * 10) # Higher confidence for good predictions
|
||||
elif decision.action == 'SELL' and actual_outcome < 0:
|
||||
confidence_target = min(0.95, 0.5 + abs(actual_outcome) * 10)
|
||||
elif decision.action == 'HOLD':
|
||||
confidence_target = 0.5 # Neutral confidence for hold
|
||||
else:
|
||||
confidence_target = max(0.05, 0.5 - abs(actual_outcome) * 10) # Lower confidence for bad predictions
|
||||
|
||||
# Train the network
|
||||
self.decision_fusion_network.train()
|
||||
self.fusion_optimizer.zero_grad()
|
||||
|
||||
action_logits, predicted_confidence = self.decision_fusion_network(features)
|
||||
|
||||
# Calculate losses
|
||||
action_loss = nn.CrossEntropyLoss()(action_logits, torch.tensor([action_target], device=self.fusion_device))
|
||||
confidence_loss = nn.MSELoss()(predicted_confidence, torch.tensor([confidence_target], device=self.fusion_device))
|
||||
|
||||
total_loss = action_loss + confidence_loss
|
||||
total_loss.backward()
|
||||
self.fusion_optimizer.step()
|
||||
|
||||
# Update model state with REAL loss values
|
||||
self.model_states['decision']['current_loss'] = total_loss.item()
|
||||
if self.model_states['decision']['best_loss'] is None or total_loss.item() < self.model_states['decision']['best_loss']:
|
||||
self.model_states['decision']['best_loss'] = total_loss.item()
|
||||
|
||||
# Store training example
|
||||
self.fusion_training_data.append({
|
||||
'features': features.cpu().numpy(),
|
||||
'action_target': action_target,
|
||||
'confidence_target': confidence_target,
|
||||
'loss': total_loss.item(),
|
||||
'timestamp': datetime.now()
|
||||
})
|
||||
|
||||
# Save checkpoint periodically
|
||||
if self.fusion_decisions_count % self.fusion_checkpoint_frequency == 0:
|
||||
self._save_fusion_checkpoint()
|
||||
|
||||
logger.debug(f"🧠 Fusion training: action_loss={action_loss.item():.4f}, conf_loss={confidence_loss.item():.4f}, total={total_loss.item():.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training fusion network: {e}")
|
||||
|
||||
def _save_fusion_checkpoint(self):
|
||||
"""Save decision fusion checkpoint with real performance data"""
|
||||
try:
|
||||
if not self.decision_fusion_network:
|
||||
return
|
||||
|
||||
from utils.checkpoint_manager import save_checkpoint
|
||||
|
||||
# Prepare checkpoint data
|
||||
checkpoint_data = {
|
||||
'model_state_dict': self.decision_fusion_network.state_dict(),
|
||||
'optimizer_state_dict': self.fusion_optimizer.state_dict(),
|
||||
'fusion_decisions_count': self.fusion_decisions_count,
|
||||
'training_history': self.fusion_training_history[-100:], # Last 100 entries
|
||||
}
|
||||
|
||||
# Calculate REAL performance metrics from actual training
|
||||
recent_losses = [entry['loss'] for entry in self.fusion_training_data[-50:]]
|
||||
avg_loss = sum(recent_losses) / len(recent_losses) if recent_losses else self.model_states['decision']['current_loss']
|
||||
|
||||
performance_metrics = {
|
||||
'loss': avg_loss,
|
||||
'decisions_count': self.fusion_decisions_count,
|
||||
'model_parameters': sum(p.numel() for p in self.decision_fusion_network.parameters())
|
||||
}
|
||||
|
||||
metadata = save_checkpoint(
|
||||
model=checkpoint_data,
|
||||
model_name="decision",
|
||||
model_type="decision_fusion",
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata={'decisions_trained': self.fusion_decisions_count}
|
||||
# Initialize the enhanced training system
|
||||
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
|
||||
orchestrator=self,
|
||||
data_provider=self.data_provider,
|
||||
dashboard=None # Will be set by dashboard when available
|
||||
)
|
||||
|
||||
if metadata:
|
||||
self.model_states['decision']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
logger.info(f"🧠 Decision fusion checkpoint saved: {metadata.checkpoint_id} (loss={avg_loss:.4f})")
|
||||
logger.info("Enhanced real-time training system initialized")
|
||||
logger.info(" - Real-time model training: ENABLED")
|
||||
logger.info(" - Comprehensive feature extraction: ENABLED")
|
||||
logger.info(" - Enhanced reward calculation: ENABLED")
|
||||
logger.info(" - Forward-looking predictions: ENABLED")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving fusion checkpoint: {e}")
|
||||
|
||||
logger.error(f"Error initializing enhanced training system: {e}")
|
||||
self.training_enabled = False
|
||||
self.enhanced_training_system = None
|
||||
|
||||
def start_enhanced_training(self):
|
||||
"""Start the enhanced real-time training system"""
|
||||
try:
|
||||
if not self.training_enabled or not self.enhanced_training_system:
|
||||
logger.warning("Enhanced training system not available")
|
||||
return False
|
||||
|
||||
self.enhanced_training_system.start_training()
|
||||
logger.info("Enhanced real-time training started")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting enhanced training: {e}")
|
||||
return False
|
||||
|
||||
def stop_enhanced_training(self):
|
||||
"""Stop the enhanced real-time training system"""
|
||||
try:
|
||||
if self.enhanced_training_system:
|
||||
self.enhanced_training_system.stop_training()
|
||||
logger.info("Enhanced real-time training stopped")
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping enhanced training: {e}")
|
||||
return False
|
||||
|
||||
def get_enhanced_training_stats(self) -> Dict[str, Any]:
|
||||
"""Get enhanced training system statistics"""
|
||||
try:
|
||||
if not self.enhanced_training_system:
|
||||
return {
|
||||
'training_enabled': False,
|
||||
'system_available': ENHANCED_TRAINING_AVAILABLE,
|
||||
'error': 'Training system not initialized'
|
||||
}
|
||||
|
||||
stats = self.enhanced_training_system.get_training_statistics()
|
||||
stats['training_enabled'] = self.training_enabled
|
||||
stats['system_available'] = ENHANCED_TRAINING_AVAILABLE
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting training stats: {e}")
|
||||
return {
|
||||
'training_enabled': self.training_enabled,
|
||||
'system_available': ENHANCED_TRAINING_AVAILABLE,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
def set_training_dashboard(self, dashboard):
|
||||
"""Set the dashboard reference for the training system"""
|
||||
try:
|
||||
if self.enhanced_training_system:
|
||||
self.enhanced_training_system.dashboard = dashboard
|
||||
logger.info("Dashboard reference set for enhanced training system")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting training dashboard: {e}")
|
||||
|
||||
def get_universal_data_stream(self, current_time: datetime = None) -> Optional[UniversalDataStream]:
|
||||
"""Get universal data stream for external consumers like dashboard"""
|
||||
try:
|
||||
|
@ -873,7 +873,7 @@ class RealtimeRLCOBTrader:
|
||||
# Penalize for large predicted changes that are wrong
|
||||
if predicted_direction != actual_direction and abs(predicted_change) > 0.001:
|
||||
reward -= abs(predicted_change) * 2.0
|
||||
|
||||
|
||||
# Add reward for PnL (realized or unrealized)
|
||||
reward += current_pnl * 0.1 # Small reward for PnL, adjusted by a factor
|
||||
|
||||
|
@ -219,7 +219,7 @@ class TradingExecutor:
|
||||
quote_asset = 'USDC'
|
||||
else:
|
||||
# Fallback for symbols like ETHUSDT (assuming last 4 chars are quote)
|
||||
quote_asset = symbol[-4:].upper()
|
||||
quote_asset = symbol[-4:].upper()
|
||||
# Convert USDT to USDC for MEXC spot trading
|
||||
if quote_asset == 'USDT':
|
||||
quote_asset = 'USDC'
|
||||
@ -423,7 +423,7 @@ class TradingExecutor:
|
||||
# Calculate simulated fees in simulation mode
|
||||
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
|
||||
simulated_fees = position.quantity * current_price * taker_fee_rate
|
||||
|
||||
|
||||
# Create trade record
|
||||
trade_record = TradeRecord(
|
||||
symbol=symbol,
|
||||
|
Reference in New Issue
Block a user