182 lines
6.7 KiB
Python
182 lines
6.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Training Integration for Checkpoint Management
|
|
"""
|
|
|
|
import logging
|
|
import torch
|
|
from datetime import datetime
|
|
from typing import Dict, Any, Optional
|
|
from pathlib import Path
|
|
|
|
from .checkpoint_manager import get_checkpoint_manager, load_best_checkpoint
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TrainingIntegration:
|
|
def __init__(self, enable_wandb: bool = True):
|
|
self.checkpoint_manager = get_checkpoint_manager()
|
|
|
|
|
|
def save_cnn_checkpoint(self,
|
|
cnn_model,
|
|
model_name: str,
|
|
epoch: int,
|
|
train_accuracy: float,
|
|
val_accuracy: float,
|
|
train_loss: float,
|
|
val_loss: float,
|
|
training_time_hours: float = None) -> bool:
|
|
try:
|
|
performance_metrics = {
|
|
'accuracy': train_accuracy,
|
|
'val_accuracy': val_accuracy,
|
|
'loss': train_loss,
|
|
'val_loss': val_loss
|
|
}
|
|
|
|
training_metadata = {
|
|
'epoch': epoch,
|
|
'training_time_hours': training_time_hours,
|
|
'total_parameters': self._count_parameters(cnn_model)
|
|
}
|
|
|
|
if self.enable_wandb:
|
|
try:
|
|
import wandb
|
|
if wandb.run is not None:
|
|
wandb.log({
|
|
f"{model_name}/train_accuracy": train_accuracy,
|
|
f"{model_name}/val_accuracy": val_accuracy,
|
|
f"{model_name}/train_loss": train_loss,
|
|
f"{model_name}/val_loss": val_loss,
|
|
f"{model_name}/epoch": epoch
|
|
})
|
|
except Exception as e:
|
|
logger.warning(f"Error logging to W&B: {e}")
|
|
|
|
metadata = self.checkpoint_manager.save_checkpoint(
|
|
model=cnn_model,
|
|
model_name=model_name,
|
|
model_type='cnn',
|
|
performance_metrics=performance_metrics,
|
|
training_metadata=training_metadata
|
|
)
|
|
|
|
if metadata:
|
|
logger.info(f"CNN checkpoint saved: {metadata.checkpoint_id}")
|
|
return True
|
|
else:
|
|
logger.info(f"CNN checkpoint not saved (performance not improved)")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving CNN checkpoint: {e}")
|
|
return False
|
|
|
|
def save_rl_checkpoint(self,
|
|
rl_agent,
|
|
model_name: str,
|
|
episode: int,
|
|
avg_reward: float,
|
|
best_reward: float,
|
|
epsilon: float,
|
|
total_pnl: float = None) -> bool:
|
|
try:
|
|
performance_metrics = {
|
|
'reward': avg_reward,
|
|
'best_reward': best_reward
|
|
}
|
|
|
|
if total_pnl is not None:
|
|
performance_metrics['pnl'] = total_pnl
|
|
|
|
training_metadata = {
|
|
'episode': episode,
|
|
'epsilon': epsilon,
|
|
'total_parameters': self._count_parameters(rl_agent)
|
|
}
|
|
|
|
if self.enable_wandb:
|
|
try:
|
|
import wandb
|
|
if wandb.run is not None:
|
|
wandb.log({
|
|
f"{model_name}/avg_reward": avg_reward,
|
|
f"{model_name}/best_reward": best_reward,
|
|
f"{model_name}/epsilon": epsilon,
|
|
f"{model_name}/episode": episode
|
|
})
|
|
|
|
if total_pnl is not None:
|
|
wandb.log({f"{model_name}/total_pnl": total_pnl})
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error logging to W&B: {e}")
|
|
|
|
metadata = self.checkpoint_manager.save_checkpoint(
|
|
model=rl_agent,
|
|
model_name=model_name,
|
|
model_type='rl',
|
|
performance_metrics=performance_metrics,
|
|
training_metadata=training_metadata
|
|
)
|
|
|
|
if metadata:
|
|
logger.info(f"RL checkpoint saved: {metadata.checkpoint_id}")
|
|
return True
|
|
else:
|
|
logger.info(f"RL checkpoint not saved (performance not improved)")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving RL checkpoint: {e}")
|
|
return False
|
|
|
|
def load_best_model(self, model_name: str, model_class=None):
|
|
try:
|
|
result = self.checkpoint_manager.load_best_checkpoint(model_name)
|
|
if not result:
|
|
logger.warning(f"No checkpoint found for model: {model_name}")
|
|
return None
|
|
|
|
file_path, metadata = result
|
|
|
|
checkpoint = torch.load(file_path, map_location='cpu')
|
|
|
|
logger.info(f"Loaded best checkpoint for {model_name}:")
|
|
logger.info(f" Performance score: {metadata.performance_score:.4f}")
|
|
logger.info(f" Created: {metadata.created_at}")
|
|
|
|
if model_class and 'model_state_dict' in checkpoint:
|
|
model = model_class()
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
return model
|
|
|
|
return checkpoint
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading best model {model_name}: {e}")
|
|
return None
|
|
|
|
def _count_parameters(self, model) -> int:
|
|
try:
|
|
if hasattr(model, 'parameters'):
|
|
return sum(p.numel() for p in model.parameters())
|
|
elif hasattr(model, 'policy_net'):
|
|
policy_params = sum(p.numel() for p in model.policy_net.parameters())
|
|
target_params = sum(p.numel() for p in model.target_net.parameters()) if hasattr(model, 'target_net') else 0
|
|
return policy_params + target_params
|
|
else:
|
|
return 0
|
|
except Exception:
|
|
return 0
|
|
|
|
_training_integration = None
|
|
|
|
def get_training_integration() -> TrainingIntegration:
|
|
global _training_integration
|
|
if _training_integration is None:
|
|
_training_integration = TrainingIntegration()
|
|
return _training_integration
|