405 lines
12 KiB
Markdown
405 lines
12 KiB
Markdown
# Checkpoint Strategy
|
||
|
||
## Current System
|
||
|
||
### ✅ What Exists
|
||
|
||
The system has a sophisticated checkpoint management system in `utils/checkpoint_manager.py`:
|
||
|
||
1. **Automatic Saving**: Checkpoints saved with metadata
|
||
2. **Performance Tracking**: Tracks metrics (loss, accuracy, reward)
|
||
3. **Best Checkpoint Selection**: Loads best performing checkpoint
|
||
4. **Automatic Cleanup**: Keeps only top N checkpoints
|
||
5. **Database Integration**: Metadata stored in database for fast access
|
||
|
||
### How It Works
|
||
|
||
```python
|
||
# Checkpoint Manager Configuration
|
||
max_checkpoints = 10 # Keep top 10 checkpoints
|
||
metric_name = "accuracy" # Rank by accuracy (or loss, reward)
|
||
checkpoint_dir = "models/checkpoints"
|
||
```
|
||
|
||
---
|
||
|
||
## Checkpoint Saving Logic
|
||
|
||
### When Checkpoints Are Saved
|
||
|
||
**Current Behavior**: Checkpoints are saved at **fixed intervals**, not based on performance improvement.
|
||
|
||
```python
|
||
# Example from DQN Agent
|
||
def save_checkpoint(self, episode_reward: float, force_save: bool = False):
|
||
"""Save checkpoint if performance improved or forced"""
|
||
|
||
# Save every N episodes (e.g., every 100 episodes)
|
||
if self.episode_count % 100 == 0 or force_save:
|
||
save_checkpoint(
|
||
model=self.policy_net,
|
||
model_name=self.model_name,
|
||
model_type="dqn",
|
||
performance_metrics={
|
||
'loss': self.current_loss,
|
||
'reward': episode_reward,
|
||
'accuracy': self.accuracy
|
||
}
|
||
)
|
||
```
|
||
|
||
### Cleanup Logic
|
||
|
||
After saving, the system automatically cleans up:
|
||
|
||
```python
|
||
def _cleanup_checkpoints(self, model_name: str):
|
||
"""
|
||
Keep only the best N checkpoints
|
||
|
||
Process:
|
||
1. Load all checkpoint metadata
|
||
2. Sort by metric (accuracy/loss/reward)
|
||
3. Keep top N (default: 10)
|
||
4. Delete the rest
|
||
"""
|
||
|
||
# Sort by metric (highest first for accuracy, lowest for loss)
|
||
checkpoints.sort(key=lambda x: x['metrics'][metric_name], reverse=True)
|
||
|
||
# Keep only top N
|
||
checkpoints_to_keep = checkpoints[:max_checkpoints]
|
||
checkpoints_to_delete = checkpoints[max_checkpoints:]
|
||
|
||
# Delete old checkpoints
|
||
for checkpoint in checkpoints_to_delete:
|
||
os.remove(checkpoint_path)
|
||
```
|
||
|
||
---
|
||
|
||
## Recommended Strategy
|
||
|
||
### Option 1: Save Every Batch, Keep Best (Current + Enhancement)
|
||
|
||
**Pros**:
|
||
- Never miss a good checkpoint
|
||
- Automatic cleanup keeps disk usage low
|
||
- Simple to implement
|
||
|
||
**Cons**:
|
||
- High I/O overhead (saving every batch)
|
||
- Slower training (disk writes)
|
||
|
||
**Implementation**:
|
||
```python
|
||
def train_step(self, batch):
|
||
# Train
|
||
result = trainer.train_step(batch)
|
||
|
||
# Save checkpoint after EVERY batch
|
||
save_checkpoint(
|
||
model=self.model,
|
||
model_name="transformer",
|
||
model_type="transformer",
|
||
performance_metrics={
|
||
'loss': result['total_loss'],
|
||
'accuracy': result['accuracy']
|
||
}
|
||
)
|
||
# Cleanup automatically keeps only best 10
|
||
```
|
||
|
||
**Disk Usage**: ~10 checkpoints × 200MB = 2GB (manageable)
|
||
|
||
---
|
||
|
||
### Option 2: Save Only If Better (Recommended)
|
||
|
||
**Pros**:
|
||
- Minimal I/O overhead
|
||
- Only saves improvements
|
||
- Faster training
|
||
|
||
**Cons**:
|
||
- Need to track best performance
|
||
- Slightly more complex
|
||
|
||
**Implementation**:
|
||
```python
|
||
class TrainingSession:
|
||
def __init__(self):
|
||
self.best_loss = float('inf')
|
||
self.best_accuracy = 0.0
|
||
self.checkpoints_saved = 0
|
||
|
||
def train_step(self, batch):
|
||
# Train
|
||
result = trainer.train_step(batch)
|
||
|
||
# Check if performance improved
|
||
current_loss = result['total_loss']
|
||
current_accuracy = result['accuracy']
|
||
|
||
# Save if better (lower loss OR higher accuracy)
|
||
if current_loss < self.best_loss or current_accuracy > self.best_accuracy:
|
||
logger.info(f"Performance improved! Loss: {current_loss:.4f} (best: {self.best_loss:.4f}), "
|
||
f"Accuracy: {current_accuracy:.2%} (best: {self.best_accuracy:.2%})")
|
||
|
||
save_checkpoint(
|
||
model=self.model,
|
||
model_name="transformer",
|
||
model_type="transformer",
|
||
performance_metrics={
|
||
'loss': current_loss,
|
||
'accuracy': current_accuracy
|
||
}
|
||
)
|
||
|
||
# Update best metrics
|
||
self.best_loss = min(self.best_loss, current_loss)
|
||
self.best_accuracy = max(self.best_accuracy, current_accuracy)
|
||
self.checkpoints_saved += 1
|
||
```
|
||
|
||
---
|
||
|
||
### Option 3: Hybrid Approach (Best of Both)
|
||
|
||
**Strategy**:
|
||
- Save if performance improved (Option 2)
|
||
- Also save every N batches as backup (Option 1)
|
||
- Keep best 10 checkpoints
|
||
|
||
**Implementation**:
|
||
```python
|
||
def train_step(self, batch, batch_num):
|
||
result = trainer.train_step(batch)
|
||
|
||
current_loss = result['total_loss']
|
||
current_accuracy = result['accuracy']
|
||
|
||
# Condition 1: Performance improved
|
||
performance_improved = (
|
||
current_loss < self.best_loss or
|
||
current_accuracy > self.best_accuracy
|
||
)
|
||
|
||
# Condition 2: Regular interval (every 100 batches)
|
||
regular_interval = (batch_num % 100 == 0)
|
||
|
||
# Save if either condition is met
|
||
if performance_improved or regular_interval:
|
||
reason = "improved" if performance_improved else "interval"
|
||
logger.info(f"Saving checkpoint ({reason}): loss={current_loss:.4f}, acc={current_accuracy:.2%}")
|
||
|
||
save_checkpoint(
|
||
model=self.model,
|
||
model_name="transformer",
|
||
model_type="transformer",
|
||
performance_metrics={
|
||
'loss': current_loss,
|
||
'accuracy': current_accuracy
|
||
},
|
||
training_metadata={
|
||
'batch_num': batch_num,
|
||
'reason': reason,
|
||
'epoch': self.current_epoch
|
||
}
|
||
)
|
||
|
||
# Update best metrics
|
||
if performance_improved:
|
||
self.best_loss = min(self.best_loss, current_loss)
|
||
self.best_accuracy = max(self.best_accuracy, current_accuracy)
|
||
```
|
||
|
||
---
|
||
|
||
## Implementation for ANNOTATE Training
|
||
|
||
### Current Code Location
|
||
|
||
In `ANNOTATE/core/real_training_adapter.py`, the training loop is:
|
||
|
||
```python
|
||
def _train_transformer_real(self, session, training_data):
|
||
# ... setup ...
|
||
|
||
for epoch in range(session.total_epochs):
|
||
for i, batch in enumerate(converted_batches):
|
||
result = trainer.train_step(batch)
|
||
|
||
# ← ADD CHECKPOINT LOGIC HERE
|
||
```
|
||
|
||
### Recommended Addition
|
||
|
||
```python
|
||
def _train_transformer_real(self, session, training_data):
|
||
# Initialize best metrics
|
||
best_loss = float('inf')
|
||
best_accuracy = 0.0
|
||
checkpoints_saved = 0
|
||
|
||
for epoch in range(session.total_epochs):
|
||
for i, batch in enumerate(converted_batches):
|
||
result = trainer.train_step(batch)
|
||
|
||
if result is not None:
|
||
current_loss = result.get('total_loss', float('inf'))
|
||
current_accuracy = result.get('accuracy', 0.0)
|
||
|
||
# Check if performance improved
|
||
performance_improved = (
|
||
current_loss < best_loss or
|
||
current_accuracy > best_accuracy
|
||
)
|
||
|
||
# Save every 100 batches OR if improved
|
||
should_save = performance_improved or (i % 100 == 0 and i > 0)
|
||
|
||
if should_save:
|
||
try:
|
||
# Save checkpoint
|
||
from utils.checkpoint_manager import save_checkpoint
|
||
|
||
checkpoint_metadata = save_checkpoint(
|
||
model=self.orchestrator.primary_transformer,
|
||
model_name="transformer",
|
||
model_type="transformer",
|
||
performance_metrics={
|
||
'loss': current_loss,
|
||
'accuracy': current_accuracy,
|
||
'action_loss': result.get('action_loss', 0.0),
|
||
'price_loss': result.get('price_loss', 0.0)
|
||
},
|
||
training_metadata={
|
||
'epoch': epoch + 1,
|
||
'batch': i + 1,
|
||
'total_batches': len(converted_batches),
|
||
'training_session': session.training_id,
|
||
'reason': 'improved' if performance_improved else 'interval'
|
||
}
|
||
)
|
||
|
||
if checkpoint_metadata:
|
||
checkpoints_saved += 1
|
||
reason = "improved" if performance_improved else "interval"
|
||
logger.info(f" Checkpoint saved ({reason}): {checkpoint_metadata.checkpoint_id}")
|
||
logger.info(f" Loss: {current_loss:.4f}, Accuracy: {current_accuracy:.2%}")
|
||
|
||
# Update best metrics
|
||
if performance_improved:
|
||
best_loss = min(best_loss, current_loss)
|
||
best_accuracy = max(best_accuracy, current_accuracy)
|
||
logger.info(f" New best! Loss: {best_loss:.4f}, Accuracy: {best_accuracy:.2%}")
|
||
|
||
except Exception as e:
|
||
logger.error(f" Error saving checkpoint: {e}")
|
||
|
||
logger.info(f" Training complete: {checkpoints_saved} checkpoints saved")
|
||
logger.info(f" Best loss: {best_loss:.4f}, Best accuracy: {best_accuracy:.2%}")
|
||
```
|
||
|
||
---
|
||
|
||
## Configuration
|
||
|
||
### Checkpoint Settings
|
||
|
||
```python
|
||
# In orchestrator initialization
|
||
checkpoint_manager = get_checkpoint_manager(
|
||
checkpoint_dir="models/checkpoints",
|
||
max_checkpoints=10, # Keep top 10 checkpoints
|
||
metric_name="accuracy" # Rank by accuracy (or "loss")
|
||
)
|
||
```
|
||
|
||
### Tuning Parameters
|
||
|
||
| Parameter | Conservative | Balanced | Aggressive |
|
||
|-----------|-------------|----------|------------|
|
||
| `max_checkpoints` | 20 | 10 | 5 |
|
||
| `save_interval` | 50 batches | 100 batches | 200 batches |
|
||
| `improvement_threshold` | 0.1% | 0.5% | 1.0% |
|
||
|
||
**Conservative**: Save more often, keep more checkpoints (safer, more disk)
|
||
**Balanced**: Default settings (recommended)
|
||
**Aggressive**: Save less often, keep fewer checkpoints (faster, less disk)
|
||
|
||
---
|
||
|
||
## Disk Usage
|
||
|
||
### Per Checkpoint
|
||
|
||
| Model | Size | Notes |
|
||
|-------|------|-------|
|
||
| Transformer (46M params) | ~200MB | Full model + optimizer state |
|
||
| CNN | ~50MB | Smaller model |
|
||
| DQN | ~100MB | Medium model |
|
||
|
||
### Total Storage
|
||
|
||
```
|
||
10 checkpoints × 200MB = 2GB per model
|
||
3 models × 2GB = 6GB total
|
||
|
||
With metadata and backups: ~8GB
|
||
```
|
||
|
||
**Recommendation**: Keep 10 checkpoints (2GB per model is manageable)
|
||
|
||
---
|
||
|
||
## Monitoring
|
||
|
||
### Checkpoint Logs
|
||
|
||
```
|
||
INFO - Checkpoint saved (improved): transformer_20251031_142530
|
||
INFO - Loss: 0.234, Accuracy: 78.5%
|
||
INFO - New best! Loss: 0.234, Accuracy: 78.5%
|
||
|
||
INFO - Checkpoint saved (interval): transformer_20251031_142630
|
||
INFO - Loss: 0.245, Accuracy: 77.2%
|
||
|
||
INFO - Deleted 1 old checkpoints for transformer
|
||
```
|
||
|
||
### Dashboard Metrics
|
||
|
||
```
|
||
Checkpoints Saved: 15
|
||
Best Loss: 0.234
|
||
Best Accuracy: 78.5%
|
||
Disk Usage: 1.8GB / 2.0GB
|
||
Last Checkpoint: 2 minutes ago
|
||
```
|
||
|
||
---
|
||
|
||
## Summary
|
||
|
||
### Current System
|
||
- ✅ Automatic checkpoint management
|
||
- ✅ Keeps best N checkpoints
|
||
- ✅ Database-backed metadata
|
||
- ❌ Saves at fixed intervals (not performance-based)
|
||
|
||
### Recommended Enhancement
|
||
- ✅ Save when performance improves
|
||
- ✅ Also save every N batches as backup
|
||
- ✅ Keep best 10 checkpoints
|
||
- ✅ Minimal I/O overhead
|
||
- ✅ Never miss a good checkpoint
|
||
|
||
### Implementation
|
||
Add checkpoint logic to `_train_transformer_real()` in `real_training_adapter.py` to save when:
|
||
1. Loss decreases OR accuracy increases (performance improved)
|
||
2. Every 100 batches (regular backup)
|
||
|
||
The cleanup system automatically keeps only the best 10 checkpoints!
|