Files
gogo2/CHECKPOINT_STRATEGY.md
2025-10-31 03:52:41 +02:00

405 lines
12 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Checkpoint Strategy
## Current System
### ✅ What Exists
The system has a sophisticated checkpoint management system in `utils/checkpoint_manager.py`:
1. **Automatic Saving**: Checkpoints saved with metadata
2. **Performance Tracking**: Tracks metrics (loss, accuracy, reward)
3. **Best Checkpoint Selection**: Loads best performing checkpoint
4. **Automatic Cleanup**: Keeps only top N checkpoints
5. **Database Integration**: Metadata stored in database for fast access
### How It Works
```python
# Checkpoint Manager Configuration
max_checkpoints = 10 # Keep top 10 checkpoints
metric_name = "accuracy" # Rank by accuracy (or loss, reward)
checkpoint_dir = "models/checkpoints"
```
---
## Checkpoint Saving Logic
### When Checkpoints Are Saved
**Current Behavior**: Checkpoints are saved at **fixed intervals**, not based on performance improvement.
```python
# Example from DQN Agent
def save_checkpoint(self, episode_reward: float, force_save: bool = False):
"""Save checkpoint if performance improved or forced"""
# Save every N episodes (e.g., every 100 episodes)
if self.episode_count % 100 == 0 or force_save:
save_checkpoint(
model=self.policy_net,
model_name=self.model_name,
model_type="dqn",
performance_metrics={
'loss': self.current_loss,
'reward': episode_reward,
'accuracy': self.accuracy
}
)
```
### Cleanup Logic
After saving, the system automatically cleans up:
```python
def _cleanup_checkpoints(self, model_name: str):
"""
Keep only the best N checkpoints
Process:
1. Load all checkpoint metadata
2. Sort by metric (accuracy/loss/reward)
3. Keep top N (default: 10)
4. Delete the rest
"""
# Sort by metric (highest first for accuracy, lowest for loss)
checkpoints.sort(key=lambda x: x['metrics'][metric_name], reverse=True)
# Keep only top N
checkpoints_to_keep = checkpoints[:max_checkpoints]
checkpoints_to_delete = checkpoints[max_checkpoints:]
# Delete old checkpoints
for checkpoint in checkpoints_to_delete:
os.remove(checkpoint_path)
```
---
## Recommended Strategy
### Option 1: Save Every Batch, Keep Best (Current + Enhancement)
**Pros**:
- Never miss a good checkpoint
- Automatic cleanup keeps disk usage low
- Simple to implement
**Cons**:
- High I/O overhead (saving every batch)
- Slower training (disk writes)
**Implementation**:
```python
def train_step(self, batch):
# Train
result = trainer.train_step(batch)
# Save checkpoint after EVERY batch
save_checkpoint(
model=self.model,
model_name="transformer",
model_type="transformer",
performance_metrics={
'loss': result['total_loss'],
'accuracy': result['accuracy']
}
)
# Cleanup automatically keeps only best 10
```
**Disk Usage**: ~10 checkpoints × 200MB = 2GB (manageable)
---
### Option 2: Save Only If Better (Recommended)
**Pros**:
- Minimal I/O overhead
- Only saves improvements
- Faster training
**Cons**:
- Need to track best performance
- Slightly more complex
**Implementation**:
```python
class TrainingSession:
def __init__(self):
self.best_loss = float('inf')
self.best_accuracy = 0.0
self.checkpoints_saved = 0
def train_step(self, batch):
# Train
result = trainer.train_step(batch)
# Check if performance improved
current_loss = result['total_loss']
current_accuracy = result['accuracy']
# Save if better (lower loss OR higher accuracy)
if current_loss < self.best_loss or current_accuracy > self.best_accuracy:
logger.info(f"Performance improved! Loss: {current_loss:.4f} (best: {self.best_loss:.4f}), "
f"Accuracy: {current_accuracy:.2%} (best: {self.best_accuracy:.2%})")
save_checkpoint(
model=self.model,
model_name="transformer",
model_type="transformer",
performance_metrics={
'loss': current_loss,
'accuracy': current_accuracy
}
)
# Update best metrics
self.best_loss = min(self.best_loss, current_loss)
self.best_accuracy = max(self.best_accuracy, current_accuracy)
self.checkpoints_saved += 1
```
---
### Option 3: Hybrid Approach (Best of Both)
**Strategy**:
- Save if performance improved (Option 2)
- Also save every N batches as backup (Option 1)
- Keep best 10 checkpoints
**Implementation**:
```python
def train_step(self, batch, batch_num):
result = trainer.train_step(batch)
current_loss = result['total_loss']
current_accuracy = result['accuracy']
# Condition 1: Performance improved
performance_improved = (
current_loss < self.best_loss or
current_accuracy > self.best_accuracy
)
# Condition 2: Regular interval (every 100 batches)
regular_interval = (batch_num % 100 == 0)
# Save if either condition is met
if performance_improved or regular_interval:
reason = "improved" if performance_improved else "interval"
logger.info(f"Saving checkpoint ({reason}): loss={current_loss:.4f}, acc={current_accuracy:.2%}")
save_checkpoint(
model=self.model,
model_name="transformer",
model_type="transformer",
performance_metrics={
'loss': current_loss,
'accuracy': current_accuracy
},
training_metadata={
'batch_num': batch_num,
'reason': reason,
'epoch': self.current_epoch
}
)
# Update best metrics
if performance_improved:
self.best_loss = min(self.best_loss, current_loss)
self.best_accuracy = max(self.best_accuracy, current_accuracy)
```
---
## Implementation for ANNOTATE Training
### Current Code Location
In `ANNOTATE/core/real_training_adapter.py`, the training loop is:
```python
def _train_transformer_real(self, session, training_data):
# ... setup ...
for epoch in range(session.total_epochs):
for i, batch in enumerate(converted_batches):
result = trainer.train_step(batch)
# ← ADD CHECKPOINT LOGIC HERE
```
### Recommended Addition
```python
def _train_transformer_real(self, session, training_data):
# Initialize best metrics
best_loss = float('inf')
best_accuracy = 0.0
checkpoints_saved = 0
for epoch in range(session.total_epochs):
for i, batch in enumerate(converted_batches):
result = trainer.train_step(batch)
if result is not None:
current_loss = result.get('total_loss', float('inf'))
current_accuracy = result.get('accuracy', 0.0)
# Check if performance improved
performance_improved = (
current_loss < best_loss or
current_accuracy > best_accuracy
)
# Save every 100 batches OR if improved
should_save = performance_improved or (i % 100 == 0 and i > 0)
if should_save:
try:
# Save checkpoint
from utils.checkpoint_manager import save_checkpoint
checkpoint_metadata = save_checkpoint(
model=self.orchestrator.primary_transformer,
model_name="transformer",
model_type="transformer",
performance_metrics={
'loss': current_loss,
'accuracy': current_accuracy,
'action_loss': result.get('action_loss', 0.0),
'price_loss': result.get('price_loss', 0.0)
},
training_metadata={
'epoch': epoch + 1,
'batch': i + 1,
'total_batches': len(converted_batches),
'training_session': session.training_id,
'reason': 'improved' if performance_improved else 'interval'
}
)
if checkpoint_metadata:
checkpoints_saved += 1
reason = "improved" if performance_improved else "interval"
logger.info(f" Checkpoint saved ({reason}): {checkpoint_metadata.checkpoint_id}")
logger.info(f" Loss: {current_loss:.4f}, Accuracy: {current_accuracy:.2%}")
# Update best metrics
if performance_improved:
best_loss = min(best_loss, current_loss)
best_accuracy = max(best_accuracy, current_accuracy)
logger.info(f" New best! Loss: {best_loss:.4f}, Accuracy: {best_accuracy:.2%}")
except Exception as e:
logger.error(f" Error saving checkpoint: {e}")
logger.info(f" Training complete: {checkpoints_saved} checkpoints saved")
logger.info(f" Best loss: {best_loss:.4f}, Best accuracy: {best_accuracy:.2%}")
```
---
## Configuration
### Checkpoint Settings
```python
# In orchestrator initialization
checkpoint_manager = get_checkpoint_manager(
checkpoint_dir="models/checkpoints",
max_checkpoints=10, # Keep top 10 checkpoints
metric_name="accuracy" # Rank by accuracy (or "loss")
)
```
### Tuning Parameters
| Parameter | Conservative | Balanced | Aggressive |
|-----------|-------------|----------|------------|
| `max_checkpoints` | 20 | 10 | 5 |
| `save_interval` | 50 batches | 100 batches | 200 batches |
| `improvement_threshold` | 0.1% | 0.5% | 1.0% |
**Conservative**: Save more often, keep more checkpoints (safer, more disk)
**Balanced**: Default settings (recommended)
**Aggressive**: Save less often, keep fewer checkpoints (faster, less disk)
---
## Disk Usage
### Per Checkpoint
| Model | Size | Notes |
|-------|------|-------|
| Transformer (46M params) | ~200MB | Full model + optimizer state |
| CNN | ~50MB | Smaller model |
| DQN | ~100MB | Medium model |
### Total Storage
```
10 checkpoints × 200MB = 2GB per model
3 models × 2GB = 6GB total
With metadata and backups: ~8GB
```
**Recommendation**: Keep 10 checkpoints (2GB per model is manageable)
---
## Monitoring
### Checkpoint Logs
```
INFO - Checkpoint saved (improved): transformer_20251031_142530
INFO - Loss: 0.234, Accuracy: 78.5%
INFO - New best! Loss: 0.234, Accuracy: 78.5%
INFO - Checkpoint saved (interval): transformer_20251031_142630
INFO - Loss: 0.245, Accuracy: 77.2%
INFO - Deleted 1 old checkpoints for transformer
```
### Dashboard Metrics
```
Checkpoints Saved: 15
Best Loss: 0.234
Best Accuracy: 78.5%
Disk Usage: 1.8GB / 2.0GB
Last Checkpoint: 2 minutes ago
```
---
## Summary
### Current System
- ✅ Automatic checkpoint management
- ✅ Keeps best N checkpoints
- ✅ Database-backed metadata
- ❌ Saves at fixed intervals (not performance-based)
### Recommended Enhancement
- ✅ Save when performance improves
- ✅ Also save every N batches as backup
- ✅ Keep best 10 checkpoints
- ✅ Minimal I/O overhead
- ✅ Never miss a good checkpoint
### Implementation
Add checkpoint logic to `_train_transformer_real()` in `real_training_adapter.py` to save when:
1. Loss decreases OR accuracy increases (performance improved)
2. Every 100 batches (regular backup)
The cleanup system automatically keeps only the best 10 checkpoints!