checkpoint manager
This commit is contained in:
306
utils/checkpoint_manager.py
Normal file
306
utils/checkpoint_manager.py
Normal file
@ -0,0 +1,306 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Checkpoint Management System for W&B Training
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
|
||||
try:
|
||||
import wandb
|
||||
WANDB_AVAILABLE = True
|
||||
except ImportError:
|
||||
WANDB_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class CheckpointMetadata:
|
||||
checkpoint_id: str
|
||||
model_name: str
|
||||
model_type: str
|
||||
file_path: str
|
||||
created_at: datetime
|
||||
file_size_mb: float
|
||||
performance_score: float
|
||||
accuracy: Optional[float] = None
|
||||
loss: Optional[float] = None
|
||||
val_accuracy: Optional[float] = None
|
||||
val_loss: Optional[float] = None
|
||||
reward: Optional[float] = None
|
||||
pnl: Optional[float] = None
|
||||
epoch: Optional[int] = None
|
||||
training_time_hours: Optional[float] = None
|
||||
total_parameters: Optional[int] = None
|
||||
wandb_run_id: Optional[str] = None
|
||||
wandb_artifact_name: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
data = asdict(self)
|
||||
data['created_at'] = self.created_at.isoformat()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata':
|
||||
data['created_at'] = datetime.fromisoformat(data['created_at'])
|
||||
return cls(**data)
|
||||
|
||||
class CheckpointManager:
|
||||
def __init__(self,
|
||||
base_checkpoint_dir: str = "NN/models/saved",
|
||||
max_checkpoints_per_model: int = 5,
|
||||
metadata_file: str = "checkpoint_metadata.json",
|
||||
enable_wandb: bool = True):
|
||||
self.base_dir = Path(base_checkpoint_dir)
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.max_checkpoints = max_checkpoints_per_model
|
||||
self.metadata_file = self.base_dir / metadata_file
|
||||
self.enable_wandb = enable_wandb and WANDB_AVAILABLE
|
||||
|
||||
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
|
||||
self._load_metadata()
|
||||
|
||||
logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}")
|
||||
|
||||
def save_checkpoint(self, model, model_name: str, model_type: str,
|
||||
performance_metrics: Dict[str, float],
|
||||
training_metadata: Optional[Dict[str, Any]] = None,
|
||||
force_save: bool = False) -> Optional[CheckpointMetadata]:
|
||||
try:
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
checkpoint_id = f"{model_name}_{timestamp}"
|
||||
|
||||
model_dir = self.base_dir / model_name
|
||||
model_dir.mkdir(exist_ok=True)
|
||||
|
||||
checkpoint_path = model_dir / f"{checkpoint_id}.pt"
|
||||
|
||||
performance_score = self._calculate_performance_score(performance_metrics)
|
||||
|
||||
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
|
||||
logger.info(f"Skipping checkpoint save for {model_name} - performance not improved")
|
||||
return None
|
||||
|
||||
success = self._save_model_file(model, checkpoint_path, model_type)
|
||||
if not success:
|
||||
return None
|
||||
|
||||
file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
metadata = CheckpointMetadata(
|
||||
checkpoint_id=checkpoint_id,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
file_path=str(checkpoint_path),
|
||||
created_at=datetime.now(),
|
||||
file_size_mb=file_size_mb,
|
||||
performance_score=performance_score,
|
||||
accuracy=performance_metrics.get('accuracy'),
|
||||
loss=performance_metrics.get('loss'),
|
||||
val_accuracy=performance_metrics.get('val_accuracy'),
|
||||
val_loss=performance_metrics.get('val_loss'),
|
||||
reward=performance_metrics.get('reward'),
|
||||
pnl=performance_metrics.get('pnl'),
|
||||
epoch=training_metadata.get('epoch') if training_metadata else None,
|
||||
training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None,
|
||||
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
|
||||
)
|
||||
|
||||
if self.enable_wandb and wandb.run is not None:
|
||||
artifact_name = self._upload_to_wandb(checkpoint_path, metadata)
|
||||
metadata.wandb_run_id = wandb.run.id
|
||||
metadata.wandb_artifact_name = artifact_name
|
||||
|
||||
self.checkpoints[model_name].append(metadata)
|
||||
self._rotate_checkpoints(model_name)
|
||||
self._save_metadata()
|
||||
|
||||
logger.info(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})")
|
||||
return metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
||||
try:
|
||||
if model_name not in self.checkpoints or not self.checkpoints[model_name]:
|
||||
logger.warning(f"No checkpoints found for model: {model_name}")
|
||||
return None
|
||||
|
||||
best_checkpoint = max(self.checkpoints[model_name], key=lambda x: x.performance_score)
|
||||
|
||||
if not Path(best_checkpoint.file_path).exists():
|
||||
logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}")
|
||||
return None
|
||||
|
||||
logger.info(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}")
|
||||
return best_checkpoint.file_path, best_checkpoint
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def _calculate_performance_score(self, metrics: Dict[str, float]) -> float:
|
||||
score = 0.0
|
||||
|
||||
if 'accuracy' in metrics:
|
||||
score += metrics['accuracy'] * 100
|
||||
if 'val_accuracy' in metrics:
|
||||
score += metrics['val_accuracy'] * 100
|
||||
if 'loss' in metrics:
|
||||
score += max(0, 10 - metrics['loss'])
|
||||
if 'val_loss' in metrics:
|
||||
score += max(0, 10 - metrics['val_loss'])
|
||||
if 'reward' in metrics:
|
||||
score += metrics['reward']
|
||||
if 'pnl' in metrics:
|
||||
score += metrics['pnl']
|
||||
|
||||
if score == 0.0 and metrics:
|
||||
first_metric = next(iter(metrics.values()))
|
||||
score = first_metric if first_metric > 0 else 0.1
|
||||
|
||||
return max(score, 0.1)
|
||||
|
||||
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
|
||||
if model_name not in self.checkpoints or not self.checkpoints[model_name]:
|
||||
return True
|
||||
|
||||
if len(self.checkpoints[model_name]) < self.max_checkpoints:
|
||||
return True
|
||||
|
||||
worst_score = min(cp.performance_score for cp in self.checkpoints[model_name])
|
||||
return performance_score > worst_score
|
||||
|
||||
def _save_model_file(self, model, file_path: Path, model_type: str) -> bool:
|
||||
try:
|
||||
if hasattr(model, 'state_dict'):
|
||||
torch.save({
|
||||
'model_state_dict': model.state_dict(),
|
||||
'model_type': model_type,
|
||||
'saved_at': datetime.now().isoformat()
|
||||
}, file_path)
|
||||
else:
|
||||
torch.save(model, file_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model file {file_path}: {e}")
|
||||
return False
|
||||
|
||||
def _rotate_checkpoints(self, model_name: str):
|
||||
checkpoint_list = self.checkpoints[model_name]
|
||||
|
||||
if len(checkpoint_list) <= self.max_checkpoints:
|
||||
return
|
||||
|
||||
checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True)
|
||||
|
||||
to_remove = checkpoint_list[self.max_checkpoints:]
|
||||
self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints]
|
||||
|
||||
for checkpoint in to_remove:
|
||||
try:
|
||||
file_path = Path(checkpoint.file_path)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
logger.info(f"Rotated out checkpoint: {checkpoint.checkpoint_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
|
||||
|
||||
def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]:
|
||||
try:
|
||||
if not self.enable_wandb or wandb.run is None:
|
||||
return None
|
||||
|
||||
artifact_name = f"{metadata.model_name}_checkpoint"
|
||||
artifact = wandb.Artifact(artifact_name, type="model")
|
||||
artifact.add_file(str(file_path))
|
||||
wandb.log_artifact(artifact)
|
||||
|
||||
return artifact_name
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading to W&B: {e}")
|
||||
return None
|
||||
|
||||
def _load_metadata(self):
|
||||
try:
|
||||
if self.metadata_file.exists():
|
||||
with open(self.metadata_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
for model_name, checkpoint_list in data.items():
|
||||
self.checkpoints[model_name] = [
|
||||
CheckpointMetadata.from_dict(cp_data)
|
||||
for cp_data in checkpoint_list
|
||||
]
|
||||
|
||||
logger.info(f"Loaded metadata for {len(self.checkpoints)} models")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checkpoint metadata: {e}")
|
||||
|
||||
def _save_metadata(self):
|
||||
try:
|
||||
data = {}
|
||||
for model_name, checkpoint_list in self.checkpoints.items():
|
||||
data[model_name] = [cp.to_dict() for cp in checkpoint_list]
|
||||
|
||||
with open(self.metadata_file, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint metadata: {e}")
|
||||
|
||||
def get_checkpoint_stats(self):
|
||||
"""Get statistics about managed checkpoints"""
|
||||
stats = {
|
||||
'total_models': len(self.checkpoints),
|
||||
'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()),
|
||||
'total_size_mb': 0.0,
|
||||
'models': {}
|
||||
}
|
||||
|
||||
for model_name, checkpoint_list in self.checkpoints.items():
|
||||
if not checkpoint_list:
|
||||
continue
|
||||
|
||||
model_size = sum(cp.file_size_mb for cp in checkpoint_list)
|
||||
best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score)
|
||||
|
||||
stats['models'][model_name] = {
|
||||
'checkpoint_count': len(checkpoint_list),
|
||||
'total_size_mb': model_size,
|
||||
'best_performance': best_checkpoint.performance_score,
|
||||
'best_checkpoint_id': best_checkpoint.checkpoint_id,
|
||||
'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id
|
||||
}
|
||||
|
||||
stats['total_size_mb'] += model_size
|
||||
|
||||
return stats
|
||||
|
||||
_checkpoint_manager = None
|
||||
|
||||
def get_checkpoint_manager() -> CheckpointManager:
|
||||
global _checkpoint_manager
|
||||
if _checkpoint_manager is None:
|
||||
_checkpoint_manager = CheckpointManager()
|
||||
return _checkpoint_manager
|
||||
|
||||
def save_checkpoint(model, model_name: str, model_type: str,
|
||||
performance_metrics: Dict[str, float],
|
||||
training_metadata: Optional[Dict[str, Any]] = None,
|
||||
force_save: bool = False) -> Optional[CheckpointMetadata]:
|
||||
return get_checkpoint_manager().save_checkpoint(
|
||||
model, model_name, model_type, performance_metrics, training_metadata, force_save
|
||||
)
|
||||
|
||||
def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
||||
return get_checkpoint_manager().load_best_checkpoint(model_name)
|
204
utils/training_integration.py
Normal file
204
utils/training_integration.py
Normal file
@ -0,0 +1,204 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training Integration for Checkpoint Management
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_best_checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TrainingIntegration:
|
||||
def __init__(self, enable_wandb: bool = True):
|
||||
self.checkpoint_manager = get_checkpoint_manager()
|
||||
self.enable_wandb = enable_wandb
|
||||
|
||||
if self.enable_wandb:
|
||||
self._init_wandb()
|
||||
|
||||
def _init_wandb(self):
|
||||
try:
|
||||
import wandb
|
||||
|
||||
if wandb.run is None:
|
||||
wandb.init(
|
||||
project="gogo2-trading",
|
||||
name=f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||
config={
|
||||
"max_checkpoints_per_model": self.checkpoint_manager.max_checkpoints,
|
||||
"checkpoint_dir": str(self.checkpoint_manager.base_dir)
|
||||
}
|
||||
)
|
||||
logger.info(f"Initialized W&B run: {wandb.run.id}")
|
||||
|
||||
except ImportError:
|
||||
logger.warning("W&B not available - checkpoint management will work without it")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing W&B: {e}")
|
||||
|
||||
def save_cnn_checkpoint(self,
|
||||
cnn_model,
|
||||
model_name: str,
|
||||
epoch: int,
|
||||
train_accuracy: float,
|
||||
val_accuracy: float,
|
||||
train_loss: float,
|
||||
val_loss: float,
|
||||
training_time_hours: float = None) -> bool:
|
||||
try:
|
||||
performance_metrics = {
|
||||
'accuracy': train_accuracy,
|
||||
'val_accuracy': val_accuracy,
|
||||
'loss': train_loss,
|
||||
'val_loss': val_loss
|
||||
}
|
||||
|
||||
training_metadata = {
|
||||
'epoch': epoch,
|
||||
'training_time_hours': training_time_hours,
|
||||
'total_parameters': self._count_parameters(cnn_model)
|
||||
}
|
||||
|
||||
if self.enable_wandb:
|
||||
try:
|
||||
import wandb
|
||||
if wandb.run is not None:
|
||||
wandb.log({
|
||||
f"{model_name}/train_accuracy": train_accuracy,
|
||||
f"{model_name}/val_accuracy": val_accuracy,
|
||||
f"{model_name}/train_loss": train_loss,
|
||||
f"{model_name}/val_loss": val_loss,
|
||||
f"{model_name}/epoch": epoch
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging to W&B: {e}")
|
||||
|
||||
metadata = save_checkpoint(
|
||||
model=cnn_model,
|
||||
model_name=model_name,
|
||||
model_type='cnn',
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f"CNN checkpoint saved: {metadata.checkpoint_id}")
|
||||
return True
|
||||
else:
|
||||
logger.info(f"CNN checkpoint not saved (performance not improved)")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving CNN checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def save_rl_checkpoint(self,
|
||||
rl_agent,
|
||||
model_name: str,
|
||||
episode: int,
|
||||
avg_reward: float,
|
||||
best_reward: float,
|
||||
epsilon: float,
|
||||
total_pnl: float = None) -> bool:
|
||||
try:
|
||||
performance_metrics = {
|
||||
'reward': avg_reward,
|
||||
'best_reward': best_reward
|
||||
}
|
||||
|
||||
if total_pnl is not None:
|
||||
performance_metrics['pnl'] = total_pnl
|
||||
|
||||
training_metadata = {
|
||||
'episode': episode,
|
||||
'epsilon': epsilon,
|
||||
'total_parameters': self._count_parameters(rl_agent)
|
||||
}
|
||||
|
||||
if self.enable_wandb:
|
||||
try:
|
||||
import wandb
|
||||
if wandb.run is not None:
|
||||
wandb.log({
|
||||
f"{model_name}/avg_reward": avg_reward,
|
||||
f"{model_name}/best_reward": best_reward,
|
||||
f"{model_name}/epsilon": epsilon,
|
||||
f"{model_name}/episode": episode
|
||||
})
|
||||
|
||||
if total_pnl is not None:
|
||||
wandb.log({f"{model_name}/total_pnl": total_pnl})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging to W&B: {e}")
|
||||
|
||||
metadata = save_checkpoint(
|
||||
model=rl_agent,
|
||||
model_name=model_name,
|
||||
model_type='rl',
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f"RL checkpoint saved: {metadata.checkpoint_id}")
|
||||
return True
|
||||
else:
|
||||
logger.info(f"RL checkpoint not saved (performance not improved)")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving RL checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def load_best_model(self, model_name: str, model_class=None):
|
||||
try:
|
||||
result = load_best_checkpoint(model_name)
|
||||
if not result:
|
||||
logger.warning(f"No checkpoint found for model: {model_name}")
|
||||
return None
|
||||
|
||||
file_path, metadata = result
|
||||
|
||||
checkpoint = torch.load(file_path, map_location='cpu')
|
||||
|
||||
logger.info(f"Loaded best checkpoint for {model_name}:")
|
||||
logger.info(f" Performance score: {metadata.performance_score:.4f}")
|
||||
logger.info(f" Created: {metadata.created_at}")
|
||||
|
||||
if model_class and 'model_state_dict' in checkpoint:
|
||||
model = model_class()
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
return model
|
||||
|
||||
return checkpoint
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best model {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def _count_parameters(self, model) -> int:
|
||||
try:
|
||||
if hasattr(model, 'parameters'):
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
elif hasattr(model, 'policy_net'):
|
||||
policy_params = sum(p.numel() for p in model.policy_net.parameters())
|
||||
target_params = sum(p.numel() for p in model.target_net.parameters()) if hasattr(model, 'target_net') else 0
|
||||
return policy_params + target_params
|
||||
else:
|
||||
return 0
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
_training_integration = None
|
||||
|
||||
def get_training_integration() -> TrainingIntegration:
|
||||
global _training_integration
|
||||
if _training_integration is None:
|
||||
_training_integration = TrainingIntegration()
|
||||
return _training_integration
|
Reference in New Issue
Block a user