12 KiB
Checkpoint Strategy
Current System
✅ What Exists
The system has a sophisticated checkpoint management system in utils/checkpoint_manager.py:
- Automatic Saving: Checkpoints saved with metadata
- Performance Tracking: Tracks metrics (loss, accuracy, reward)
- Best Checkpoint Selection: Loads best performing checkpoint
- Automatic Cleanup: Keeps only top N checkpoints
- Database Integration: Metadata stored in database for fast access
How It Works
# Checkpoint Manager Configuration
max_checkpoints = 10 # Keep top 10 checkpoints
metric_name = "accuracy" # Rank by accuracy (or loss, reward)
checkpoint_dir = "models/checkpoints"
Checkpoint Saving Logic
When Checkpoints Are Saved
Current Behavior: Checkpoints are saved at fixed intervals, not based on performance improvement.
# Example from DQN Agent
def save_checkpoint(self, episode_reward: float, force_save: bool = False):
"""Save checkpoint if performance improved or forced"""
# Save every N episodes (e.g., every 100 episodes)
if self.episode_count % 100 == 0 or force_save:
save_checkpoint(
model=self.policy_net,
model_name=self.model_name,
model_type="dqn",
performance_metrics={
'loss': self.current_loss,
'reward': episode_reward,
'accuracy': self.accuracy
}
)
Cleanup Logic
After saving, the system automatically cleans up:
def _cleanup_checkpoints(self, model_name: str):
"""
Keep only the best N checkpoints
Process:
1. Load all checkpoint metadata
2. Sort by metric (accuracy/loss/reward)
3. Keep top N (default: 10)
4. Delete the rest
"""
# Sort by metric (highest first for accuracy, lowest for loss)
checkpoints.sort(key=lambda x: x['metrics'][metric_name], reverse=True)
# Keep only top N
checkpoints_to_keep = checkpoints[:max_checkpoints]
checkpoints_to_delete = checkpoints[max_checkpoints:]
# Delete old checkpoints
for checkpoint in checkpoints_to_delete:
os.remove(checkpoint_path)
Recommended Strategy
Option 1: Save Every Batch, Keep Best (Current + Enhancement)
Pros:
- Never miss a good checkpoint
- Automatic cleanup keeps disk usage low
- Simple to implement
Cons:
- High I/O overhead (saving every batch)
- Slower training (disk writes)
Implementation:
def train_step(self, batch):
# Train
result = trainer.train_step(batch)
# Save checkpoint after EVERY batch
save_checkpoint(
model=self.model,
model_name="transformer",
model_type="transformer",
performance_metrics={
'loss': result['total_loss'],
'accuracy': result['accuracy']
}
)
# Cleanup automatically keeps only best 10
Disk Usage: ~10 checkpoints × 200MB = 2GB (manageable)
Option 2: Save Only If Better (Recommended)
Pros:
- Minimal I/O overhead
- Only saves improvements
- Faster training
Cons:
- Need to track best performance
- Slightly more complex
Implementation:
class TrainingSession:
def __init__(self):
self.best_loss = float('inf')
self.best_accuracy = 0.0
self.checkpoints_saved = 0
def train_step(self, batch):
# Train
result = trainer.train_step(batch)
# Check if performance improved
current_loss = result['total_loss']
current_accuracy = result['accuracy']
# Save if better (lower loss OR higher accuracy)
if current_loss < self.best_loss or current_accuracy > self.best_accuracy:
logger.info(f"Performance improved! Loss: {current_loss:.4f} (best: {self.best_loss:.4f}), "
f"Accuracy: {current_accuracy:.2%} (best: {self.best_accuracy:.2%})")
save_checkpoint(
model=self.model,
model_name="transformer",
model_type="transformer",
performance_metrics={
'loss': current_loss,
'accuracy': current_accuracy
}
)
# Update best metrics
self.best_loss = min(self.best_loss, current_loss)
self.best_accuracy = max(self.best_accuracy, current_accuracy)
self.checkpoints_saved += 1
Option 3: Hybrid Approach (Best of Both)
Strategy:
- Save if performance improved (Option 2)
- Also save every N batches as backup (Option 1)
- Keep best 10 checkpoints
Implementation:
def train_step(self, batch, batch_num):
result = trainer.train_step(batch)
current_loss = result['total_loss']
current_accuracy = result['accuracy']
# Condition 1: Performance improved
performance_improved = (
current_loss < self.best_loss or
current_accuracy > self.best_accuracy
)
# Condition 2: Regular interval (every 100 batches)
regular_interval = (batch_num % 100 == 0)
# Save if either condition is met
if performance_improved or regular_interval:
reason = "improved" if performance_improved else "interval"
logger.info(f"Saving checkpoint ({reason}): loss={current_loss:.4f}, acc={current_accuracy:.2%}")
save_checkpoint(
model=self.model,
model_name="transformer",
model_type="transformer",
performance_metrics={
'loss': current_loss,
'accuracy': current_accuracy
},
training_metadata={
'batch_num': batch_num,
'reason': reason,
'epoch': self.current_epoch
}
)
# Update best metrics
if performance_improved:
self.best_loss = min(self.best_loss, current_loss)
self.best_accuracy = max(self.best_accuracy, current_accuracy)
Implementation for ANNOTATE Training
Current Code Location
In ANNOTATE/core/real_training_adapter.py, the training loop is:
def _train_transformer_real(self, session, training_data):
# ... setup ...
for epoch in range(session.total_epochs):
for i, batch in enumerate(converted_batches):
result = trainer.train_step(batch)
# ← ADD CHECKPOINT LOGIC HERE
Recommended Addition
def _train_transformer_real(self, session, training_data):
# Initialize best metrics
best_loss = float('inf')
best_accuracy = 0.0
checkpoints_saved = 0
for epoch in range(session.total_epochs):
for i, batch in enumerate(converted_batches):
result = trainer.train_step(batch)
if result is not None:
current_loss = result.get('total_loss', float('inf'))
current_accuracy = result.get('accuracy', 0.0)
# Check if performance improved
performance_improved = (
current_loss < best_loss or
current_accuracy > best_accuracy
)
# Save every 100 batches OR if improved
should_save = performance_improved or (i % 100 == 0 and i > 0)
if should_save:
try:
# Save checkpoint
from utils.checkpoint_manager import save_checkpoint
checkpoint_metadata = save_checkpoint(
model=self.orchestrator.primary_transformer,
model_name="transformer",
model_type="transformer",
performance_metrics={
'loss': current_loss,
'accuracy': current_accuracy,
'action_loss': result.get('action_loss', 0.0),
'price_loss': result.get('price_loss', 0.0)
},
training_metadata={
'epoch': epoch + 1,
'batch': i + 1,
'total_batches': len(converted_batches),
'training_session': session.training_id,
'reason': 'improved' if performance_improved else 'interval'
}
)
if checkpoint_metadata:
checkpoints_saved += 1
reason = "improved" if performance_improved else "interval"
logger.info(f" Checkpoint saved ({reason}): {checkpoint_metadata.checkpoint_id}")
logger.info(f" Loss: {current_loss:.4f}, Accuracy: {current_accuracy:.2%}")
# Update best metrics
if performance_improved:
best_loss = min(best_loss, current_loss)
best_accuracy = max(best_accuracy, current_accuracy)
logger.info(f" New best! Loss: {best_loss:.4f}, Accuracy: {best_accuracy:.2%}")
except Exception as e:
logger.error(f" Error saving checkpoint: {e}")
logger.info(f" Training complete: {checkpoints_saved} checkpoints saved")
logger.info(f" Best loss: {best_loss:.4f}, Best accuracy: {best_accuracy:.2%}")
Configuration
Checkpoint Settings
# In orchestrator initialization
checkpoint_manager = get_checkpoint_manager(
checkpoint_dir="models/checkpoints",
max_checkpoints=10, # Keep top 10 checkpoints
metric_name="accuracy" # Rank by accuracy (or "loss")
)
Tuning Parameters
| Parameter | Conservative | Balanced | Aggressive |
|---|---|---|---|
max_checkpoints |
20 | 10 | 5 |
save_interval |
50 batches | 100 batches | 200 batches |
improvement_threshold |
0.1% | 0.5% | 1.0% |
Conservative: Save more often, keep more checkpoints (safer, more disk)
Balanced: Default settings (recommended)
Aggressive: Save less often, keep fewer checkpoints (faster, less disk)
Disk Usage
Per Checkpoint
| Model | Size | Notes |
|---|---|---|
| Transformer (46M params) | ~200MB | Full model + optimizer state |
| CNN | ~50MB | Smaller model |
| DQN | ~100MB | Medium model |
Total Storage
10 checkpoints × 200MB = 2GB per model
3 models × 2GB = 6GB total
With metadata and backups: ~8GB
Recommendation: Keep 10 checkpoints (2GB per model is manageable)
Monitoring
Checkpoint Logs
INFO - Checkpoint saved (improved): transformer_20251031_142530
INFO - Loss: 0.234, Accuracy: 78.5%
INFO - New best! Loss: 0.234, Accuracy: 78.5%
INFO - Checkpoint saved (interval): transformer_20251031_142630
INFO - Loss: 0.245, Accuracy: 77.2%
INFO - Deleted 1 old checkpoints for transformer
Dashboard Metrics
Checkpoints Saved: 15
Best Loss: 0.234
Best Accuracy: 78.5%
Disk Usage: 1.8GB / 2.0GB
Last Checkpoint: 2 minutes ago
Summary
Current System
- ✅ Automatic checkpoint management
- ✅ Keeps best N checkpoints
- ✅ Database-backed metadata
- ❌ Saves at fixed intervals (not performance-based)
Recommended Enhancement
- ✅ Save when performance improves
- ✅ Also save every N batches as backup
- ✅ Keep best 10 checkpoints
- ✅ Minimal I/O overhead
- ✅ Never miss a good checkpoint
Implementation
Add checkpoint logic to _train_transformer_real() in real_training_adapter.py to save when:
- Loss decreases OR accuracy increases (performance improved)
- Every 100 batches (regular backup)
The cleanup system automatically keeps only the best 10 checkpoints!