checkpoint manager

This commit is contained in:
Dobromir Popov
2025-06-24 21:41:50 +03:00
parent c9d1e029c5
commit 706eb13912
7 changed files with 978 additions and 1 deletions

View File

@ -0,0 +1,126 @@
{
"example_cnn": [
{
"checkpoint_id": "example_cnn_20250624_213913",
"model_name": "example_cnn",
"model_type": "cnn",
"file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt",
"created_at": "2025-06-24T21:39:13.559926",
"file_size_mb": 0.0797882080078125,
"performance_score": 65.67219525381417,
"accuracy": 0.28019601724789606,
"loss": 1.9252885885630378,
"val_accuracy": 0.21531048803825983,
"val_loss": 1.953166686238386,
"reward": null,
"pnl": null,
"epoch": 1,
"training_time_hours": 0.1,
"total_parameters": 20163,
"wandb_run_id": null,
"wandb_artifact_name": null
},
{
"checkpoint_id": "example_cnn_20250624_213913",
"model_name": "example_cnn",
"model_type": "cnn",
"file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt",
"created_at": "2025-06-24T21:39:13.563368",
"file_size_mb": 0.0797882080078125,
"performance_score": 85.85617724870231,
"accuracy": 0.3797766367576808,
"loss": 1.738881079808816,
"val_accuracy": 0.31375868989071576,
"val_loss": 1.758474336328537,
"reward": null,
"pnl": null,
"epoch": 2,
"training_time_hours": 0.2,
"total_parameters": 20163,
"wandb_run_id": null,
"wandb_artifact_name": null
},
{
"checkpoint_id": "example_cnn_20250624_213913",
"model_name": "example_cnn",
"model_type": "cnn",
"file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt",
"created_at": "2025-06-24T21:39:13.566494",
"file_size_mb": 0.0797882080078125,
"performance_score": 96.86696983784515,
"accuracy": 0.41565501055141396,
"loss": 1.731468873500252,
"val_accuracy": 0.38848400580514414,
"val_loss": 1.8154629243104177,
"reward": null,
"pnl": null,
"epoch": 3,
"training_time_hours": 0.30000000000000004,
"total_parameters": 20163,
"wandb_run_id": null,
"wandb_artifact_name": null
},
{
"checkpoint_id": "example_cnn_20250624_213913",
"model_name": "example_cnn",
"model_type": "cnn",
"file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt",
"created_at": "2025-06-24T21:39:13.569547",
"file_size_mb": 0.0797882080078125,
"performance_score": 106.29887197896815,
"accuracy": 0.4639872237832544,
"loss": 1.4731813440281318,
"val_accuracy": 0.4291565645756503,
"val_loss": 1.5423255128941882,
"reward": null,
"pnl": null,
"epoch": 4,
"training_time_hours": 0.4,
"total_parameters": 20163,
"wandb_run_id": null,
"wandb_artifact_name": null
},
{
"checkpoint_id": "example_cnn_20250624_213913",
"model_name": "example_cnn",
"model_type": "cnn",
"file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt",
"created_at": "2025-06-24T21:39:13.575375",
"file_size_mb": 0.0797882080078125,
"performance_score": 115.87168812846218,
"accuracy": 0.5256293272461906,
"loss": 1.3264778472364203,
"val_accuracy": 0.46011511860837684,
"val_loss": 1.3762786097581432,
"reward": null,
"pnl": null,
"epoch": 5,
"training_time_hours": 0.5,
"total_parameters": 20163,
"wandb_run_id": null,
"wandb_artifact_name": null
}
],
"example_manual": [
{
"checkpoint_id": "example_manual_20250624_213913",
"model_name": "example_manual",
"model_type": "cnn",
"file_path": "NN\\models\\saved\\example_manual\\example_manual_20250624_213913.pt",
"created_at": "2025-06-24T21:39:13.578488",
"file_size_mb": 0.0018634796142578125,
"performance_score": 186.07000000000002,
"accuracy": 0.85,
"loss": 0.45,
"val_accuracy": 0.82,
"val_loss": 0.48,
"reward": null,
"pnl": null,
"epoch": 25,
"training_time_hours": 2.5,
"total_parameters": 33,
"wandb_run_id": null,
"wandb_artifact_name": null
}
]
}

6
_dev/notes.md Normal file
View File

@ -0,0 +1,6 @@
how we manage our training W&B checkpoints? we need to clean up old checlpoints. for every model we keep 5 checkpoints maximum and rotate them. by default we always load te best, and during training when we save new we discard the 6th ordered by performance
add integration of the checkpoint manager to all training pipelines
we stopped showing executed trades on the chart. let's add them back

186
cleanup_checkpoints.py Normal file
View File

@ -0,0 +1,186 @@
#!/usr/bin/env python3
"""
Checkpoint Cleanup and Migration Script
This script helps clean up existing checkpoints and migrate to the new
checkpoint management system with W&B integration.
"""
import os
import logging
import shutil
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any
import torch
from utils.checkpoint_manager import get_checkpoint_manager, CheckpointMetadata
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class CheckpointCleanup:
def __init__(self):
self.saved_models_dir = Path("NN/models/saved")
self.checkpoint_manager = get_checkpoint_manager()
def analyze_existing_checkpoints(self) -> Dict[str, Any]:
logger.info("Analyzing existing checkpoint files...")
analysis = {
'total_files': 0,
'total_size_mb': 0.0,
'model_types': {},
'file_patterns': {},
'potential_duplicates': []
}
if not self.saved_models_dir.exists():
logger.warning(f"Saved models directory not found: {self.saved_models_dir}")
return analysis
for pt_file in self.saved_models_dir.rglob("*.pt"):
try:
file_size_mb = pt_file.stat().st_size / (1024 * 1024)
analysis['total_files'] += 1
analysis['total_size_mb'] += file_size_mb
filename = pt_file.name
if 'cnn' in filename.lower():
model_type = 'cnn'
elif 'dqn' in filename.lower() or 'rl' in filename.lower():
model_type = 'rl'
elif 'agent' in filename.lower():
model_type = 'rl'
else:
model_type = 'unknown'
if model_type not in analysis['model_types']:
analysis['model_types'][model_type] = {'count': 0, 'size_mb': 0.0}
analysis['model_types'][model_type]['count'] += 1
analysis['model_types'][model_type]['size_mb'] += file_size_mb
base_name = filename.split('_')[0] if '_' in filename else filename.replace('.pt', '')
if base_name not in analysis['file_patterns']:
analysis['file_patterns'][base_name] = []
analysis['file_patterns'][base_name].append({
'path': str(pt_file),
'size_mb': file_size_mb,
'modified': datetime.fromtimestamp(pt_file.stat().st_mtime)
})
except Exception as e:
logger.error(f"Error analyzing {pt_file}: {e}")
for base_name, files in analysis['file_patterns'].items():
if len(files) > 5: # More than 5 files with same base name
analysis['potential_duplicates'].append({
'base_name': base_name,
'count': len(files),
'total_size_mb': sum(f['size_mb'] for f in files),
'files': files
})
logger.info(f"Analysis complete:")
logger.info(f" Total files: {analysis['total_files']}")
logger.info(f" Total size: {analysis['total_size_mb']:.2f} MB")
logger.info(f" Model types: {analysis['model_types']}")
logger.info(f" Potential duplicates: {len(analysis['potential_duplicates'])}")
return analysis
def cleanup_duplicates(self, dry_run: bool = True) -> Dict[str, Any]:
logger.info(f"Starting duplicate cleanup (dry_run={dry_run})...")
cleanup_results = {
'removed': 0,
'kept': 0,
'space_saved_mb': 0.0,
'details': []
}
analysis = self.analyze_existing_checkpoints()
for duplicate_group in analysis['potential_duplicates']:
base_name = duplicate_group['base_name']
files = duplicate_group['files']
# Sort by modification time (newest first)
files.sort(key=lambda x: x['modified'], reverse=True)
logger.info(f"Processing {base_name}: {len(files)} files")
# Keep only the 5 newest files
for i, file_info in enumerate(files):
if i < 5: # Keep first 5 (newest)
cleanup_results['kept'] += 1
cleanup_results['details'].append({
'action': 'kept',
'file': file_info['path']
})
else: # Remove the rest
if not dry_run:
try:
Path(file_info['path']).unlink()
logger.info(f"Removed: {file_info['path']}")
except Exception as e:
logger.error(f"Error removing {file_info['path']}: {e}")
continue
cleanup_results['removed'] += 1
cleanup_results['space_saved_mb'] += file_info['size_mb']
cleanup_results['details'].append({
'action': 'removed',
'file': file_info['path'],
'size_mb': file_info['size_mb']
})
logger.info(f"Cleanup {'simulation' if dry_run else 'complete'}:")
logger.info(f" Kept: {cleanup_results['kept']}")
logger.info(f" Removed: {cleanup_results['removed']}")
logger.info(f" Space saved: {cleanup_results['space_saved_mb']:.2f} MB")
return cleanup_results
def main():
logger.info("=== Checkpoint Cleanup Tool ===")
cleanup = CheckpointCleanup()
# Analyze existing checkpoints
logger.info("\\n1. Analyzing existing checkpoints...")
analysis = cleanup.analyze_existing_checkpoints()
if analysis['total_files'] == 0:
logger.info("No checkpoint files found.")
return
# Show potential space savings
total_duplicates = sum(len(group['files']) - 5 for group in analysis['potential_duplicates'] if len(group['files']) > 5)
if total_duplicates > 0:
logger.info(f"\\nFound {total_duplicates} files that could be cleaned up")
# Dry run first
logger.info("\\n2. Simulating cleanup...")
dry_run_results = cleanup.cleanup_duplicates(dry_run=True)
if dry_run_results['removed'] > 0:
proceed = input(f"\\nProceed with cleanup? Will remove {dry_run_results['removed']} files "
f"and save {dry_run_results['space_saved_mb']:.2f} MB. (y/n): ").lower().strip() == 'y'
if proceed:
logger.info("\\n3. Performing actual cleanup...")
cleanup_results = cleanup.cleanup_duplicates(dry_run=False)
logger.info("\\n=== Cleanup Complete ===")
else:
logger.info("Cleanup cancelled.")
else:
logger.info("No files to remove.")
else:
logger.info("No duplicate files found that need cleanup.")
if __name__ == "__main__":
main()

148
example_checkpoint_usage.py Normal file
View File

@ -0,0 +1,148 @@
#!/usr/bin/env python3
"""
Example: Using the Checkpoint Management System
"""
import logging
import torch
import torch.nn as nn
import numpy as np
from datetime import datetime
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint, get_checkpoint_manager
from utils.training_integration import get_training_integration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ExampleCNN(nn.Module):
def __init__(self, input_channels=5, num_classes=3):
super().__init__()
self.conv1 = nn.Conv2d(input_channels, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(64, num_classes)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = self.pool(x)
x = x.view(x.size(0), -1)
return self.fc(x)
def example_cnn_training():
logger.info("=== CNN Training Example ===")
model = ExampleCNN()
training_integration = get_training_integration()
for epoch in range(5): # Simulate 5 epochs
# Simulate training metrics
train_loss = 2.0 - (epoch * 0.15) + np.random.normal(0, 0.1)
train_acc = 0.3 + (epoch * 0.06) + np.random.normal(0, 0.02)
val_loss = train_loss + np.random.normal(0, 0.05)
val_acc = train_acc - 0.05 + np.random.normal(0, 0.02)
# Clamp values to realistic ranges
train_acc = max(0.0, min(1.0, train_acc))
val_acc = max(0.0, min(1.0, val_acc))
train_loss = max(0.1, train_loss)
val_loss = max(0.1, val_loss)
logger.info(f"Epoch {epoch+1}: train_acc={train_acc:.3f}, val_acc={val_acc:.3f}")
# Save checkpoint
saved = training_integration.save_cnn_checkpoint(
cnn_model=model,
model_name="example_cnn",
epoch=epoch + 1,
train_accuracy=train_acc,
val_accuracy=val_acc,
train_loss=train_loss,
val_loss=val_loss,
training_time_hours=0.1 * (epoch + 1)
)
if saved:
logger.info(f" Checkpoint saved for epoch {epoch+1}")
else:
logger.info(f" Checkpoint not saved (performance not improved)")
# Load the best checkpoint
logger.info("\\nLoading best checkpoint...")
best_result = load_best_checkpoint("example_cnn")
if best_result:
file_path, metadata = best_result
logger.info(f"Best checkpoint: {metadata.checkpoint_id}")
logger.info(f"Performance score: {metadata.performance_score:.4f}")
def example_manual_checkpoint():
logger.info("\\n=== Manual Checkpoint Example ===")
model = nn.Linear(10, 3)
performance_metrics = {
'accuracy': 0.85,
'val_accuracy': 0.82,
'loss': 0.45,
'val_loss': 0.48
}
training_metadata = {
'epoch': 25,
'training_time_hours': 2.5,
'total_parameters': sum(p.numel() for p in model.parameters())
}
logger.info("Saving checkpoint manually...")
metadata = save_checkpoint(
model=model,
model_name="example_manual",
model_type="cnn",
performance_metrics=performance_metrics,
training_metadata=training_metadata,
force_save=True
)
if metadata:
logger.info(f" Manual checkpoint saved: {metadata.checkpoint_id}")
logger.info(f" Performance score: {metadata.performance_score:.4f}")
def show_checkpoint_stats():
logger.info("\\n=== Checkpoint Statistics ===")
checkpoint_manager = get_checkpoint_manager()
stats = checkpoint_manager.get_checkpoint_stats()
logger.info(f"Total models: {stats['total_models']}")
logger.info(f"Total checkpoints: {stats['total_checkpoints']}")
logger.info(f"Total size: {stats['total_size_mb']:.2f} MB")
for model_name, model_stats in stats['models'].items():
logger.info(f"\\n{model_name}:")
logger.info(f" Checkpoints: {model_stats['checkpoint_count']}")
logger.info(f" Size: {model_stats['total_size_mb']:.2f} MB")
logger.info(f" Best performance: {model_stats['best_performance']:.4f}")
def main():
logger.info(" Checkpoint Management System Examples")
logger.info("=" * 50)
try:
example_cnn_training()
example_manual_checkpoint()
show_checkpoint_stats()
logger.info("\\n All examples completed successfully!")
logger.info("\\nTo use in your training:")
logger.info("1. Import: from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint")
logger.info("2. Or use: from utils.training_integration import get_training_integration")
logger.info("3. Save checkpoints during training with performance metrics")
logger.info("4. Load best checkpoints for inference or continued training")
except Exception as e:
logger.error(f"Error in examples: {e}")
raise
if __name__ == "__main__":
main()

View File

@ -14,3 +14,4 @@ scikit-learn>=1.3.0
matplotlib>=3.7.0 matplotlib>=3.7.0
seaborn>=0.12.0 seaborn>=0.12.0
asyncio-compat>=0.1.2 asyncio-compat>=0.1.2
wandb>=0.16.0

306
utils/checkpoint_manager.py Normal file
View File

@ -0,0 +1,306 @@
#!/usr/bin/env python3
"""
Checkpoint Management System for W&B Training
"""
import os
import json
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, asdict
from collections import defaultdict
import torch
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
logger = logging.getLogger(__name__)
@dataclass
class CheckpointMetadata:
checkpoint_id: str
model_name: str
model_type: str
file_path: str
created_at: datetime
file_size_mb: float
performance_score: float
accuracy: Optional[float] = None
loss: Optional[float] = None
val_accuracy: Optional[float] = None
val_loss: Optional[float] = None
reward: Optional[float] = None
pnl: Optional[float] = None
epoch: Optional[int] = None
training_time_hours: Optional[float] = None
total_parameters: Optional[int] = None
wandb_run_id: Optional[str] = None
wandb_artifact_name: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
data = asdict(self)
data['created_at'] = self.created_at.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata':
data['created_at'] = datetime.fromisoformat(data['created_at'])
return cls(**data)
class CheckpointManager:
def __init__(self,
base_checkpoint_dir: str = "NN/models/saved",
max_checkpoints_per_model: int = 5,
metadata_file: str = "checkpoint_metadata.json",
enable_wandb: bool = True):
self.base_dir = Path(base_checkpoint_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
self.max_checkpoints = max_checkpoints_per_model
self.metadata_file = self.base_dir / metadata_file
self.enable_wandb = enable_wandb and WANDB_AVAILABLE
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
self._load_metadata()
logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}")
def save_checkpoint(self, model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
try:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_id = f"{model_name}_{timestamp}"
model_dir = self.base_dir / model_name
model_dir.mkdir(exist_ok=True)
checkpoint_path = model_dir / f"{checkpoint_id}.pt"
performance_score = self._calculate_performance_score(performance_metrics)
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
logger.info(f"Skipping checkpoint save for {model_name} - performance not improved")
return None
success = self._save_model_file(model, checkpoint_path, model_type)
if not success:
return None
file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024)
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name=model_name,
model_type=model_type,
file_path=str(checkpoint_path),
created_at=datetime.now(),
file_size_mb=file_size_mb,
performance_score=performance_score,
accuracy=performance_metrics.get('accuracy'),
loss=performance_metrics.get('loss'),
val_accuracy=performance_metrics.get('val_accuracy'),
val_loss=performance_metrics.get('val_loss'),
reward=performance_metrics.get('reward'),
pnl=performance_metrics.get('pnl'),
epoch=training_metadata.get('epoch') if training_metadata else None,
training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None,
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
)
if self.enable_wandb and wandb.run is not None:
artifact_name = self._upload_to_wandb(checkpoint_path, metadata)
metadata.wandb_run_id = wandb.run.id
metadata.wandb_artifact_name = artifact_name
self.checkpoints[model_name].append(metadata)
self._rotate_checkpoints(model_name)
self._save_metadata()
logger.info(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})")
return metadata
except Exception as e:
logger.error(f"Error saving checkpoint for {model_name}: {e}")
return None
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
try:
if model_name not in self.checkpoints or not self.checkpoints[model_name]:
logger.warning(f"No checkpoints found for model: {model_name}")
return None
best_checkpoint = max(self.checkpoints[model_name], key=lambda x: x.performance_score)
if not Path(best_checkpoint.file_path).exists():
logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}")
return None
logger.info(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}")
return best_checkpoint.file_path, best_checkpoint
except Exception as e:
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
return None
def _calculate_performance_score(self, metrics: Dict[str, float]) -> float:
score = 0.0
if 'accuracy' in metrics:
score += metrics['accuracy'] * 100
if 'val_accuracy' in metrics:
score += metrics['val_accuracy'] * 100
if 'loss' in metrics:
score += max(0, 10 - metrics['loss'])
if 'val_loss' in metrics:
score += max(0, 10 - metrics['val_loss'])
if 'reward' in metrics:
score += metrics['reward']
if 'pnl' in metrics:
score += metrics['pnl']
if score == 0.0 and metrics:
first_metric = next(iter(metrics.values()))
score = first_metric if first_metric > 0 else 0.1
return max(score, 0.1)
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
if model_name not in self.checkpoints or not self.checkpoints[model_name]:
return True
if len(self.checkpoints[model_name]) < self.max_checkpoints:
return True
worst_score = min(cp.performance_score for cp in self.checkpoints[model_name])
return performance_score > worst_score
def _save_model_file(self, model, file_path: Path, model_type: str) -> bool:
try:
if hasattr(model, 'state_dict'):
torch.save({
'model_state_dict': model.state_dict(),
'model_type': model_type,
'saved_at': datetime.now().isoformat()
}, file_path)
else:
torch.save(model, file_path)
return True
except Exception as e:
logger.error(f"Error saving model file {file_path}: {e}")
return False
def _rotate_checkpoints(self, model_name: str):
checkpoint_list = self.checkpoints[model_name]
if len(checkpoint_list) <= self.max_checkpoints:
return
checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True)
to_remove = checkpoint_list[self.max_checkpoints:]
self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints]
for checkpoint in to_remove:
try:
file_path = Path(checkpoint.file_path)
if file_path.exists():
file_path.unlink()
logger.info(f"Rotated out checkpoint: {checkpoint.checkpoint_id}")
except Exception as e:
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]:
try:
if not self.enable_wandb or wandb.run is None:
return None
artifact_name = f"{metadata.model_name}_checkpoint"
artifact = wandb.Artifact(artifact_name, type="model")
artifact.add_file(str(file_path))
wandb.log_artifact(artifact)
return artifact_name
except Exception as e:
logger.error(f"Error uploading to W&B: {e}")
return None
def _load_metadata(self):
try:
if self.metadata_file.exists():
with open(self.metadata_file, 'r') as f:
data = json.load(f)
for model_name, checkpoint_list in data.items():
self.checkpoints[model_name] = [
CheckpointMetadata.from_dict(cp_data)
for cp_data in checkpoint_list
]
logger.info(f"Loaded metadata for {len(self.checkpoints)} models")
except Exception as e:
logger.error(f"Error loading checkpoint metadata: {e}")
def _save_metadata(self):
try:
data = {}
for model_name, checkpoint_list in self.checkpoints.items():
data[model_name] = [cp.to_dict() for cp in checkpoint_list]
with open(self.metadata_file, 'w') as f:
json.dump(data, f, indent=2)
except Exception as e:
logger.error(f"Error saving checkpoint metadata: {e}")
def get_checkpoint_stats(self):
"""Get statistics about managed checkpoints"""
stats = {
'total_models': len(self.checkpoints),
'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()),
'total_size_mb': 0.0,
'models': {}
}
for model_name, checkpoint_list in self.checkpoints.items():
if not checkpoint_list:
continue
model_size = sum(cp.file_size_mb for cp in checkpoint_list)
best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score)
stats['models'][model_name] = {
'checkpoint_count': len(checkpoint_list),
'total_size_mb': model_size,
'best_performance': best_checkpoint.performance_score,
'best_checkpoint_id': best_checkpoint.checkpoint_id,
'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id
}
stats['total_size_mb'] += model_size
return stats
_checkpoint_manager = None
def get_checkpoint_manager() -> CheckpointManager:
global _checkpoint_manager
if _checkpoint_manager is None:
_checkpoint_manager = CheckpointManager()
return _checkpoint_manager
def save_checkpoint(model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
return get_checkpoint_manager().save_checkpoint(
model, model_name, model_type, performance_metrics, training_metadata, force_save
)
def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
return get_checkpoint_manager().load_best_checkpoint(model_name)

View File

@ -0,0 +1,204 @@
#!/usr/bin/env python3
"""
Training Integration for Checkpoint Management
"""
import logging
import torch
from datetime import datetime
from typing import Dict, Any, Optional
from pathlib import Path
from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_best_checkpoint
logger = logging.getLogger(__name__)
class TrainingIntegration:
def __init__(self, enable_wandb: bool = True):
self.checkpoint_manager = get_checkpoint_manager()
self.enable_wandb = enable_wandb
if self.enable_wandb:
self._init_wandb()
def _init_wandb(self):
try:
import wandb
if wandb.run is None:
wandb.init(
project="gogo2-trading",
name=f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
config={
"max_checkpoints_per_model": self.checkpoint_manager.max_checkpoints,
"checkpoint_dir": str(self.checkpoint_manager.base_dir)
}
)
logger.info(f"Initialized W&B run: {wandb.run.id}")
except ImportError:
logger.warning("W&B not available - checkpoint management will work without it")
except Exception as e:
logger.error(f"Error initializing W&B: {e}")
def save_cnn_checkpoint(self,
cnn_model,
model_name: str,
epoch: int,
train_accuracy: float,
val_accuracy: float,
train_loss: float,
val_loss: float,
training_time_hours: float = None) -> bool:
try:
performance_metrics = {
'accuracy': train_accuracy,
'val_accuracy': val_accuracy,
'loss': train_loss,
'val_loss': val_loss
}
training_metadata = {
'epoch': epoch,
'training_time_hours': training_time_hours,
'total_parameters': self._count_parameters(cnn_model)
}
if self.enable_wandb:
try:
import wandb
if wandb.run is not None:
wandb.log({
f"{model_name}/train_accuracy": train_accuracy,
f"{model_name}/val_accuracy": val_accuracy,
f"{model_name}/train_loss": train_loss,
f"{model_name}/val_loss": val_loss,
f"{model_name}/epoch": epoch
})
except Exception as e:
logger.warning(f"Error logging to W&B: {e}")
metadata = save_checkpoint(
model=cnn_model,
model_name=model_name,
model_type='cnn',
performance_metrics=performance_metrics,
training_metadata=training_metadata
)
if metadata:
logger.info(f"CNN checkpoint saved: {metadata.checkpoint_id}")
return True
else:
logger.info(f"CNN checkpoint not saved (performance not improved)")
return False
except Exception as e:
logger.error(f"Error saving CNN checkpoint: {e}")
return False
def save_rl_checkpoint(self,
rl_agent,
model_name: str,
episode: int,
avg_reward: float,
best_reward: float,
epsilon: float,
total_pnl: float = None) -> bool:
try:
performance_metrics = {
'reward': avg_reward,
'best_reward': best_reward
}
if total_pnl is not None:
performance_metrics['pnl'] = total_pnl
training_metadata = {
'episode': episode,
'epsilon': epsilon,
'total_parameters': self._count_parameters(rl_agent)
}
if self.enable_wandb:
try:
import wandb
if wandb.run is not None:
wandb.log({
f"{model_name}/avg_reward": avg_reward,
f"{model_name}/best_reward": best_reward,
f"{model_name}/epsilon": epsilon,
f"{model_name}/episode": episode
})
if total_pnl is not None:
wandb.log({f"{model_name}/total_pnl": total_pnl})
except Exception as e:
logger.warning(f"Error logging to W&B: {e}")
metadata = save_checkpoint(
model=rl_agent,
model_name=model_name,
model_type='rl',
performance_metrics=performance_metrics,
training_metadata=training_metadata
)
if metadata:
logger.info(f"RL checkpoint saved: {metadata.checkpoint_id}")
return True
else:
logger.info(f"RL checkpoint not saved (performance not improved)")
return False
except Exception as e:
logger.error(f"Error saving RL checkpoint: {e}")
return False
def load_best_model(self, model_name: str, model_class=None):
try:
result = load_best_checkpoint(model_name)
if not result:
logger.warning(f"No checkpoint found for model: {model_name}")
return None
file_path, metadata = result
checkpoint = torch.load(file_path, map_location='cpu')
logger.info(f"Loaded best checkpoint for {model_name}:")
logger.info(f" Performance score: {metadata.performance_score:.4f}")
logger.info(f" Created: {metadata.created_at}")
if model_class and 'model_state_dict' in checkpoint:
model = model_class()
model.load_state_dict(checkpoint['model_state_dict'])
return model
return checkpoint
except Exception as e:
logger.error(f"Error loading best model {model_name}: {e}")
return None
def _count_parameters(self, model) -> int:
try:
if hasattr(model, 'parameters'):
return sum(p.numel() for p in model.parameters())
elif hasattr(model, 'policy_net'):
policy_params = sum(p.numel() for p in model.policy_net.parameters())
target_params = sum(p.numel() for p in model.target_net.parameters()) if hasattr(model, 'target_net') else 0
return policy_params + target_params
else:
return 0
except Exception:
return 0
_training_integration = None
def get_training_integration() -> TrainingIntegration:
global _training_integration
if _training_integration is None:
_training_integration = TrainingIntegration()
return _training_integration