Files
gogo2/CHECKPOINT_STRATEGY.md
2025-10-31 03:52:41 +02:00

12 KiB
Raw Permalink Blame History

Checkpoint Strategy

Current System

What Exists

The system has a sophisticated checkpoint management system in utils/checkpoint_manager.py:

  1. Automatic Saving: Checkpoints saved with metadata
  2. Performance Tracking: Tracks metrics (loss, accuracy, reward)
  3. Best Checkpoint Selection: Loads best performing checkpoint
  4. Automatic Cleanup: Keeps only top N checkpoints
  5. Database Integration: Metadata stored in database for fast access

How It Works

# Checkpoint Manager Configuration
max_checkpoints = 10          # Keep top 10 checkpoints
metric_name = "accuracy"      # Rank by accuracy (or loss, reward)
checkpoint_dir = "models/checkpoints"

Checkpoint Saving Logic

When Checkpoints Are Saved

Current Behavior: Checkpoints are saved at fixed intervals, not based on performance improvement.

# Example from DQN Agent
def save_checkpoint(self, episode_reward: float, force_save: bool = False):
    """Save checkpoint if performance improved or forced"""
    
    # Save every N episodes (e.g., every 100 episodes)
    if self.episode_count % 100 == 0 or force_save:
        save_checkpoint(
            model=self.policy_net,
            model_name=self.model_name,
            model_type="dqn",
            performance_metrics={
                'loss': self.current_loss,
                'reward': episode_reward,
                'accuracy': self.accuracy
            }
        )

Cleanup Logic

After saving, the system automatically cleans up:

def _cleanup_checkpoints(self, model_name: str):
    """
    Keep only the best N checkpoints
    
    Process:
    1. Load all checkpoint metadata
    2. Sort by metric (accuracy/loss/reward)
    3. Keep top N (default: 10)
    4. Delete the rest
    """
    
    # Sort by metric (highest first for accuracy, lowest for loss)
    checkpoints.sort(key=lambda x: x['metrics'][metric_name], reverse=True)
    
    # Keep only top N
    checkpoints_to_keep = checkpoints[:max_checkpoints]
    checkpoints_to_delete = checkpoints[max_checkpoints:]
    
    # Delete old checkpoints
    for checkpoint in checkpoints_to_delete:
        os.remove(checkpoint_path)

Option 1: Save Every Batch, Keep Best (Current + Enhancement)

Pros:

  • Never miss a good checkpoint
  • Automatic cleanup keeps disk usage low
  • Simple to implement

Cons:

  • High I/O overhead (saving every batch)
  • Slower training (disk writes)

Implementation:

def train_step(self, batch):
    # Train
    result = trainer.train_step(batch)
    
    # Save checkpoint after EVERY batch
    save_checkpoint(
        model=self.model,
        model_name="transformer",
        model_type="transformer",
        performance_metrics={
            'loss': result['total_loss'],
            'accuracy': result['accuracy']
        }
    )
    # Cleanup automatically keeps only best 10

Disk Usage: ~10 checkpoints × 200MB = 2GB (manageable)


Pros:

  • Minimal I/O overhead
  • Only saves improvements
  • Faster training

Cons:

  • Need to track best performance
  • Slightly more complex

Implementation:

class TrainingSession:
    def __init__(self):
        self.best_loss = float('inf')
        self.best_accuracy = 0.0
        self.checkpoints_saved = 0

def train_step(self, batch):
    # Train
    result = trainer.train_step(batch)
    
    # Check if performance improved
    current_loss = result['total_loss']
    current_accuracy = result['accuracy']
    
    # Save if better (lower loss OR higher accuracy)
    if current_loss < self.best_loss or current_accuracy > self.best_accuracy:
        logger.info(f"Performance improved! Loss: {current_loss:.4f} (best: {self.best_loss:.4f}), "
                   f"Accuracy: {current_accuracy:.2%} (best: {self.best_accuracy:.2%})")
        
        save_checkpoint(
            model=self.model,
            model_name="transformer",
            model_type="transformer",
            performance_metrics={
                'loss': current_loss,
                'accuracy': current_accuracy
            }
        )
        
        # Update best metrics
        self.best_loss = min(self.best_loss, current_loss)
        self.best_accuracy = max(self.best_accuracy, current_accuracy)
        self.checkpoints_saved += 1

Option 3: Hybrid Approach (Best of Both)

Strategy:

  • Save if performance improved (Option 2)
  • Also save every N batches as backup (Option 1)
  • Keep best 10 checkpoints

Implementation:

def train_step(self, batch, batch_num):
    result = trainer.train_step(batch)
    
    current_loss = result['total_loss']
    current_accuracy = result['accuracy']
    
    # Condition 1: Performance improved
    performance_improved = (
        current_loss < self.best_loss or 
        current_accuracy > self.best_accuracy
    )
    
    # Condition 2: Regular interval (every 100 batches)
    regular_interval = (batch_num % 100 == 0)
    
    # Save if either condition is met
    if performance_improved or regular_interval:
        reason = "improved" if performance_improved else "interval"
        logger.info(f"Saving checkpoint ({reason}): loss={current_loss:.4f}, acc={current_accuracy:.2%}")
        
        save_checkpoint(
            model=self.model,
            model_name="transformer",
            model_type="transformer",
            performance_metrics={
                'loss': current_loss,
                'accuracy': current_accuracy
            },
            training_metadata={
                'batch_num': batch_num,
                'reason': reason,
                'epoch': self.current_epoch
            }
        )
        
        # Update best metrics
        if performance_improved:
            self.best_loss = min(self.best_loss, current_loss)
            self.best_accuracy = max(self.best_accuracy, current_accuracy)

Implementation for ANNOTATE Training

Current Code Location

In ANNOTATE/core/real_training_adapter.py, the training loop is:

def _train_transformer_real(self, session, training_data):
    # ... setup ...
    
    for epoch in range(session.total_epochs):
        for i, batch in enumerate(converted_batches):
            result = trainer.train_step(batch)
            
            # ← ADD CHECKPOINT LOGIC HERE
def _train_transformer_real(self, session, training_data):
    # Initialize best metrics
    best_loss = float('inf')
    best_accuracy = 0.0
    checkpoints_saved = 0
    
    for epoch in range(session.total_epochs):
        for i, batch in enumerate(converted_batches):
            result = trainer.train_step(batch)
            
            if result is not None:
                current_loss = result.get('total_loss', float('inf'))
                current_accuracy = result.get('accuracy', 0.0)
                
                # Check if performance improved
                performance_improved = (
                    current_loss < best_loss or 
                    current_accuracy > best_accuracy
                )
                
                # Save every 100 batches OR if improved
                should_save = performance_improved or (i % 100 == 0 and i > 0)
                
                if should_save:
                    try:
                        # Save checkpoint
                        from utils.checkpoint_manager import save_checkpoint
                        
                        checkpoint_metadata = save_checkpoint(
                            model=self.orchestrator.primary_transformer,
                            model_name="transformer",
                            model_type="transformer",
                            performance_metrics={
                                'loss': current_loss,
                                'accuracy': current_accuracy,
                                'action_loss': result.get('action_loss', 0.0),
                                'price_loss': result.get('price_loss', 0.0)
                            },
                            training_metadata={
                                'epoch': epoch + 1,
                                'batch': i + 1,
                                'total_batches': len(converted_batches),
                                'training_session': session.training_id,
                                'reason': 'improved' if performance_improved else 'interval'
                            }
                        )
                        
                        if checkpoint_metadata:
                            checkpoints_saved += 1
                            reason = "improved" if performance_improved else "interval"
                            logger.info(f"   Checkpoint saved ({reason}): {checkpoint_metadata.checkpoint_id}")
                            logger.info(f"      Loss: {current_loss:.4f}, Accuracy: {current_accuracy:.2%}")
                        
                        # Update best metrics
                        if performance_improved:
                            best_loss = min(best_loss, current_loss)
                            best_accuracy = max(best_accuracy, current_accuracy)
                            logger.info(f"   New best! Loss: {best_loss:.4f}, Accuracy: {best_accuracy:.2%}")
                    
                    except Exception as e:
                        logger.error(f"   Error saving checkpoint: {e}")
    
    logger.info(f"   Training complete: {checkpoints_saved} checkpoints saved")
    logger.info(f"   Best loss: {best_loss:.4f}, Best accuracy: {best_accuracy:.2%}")

Configuration

Checkpoint Settings

# In orchestrator initialization
checkpoint_manager = get_checkpoint_manager(
    checkpoint_dir="models/checkpoints",
    max_checkpoints=10,        # Keep top 10 checkpoints
    metric_name="accuracy"     # Rank by accuracy (or "loss")
)

Tuning Parameters

Parameter Conservative Balanced Aggressive
max_checkpoints 20 10 5
save_interval 50 batches 100 batches 200 batches
improvement_threshold 0.1% 0.5% 1.0%

Conservative: Save more often, keep more checkpoints (safer, more disk)
Balanced: Default settings (recommended)
Aggressive: Save less often, keep fewer checkpoints (faster, less disk)


Disk Usage

Per Checkpoint

Model Size Notes
Transformer (46M params) ~200MB Full model + optimizer state
CNN ~50MB Smaller model
DQN ~100MB Medium model

Total Storage

10 checkpoints × 200MB = 2GB per model
3 models × 2GB = 6GB total

With metadata and backups: ~8GB

Recommendation: Keep 10 checkpoints (2GB per model is manageable)


Monitoring

Checkpoint Logs

INFO - Checkpoint saved (improved): transformer_20251031_142530
INFO -    Loss: 0.234, Accuracy: 78.5%
INFO - New best! Loss: 0.234, Accuracy: 78.5%

INFO - Checkpoint saved (interval): transformer_20251031_142630
INFO -    Loss: 0.245, Accuracy: 77.2%

INFO - Deleted 1 old checkpoints for transformer

Dashboard Metrics

Checkpoints Saved: 15
Best Loss: 0.234
Best Accuracy: 78.5%
Disk Usage: 1.8GB / 2.0GB
Last Checkpoint: 2 minutes ago

Summary

Current System

  • Automatic checkpoint management
  • Keeps best N checkpoints
  • Database-backed metadata
  • Saves at fixed intervals (not performance-based)
  • Save when performance improves
  • Also save every N batches as backup
  • Keep best 10 checkpoints
  • Minimal I/O overhead
  • Never miss a good checkpoint

Implementation

Add checkpoint logic to _train_transformer_real() in real_training_adapter.py to save when:

  1. Loss decreases OR accuracy increases (performance improved)
  2. Every 100 batches (regular backup)

The cleanup system automatically keeps only the best 10 checkpoints!