# Checkpoint Strategy ## Current System ### ✅ What Exists The system has a sophisticated checkpoint management system in `utils/checkpoint_manager.py`: 1. **Automatic Saving**: Checkpoints saved with metadata 2. **Performance Tracking**: Tracks metrics (loss, accuracy, reward) 3. **Best Checkpoint Selection**: Loads best performing checkpoint 4. **Automatic Cleanup**: Keeps only top N checkpoints 5. **Database Integration**: Metadata stored in database for fast access ### How It Works ```python # Checkpoint Manager Configuration max_checkpoints = 10 # Keep top 10 checkpoints metric_name = "accuracy" # Rank by accuracy (or loss, reward) checkpoint_dir = "models/checkpoints" ``` --- ## Checkpoint Saving Logic ### When Checkpoints Are Saved **Current Behavior**: Checkpoints are saved at **fixed intervals**, not based on performance improvement. ```python # Example from DQN Agent def save_checkpoint(self, episode_reward: float, force_save: bool = False): """Save checkpoint if performance improved or forced""" # Save every N episodes (e.g., every 100 episodes) if self.episode_count % 100 == 0 or force_save: save_checkpoint( model=self.policy_net, model_name=self.model_name, model_type="dqn", performance_metrics={ 'loss': self.current_loss, 'reward': episode_reward, 'accuracy': self.accuracy } ) ``` ### Cleanup Logic After saving, the system automatically cleans up: ```python def _cleanup_checkpoints(self, model_name: str): """ Keep only the best N checkpoints Process: 1. Load all checkpoint metadata 2. Sort by metric (accuracy/loss/reward) 3. Keep top N (default: 10) 4. Delete the rest """ # Sort by metric (highest first for accuracy, lowest for loss) checkpoints.sort(key=lambda x: x['metrics'][metric_name], reverse=True) # Keep only top N checkpoints_to_keep = checkpoints[:max_checkpoints] checkpoints_to_delete = checkpoints[max_checkpoints:] # Delete old checkpoints for checkpoint in checkpoints_to_delete: os.remove(checkpoint_path) ``` --- ## Recommended Strategy ### Option 1: Save Every Batch, Keep Best (Current + Enhancement) **Pros**: - Never miss a good checkpoint - Automatic cleanup keeps disk usage low - Simple to implement **Cons**: - High I/O overhead (saving every batch) - Slower training (disk writes) **Implementation**: ```python def train_step(self, batch): # Train result = trainer.train_step(batch) # Save checkpoint after EVERY batch save_checkpoint( model=self.model, model_name="transformer", model_type="transformer", performance_metrics={ 'loss': result['total_loss'], 'accuracy': result['accuracy'] } ) # Cleanup automatically keeps only best 10 ``` **Disk Usage**: ~10 checkpoints × 200MB = 2GB (manageable) --- ### Option 2: Save Only If Better (Recommended) **Pros**: - Minimal I/O overhead - Only saves improvements - Faster training **Cons**: - Need to track best performance - Slightly more complex **Implementation**: ```python class TrainingSession: def __init__(self): self.best_loss = float('inf') self.best_accuracy = 0.0 self.checkpoints_saved = 0 def train_step(self, batch): # Train result = trainer.train_step(batch) # Check if performance improved current_loss = result['total_loss'] current_accuracy = result['accuracy'] # Save if better (lower loss OR higher accuracy) if current_loss < self.best_loss or current_accuracy > self.best_accuracy: logger.info(f"Performance improved! Loss: {current_loss:.4f} (best: {self.best_loss:.4f}), " f"Accuracy: {current_accuracy:.2%} (best: {self.best_accuracy:.2%})") save_checkpoint( model=self.model, model_name="transformer", model_type="transformer", performance_metrics={ 'loss': current_loss, 'accuracy': current_accuracy } ) # Update best metrics self.best_loss = min(self.best_loss, current_loss) self.best_accuracy = max(self.best_accuracy, current_accuracy) self.checkpoints_saved += 1 ``` --- ### Option 3: Hybrid Approach (Best of Both) **Strategy**: - Save if performance improved (Option 2) - Also save every N batches as backup (Option 1) - Keep best 10 checkpoints **Implementation**: ```python def train_step(self, batch, batch_num): result = trainer.train_step(batch) current_loss = result['total_loss'] current_accuracy = result['accuracy'] # Condition 1: Performance improved performance_improved = ( current_loss < self.best_loss or current_accuracy > self.best_accuracy ) # Condition 2: Regular interval (every 100 batches) regular_interval = (batch_num % 100 == 0) # Save if either condition is met if performance_improved or regular_interval: reason = "improved" if performance_improved else "interval" logger.info(f"Saving checkpoint ({reason}): loss={current_loss:.4f}, acc={current_accuracy:.2%}") save_checkpoint( model=self.model, model_name="transformer", model_type="transformer", performance_metrics={ 'loss': current_loss, 'accuracy': current_accuracy }, training_metadata={ 'batch_num': batch_num, 'reason': reason, 'epoch': self.current_epoch } ) # Update best metrics if performance_improved: self.best_loss = min(self.best_loss, current_loss) self.best_accuracy = max(self.best_accuracy, current_accuracy) ``` --- ## Implementation for ANNOTATE Training ### Current Code Location In `ANNOTATE/core/real_training_adapter.py`, the training loop is: ```python def _train_transformer_real(self, session, training_data): # ... setup ... for epoch in range(session.total_epochs): for i, batch in enumerate(converted_batches): result = trainer.train_step(batch) # ← ADD CHECKPOINT LOGIC HERE ``` ### Recommended Addition ```python def _train_transformer_real(self, session, training_data): # Initialize best metrics best_loss = float('inf') best_accuracy = 0.0 checkpoints_saved = 0 for epoch in range(session.total_epochs): for i, batch in enumerate(converted_batches): result = trainer.train_step(batch) if result is not None: current_loss = result.get('total_loss', float('inf')) current_accuracy = result.get('accuracy', 0.0) # Check if performance improved performance_improved = ( current_loss < best_loss or current_accuracy > best_accuracy ) # Save every 100 batches OR if improved should_save = performance_improved or (i % 100 == 0 and i > 0) if should_save: try: # Save checkpoint from utils.checkpoint_manager import save_checkpoint checkpoint_metadata = save_checkpoint( model=self.orchestrator.primary_transformer, model_name="transformer", model_type="transformer", performance_metrics={ 'loss': current_loss, 'accuracy': current_accuracy, 'action_loss': result.get('action_loss', 0.0), 'price_loss': result.get('price_loss', 0.0) }, training_metadata={ 'epoch': epoch + 1, 'batch': i + 1, 'total_batches': len(converted_batches), 'training_session': session.training_id, 'reason': 'improved' if performance_improved else 'interval' } ) if checkpoint_metadata: checkpoints_saved += 1 reason = "improved" if performance_improved else "interval" logger.info(f" Checkpoint saved ({reason}): {checkpoint_metadata.checkpoint_id}") logger.info(f" Loss: {current_loss:.4f}, Accuracy: {current_accuracy:.2%}") # Update best metrics if performance_improved: best_loss = min(best_loss, current_loss) best_accuracy = max(best_accuracy, current_accuracy) logger.info(f" New best! Loss: {best_loss:.4f}, Accuracy: {best_accuracy:.2%}") except Exception as e: logger.error(f" Error saving checkpoint: {e}") logger.info(f" Training complete: {checkpoints_saved} checkpoints saved") logger.info(f" Best loss: {best_loss:.4f}, Best accuracy: {best_accuracy:.2%}") ``` --- ## Configuration ### Checkpoint Settings ```python # In orchestrator initialization checkpoint_manager = get_checkpoint_manager( checkpoint_dir="models/checkpoints", max_checkpoints=10, # Keep top 10 checkpoints metric_name="accuracy" # Rank by accuracy (or "loss") ) ``` ### Tuning Parameters | Parameter | Conservative | Balanced | Aggressive | |-----------|-------------|----------|------------| | `max_checkpoints` | 20 | 10 | 5 | | `save_interval` | 50 batches | 100 batches | 200 batches | | `improvement_threshold` | 0.1% | 0.5% | 1.0% | **Conservative**: Save more often, keep more checkpoints (safer, more disk) **Balanced**: Default settings (recommended) **Aggressive**: Save less often, keep fewer checkpoints (faster, less disk) --- ## Disk Usage ### Per Checkpoint | Model | Size | Notes | |-------|------|-------| | Transformer (46M params) | ~200MB | Full model + optimizer state | | CNN | ~50MB | Smaller model | | DQN | ~100MB | Medium model | ### Total Storage ``` 10 checkpoints × 200MB = 2GB per model 3 models × 2GB = 6GB total With metadata and backups: ~8GB ``` **Recommendation**: Keep 10 checkpoints (2GB per model is manageable) --- ## Monitoring ### Checkpoint Logs ``` INFO - Checkpoint saved (improved): transformer_20251031_142530 INFO - Loss: 0.234, Accuracy: 78.5% INFO - New best! Loss: 0.234, Accuracy: 78.5% INFO - Checkpoint saved (interval): transformer_20251031_142630 INFO - Loss: 0.245, Accuracy: 77.2% INFO - Deleted 1 old checkpoints for transformer ``` ### Dashboard Metrics ``` Checkpoints Saved: 15 Best Loss: 0.234 Best Accuracy: 78.5% Disk Usage: 1.8GB / 2.0GB Last Checkpoint: 2 minutes ago ``` --- ## Summary ### Current System - ✅ Automatic checkpoint management - ✅ Keeps best N checkpoints - ✅ Database-backed metadata - ❌ Saves at fixed intervals (not performance-based) ### Recommended Enhancement - ✅ Save when performance improves - ✅ Also save every N batches as backup - ✅ Keep best 10 checkpoints - ✅ Minimal I/O overhead - ✅ Never miss a good checkpoint ### Implementation Add checkpoint logic to `_train_transformer_real()` in `real_training_adapter.py` to save when: 1. Loss decreases OR accuracy increases (performance improved) 2. Every 100 batches (regular backup) The cleanup system automatically keeps only the best 10 checkpoints!