This commit is contained in:
Dobromir Popov
2025-07-22 16:13:42 +03:00
parent c63dc11c14
commit cc0c783411
2 changed files with 101 additions and 124 deletions

View File

@ -220,6 +220,11 @@ class TradingOrchestrator:
self.data_provider.start_centralized_data_collection()
logger.info("Centralized data collection started - all models and dashboard will receive data")
# CRITICAL: Initialize checkpoint manager for saving training progress
self.checkpoint_manager = None
self.training_iterations = 0 # Track training iterations for periodic saves
self._initialize_checkpoint_manager()
# Initialize models, COB integration, and training system
self._initialize_ml_models()
self._initialize_cob_integration()
@ -2145,6 +2150,9 @@ class TradingOrchestrator:
if not market_data:
return
# Track if any model was trained for checkpoint saving
models_trained = []
# Train DQN agent if available
if self.rl_agent and hasattr(self.rl_agent, 'add_experience'):
try:
@ -2167,6 +2175,7 @@ class TradingOrchestrator:
done=False
)
models_trained.append('dqn')
logger.debug(f"🧠 Added DQN experience: {action} {symbol} (reward: {immediate_reward:.3f})")
except Exception as e:
@ -2185,6 +2194,7 @@ class TradingOrchestrator:
# Add training sample
self.cnn_model.add_training_sample(cnn_features, target, weight=confidence)
models_trained.append('cnn')
logger.debug(f"🔍 Added CNN training sample: {action} {symbol}")
except Exception as e:
@ -2206,14 +2216,105 @@ class TradingOrchestrator:
symbol=symbol
)
models_trained.append('cob_rl')
logger.debug(f"📊 Added COB RL experience: {action} {symbol}")
except Exception as e:
logger.debug(f"Error training COB RL on decision: {e}")
# CRITICAL FIX: Save checkpoints after training
if models_trained:
self._save_training_checkpoints(models_trained, confidence)
except Exception as e:
logger.error(f"Error training models on decision: {e}")
def _save_training_checkpoints(self, models_trained: List[str], performance_score: float):
"""Save checkpoints for trained models if performance improved
This is CRITICAL for preserving training progress across restarts.
"""
try:
if not self.checkpoint_manager:
return
# Increment training counter
self.training_iterations += 1
# Save checkpoints for each trained model
for model_name in models_trained:
try:
model_obj = None
current_loss = None
# Get model object and calculate current performance
if model_name == 'dqn' and self.rl_agent:
model_obj = self.rl_agent
# Use negative performance score as loss (higher confidence = lower loss)
current_loss = 1.0 - performance_score
elif model_name == 'cnn' and self.cnn_model:
model_obj = self.cnn_model
current_loss = 1.0 - performance_score
elif model_name == 'cob_rl' and self.cob_rl_agent:
model_obj = self.cob_rl_agent
current_loss = 1.0 - performance_score
if model_obj and current_loss is not None:
# Check if this is the best performance so far
model_state = self.model_states.get(model_name, {})
best_loss = model_state.get('best_loss', float('inf'))
# Update current loss
model_state['current_loss'] = current_loss
model_state['last_training'] = datetime.now()
# Save checkpoint if performance improved or periodic save
should_save = (
current_loss < best_loss or # Performance improved
self.training_iterations % 100 == 0 # Periodic save every 100 iterations
)
if should_save:
# Prepare metadata
metadata = {
'loss': current_loss,
'performance_score': performance_score,
'training_iterations': self.training_iterations,
'timestamp': datetime.now().isoformat(),
'model_type': model_name
}
# Save checkpoint
checkpoint_path = self.checkpoint_manager.save_checkpoint(
model=model_obj,
model_name=model_name,
performance=current_loss,
metadata=metadata
)
if checkpoint_path:
# Update best performance
if current_loss < best_loss:
model_state['best_loss'] = current_loss
model_state['best_checkpoint'] = checkpoint_path
logger.info(f"💾 Saved BEST checkpoint for {model_name}: {checkpoint_path} (loss: {current_loss:.4f})")
else:
logger.debug(f"💾 Saved periodic checkpoint for {model_name}: {checkpoint_path}")
model_state['last_checkpoint'] = checkpoint_path
model_state['checkpoints_saved'] = model_state.get('checkpoints_saved', 0) + 1
# Update model state
self.model_states[model_name] = model_state
except Exception as e:
logger.error(f"Error saving checkpoint for {model_name}: {e}")
except Exception as e:
logger.error(f"Error saving training checkpoints: {e}")
def _get_current_market_data(self, symbol: str) -> Optional[Dict]:
"""Get current market data for training context"""
try: