cp man
This commit is contained in:
@ -220,6 +220,11 @@ class TradingOrchestrator:
|
||||
self.data_provider.start_centralized_data_collection()
|
||||
logger.info("Centralized data collection started - all models and dashboard will receive data")
|
||||
|
||||
# CRITICAL: Initialize checkpoint manager for saving training progress
|
||||
self.checkpoint_manager = None
|
||||
self.training_iterations = 0 # Track training iterations for periodic saves
|
||||
self._initialize_checkpoint_manager()
|
||||
|
||||
# Initialize models, COB integration, and training system
|
||||
self._initialize_ml_models()
|
||||
self._initialize_cob_integration()
|
||||
@ -2145,6 +2150,9 @@ class TradingOrchestrator:
|
||||
if not market_data:
|
||||
return
|
||||
|
||||
# Track if any model was trained for checkpoint saving
|
||||
models_trained = []
|
||||
|
||||
# Train DQN agent if available
|
||||
if self.rl_agent and hasattr(self.rl_agent, 'add_experience'):
|
||||
try:
|
||||
@ -2167,6 +2175,7 @@ class TradingOrchestrator:
|
||||
done=False
|
||||
)
|
||||
|
||||
models_trained.append('dqn')
|
||||
logger.debug(f"🧠 Added DQN experience: {action} {symbol} (reward: {immediate_reward:.3f})")
|
||||
|
||||
except Exception as e:
|
||||
@ -2185,6 +2194,7 @@ class TradingOrchestrator:
|
||||
# Add training sample
|
||||
self.cnn_model.add_training_sample(cnn_features, target, weight=confidence)
|
||||
|
||||
models_trained.append('cnn')
|
||||
logger.debug(f"🔍 Added CNN training sample: {action} {symbol}")
|
||||
|
||||
except Exception as e:
|
||||
@ -2206,14 +2216,105 @@ class TradingOrchestrator:
|
||||
symbol=symbol
|
||||
)
|
||||
|
||||
models_trained.append('cob_rl')
|
||||
logger.debug(f"📊 Added COB RL experience: {action} {symbol}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training COB RL on decision: {e}")
|
||||
|
||||
# CRITICAL FIX: Save checkpoints after training
|
||||
if models_trained:
|
||||
self._save_training_checkpoints(models_trained, confidence)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training models on decision: {e}")
|
||||
|
||||
def _save_training_checkpoints(self, models_trained: List[str], performance_score: float):
|
||||
"""Save checkpoints for trained models if performance improved
|
||||
|
||||
This is CRITICAL for preserving training progress across restarts.
|
||||
"""
|
||||
try:
|
||||
if not self.checkpoint_manager:
|
||||
return
|
||||
|
||||
# Increment training counter
|
||||
self.training_iterations += 1
|
||||
|
||||
# Save checkpoints for each trained model
|
||||
for model_name in models_trained:
|
||||
try:
|
||||
model_obj = None
|
||||
current_loss = None
|
||||
|
||||
# Get model object and calculate current performance
|
||||
if model_name == 'dqn' and self.rl_agent:
|
||||
model_obj = self.rl_agent
|
||||
# Use negative performance score as loss (higher confidence = lower loss)
|
||||
current_loss = 1.0 - performance_score
|
||||
|
||||
elif model_name == 'cnn' and self.cnn_model:
|
||||
model_obj = self.cnn_model
|
||||
current_loss = 1.0 - performance_score
|
||||
|
||||
elif model_name == 'cob_rl' and self.cob_rl_agent:
|
||||
model_obj = self.cob_rl_agent
|
||||
current_loss = 1.0 - performance_score
|
||||
|
||||
if model_obj and current_loss is not None:
|
||||
# Check if this is the best performance so far
|
||||
model_state = self.model_states.get(model_name, {})
|
||||
best_loss = model_state.get('best_loss', float('inf'))
|
||||
|
||||
# Update current loss
|
||||
model_state['current_loss'] = current_loss
|
||||
model_state['last_training'] = datetime.now()
|
||||
|
||||
# Save checkpoint if performance improved or periodic save
|
||||
should_save = (
|
||||
current_loss < best_loss or # Performance improved
|
||||
self.training_iterations % 100 == 0 # Periodic save every 100 iterations
|
||||
)
|
||||
|
||||
if should_save:
|
||||
# Prepare metadata
|
||||
metadata = {
|
||||
'loss': current_loss,
|
||||
'performance_score': performance_score,
|
||||
'training_iterations': self.training_iterations,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'model_type': model_name
|
||||
}
|
||||
|
||||
# Save checkpoint
|
||||
checkpoint_path = self.checkpoint_manager.save_checkpoint(
|
||||
model=model_obj,
|
||||
model_name=model_name,
|
||||
performance=current_loss,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
if checkpoint_path:
|
||||
# Update best performance
|
||||
if current_loss < best_loss:
|
||||
model_state['best_loss'] = current_loss
|
||||
model_state['best_checkpoint'] = checkpoint_path
|
||||
logger.info(f"💾 Saved BEST checkpoint for {model_name}: {checkpoint_path} (loss: {current_loss:.4f})")
|
||||
else:
|
||||
logger.debug(f"💾 Saved periodic checkpoint for {model_name}: {checkpoint_path}")
|
||||
|
||||
model_state['last_checkpoint'] = checkpoint_path
|
||||
model_state['checkpoints_saved'] = model_state.get('checkpoints_saved', 0) + 1
|
||||
|
||||
# Update model state
|
||||
self.model_states[model_name] = model_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving training checkpoints: {e}")
|
||||
|
||||
def _get_current_market_data(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get current market data for training context"""
|
||||
try:
|
||||
|
Reference in New Issue
Block a user