diff --git a/NN/models/saved/checkpoint_metadata.json b/NN/models/saved/checkpoint_metadata.json new file mode 100644 index 0000000..3cb26aa --- /dev/null +++ b/NN/models/saved/checkpoint_metadata.json @@ -0,0 +1,126 @@ +{ + "example_cnn": [ + { + "checkpoint_id": "example_cnn_20250624_213913", + "model_name": "example_cnn", + "model_type": "cnn", + "file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt", + "created_at": "2025-06-24T21:39:13.559926", + "file_size_mb": 0.0797882080078125, + "performance_score": 65.67219525381417, + "accuracy": 0.28019601724789606, + "loss": 1.9252885885630378, + "val_accuracy": 0.21531048803825983, + "val_loss": 1.953166686238386, + "reward": null, + "pnl": null, + "epoch": 1, + "training_time_hours": 0.1, + "total_parameters": 20163, + "wandb_run_id": null, + "wandb_artifact_name": null + }, + { + "checkpoint_id": "example_cnn_20250624_213913", + "model_name": "example_cnn", + "model_type": "cnn", + "file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt", + "created_at": "2025-06-24T21:39:13.563368", + "file_size_mb": 0.0797882080078125, + "performance_score": 85.85617724870231, + "accuracy": 0.3797766367576808, + "loss": 1.738881079808816, + "val_accuracy": 0.31375868989071576, + "val_loss": 1.758474336328537, + "reward": null, + "pnl": null, + "epoch": 2, + "training_time_hours": 0.2, + "total_parameters": 20163, + "wandb_run_id": null, + "wandb_artifact_name": null + }, + { + "checkpoint_id": "example_cnn_20250624_213913", + "model_name": "example_cnn", + "model_type": "cnn", + "file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt", + "created_at": "2025-06-24T21:39:13.566494", + "file_size_mb": 0.0797882080078125, + "performance_score": 96.86696983784515, + "accuracy": 0.41565501055141396, + "loss": 1.731468873500252, + "val_accuracy": 0.38848400580514414, + "val_loss": 1.8154629243104177, + "reward": null, + "pnl": null, + "epoch": 3, + "training_time_hours": 0.30000000000000004, + "total_parameters": 20163, + "wandb_run_id": null, + "wandb_artifact_name": null + }, + { + "checkpoint_id": "example_cnn_20250624_213913", + "model_name": "example_cnn", + "model_type": "cnn", + "file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt", + "created_at": "2025-06-24T21:39:13.569547", + "file_size_mb": 0.0797882080078125, + "performance_score": 106.29887197896815, + "accuracy": 0.4639872237832544, + "loss": 1.4731813440281318, + "val_accuracy": 0.4291565645756503, + "val_loss": 1.5423255128941882, + "reward": null, + "pnl": null, + "epoch": 4, + "training_time_hours": 0.4, + "total_parameters": 20163, + "wandb_run_id": null, + "wandb_artifact_name": null + }, + { + "checkpoint_id": "example_cnn_20250624_213913", + "model_name": "example_cnn", + "model_type": "cnn", + "file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt", + "created_at": "2025-06-24T21:39:13.575375", + "file_size_mb": 0.0797882080078125, + "performance_score": 115.87168812846218, + "accuracy": 0.5256293272461906, + "loss": 1.3264778472364203, + "val_accuracy": 0.46011511860837684, + "val_loss": 1.3762786097581432, + "reward": null, + "pnl": null, + "epoch": 5, + "training_time_hours": 0.5, + "total_parameters": 20163, + "wandb_run_id": null, + "wandb_artifact_name": null + } + ], + "example_manual": [ + { + "checkpoint_id": "example_manual_20250624_213913", + "model_name": "example_manual", + "model_type": "cnn", + "file_path": "NN\\models\\saved\\example_manual\\example_manual_20250624_213913.pt", + "created_at": "2025-06-24T21:39:13.578488", + "file_size_mb": 0.0018634796142578125, + "performance_score": 186.07000000000002, + "accuracy": 0.85, + "loss": 0.45, + "val_accuracy": 0.82, + "val_loss": 0.48, + "reward": null, + "pnl": null, + "epoch": 25, + "training_time_hours": 2.5, + "total_parameters": 33, + "wandb_run_id": null, + "wandb_artifact_name": null + } + ] +} \ No newline at end of file diff --git a/_dev/notes.md b/_dev/notes.md new file mode 100644 index 0000000..e642342 --- /dev/null +++ b/_dev/notes.md @@ -0,0 +1,6 @@ +how we manage our training W&B checkpoints? we need to clean up old checlpoints. for every model we keep 5 checkpoints maximum and rotate them. by default we always load te best, and during training when we save new we discard the 6th ordered by performance + +add integration of the checkpoint manager to all training pipelines + +we stopped showing executed trades on the chart. let's add them back + diff --git a/cleanup_checkpoints.py b/cleanup_checkpoints.py new file mode 100644 index 0000000..5f4ab67 --- /dev/null +++ b/cleanup_checkpoints.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +""" +Checkpoint Cleanup and Migration Script + +This script helps clean up existing checkpoints and migrate to the new +checkpoint management system with W&B integration. +""" + +import os +import logging +import shutil +from pathlib import Path +from datetime import datetime +from typing import List, Dict, Any +import torch + +from utils.checkpoint_manager import get_checkpoint_manager, CheckpointMetadata + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class CheckpointCleanup: + def __init__(self): + self.saved_models_dir = Path("NN/models/saved") + self.checkpoint_manager = get_checkpoint_manager() + + def analyze_existing_checkpoints(self) -> Dict[str, Any]: + logger.info("Analyzing existing checkpoint files...") + + analysis = { + 'total_files': 0, + 'total_size_mb': 0.0, + 'model_types': {}, + 'file_patterns': {}, + 'potential_duplicates': [] + } + + if not self.saved_models_dir.exists(): + logger.warning(f"Saved models directory not found: {self.saved_models_dir}") + return analysis + + for pt_file in self.saved_models_dir.rglob("*.pt"): + try: + file_size_mb = pt_file.stat().st_size / (1024 * 1024) + analysis['total_files'] += 1 + analysis['total_size_mb'] += file_size_mb + + filename = pt_file.name + + if 'cnn' in filename.lower(): + model_type = 'cnn' + elif 'dqn' in filename.lower() or 'rl' in filename.lower(): + model_type = 'rl' + elif 'agent' in filename.lower(): + model_type = 'rl' + else: + model_type = 'unknown' + + if model_type not in analysis['model_types']: + analysis['model_types'][model_type] = {'count': 0, 'size_mb': 0.0} + + analysis['model_types'][model_type]['count'] += 1 + analysis['model_types'][model_type]['size_mb'] += file_size_mb + + base_name = filename.split('_')[0] if '_' in filename else filename.replace('.pt', '') + if base_name not in analysis['file_patterns']: + analysis['file_patterns'][base_name] = [] + + analysis['file_patterns'][base_name].append({ + 'path': str(pt_file), + 'size_mb': file_size_mb, + 'modified': datetime.fromtimestamp(pt_file.stat().st_mtime) + }) + + except Exception as e: + logger.error(f"Error analyzing {pt_file}: {e}") + + for base_name, files in analysis['file_patterns'].items(): + if len(files) > 5: # More than 5 files with same base name + analysis['potential_duplicates'].append({ + 'base_name': base_name, + 'count': len(files), + 'total_size_mb': sum(f['size_mb'] for f in files), + 'files': files + }) + + logger.info(f"Analysis complete:") + logger.info(f" Total files: {analysis['total_files']}") + logger.info(f" Total size: {analysis['total_size_mb']:.2f} MB") + logger.info(f" Model types: {analysis['model_types']}") + logger.info(f" Potential duplicates: {len(analysis['potential_duplicates'])}") + + return analysis + + def cleanup_duplicates(self, dry_run: bool = True) -> Dict[str, Any]: + logger.info(f"Starting duplicate cleanup (dry_run={dry_run})...") + + cleanup_results = { + 'removed': 0, + 'kept': 0, + 'space_saved_mb': 0.0, + 'details': [] + } + + analysis = self.analyze_existing_checkpoints() + + for duplicate_group in analysis['potential_duplicates']: + base_name = duplicate_group['base_name'] + files = duplicate_group['files'] + + # Sort by modification time (newest first) + files.sort(key=lambda x: x['modified'], reverse=True) + + logger.info(f"Processing {base_name}: {len(files)} files") + + # Keep only the 5 newest files + for i, file_info in enumerate(files): + if i < 5: # Keep first 5 (newest) + cleanup_results['kept'] += 1 + cleanup_results['details'].append({ + 'action': 'kept', + 'file': file_info['path'] + }) + else: # Remove the rest + if not dry_run: + try: + Path(file_info['path']).unlink() + logger.info(f"Removed: {file_info['path']}") + except Exception as e: + logger.error(f"Error removing {file_info['path']}: {e}") + continue + + cleanup_results['removed'] += 1 + cleanup_results['space_saved_mb'] += file_info['size_mb'] + cleanup_results['details'].append({ + 'action': 'removed', + 'file': file_info['path'], + 'size_mb': file_info['size_mb'] + }) + + logger.info(f"Cleanup {'simulation' if dry_run else 'complete'}:") + logger.info(f" Kept: {cleanup_results['kept']}") + logger.info(f" Removed: {cleanup_results['removed']}") + logger.info(f" Space saved: {cleanup_results['space_saved_mb']:.2f} MB") + + return cleanup_results + +def main(): + logger.info("=== Checkpoint Cleanup Tool ===") + + cleanup = CheckpointCleanup() + + # Analyze existing checkpoints + logger.info("\\n1. Analyzing existing checkpoints...") + analysis = cleanup.analyze_existing_checkpoints() + + if analysis['total_files'] == 0: + logger.info("No checkpoint files found.") + return + + # Show potential space savings + total_duplicates = sum(len(group['files']) - 5 for group in analysis['potential_duplicates'] if len(group['files']) > 5) + if total_duplicates > 0: + logger.info(f"\\nFound {total_duplicates} files that could be cleaned up") + + # Dry run first + logger.info("\\n2. Simulating cleanup...") + dry_run_results = cleanup.cleanup_duplicates(dry_run=True) + + if dry_run_results['removed'] > 0: + proceed = input(f"\\nProceed with cleanup? Will remove {dry_run_results['removed']} files " + f"and save {dry_run_results['space_saved_mb']:.2f} MB. (y/n): ").lower().strip() == 'y' + + if proceed: + logger.info("\\n3. Performing actual cleanup...") + cleanup_results = cleanup.cleanup_duplicates(dry_run=False) + logger.info("\\n=== Cleanup Complete ===") + else: + logger.info("Cleanup cancelled.") + else: + logger.info("No files to remove.") + else: + logger.info("No duplicate files found that need cleanup.") + +if __name__ == "__main__": + main() diff --git a/example_checkpoint_usage.py b/example_checkpoint_usage.py new file mode 100644 index 0000000..b54fe38 --- /dev/null +++ b/example_checkpoint_usage.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +""" +Example: Using the Checkpoint Management System +""" + +import logging +import torch +import torch.nn as nn +import numpy as np +from datetime import datetime + +from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint, get_checkpoint_manager +from utils.training_integration import get_training_integration + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class ExampleCNN(nn.Module): + def __init__(self, input_channels=5, num_classes=3): + super().__init__() + self.conv1 = nn.Conv2d(input_channels, 32, 3, padding=1) + self.conv2 = nn.Conv2d(32, 64, 3, padding=1) + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(64, num_classes) + + def forward(self, x): + x = torch.relu(self.conv1(x)) + x = torch.relu(self.conv2(x)) + x = self.pool(x) + x = x.view(x.size(0), -1) + return self.fc(x) + +def example_cnn_training(): + logger.info("=== CNN Training Example ===") + + model = ExampleCNN() + training_integration = get_training_integration() + + for epoch in range(5): # Simulate 5 epochs + # Simulate training metrics + train_loss = 2.0 - (epoch * 0.15) + np.random.normal(0, 0.1) + train_acc = 0.3 + (epoch * 0.06) + np.random.normal(0, 0.02) + val_loss = train_loss + np.random.normal(0, 0.05) + val_acc = train_acc - 0.05 + np.random.normal(0, 0.02) + + # Clamp values to realistic ranges + train_acc = max(0.0, min(1.0, train_acc)) + val_acc = max(0.0, min(1.0, val_acc)) + train_loss = max(0.1, train_loss) + val_loss = max(0.1, val_loss) + + logger.info(f"Epoch {epoch+1}: train_acc={train_acc:.3f}, val_acc={val_acc:.3f}") + + # Save checkpoint + saved = training_integration.save_cnn_checkpoint( + cnn_model=model, + model_name="example_cnn", + epoch=epoch + 1, + train_accuracy=train_acc, + val_accuracy=val_acc, + train_loss=train_loss, + val_loss=val_loss, + training_time_hours=0.1 * (epoch + 1) + ) + + if saved: + logger.info(f" Checkpoint saved for epoch {epoch+1}") + else: + logger.info(f" Checkpoint not saved (performance not improved)") + + # Load the best checkpoint + logger.info("\\nLoading best checkpoint...") + best_result = load_best_checkpoint("example_cnn") + if best_result: + file_path, metadata = best_result + logger.info(f"Best checkpoint: {metadata.checkpoint_id}") + logger.info(f"Performance score: {metadata.performance_score:.4f}") + +def example_manual_checkpoint(): + logger.info("\\n=== Manual Checkpoint Example ===") + + model = nn.Linear(10, 3) + + performance_metrics = { + 'accuracy': 0.85, + 'val_accuracy': 0.82, + 'loss': 0.45, + 'val_loss': 0.48 + } + + training_metadata = { + 'epoch': 25, + 'training_time_hours': 2.5, + 'total_parameters': sum(p.numel() for p in model.parameters()) + } + + logger.info("Saving checkpoint manually...") + metadata = save_checkpoint( + model=model, + model_name="example_manual", + model_type="cnn", + performance_metrics=performance_metrics, + training_metadata=training_metadata, + force_save=True + ) + + if metadata: + logger.info(f" Manual checkpoint saved: {metadata.checkpoint_id}") + logger.info(f" Performance score: {metadata.performance_score:.4f}") + +def show_checkpoint_stats(): + logger.info("\\n=== Checkpoint Statistics ===") + + checkpoint_manager = get_checkpoint_manager() + stats = checkpoint_manager.get_checkpoint_stats() + + logger.info(f"Total models: {stats['total_models']}") + logger.info(f"Total checkpoints: {stats['total_checkpoints']}") + logger.info(f"Total size: {stats['total_size_mb']:.2f} MB") + + for model_name, model_stats in stats['models'].items(): + logger.info(f"\\n{model_name}:") + logger.info(f" Checkpoints: {model_stats['checkpoint_count']}") + logger.info(f" Size: {model_stats['total_size_mb']:.2f} MB") + logger.info(f" Best performance: {model_stats['best_performance']:.4f}") + +def main(): + logger.info(" Checkpoint Management System Examples") + logger.info("=" * 50) + + try: + example_cnn_training() + example_manual_checkpoint() + show_checkpoint_stats() + + logger.info("\\n All examples completed successfully!") + logger.info("\\nTo use in your training:") + logger.info("1. Import: from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint") + logger.info("2. Or use: from utils.training_integration import get_training_integration") + logger.info("3. Save checkpoints during training with performance metrics") + logger.info("4. Load best checkpoints for inference or continued training") + + except Exception as e: + logger.error(f"Error in examples: {e}") + raise + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index b817be7..b410f71 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ torchaudio>=2.0.0 scikit-learn>=1.3.0 matplotlib>=3.7.0 seaborn>=0.12.0 -asyncio-compat>=0.1.2 \ No newline at end of file +asyncio-compat>=0.1.2 +wandb>=0.16.0 \ No newline at end of file diff --git a/utils/checkpoint_manager.py b/utils/checkpoint_manager.py new file mode 100644 index 0000000..499d572 --- /dev/null +++ b/utils/checkpoint_manager.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python3 +""" +Checkpoint Management System for W&B Training +""" + +import os +import json +import logging +from datetime import datetime, timedelta +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Any +from dataclasses import dataclass, asdict +from collections import defaultdict +import torch + +try: + import wandb + WANDB_AVAILABLE = True +except ImportError: + WANDB_AVAILABLE = False + +logger = logging.getLogger(__name__) + +@dataclass +class CheckpointMetadata: + checkpoint_id: str + model_name: str + model_type: str + file_path: str + created_at: datetime + file_size_mb: float + performance_score: float + accuracy: Optional[float] = None + loss: Optional[float] = None + val_accuracy: Optional[float] = None + val_loss: Optional[float] = None + reward: Optional[float] = None + pnl: Optional[float] = None + epoch: Optional[int] = None + training_time_hours: Optional[float] = None + total_parameters: Optional[int] = None + wandb_run_id: Optional[str] = None + wandb_artifact_name: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + data = asdict(self) + data['created_at'] = self.created_at.isoformat() + return data + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata': + data['created_at'] = datetime.fromisoformat(data['created_at']) + return cls(**data) + +class CheckpointManager: + def __init__(self, + base_checkpoint_dir: str = "NN/models/saved", + max_checkpoints_per_model: int = 5, + metadata_file: str = "checkpoint_metadata.json", + enable_wandb: bool = True): + self.base_dir = Path(base_checkpoint_dir) + self.base_dir.mkdir(parents=True, exist_ok=True) + + self.max_checkpoints = max_checkpoints_per_model + self.metadata_file = self.base_dir / metadata_file + self.enable_wandb = enable_wandb and WANDB_AVAILABLE + + self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list) + self._load_metadata() + + logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}") + + def save_checkpoint(self, model, model_name: str, model_type: str, + performance_metrics: Dict[str, float], + training_metadata: Optional[Dict[str, Any]] = None, + force_save: bool = False) -> Optional[CheckpointMetadata]: + try: + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + checkpoint_id = f"{model_name}_{timestamp}" + + model_dir = self.base_dir / model_name + model_dir.mkdir(exist_ok=True) + + checkpoint_path = model_dir / f"{checkpoint_id}.pt" + + performance_score = self._calculate_performance_score(performance_metrics) + + if not force_save and not self._should_save_checkpoint(model_name, performance_score): + logger.info(f"Skipping checkpoint save for {model_name} - performance not improved") + return None + + success = self._save_model_file(model, checkpoint_path, model_type) + if not success: + return None + + file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024) + + metadata = CheckpointMetadata( + checkpoint_id=checkpoint_id, + model_name=model_name, + model_type=model_type, + file_path=str(checkpoint_path), + created_at=datetime.now(), + file_size_mb=file_size_mb, + performance_score=performance_score, + accuracy=performance_metrics.get('accuracy'), + loss=performance_metrics.get('loss'), + val_accuracy=performance_metrics.get('val_accuracy'), + val_loss=performance_metrics.get('val_loss'), + reward=performance_metrics.get('reward'), + pnl=performance_metrics.get('pnl'), + epoch=training_metadata.get('epoch') if training_metadata else None, + training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None, + total_parameters=training_metadata.get('total_parameters') if training_metadata else None + ) + + if self.enable_wandb and wandb.run is not None: + artifact_name = self._upload_to_wandb(checkpoint_path, metadata) + metadata.wandb_run_id = wandb.run.id + metadata.wandb_artifact_name = artifact_name + + self.checkpoints[model_name].append(metadata) + self._rotate_checkpoints(model_name) + self._save_metadata() + + logger.info(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})") + return metadata + + except Exception as e: + logger.error(f"Error saving checkpoint for {model_name}: {e}") + return None + + def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: + try: + if model_name not in self.checkpoints or not self.checkpoints[model_name]: + logger.warning(f"No checkpoints found for model: {model_name}") + return None + + best_checkpoint = max(self.checkpoints[model_name], key=lambda x: x.performance_score) + + if not Path(best_checkpoint.file_path).exists(): + logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}") + return None + + logger.info(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}") + return best_checkpoint.file_path, best_checkpoint + + except Exception as e: + logger.error(f"Error loading best checkpoint for {model_name}: {e}") + return None + + def _calculate_performance_score(self, metrics: Dict[str, float]) -> float: + score = 0.0 + + if 'accuracy' in metrics: + score += metrics['accuracy'] * 100 + if 'val_accuracy' in metrics: + score += metrics['val_accuracy'] * 100 + if 'loss' in metrics: + score += max(0, 10 - metrics['loss']) + if 'val_loss' in metrics: + score += max(0, 10 - metrics['val_loss']) + if 'reward' in metrics: + score += metrics['reward'] + if 'pnl' in metrics: + score += metrics['pnl'] + + if score == 0.0 and metrics: + first_metric = next(iter(metrics.values())) + score = first_metric if first_metric > 0 else 0.1 + + return max(score, 0.1) + + def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool: + if model_name not in self.checkpoints or not self.checkpoints[model_name]: + return True + + if len(self.checkpoints[model_name]) < self.max_checkpoints: + return True + + worst_score = min(cp.performance_score for cp in self.checkpoints[model_name]) + return performance_score > worst_score + + def _save_model_file(self, model, file_path: Path, model_type: str) -> bool: + try: + if hasattr(model, 'state_dict'): + torch.save({ + 'model_state_dict': model.state_dict(), + 'model_type': model_type, + 'saved_at': datetime.now().isoformat() + }, file_path) + else: + torch.save(model, file_path) + return True + except Exception as e: + logger.error(f"Error saving model file {file_path}: {e}") + return False + + def _rotate_checkpoints(self, model_name: str): + checkpoint_list = self.checkpoints[model_name] + + if len(checkpoint_list) <= self.max_checkpoints: + return + + checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True) + + to_remove = checkpoint_list[self.max_checkpoints:] + self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints] + + for checkpoint in to_remove: + try: + file_path = Path(checkpoint.file_path) + if file_path.exists(): + file_path.unlink() + logger.info(f"Rotated out checkpoint: {checkpoint.checkpoint_id}") + except Exception as e: + logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}") + + def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]: + try: + if not self.enable_wandb or wandb.run is None: + return None + + artifact_name = f"{metadata.model_name}_checkpoint" + artifact = wandb.Artifact(artifact_name, type="model") + artifact.add_file(str(file_path)) + wandb.log_artifact(artifact) + + return artifact_name + except Exception as e: + logger.error(f"Error uploading to W&B: {e}") + return None + + def _load_metadata(self): + try: + if self.metadata_file.exists(): + with open(self.metadata_file, 'r') as f: + data = json.load(f) + + for model_name, checkpoint_list in data.items(): + self.checkpoints[model_name] = [ + CheckpointMetadata.from_dict(cp_data) + for cp_data in checkpoint_list + ] + + logger.info(f"Loaded metadata for {len(self.checkpoints)} models") + except Exception as e: + logger.error(f"Error loading checkpoint metadata: {e}") + + def _save_metadata(self): + try: + data = {} + for model_name, checkpoint_list in self.checkpoints.items(): + data[model_name] = [cp.to_dict() for cp in checkpoint_list] + + with open(self.metadata_file, 'w') as f: + json.dump(data, f, indent=2) + except Exception as e: + logger.error(f"Error saving checkpoint metadata: {e}") + + def get_checkpoint_stats(self): + """Get statistics about managed checkpoints""" + stats = { + 'total_models': len(self.checkpoints), + 'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()), + 'total_size_mb': 0.0, + 'models': {} + } + + for model_name, checkpoint_list in self.checkpoints.items(): + if not checkpoint_list: + continue + + model_size = sum(cp.file_size_mb for cp in checkpoint_list) + best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score) + + stats['models'][model_name] = { + 'checkpoint_count': len(checkpoint_list), + 'total_size_mb': model_size, + 'best_performance': best_checkpoint.performance_score, + 'best_checkpoint_id': best_checkpoint.checkpoint_id, + 'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id + } + + stats['total_size_mb'] += model_size + + return stats + +_checkpoint_manager = None + +def get_checkpoint_manager() -> CheckpointManager: + global _checkpoint_manager + if _checkpoint_manager is None: + _checkpoint_manager = CheckpointManager() + return _checkpoint_manager + +def save_checkpoint(model, model_name: str, model_type: str, + performance_metrics: Dict[str, float], + training_metadata: Optional[Dict[str, Any]] = None, + force_save: bool = False) -> Optional[CheckpointMetadata]: + return get_checkpoint_manager().save_checkpoint( + model, model_name, model_type, performance_metrics, training_metadata, force_save + ) + +def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: + return get_checkpoint_manager().load_best_checkpoint(model_name) diff --git a/utils/training_integration.py b/utils/training_integration.py new file mode 100644 index 0000000..0353a84 --- /dev/null +++ b/utils/training_integration.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +""" +Training Integration for Checkpoint Management +""" + +import logging +import torch +from datetime import datetime +from typing import Dict, Any, Optional +from pathlib import Path + +from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_best_checkpoint + +logger = logging.getLogger(__name__) + +class TrainingIntegration: + def __init__(self, enable_wandb: bool = True): + self.checkpoint_manager = get_checkpoint_manager() + self.enable_wandb = enable_wandb + + if self.enable_wandb: + self._init_wandb() + + def _init_wandb(self): + try: + import wandb + + if wandb.run is None: + wandb.init( + project="gogo2-trading", + name=f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}", + config={ + "max_checkpoints_per_model": self.checkpoint_manager.max_checkpoints, + "checkpoint_dir": str(self.checkpoint_manager.base_dir) + } + ) + logger.info(f"Initialized W&B run: {wandb.run.id}") + + except ImportError: + logger.warning("W&B not available - checkpoint management will work without it") + except Exception as e: + logger.error(f"Error initializing W&B: {e}") + + def save_cnn_checkpoint(self, + cnn_model, + model_name: str, + epoch: int, + train_accuracy: float, + val_accuracy: float, + train_loss: float, + val_loss: float, + training_time_hours: float = None) -> bool: + try: + performance_metrics = { + 'accuracy': train_accuracy, + 'val_accuracy': val_accuracy, + 'loss': train_loss, + 'val_loss': val_loss + } + + training_metadata = { + 'epoch': epoch, + 'training_time_hours': training_time_hours, + 'total_parameters': self._count_parameters(cnn_model) + } + + if self.enable_wandb: + try: + import wandb + if wandb.run is not None: + wandb.log({ + f"{model_name}/train_accuracy": train_accuracy, + f"{model_name}/val_accuracy": val_accuracy, + f"{model_name}/train_loss": train_loss, + f"{model_name}/val_loss": val_loss, + f"{model_name}/epoch": epoch + }) + except Exception as e: + logger.warning(f"Error logging to W&B: {e}") + + metadata = save_checkpoint( + model=cnn_model, + model_name=model_name, + model_type='cnn', + performance_metrics=performance_metrics, + training_metadata=training_metadata + ) + + if metadata: + logger.info(f"CNN checkpoint saved: {metadata.checkpoint_id}") + return True + else: + logger.info(f"CNN checkpoint not saved (performance not improved)") + return False + + except Exception as e: + logger.error(f"Error saving CNN checkpoint: {e}") + return False + + def save_rl_checkpoint(self, + rl_agent, + model_name: str, + episode: int, + avg_reward: float, + best_reward: float, + epsilon: float, + total_pnl: float = None) -> bool: + try: + performance_metrics = { + 'reward': avg_reward, + 'best_reward': best_reward + } + + if total_pnl is not None: + performance_metrics['pnl'] = total_pnl + + training_metadata = { + 'episode': episode, + 'epsilon': epsilon, + 'total_parameters': self._count_parameters(rl_agent) + } + + if self.enable_wandb: + try: + import wandb + if wandb.run is not None: + wandb.log({ + f"{model_name}/avg_reward": avg_reward, + f"{model_name}/best_reward": best_reward, + f"{model_name}/epsilon": epsilon, + f"{model_name}/episode": episode + }) + + if total_pnl is not None: + wandb.log({f"{model_name}/total_pnl": total_pnl}) + + except Exception as e: + logger.warning(f"Error logging to W&B: {e}") + + metadata = save_checkpoint( + model=rl_agent, + model_name=model_name, + model_type='rl', + performance_metrics=performance_metrics, + training_metadata=training_metadata + ) + + if metadata: + logger.info(f"RL checkpoint saved: {metadata.checkpoint_id}") + return True + else: + logger.info(f"RL checkpoint not saved (performance not improved)") + return False + + except Exception as e: + logger.error(f"Error saving RL checkpoint: {e}") + return False + + def load_best_model(self, model_name: str, model_class=None): + try: + result = load_best_checkpoint(model_name) + if not result: + logger.warning(f"No checkpoint found for model: {model_name}") + return None + + file_path, metadata = result + + checkpoint = torch.load(file_path, map_location='cpu') + + logger.info(f"Loaded best checkpoint for {model_name}:") + logger.info(f" Performance score: {metadata.performance_score:.4f}") + logger.info(f" Created: {metadata.created_at}") + + if model_class and 'model_state_dict' in checkpoint: + model = model_class() + model.load_state_dict(checkpoint['model_state_dict']) + return model + + return checkpoint + + except Exception as e: + logger.error(f"Error loading best model {model_name}: {e}") + return None + + def _count_parameters(self, model) -> int: + try: + if hasattr(model, 'parameters'): + return sum(p.numel() for p in model.parameters()) + elif hasattr(model, 'policy_net'): + policy_params = sum(p.numel() for p in model.policy_net.parameters()) + target_params = sum(p.numel() for p in model.target_net.parameters()) if hasattr(model, 'target_net') else 0 + return policy_params + target_params + else: + return 0 + except Exception: + return 0 + +_training_integration = None + +def get_training_integration() -> TrainingIntegration: + global _training_integration + if _training_integration is None: + _training_integration = TrainingIntegration() + return _training_integration