#!/usr/bin/env python3 """ Training Integration for Checkpoint Management """ import logging import torch from datetime import datetime from typing import Dict, Any, Optional from pathlib import Path from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_best_checkpoint logger = logging.getLogger(__name__) class TrainingIntegration: def __init__(self, enable_wandb: bool = True): self.checkpoint_manager = get_checkpoint_manager() self.enable_wandb = enable_wandb if self.enable_wandb: self._init_wandb() def _init_wandb(self): try: import wandb if wandb.run is None: wandb.init( project="gogo2-trading", name=f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}", config={ "max_checkpoints_per_model": self.checkpoint_manager.max_checkpoints, "checkpoint_dir": str(self.checkpoint_manager.base_dir) } ) logger.info(f"Initialized W&B run: {wandb.run.id}") except ImportError: logger.warning("W&B not available - checkpoint management will work without it") except Exception as e: logger.error(f"Error initializing W&B: {e}") def save_cnn_checkpoint(self, cnn_model, model_name: str, epoch: int, train_accuracy: float, val_accuracy: float, train_loss: float, val_loss: float, training_time_hours: float = None) -> bool: try: performance_metrics = { 'accuracy': train_accuracy, 'val_accuracy': val_accuracy, 'loss': train_loss, 'val_loss': val_loss } training_metadata = { 'epoch': epoch, 'training_time_hours': training_time_hours, 'total_parameters': self._count_parameters(cnn_model) } if self.enable_wandb: try: import wandb if wandb.run is not None: wandb.log({ f"{model_name}/train_accuracy": train_accuracy, f"{model_name}/val_accuracy": val_accuracy, f"{model_name}/train_loss": train_loss, f"{model_name}/val_loss": val_loss, f"{model_name}/epoch": epoch }) except Exception as e: logger.warning(f"Error logging to W&B: {e}") metadata = save_checkpoint( model=cnn_model, model_name=model_name, model_type='cnn', performance_metrics=performance_metrics, training_metadata=training_metadata ) if metadata: logger.info(f"CNN checkpoint saved: {metadata.checkpoint_id}") return True else: logger.info(f"CNN checkpoint not saved (performance not improved)") return False except Exception as e: logger.error(f"Error saving CNN checkpoint: {e}") return False def save_rl_checkpoint(self, rl_agent, model_name: str, episode: int, avg_reward: float, best_reward: float, epsilon: float, total_pnl: float = None) -> bool: try: performance_metrics = { 'reward': avg_reward, 'best_reward': best_reward } if total_pnl is not None: performance_metrics['pnl'] = total_pnl training_metadata = { 'episode': episode, 'epsilon': epsilon, 'total_parameters': self._count_parameters(rl_agent) } if self.enable_wandb: try: import wandb if wandb.run is not None: wandb.log({ f"{model_name}/avg_reward": avg_reward, f"{model_name}/best_reward": best_reward, f"{model_name}/epsilon": epsilon, f"{model_name}/episode": episode }) if total_pnl is not None: wandb.log({f"{model_name}/total_pnl": total_pnl}) except Exception as e: logger.warning(f"Error logging to W&B: {e}") metadata = save_checkpoint( model=rl_agent, model_name=model_name, model_type='rl', performance_metrics=performance_metrics, training_metadata=training_metadata ) if metadata: logger.info(f"RL checkpoint saved: {metadata.checkpoint_id}") return True else: logger.info(f"RL checkpoint not saved (performance not improved)") return False except Exception as e: logger.error(f"Error saving RL checkpoint: {e}") return False def load_best_model(self, model_name: str, model_class=None): try: result = load_best_checkpoint(model_name) if not result: logger.warning(f"No checkpoint found for model: {model_name}") return None file_path, metadata = result checkpoint = torch.load(file_path, map_location='cpu') logger.info(f"Loaded best checkpoint for {model_name}:") logger.info(f" Performance score: {metadata.performance_score:.4f}") logger.info(f" Created: {metadata.created_at}") if model_class and 'model_state_dict' in checkpoint: model = model_class() model.load_state_dict(checkpoint['model_state_dict']) return model return checkpoint except Exception as e: logger.error(f"Error loading best model {model_name}: {e}") return None def _count_parameters(self, model) -> int: try: if hasattr(model, 'parameters'): return sum(p.numel() for p in model.parameters()) elif hasattr(model, 'policy_net'): policy_params = sum(p.numel() for p in model.policy_net.parameters()) target_params = sum(p.numel() for p in model.target_net.parameters()) if hasattr(model, 'target_net') else 0 return policy_params + target_params else: return 0 except Exception: return 0 _training_integration = None def get_training_integration() -> TrainingIntegration: global _training_integration if _training_integration is None: _training_integration = TrainingIntegration() return _training_integration