checkpoint manager
This commit is contained in:
126
NN/models/saved/checkpoint_metadata.json
Normal file
126
NN/models/saved/checkpoint_metadata.json
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
{
|
||||||
|
"example_cnn": [
|
||||||
|
{
|
||||||
|
"checkpoint_id": "example_cnn_20250624_213913",
|
||||||
|
"model_name": "example_cnn",
|
||||||
|
"model_type": "cnn",
|
||||||
|
"file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt",
|
||||||
|
"created_at": "2025-06-24T21:39:13.559926",
|
||||||
|
"file_size_mb": 0.0797882080078125,
|
||||||
|
"performance_score": 65.67219525381417,
|
||||||
|
"accuracy": 0.28019601724789606,
|
||||||
|
"loss": 1.9252885885630378,
|
||||||
|
"val_accuracy": 0.21531048803825983,
|
||||||
|
"val_loss": 1.953166686238386,
|
||||||
|
"reward": null,
|
||||||
|
"pnl": null,
|
||||||
|
"epoch": 1,
|
||||||
|
"training_time_hours": 0.1,
|
||||||
|
"total_parameters": 20163,
|
||||||
|
"wandb_run_id": null,
|
||||||
|
"wandb_artifact_name": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"checkpoint_id": "example_cnn_20250624_213913",
|
||||||
|
"model_name": "example_cnn",
|
||||||
|
"model_type": "cnn",
|
||||||
|
"file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt",
|
||||||
|
"created_at": "2025-06-24T21:39:13.563368",
|
||||||
|
"file_size_mb": 0.0797882080078125,
|
||||||
|
"performance_score": 85.85617724870231,
|
||||||
|
"accuracy": 0.3797766367576808,
|
||||||
|
"loss": 1.738881079808816,
|
||||||
|
"val_accuracy": 0.31375868989071576,
|
||||||
|
"val_loss": 1.758474336328537,
|
||||||
|
"reward": null,
|
||||||
|
"pnl": null,
|
||||||
|
"epoch": 2,
|
||||||
|
"training_time_hours": 0.2,
|
||||||
|
"total_parameters": 20163,
|
||||||
|
"wandb_run_id": null,
|
||||||
|
"wandb_artifact_name": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"checkpoint_id": "example_cnn_20250624_213913",
|
||||||
|
"model_name": "example_cnn",
|
||||||
|
"model_type": "cnn",
|
||||||
|
"file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt",
|
||||||
|
"created_at": "2025-06-24T21:39:13.566494",
|
||||||
|
"file_size_mb": 0.0797882080078125,
|
||||||
|
"performance_score": 96.86696983784515,
|
||||||
|
"accuracy": 0.41565501055141396,
|
||||||
|
"loss": 1.731468873500252,
|
||||||
|
"val_accuracy": 0.38848400580514414,
|
||||||
|
"val_loss": 1.8154629243104177,
|
||||||
|
"reward": null,
|
||||||
|
"pnl": null,
|
||||||
|
"epoch": 3,
|
||||||
|
"training_time_hours": 0.30000000000000004,
|
||||||
|
"total_parameters": 20163,
|
||||||
|
"wandb_run_id": null,
|
||||||
|
"wandb_artifact_name": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"checkpoint_id": "example_cnn_20250624_213913",
|
||||||
|
"model_name": "example_cnn",
|
||||||
|
"model_type": "cnn",
|
||||||
|
"file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt",
|
||||||
|
"created_at": "2025-06-24T21:39:13.569547",
|
||||||
|
"file_size_mb": 0.0797882080078125,
|
||||||
|
"performance_score": 106.29887197896815,
|
||||||
|
"accuracy": 0.4639872237832544,
|
||||||
|
"loss": 1.4731813440281318,
|
||||||
|
"val_accuracy": 0.4291565645756503,
|
||||||
|
"val_loss": 1.5423255128941882,
|
||||||
|
"reward": null,
|
||||||
|
"pnl": null,
|
||||||
|
"epoch": 4,
|
||||||
|
"training_time_hours": 0.4,
|
||||||
|
"total_parameters": 20163,
|
||||||
|
"wandb_run_id": null,
|
||||||
|
"wandb_artifact_name": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"checkpoint_id": "example_cnn_20250624_213913",
|
||||||
|
"model_name": "example_cnn",
|
||||||
|
"model_type": "cnn",
|
||||||
|
"file_path": "NN\\models\\saved\\example_cnn\\example_cnn_20250624_213913.pt",
|
||||||
|
"created_at": "2025-06-24T21:39:13.575375",
|
||||||
|
"file_size_mb": 0.0797882080078125,
|
||||||
|
"performance_score": 115.87168812846218,
|
||||||
|
"accuracy": 0.5256293272461906,
|
||||||
|
"loss": 1.3264778472364203,
|
||||||
|
"val_accuracy": 0.46011511860837684,
|
||||||
|
"val_loss": 1.3762786097581432,
|
||||||
|
"reward": null,
|
||||||
|
"pnl": null,
|
||||||
|
"epoch": 5,
|
||||||
|
"training_time_hours": 0.5,
|
||||||
|
"total_parameters": 20163,
|
||||||
|
"wandb_run_id": null,
|
||||||
|
"wandb_artifact_name": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"example_manual": [
|
||||||
|
{
|
||||||
|
"checkpoint_id": "example_manual_20250624_213913",
|
||||||
|
"model_name": "example_manual",
|
||||||
|
"model_type": "cnn",
|
||||||
|
"file_path": "NN\\models\\saved\\example_manual\\example_manual_20250624_213913.pt",
|
||||||
|
"created_at": "2025-06-24T21:39:13.578488",
|
||||||
|
"file_size_mb": 0.0018634796142578125,
|
||||||
|
"performance_score": 186.07000000000002,
|
||||||
|
"accuracy": 0.85,
|
||||||
|
"loss": 0.45,
|
||||||
|
"val_accuracy": 0.82,
|
||||||
|
"val_loss": 0.48,
|
||||||
|
"reward": null,
|
||||||
|
"pnl": null,
|
||||||
|
"epoch": 25,
|
||||||
|
"training_time_hours": 2.5,
|
||||||
|
"total_parameters": 33,
|
||||||
|
"wandb_run_id": null,
|
||||||
|
"wandb_artifact_name": null
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
6
_dev/notes.md
Normal file
6
_dev/notes.md
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
how we manage our training W&B checkpoints? we need to clean up old checlpoints. for every model we keep 5 checkpoints maximum and rotate them. by default we always load te best, and during training when we save new we discard the 6th ordered by performance
|
||||||
|
|
||||||
|
add integration of the checkpoint manager to all training pipelines
|
||||||
|
|
||||||
|
we stopped showing executed trades on the chart. let's add them back
|
||||||
|
|
186
cleanup_checkpoints.py
Normal file
186
cleanup_checkpoints.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Checkpoint Cleanup and Migration Script
|
||||||
|
|
||||||
|
This script helps clean up existing checkpoints and migrate to the new
|
||||||
|
checkpoint management system with W&B integration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from utils.checkpoint_manager import get_checkpoint_manager, CheckpointMetadata
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class CheckpointCleanup:
|
||||||
|
def __init__(self):
|
||||||
|
self.saved_models_dir = Path("NN/models/saved")
|
||||||
|
self.checkpoint_manager = get_checkpoint_manager()
|
||||||
|
|
||||||
|
def analyze_existing_checkpoints(self) -> Dict[str, Any]:
|
||||||
|
logger.info("Analyzing existing checkpoint files...")
|
||||||
|
|
||||||
|
analysis = {
|
||||||
|
'total_files': 0,
|
||||||
|
'total_size_mb': 0.0,
|
||||||
|
'model_types': {},
|
||||||
|
'file_patterns': {},
|
||||||
|
'potential_duplicates': []
|
||||||
|
}
|
||||||
|
|
||||||
|
if not self.saved_models_dir.exists():
|
||||||
|
logger.warning(f"Saved models directory not found: {self.saved_models_dir}")
|
||||||
|
return analysis
|
||||||
|
|
||||||
|
for pt_file in self.saved_models_dir.rglob("*.pt"):
|
||||||
|
try:
|
||||||
|
file_size_mb = pt_file.stat().st_size / (1024 * 1024)
|
||||||
|
analysis['total_files'] += 1
|
||||||
|
analysis['total_size_mb'] += file_size_mb
|
||||||
|
|
||||||
|
filename = pt_file.name
|
||||||
|
|
||||||
|
if 'cnn' in filename.lower():
|
||||||
|
model_type = 'cnn'
|
||||||
|
elif 'dqn' in filename.lower() or 'rl' in filename.lower():
|
||||||
|
model_type = 'rl'
|
||||||
|
elif 'agent' in filename.lower():
|
||||||
|
model_type = 'rl'
|
||||||
|
else:
|
||||||
|
model_type = 'unknown'
|
||||||
|
|
||||||
|
if model_type not in analysis['model_types']:
|
||||||
|
analysis['model_types'][model_type] = {'count': 0, 'size_mb': 0.0}
|
||||||
|
|
||||||
|
analysis['model_types'][model_type]['count'] += 1
|
||||||
|
analysis['model_types'][model_type]['size_mb'] += file_size_mb
|
||||||
|
|
||||||
|
base_name = filename.split('_')[0] if '_' in filename else filename.replace('.pt', '')
|
||||||
|
if base_name not in analysis['file_patterns']:
|
||||||
|
analysis['file_patterns'][base_name] = []
|
||||||
|
|
||||||
|
analysis['file_patterns'][base_name].append({
|
||||||
|
'path': str(pt_file),
|
||||||
|
'size_mb': file_size_mb,
|
||||||
|
'modified': datetime.fromtimestamp(pt_file.stat().st_mtime)
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error analyzing {pt_file}: {e}")
|
||||||
|
|
||||||
|
for base_name, files in analysis['file_patterns'].items():
|
||||||
|
if len(files) > 5: # More than 5 files with same base name
|
||||||
|
analysis['potential_duplicates'].append({
|
||||||
|
'base_name': base_name,
|
||||||
|
'count': len(files),
|
||||||
|
'total_size_mb': sum(f['size_mb'] for f in files),
|
||||||
|
'files': files
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Analysis complete:")
|
||||||
|
logger.info(f" Total files: {analysis['total_files']}")
|
||||||
|
logger.info(f" Total size: {analysis['total_size_mb']:.2f} MB")
|
||||||
|
logger.info(f" Model types: {analysis['model_types']}")
|
||||||
|
logger.info(f" Potential duplicates: {len(analysis['potential_duplicates'])}")
|
||||||
|
|
||||||
|
return analysis
|
||||||
|
|
||||||
|
def cleanup_duplicates(self, dry_run: bool = True) -> Dict[str, Any]:
|
||||||
|
logger.info(f"Starting duplicate cleanup (dry_run={dry_run})...")
|
||||||
|
|
||||||
|
cleanup_results = {
|
||||||
|
'removed': 0,
|
||||||
|
'kept': 0,
|
||||||
|
'space_saved_mb': 0.0,
|
||||||
|
'details': []
|
||||||
|
}
|
||||||
|
|
||||||
|
analysis = self.analyze_existing_checkpoints()
|
||||||
|
|
||||||
|
for duplicate_group in analysis['potential_duplicates']:
|
||||||
|
base_name = duplicate_group['base_name']
|
||||||
|
files = duplicate_group['files']
|
||||||
|
|
||||||
|
# Sort by modification time (newest first)
|
||||||
|
files.sort(key=lambda x: x['modified'], reverse=True)
|
||||||
|
|
||||||
|
logger.info(f"Processing {base_name}: {len(files)} files")
|
||||||
|
|
||||||
|
# Keep only the 5 newest files
|
||||||
|
for i, file_info in enumerate(files):
|
||||||
|
if i < 5: # Keep first 5 (newest)
|
||||||
|
cleanup_results['kept'] += 1
|
||||||
|
cleanup_results['details'].append({
|
||||||
|
'action': 'kept',
|
||||||
|
'file': file_info['path']
|
||||||
|
})
|
||||||
|
else: # Remove the rest
|
||||||
|
if not dry_run:
|
||||||
|
try:
|
||||||
|
Path(file_info['path']).unlink()
|
||||||
|
logger.info(f"Removed: {file_info['path']}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error removing {file_info['path']}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
cleanup_results['removed'] += 1
|
||||||
|
cleanup_results['space_saved_mb'] += file_info['size_mb']
|
||||||
|
cleanup_results['details'].append({
|
||||||
|
'action': 'removed',
|
||||||
|
'file': file_info['path'],
|
||||||
|
'size_mb': file_info['size_mb']
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Cleanup {'simulation' if dry_run else 'complete'}:")
|
||||||
|
logger.info(f" Kept: {cleanup_results['kept']}")
|
||||||
|
logger.info(f" Removed: {cleanup_results['removed']}")
|
||||||
|
logger.info(f" Space saved: {cleanup_results['space_saved_mb']:.2f} MB")
|
||||||
|
|
||||||
|
return cleanup_results
|
||||||
|
|
||||||
|
def main():
|
||||||
|
logger.info("=== Checkpoint Cleanup Tool ===")
|
||||||
|
|
||||||
|
cleanup = CheckpointCleanup()
|
||||||
|
|
||||||
|
# Analyze existing checkpoints
|
||||||
|
logger.info("\\n1. Analyzing existing checkpoints...")
|
||||||
|
analysis = cleanup.analyze_existing_checkpoints()
|
||||||
|
|
||||||
|
if analysis['total_files'] == 0:
|
||||||
|
logger.info("No checkpoint files found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Show potential space savings
|
||||||
|
total_duplicates = sum(len(group['files']) - 5 for group in analysis['potential_duplicates'] if len(group['files']) > 5)
|
||||||
|
if total_duplicates > 0:
|
||||||
|
logger.info(f"\\nFound {total_duplicates} files that could be cleaned up")
|
||||||
|
|
||||||
|
# Dry run first
|
||||||
|
logger.info("\\n2. Simulating cleanup...")
|
||||||
|
dry_run_results = cleanup.cleanup_duplicates(dry_run=True)
|
||||||
|
|
||||||
|
if dry_run_results['removed'] > 0:
|
||||||
|
proceed = input(f"\\nProceed with cleanup? Will remove {dry_run_results['removed']} files "
|
||||||
|
f"and save {dry_run_results['space_saved_mb']:.2f} MB. (y/n): ").lower().strip() == 'y'
|
||||||
|
|
||||||
|
if proceed:
|
||||||
|
logger.info("\\n3. Performing actual cleanup...")
|
||||||
|
cleanup_results = cleanup.cleanup_duplicates(dry_run=False)
|
||||||
|
logger.info("\\n=== Cleanup Complete ===")
|
||||||
|
else:
|
||||||
|
logger.info("Cleanup cancelled.")
|
||||||
|
else:
|
||||||
|
logger.info("No files to remove.")
|
||||||
|
else:
|
||||||
|
logger.info("No duplicate files found that need cleanup.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
148
example_checkpoint_usage.py
Normal file
148
example_checkpoint_usage.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Example: Using the Checkpoint Management System
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint, get_checkpoint_manager
|
||||||
|
from utils.training_integration import get_training_integration
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ExampleCNN(nn.Module):
|
||||||
|
def __init__(self, input_channels=5, num_classes=3):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = nn.Conv2d(input_channels, 32, 3, padding=1)
|
||||||
|
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
|
||||||
|
self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
self.fc = nn.Linear(64, num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.relu(self.conv1(x))
|
||||||
|
x = torch.relu(self.conv2(x))
|
||||||
|
x = self.pool(x)
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
return self.fc(x)
|
||||||
|
|
||||||
|
def example_cnn_training():
|
||||||
|
logger.info("=== CNN Training Example ===")
|
||||||
|
|
||||||
|
model = ExampleCNN()
|
||||||
|
training_integration = get_training_integration()
|
||||||
|
|
||||||
|
for epoch in range(5): # Simulate 5 epochs
|
||||||
|
# Simulate training metrics
|
||||||
|
train_loss = 2.0 - (epoch * 0.15) + np.random.normal(0, 0.1)
|
||||||
|
train_acc = 0.3 + (epoch * 0.06) + np.random.normal(0, 0.02)
|
||||||
|
val_loss = train_loss + np.random.normal(0, 0.05)
|
||||||
|
val_acc = train_acc - 0.05 + np.random.normal(0, 0.02)
|
||||||
|
|
||||||
|
# Clamp values to realistic ranges
|
||||||
|
train_acc = max(0.0, min(1.0, train_acc))
|
||||||
|
val_acc = max(0.0, min(1.0, val_acc))
|
||||||
|
train_loss = max(0.1, train_loss)
|
||||||
|
val_loss = max(0.1, val_loss)
|
||||||
|
|
||||||
|
logger.info(f"Epoch {epoch+1}: train_acc={train_acc:.3f}, val_acc={val_acc:.3f}")
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
saved = training_integration.save_cnn_checkpoint(
|
||||||
|
cnn_model=model,
|
||||||
|
model_name="example_cnn",
|
||||||
|
epoch=epoch + 1,
|
||||||
|
train_accuracy=train_acc,
|
||||||
|
val_accuracy=val_acc,
|
||||||
|
train_loss=train_loss,
|
||||||
|
val_loss=val_loss,
|
||||||
|
training_time_hours=0.1 * (epoch + 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
if saved:
|
||||||
|
logger.info(f" Checkpoint saved for epoch {epoch+1}")
|
||||||
|
else:
|
||||||
|
logger.info(f" Checkpoint not saved (performance not improved)")
|
||||||
|
|
||||||
|
# Load the best checkpoint
|
||||||
|
logger.info("\\nLoading best checkpoint...")
|
||||||
|
best_result = load_best_checkpoint("example_cnn")
|
||||||
|
if best_result:
|
||||||
|
file_path, metadata = best_result
|
||||||
|
logger.info(f"Best checkpoint: {metadata.checkpoint_id}")
|
||||||
|
logger.info(f"Performance score: {metadata.performance_score:.4f}")
|
||||||
|
|
||||||
|
def example_manual_checkpoint():
|
||||||
|
logger.info("\\n=== Manual Checkpoint Example ===")
|
||||||
|
|
||||||
|
model = nn.Linear(10, 3)
|
||||||
|
|
||||||
|
performance_metrics = {
|
||||||
|
'accuracy': 0.85,
|
||||||
|
'val_accuracy': 0.82,
|
||||||
|
'loss': 0.45,
|
||||||
|
'val_loss': 0.48
|
||||||
|
}
|
||||||
|
|
||||||
|
training_metadata = {
|
||||||
|
'epoch': 25,
|
||||||
|
'training_time_hours': 2.5,
|
||||||
|
'total_parameters': sum(p.numel() for p in model.parameters())
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("Saving checkpoint manually...")
|
||||||
|
metadata = save_checkpoint(
|
||||||
|
model=model,
|
||||||
|
model_name="example_manual",
|
||||||
|
model_type="cnn",
|
||||||
|
performance_metrics=performance_metrics,
|
||||||
|
training_metadata=training_metadata,
|
||||||
|
force_save=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
logger.info(f" Manual checkpoint saved: {metadata.checkpoint_id}")
|
||||||
|
logger.info(f" Performance score: {metadata.performance_score:.4f}")
|
||||||
|
|
||||||
|
def show_checkpoint_stats():
|
||||||
|
logger.info("\\n=== Checkpoint Statistics ===")
|
||||||
|
|
||||||
|
checkpoint_manager = get_checkpoint_manager()
|
||||||
|
stats = checkpoint_manager.get_checkpoint_stats()
|
||||||
|
|
||||||
|
logger.info(f"Total models: {stats['total_models']}")
|
||||||
|
logger.info(f"Total checkpoints: {stats['total_checkpoints']}")
|
||||||
|
logger.info(f"Total size: {stats['total_size_mb']:.2f} MB")
|
||||||
|
|
||||||
|
for model_name, model_stats in stats['models'].items():
|
||||||
|
logger.info(f"\\n{model_name}:")
|
||||||
|
logger.info(f" Checkpoints: {model_stats['checkpoint_count']}")
|
||||||
|
logger.info(f" Size: {model_stats['total_size_mb']:.2f} MB")
|
||||||
|
logger.info(f" Best performance: {model_stats['best_performance']:.4f}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
logger.info(" Checkpoint Management System Examples")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
example_cnn_training()
|
||||||
|
example_manual_checkpoint()
|
||||||
|
show_checkpoint_stats()
|
||||||
|
|
||||||
|
logger.info("\\n All examples completed successfully!")
|
||||||
|
logger.info("\\nTo use in your training:")
|
||||||
|
logger.info("1. Import: from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint")
|
||||||
|
logger.info("2. Or use: from utils.training_integration import get_training_integration")
|
||||||
|
logger.info("3. Save checkpoints during training with performance metrics")
|
||||||
|
logger.info("4. Load best checkpoints for inference or continued training")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in examples: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -13,4 +13,5 @@ torchaudio>=2.0.0
|
|||||||
scikit-learn>=1.3.0
|
scikit-learn>=1.3.0
|
||||||
matplotlib>=3.7.0
|
matplotlib>=3.7.0
|
||||||
seaborn>=0.12.0
|
seaborn>=0.12.0
|
||||||
asyncio-compat>=0.1.2
|
asyncio-compat>=0.1.2
|
||||||
|
wandb>=0.16.0
|
306
utils/checkpoint_manager.py
Normal file
306
utils/checkpoint_manager.py
Normal file
@ -0,0 +1,306 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Checkpoint Management System for W&B Training
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple, Any
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from collections import defaultdict
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
WANDB_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
WANDB_AVAILABLE = False
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CheckpointMetadata:
|
||||||
|
checkpoint_id: str
|
||||||
|
model_name: str
|
||||||
|
model_type: str
|
||||||
|
file_path: str
|
||||||
|
created_at: datetime
|
||||||
|
file_size_mb: float
|
||||||
|
performance_score: float
|
||||||
|
accuracy: Optional[float] = None
|
||||||
|
loss: Optional[float] = None
|
||||||
|
val_accuracy: Optional[float] = None
|
||||||
|
val_loss: Optional[float] = None
|
||||||
|
reward: Optional[float] = None
|
||||||
|
pnl: Optional[float] = None
|
||||||
|
epoch: Optional[int] = None
|
||||||
|
training_time_hours: Optional[float] = None
|
||||||
|
total_parameters: Optional[int] = None
|
||||||
|
wandb_run_id: Optional[str] = None
|
||||||
|
wandb_artifact_name: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
data = asdict(self)
|
||||||
|
data['created_at'] = self.created_at.isoformat()
|
||||||
|
return data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata':
|
||||||
|
data['created_at'] = datetime.fromisoformat(data['created_at'])
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
class CheckpointManager:
|
||||||
|
def __init__(self,
|
||||||
|
base_checkpoint_dir: str = "NN/models/saved",
|
||||||
|
max_checkpoints_per_model: int = 5,
|
||||||
|
metadata_file: str = "checkpoint_metadata.json",
|
||||||
|
enable_wandb: bool = True):
|
||||||
|
self.base_dir = Path(base_checkpoint_dir)
|
||||||
|
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.max_checkpoints = max_checkpoints_per_model
|
||||||
|
self.metadata_file = self.base_dir / metadata_file
|
||||||
|
self.enable_wandb = enable_wandb and WANDB_AVAILABLE
|
||||||
|
|
||||||
|
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
|
||||||
|
self._load_metadata()
|
||||||
|
|
||||||
|
logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}")
|
||||||
|
|
||||||
|
def save_checkpoint(self, model, model_name: str, model_type: str,
|
||||||
|
performance_metrics: Dict[str, float],
|
||||||
|
training_metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
force_save: bool = False) -> Optional[CheckpointMetadata]:
|
||||||
|
try:
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
checkpoint_id = f"{model_name}_{timestamp}"
|
||||||
|
|
||||||
|
model_dir = self.base_dir / model_name
|
||||||
|
model_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
checkpoint_path = model_dir / f"{checkpoint_id}.pt"
|
||||||
|
|
||||||
|
performance_score = self._calculate_performance_score(performance_metrics)
|
||||||
|
|
||||||
|
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
|
||||||
|
logger.info(f"Skipping checkpoint save for {model_name} - performance not improved")
|
||||||
|
return None
|
||||||
|
|
||||||
|
success = self._save_model_file(model, checkpoint_path, model_type)
|
||||||
|
if not success:
|
||||||
|
return None
|
||||||
|
|
||||||
|
file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024)
|
||||||
|
|
||||||
|
metadata = CheckpointMetadata(
|
||||||
|
checkpoint_id=checkpoint_id,
|
||||||
|
model_name=model_name,
|
||||||
|
model_type=model_type,
|
||||||
|
file_path=str(checkpoint_path),
|
||||||
|
created_at=datetime.now(),
|
||||||
|
file_size_mb=file_size_mb,
|
||||||
|
performance_score=performance_score,
|
||||||
|
accuracy=performance_metrics.get('accuracy'),
|
||||||
|
loss=performance_metrics.get('loss'),
|
||||||
|
val_accuracy=performance_metrics.get('val_accuracy'),
|
||||||
|
val_loss=performance_metrics.get('val_loss'),
|
||||||
|
reward=performance_metrics.get('reward'),
|
||||||
|
pnl=performance_metrics.get('pnl'),
|
||||||
|
epoch=training_metadata.get('epoch') if training_metadata else None,
|
||||||
|
training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None,
|
||||||
|
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.enable_wandb and wandb.run is not None:
|
||||||
|
artifact_name = self._upload_to_wandb(checkpoint_path, metadata)
|
||||||
|
metadata.wandb_run_id = wandb.run.id
|
||||||
|
metadata.wandb_artifact_name = artifact_name
|
||||||
|
|
||||||
|
self.checkpoints[model_name].append(metadata)
|
||||||
|
self._rotate_checkpoints(model_name)
|
||||||
|
self._save_metadata()
|
||||||
|
|
||||||
|
logger.info(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})")
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
||||||
|
try:
|
||||||
|
if model_name not in self.checkpoints or not self.checkpoints[model_name]:
|
||||||
|
logger.warning(f"No checkpoints found for model: {model_name}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
best_checkpoint = max(self.checkpoints[model_name], key=lambda x: x.performance_score)
|
||||||
|
|
||||||
|
if not Path(best_checkpoint.file_path).exists():
|
||||||
|
logger.error(f"Best checkpoint file not found: {best_checkpoint.file_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}")
|
||||||
|
return best_checkpoint.file_path, best_checkpoint
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _calculate_performance_score(self, metrics: Dict[str, float]) -> float:
|
||||||
|
score = 0.0
|
||||||
|
|
||||||
|
if 'accuracy' in metrics:
|
||||||
|
score += metrics['accuracy'] * 100
|
||||||
|
if 'val_accuracy' in metrics:
|
||||||
|
score += metrics['val_accuracy'] * 100
|
||||||
|
if 'loss' in metrics:
|
||||||
|
score += max(0, 10 - metrics['loss'])
|
||||||
|
if 'val_loss' in metrics:
|
||||||
|
score += max(0, 10 - metrics['val_loss'])
|
||||||
|
if 'reward' in metrics:
|
||||||
|
score += metrics['reward']
|
||||||
|
if 'pnl' in metrics:
|
||||||
|
score += metrics['pnl']
|
||||||
|
|
||||||
|
if score == 0.0 and metrics:
|
||||||
|
first_metric = next(iter(metrics.values()))
|
||||||
|
score = first_metric if first_metric > 0 else 0.1
|
||||||
|
|
||||||
|
return max(score, 0.1)
|
||||||
|
|
||||||
|
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
|
||||||
|
if model_name not in self.checkpoints or not self.checkpoints[model_name]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if len(self.checkpoints[model_name]) < self.max_checkpoints:
|
||||||
|
return True
|
||||||
|
|
||||||
|
worst_score = min(cp.performance_score for cp in self.checkpoints[model_name])
|
||||||
|
return performance_score > worst_score
|
||||||
|
|
||||||
|
def _save_model_file(self, model, file_path: Path, model_type: str) -> bool:
|
||||||
|
try:
|
||||||
|
if hasattr(model, 'state_dict'):
|
||||||
|
torch.save({
|
||||||
|
'model_state_dict': model.state_dict(),
|
||||||
|
'model_type': model_type,
|
||||||
|
'saved_at': datetime.now().isoformat()
|
||||||
|
}, file_path)
|
||||||
|
else:
|
||||||
|
torch.save(model, file_path)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving model file {file_path}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _rotate_checkpoints(self, model_name: str):
|
||||||
|
checkpoint_list = self.checkpoints[model_name]
|
||||||
|
|
||||||
|
if len(checkpoint_list) <= self.max_checkpoints:
|
||||||
|
return
|
||||||
|
|
||||||
|
checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True)
|
||||||
|
|
||||||
|
to_remove = checkpoint_list[self.max_checkpoints:]
|
||||||
|
self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints]
|
||||||
|
|
||||||
|
for checkpoint in to_remove:
|
||||||
|
try:
|
||||||
|
file_path = Path(checkpoint.file_path)
|
||||||
|
if file_path.exists():
|
||||||
|
file_path.unlink()
|
||||||
|
logger.info(f"Rotated out checkpoint: {checkpoint.checkpoint_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
|
||||||
|
|
||||||
|
def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
if not self.enable_wandb or wandb.run is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
artifact_name = f"{metadata.model_name}_checkpoint"
|
||||||
|
artifact = wandb.Artifact(artifact_name, type="model")
|
||||||
|
artifact.add_file(str(file_path))
|
||||||
|
wandb.log_artifact(artifact)
|
||||||
|
|
||||||
|
return artifact_name
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error uploading to W&B: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _load_metadata(self):
|
||||||
|
try:
|
||||||
|
if self.metadata_file.exists():
|
||||||
|
with open(self.metadata_file, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
for model_name, checkpoint_list in data.items():
|
||||||
|
self.checkpoints[model_name] = [
|
||||||
|
CheckpointMetadata.from_dict(cp_data)
|
||||||
|
for cp_data in checkpoint_list
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(f"Loaded metadata for {len(self.checkpoints)} models")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading checkpoint metadata: {e}")
|
||||||
|
|
||||||
|
def _save_metadata(self):
|
||||||
|
try:
|
||||||
|
data = {}
|
||||||
|
for model_name, checkpoint_list in self.checkpoints.items():
|
||||||
|
data[model_name] = [cp.to_dict() for cp in checkpoint_list]
|
||||||
|
|
||||||
|
with open(self.metadata_file, 'w') as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving checkpoint metadata: {e}")
|
||||||
|
|
||||||
|
def get_checkpoint_stats(self):
|
||||||
|
"""Get statistics about managed checkpoints"""
|
||||||
|
stats = {
|
||||||
|
'total_models': len(self.checkpoints),
|
||||||
|
'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()),
|
||||||
|
'total_size_mb': 0.0,
|
||||||
|
'models': {}
|
||||||
|
}
|
||||||
|
|
||||||
|
for model_name, checkpoint_list in self.checkpoints.items():
|
||||||
|
if not checkpoint_list:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_size = sum(cp.file_size_mb for cp in checkpoint_list)
|
||||||
|
best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score)
|
||||||
|
|
||||||
|
stats['models'][model_name] = {
|
||||||
|
'checkpoint_count': len(checkpoint_list),
|
||||||
|
'total_size_mb': model_size,
|
||||||
|
'best_performance': best_checkpoint.performance_score,
|
||||||
|
'best_checkpoint_id': best_checkpoint.checkpoint_id,
|
||||||
|
'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id
|
||||||
|
}
|
||||||
|
|
||||||
|
stats['total_size_mb'] += model_size
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
_checkpoint_manager = None
|
||||||
|
|
||||||
|
def get_checkpoint_manager() -> CheckpointManager:
|
||||||
|
global _checkpoint_manager
|
||||||
|
if _checkpoint_manager is None:
|
||||||
|
_checkpoint_manager = CheckpointManager()
|
||||||
|
return _checkpoint_manager
|
||||||
|
|
||||||
|
def save_checkpoint(model, model_name: str, model_type: str,
|
||||||
|
performance_metrics: Dict[str, float],
|
||||||
|
training_metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
force_save: bool = False) -> Optional[CheckpointMetadata]:
|
||||||
|
return get_checkpoint_manager().save_checkpoint(
|
||||||
|
model, model_name, model_type, performance_metrics, training_metadata, force_save
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
||||||
|
return get_checkpoint_manager().load_best_checkpoint(model_name)
|
204
utils/training_integration.py
Normal file
204
utils/training_integration.py
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Training Integration for Checkpoint Management
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_best_checkpoint
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class TrainingIntegration:
|
||||||
|
def __init__(self, enable_wandb: bool = True):
|
||||||
|
self.checkpoint_manager = get_checkpoint_manager()
|
||||||
|
self.enable_wandb = enable_wandb
|
||||||
|
|
||||||
|
if self.enable_wandb:
|
||||||
|
self._init_wandb()
|
||||||
|
|
||||||
|
def _init_wandb(self):
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
if wandb.run is None:
|
||||||
|
wandb.init(
|
||||||
|
project="gogo2-trading",
|
||||||
|
name=f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
||||||
|
config={
|
||||||
|
"max_checkpoints_per_model": self.checkpoint_manager.max_checkpoints,
|
||||||
|
"checkpoint_dir": str(self.checkpoint_manager.base_dir)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info(f"Initialized W&B run: {wandb.run.id}")
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("W&B not available - checkpoint management will work without it")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error initializing W&B: {e}")
|
||||||
|
|
||||||
|
def save_cnn_checkpoint(self,
|
||||||
|
cnn_model,
|
||||||
|
model_name: str,
|
||||||
|
epoch: int,
|
||||||
|
train_accuracy: float,
|
||||||
|
val_accuracy: float,
|
||||||
|
train_loss: float,
|
||||||
|
val_loss: float,
|
||||||
|
training_time_hours: float = None) -> bool:
|
||||||
|
try:
|
||||||
|
performance_metrics = {
|
||||||
|
'accuracy': train_accuracy,
|
||||||
|
'val_accuracy': val_accuracy,
|
||||||
|
'loss': train_loss,
|
||||||
|
'val_loss': val_loss
|
||||||
|
}
|
||||||
|
|
||||||
|
training_metadata = {
|
||||||
|
'epoch': epoch,
|
||||||
|
'training_time_hours': training_time_hours,
|
||||||
|
'total_parameters': self._count_parameters(cnn_model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.enable_wandb:
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
if wandb.run is not None:
|
||||||
|
wandb.log({
|
||||||
|
f"{model_name}/train_accuracy": train_accuracy,
|
||||||
|
f"{model_name}/val_accuracy": val_accuracy,
|
||||||
|
f"{model_name}/train_loss": train_loss,
|
||||||
|
f"{model_name}/val_loss": val_loss,
|
||||||
|
f"{model_name}/epoch": epoch
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error logging to W&B: {e}")
|
||||||
|
|
||||||
|
metadata = save_checkpoint(
|
||||||
|
model=cnn_model,
|
||||||
|
model_name=model_name,
|
||||||
|
model_type='cnn',
|
||||||
|
performance_metrics=performance_metrics,
|
||||||
|
training_metadata=training_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
logger.info(f"CNN checkpoint saved: {metadata.checkpoint_id}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.info(f"CNN checkpoint not saved (performance not improved)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving CNN checkpoint: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def save_rl_checkpoint(self,
|
||||||
|
rl_agent,
|
||||||
|
model_name: str,
|
||||||
|
episode: int,
|
||||||
|
avg_reward: float,
|
||||||
|
best_reward: float,
|
||||||
|
epsilon: float,
|
||||||
|
total_pnl: float = None) -> bool:
|
||||||
|
try:
|
||||||
|
performance_metrics = {
|
||||||
|
'reward': avg_reward,
|
||||||
|
'best_reward': best_reward
|
||||||
|
}
|
||||||
|
|
||||||
|
if total_pnl is not None:
|
||||||
|
performance_metrics['pnl'] = total_pnl
|
||||||
|
|
||||||
|
training_metadata = {
|
||||||
|
'episode': episode,
|
||||||
|
'epsilon': epsilon,
|
||||||
|
'total_parameters': self._count_parameters(rl_agent)
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.enable_wandb:
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
if wandb.run is not None:
|
||||||
|
wandb.log({
|
||||||
|
f"{model_name}/avg_reward": avg_reward,
|
||||||
|
f"{model_name}/best_reward": best_reward,
|
||||||
|
f"{model_name}/epsilon": epsilon,
|
||||||
|
f"{model_name}/episode": episode
|
||||||
|
})
|
||||||
|
|
||||||
|
if total_pnl is not None:
|
||||||
|
wandb.log({f"{model_name}/total_pnl": total_pnl})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error logging to W&B: {e}")
|
||||||
|
|
||||||
|
metadata = save_checkpoint(
|
||||||
|
model=rl_agent,
|
||||||
|
model_name=model_name,
|
||||||
|
model_type='rl',
|
||||||
|
performance_metrics=performance_metrics,
|
||||||
|
training_metadata=training_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
logger.info(f"RL checkpoint saved: {metadata.checkpoint_id}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.info(f"RL checkpoint not saved (performance not improved)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving RL checkpoint: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def load_best_model(self, model_name: str, model_class=None):
|
||||||
|
try:
|
||||||
|
result = load_best_checkpoint(model_name)
|
||||||
|
if not result:
|
||||||
|
logger.warning(f"No checkpoint found for model: {model_name}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
file_path, metadata = result
|
||||||
|
|
||||||
|
checkpoint = torch.load(file_path, map_location='cpu')
|
||||||
|
|
||||||
|
logger.info(f"Loaded best checkpoint for {model_name}:")
|
||||||
|
logger.info(f" Performance score: {metadata.performance_score:.4f}")
|
||||||
|
logger.info(f" Created: {metadata.created_at}")
|
||||||
|
|
||||||
|
if model_class and 'model_state_dict' in checkpoint:
|
||||||
|
model = model_class()
|
||||||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
return model
|
||||||
|
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading best model {model_name}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _count_parameters(self, model) -> int:
|
||||||
|
try:
|
||||||
|
if hasattr(model, 'parameters'):
|
||||||
|
return sum(p.numel() for p in model.parameters())
|
||||||
|
elif hasattr(model, 'policy_net'):
|
||||||
|
policy_params = sum(p.numel() for p in model.policy_net.parameters())
|
||||||
|
target_params = sum(p.numel() for p in model.target_net.parameters()) if hasattr(model, 'target_net') else 0
|
||||||
|
return policy_params + target_params
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
_training_integration = None
|
||||||
|
|
||||||
|
def get_training_integration() -> TrainingIntegration:
|
||||||
|
global _training_integration
|
||||||
|
if _training_integration is None:
|
||||||
|
_training_integration = TrainingIntegration()
|
||||||
|
return _training_integration
|
Reference in New Issue
Block a user