model checkpoint manager

This commit is contained in:
Dobromir Popov
2025-09-08 13:31:11 +03:00
parent 060fdd28b4
commit c9fba56622
6 changed files with 838 additions and 142 deletions

View File

@@ -21,6 +21,7 @@ from typing import Dict, Any, Optional, Tuple
# Import checkpoint management
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
from utils.model_registry import get_model_registry
# Configure logging
logger = logging.getLogger(__name__)
@@ -774,9 +775,13 @@ class CNNModelTrainer:
# Return realistic loss values based on random baseline performance
return {'main_loss': 0.693, 'total_loss': 0.693, 'accuracy': 0.5} # ln(2) for binary cross-entropy at random chance
def save_model(self, filepath: str, metadata: Optional[Dict] = None):
"""Save model with metadata"""
save_dict = {
def save_model(self, filepath: str = None, metadata: Optional[Dict] = None):
"""Save model with metadata using unified registry"""
try:
from utils.model_registry import save_model
# Prepare model data
model_data = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
@@ -790,13 +795,70 @@ class CNNModelTrainer:
}
if metadata:
save_dict['metadata'] = metadata
model_data['metadata'] = metadata
torch.save(save_dict, filepath)
logger.info(f"Enhanced CNN model saved to {filepath}")
# Use unified registry if no filepath specified
if filepath is None or filepath.startswith('models/'):
# Extract model name from filepath or use default
model_name = "enhanced_cnn"
if filepath:
model_name = filepath.split('/')[-1].replace('_latest.pt', '').replace('.pt', '')
def load_model(self, filepath: str) -> Dict:
"""Load model from file"""
success = save_model(
model=self.model,
model_name=model_name,
model_type='cnn',
metadata={'full_checkpoint': model_data}
)
if success:
logger.info(f"Enhanced CNN model saved to unified registry: {model_name}")
return success
else:
# Legacy direct file save
torch.save(model_data, filepath)
logger.info(f"Enhanced CNN model saved to {filepath} (legacy mode)")
return True
except Exception as e:
logger.error(f"Failed to save CNN model: {e}")
return False
def load_model(self, filepath: str = None) -> Dict:
"""Load model from unified registry or file"""
try:
from utils.model_registry import load_model
# Use unified registry if no filepath or if it's a models/ path
if filepath is None or filepath.startswith('models/'):
model_name = "enhanced_cnn"
if filepath:
model_name = filepath.split('/')[-1].replace('_latest.pt', '').replace('.pt', '')
model = load_model(model_name, 'cnn')
if model is None:
logger.warning(f"Could not load model {model_name} from unified registry")
return {}
# Load full checkpoint data from metadata
registry = get_model_registry()
if model_name in registry.metadata['models']:
model_data = registry.metadata['models'][model_name]
if 'full_checkpoint' in model_data:
checkpoint = model_data['full_checkpoint']
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
if 'training_history' in checkpoint:
self.training_history = checkpoint['training_history']
logger.info(f"Enhanced CNN model loaded from unified registry: {model_name}")
return checkpoint.get('metadata', {})
return {}
else:
# Legacy direct file load
checkpoint = torch.load(filepath, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
@@ -808,9 +870,13 @@ class CNNModelTrainer:
if 'training_history' in checkpoint:
self.training_history = checkpoint['training_history']
logger.info(f"Enhanced CNN model loaded from {filepath}")
logger.info(f"Enhanced CNN model loaded from {filepath} (legacy mode)")
return checkpoint.get('metadata', {})
except Exception as e:
logger.error(f"Failed to load CNN model: {e}")
return {}
def create_enhanced_cnn_model(input_size: int = 60,
feature_dim: int = 50,
output_size: int = 2,

View File

@@ -16,6 +16,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(
# Import checkpoint management
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
from utils.model_registry import get_model_registry
# Configure logger
logger = logging.getLogger(__name__)
@@ -1329,8 +1330,42 @@ class DQNAgent:
return False # No improvement
def save(self, path: str):
"""Save model and agent state"""
def save(self, path: str = None):
"""Save model and agent state using unified registry"""
try:
from utils.model_registry import save_model
# Use unified registry if no path or if it's a models/ path
if path is None or path.startswith('models/'):
model_name = "dqn_agent"
if path:
model_name = path.split('/')[-1].replace('_agent_state', '').replace('.pt', '')
# Prepare full agent state
agent_state = {
'epsilon': self.epsilon,
'update_count': self.update_count,
'losses': self.losses,
'optimizer_state': self.optimizer.state_dict(),
'best_reward': self.best_reward,
'avg_reward': self.avg_reward,
'policy_net_state': self.policy_net.state_dict(),
'target_net_state': self.target_net.state_dict()
}
success = save_model(
model=self.policy_net, # Save policy net as main model
model_name=model_name,
model_type='dqn',
metadata={'full_agent_state': agent_state}
)
if success:
logger.info(f"DQN agent saved to unified registry: {model_name}")
return
else:
# Legacy direct file save
os.makedirs(os.path.dirname(path), exist_ok=True)
# Save policy network
@@ -1350,10 +1385,59 @@ class DQNAgent:
}
torch.save(state, f"{path}_agent_state.pt")
logger.info(f"Agent state saved to {path}_agent_state.pt")
logger.info(f"Agent state saved to {path}_agent_state.pt (legacy mode)")
def load(self, path: str):
"""Load model and agent state"""
except Exception as e:
logger.error(f"Failed to save DQN agent: {e}")
def load(self, path: str = None):
"""Load model and agent state from unified registry or file"""
try:
from utils.model_registry import load_model
# Use unified registry if no path or if it's a models/ path
if path is None or path.startswith('models/'):
model_name = "dqn_agent"
if path:
model_name = path.split('/')[-1].replace('_agent_state', '').replace('.pt', '')
model = load_model(model_name, 'dqn')
if model is None:
logger.warning(f"Could not load DQN agent {model_name} from unified registry")
return
# Load full agent state from metadata
registry = get_model_registry()
if model_name in registry.metadata['models']:
model_data = registry.metadata['models'][model_name]
if 'full_agent_state' in model_data:
agent_state = model_data['full_agent_state']
# Restore agent state
self.epsilon = agent_state['epsilon']
self.update_count = agent_state['update_count']
self.losses = agent_state['losses']
self.optimizer.load_state_dict(agent_state['optimizer_state'])
# Load additional metrics if they exist
if 'best_reward' in agent_state:
self.best_reward = agent_state['best_reward']
if 'avg_reward' in agent_state:
self.avg_reward = agent_state['avg_reward']
# Load network states
if 'policy_net_state' in agent_state:
self.policy_net.load_state_dict(agent_state['policy_net_state'])
if 'target_net_state' in agent_state:
self.target_net.load_state_dict(agent_state['target_net_state'])
logger.info(f"DQN agent loaded from unified registry: {model_name}")
return
return
else:
# Legacy direct file load
# Load policy network
self.policy_net.load(f"{path}_policy")
@@ -1374,10 +1458,13 @@ class DQNAgent:
if 'avg_reward' in agent_state:
self.avg_reward = agent_state['avg_reward']
logger.info(f"Agent state loaded from {path}_agent_state.pt")
logger.info(f"Agent state loaded from {path}_agent_state.pt (legacy mode)")
except FileNotFoundError:
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
except Exception as e:
logger.error(f"Failed to load DQN agent: {e}")
def get_position_info(self):
"""Get current position information"""
return {

View File

@@ -16,6 +16,9 @@ import random
WANDB_AVAILABLE = False
# Import model registry
from utils.model_registry import get_model_registry
logger = logging.getLogger(__name__)
@dataclass
@@ -72,15 +75,9 @@ class CheckpointManager:
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
"""Save a model checkpoint with improved error handling and validation"""
"""Save a model checkpoint with improved error handling and validation using unified registry"""
try:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_id = f"{model_name}_{timestamp}"
model_dir = self.base_dir / model_name
model_dir.mkdir(exist_ok=True)
checkpoint_path = model_dir / f"{checkpoint_id}.pt"
from utils.model_registry import save_checkpoint as registry_save_checkpoint
performance_score = self._calculate_performance_score(performance_metrics)
@@ -88,19 +85,34 @@ class CheckpointManager:
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved")
return None
success = self._save_model_file(model, checkpoint_path, model_type)
# Use unified registry for checkpointing
success = registry_save_checkpoint(
model=model,
model_name=model_name,
model_type=model_type,
performance_score=performance_score,
metadata={
'performance_metrics': performance_metrics,
'training_metadata': training_metadata,
'checkpoint_manager': True
}
)
if not success:
return None
file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024)
# Get checkpoint info from registry
registry = get_model_registry()
checkpoint_info = registry.metadata['models'][model_name]['checkpoints'][-1]
# Create CheckpointMetadata object
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
checkpoint_id=checkpoint_info['id'],
model_name=model_name,
model_type=model_type,
file_path=str(checkpoint_path),
created_at=datetime.now(),
file_size_mb=file_size_mb,
file_path=checkpoint_info['path'],
created_at=datetime.fromisoformat(checkpoint_info['timestamp']),
file_size_mb=0.0, # Will be calculated by registry
performance_score=performance_score,
accuracy=performance_metrics.get('accuracy'),
loss=performance_metrics.get('loss'),
@@ -113,8 +125,7 @@ class CheckpointManager:
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
)
# W&B disabled
# Update local checkpoint tracking
self.checkpoints[model_name].append(metadata)
self._rotate_checkpoints(model_name)
self._save_metadata()
@@ -128,7 +139,35 @@ class CheckpointManager:
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
try:
# First, try the standard checkpoint system
from utils.model_registry import load_best_checkpoint as registry_load_checkpoint
# First, try the unified registry
registry_result = registry_load_checkpoint(model_name, 'cnn') # Try CNN type first
if registry_result is None:
registry_result = registry_load_checkpoint(model_name, 'dqn') # Try DQN type
if registry_result:
checkpoint_path, checkpoint_data = registry_result
# Create CheckpointMetadata from registry data
metadata = CheckpointMetadata(
checkpoint_id=f"{model_name}_registry",
model_name=model_name,
model_type=checkpoint_data.get('model_type', 'unknown'),
file_path=checkpoint_path,
created_at=datetime.fromisoformat(checkpoint_data.get('timestamp', datetime.now().isoformat())),
file_size_mb=0.0, # Will be calculated by registry
performance_score=checkpoint_data.get('performance_score', 0.0),
accuracy=checkpoint_data.get('accuracy'),
loss=checkpoint_data.get('loss'),
reward=checkpoint_data.get('reward'),
pnl=checkpoint_data.get('pnl')
)
logger.debug(f"Loading checkpoint from unified registry for {model_name}")
return checkpoint_path, metadata
# Fallback: Try the standard checkpoint system
if model_name in self.checkpoints and self.checkpoints[model_name]:
# Filter out checkpoints with non-existent files
valid_checkpoints = [

446
utils/model_registry.py Normal file
View File

@@ -0,0 +1,446 @@
#!/usr/bin/env python3
"""
Unified Model Registry for Centralized Model Management
This module provides a unified interface for saving, loading, and managing
all machine learning models in the trading system. It consolidates model
storage from multiple locations into a single, organized structure.
"""
import os
import json
import torch
import logging
import pickle
from pathlib import Path
from typing import Dict, Any, Optional, Tuple, List
from datetime import datetime
import hashlib
logger = logging.getLogger(__name__)
class ModelRegistry:
"""
Unified model registry for centralized model management.
Handles saving, loading, and organization of all ML models.
"""
def __init__(self, base_dir: str = "models"):
"""
Initialize the model registry.
Args:
base_dir: Base directory for model storage
"""
self.base_dir = Path(base_dir)
self.saved_dir = self.base_dir / "saved"
self.checkpoint_dir = self.base_dir / "checkpoints"
self.archive_dir = self.base_dir / "archive"
# Model type directories
self.model_dirs = {
'cnn': self.base_dir / "cnn",
'dqn': self.base_dir / "dqn",
'transformer': self.base_dir / "transformer",
'hybrid': self.base_dir / "hybrid"
}
# Ensure all directories exist
self._ensure_directories()
# Metadata tracking
self.metadata_file = self.base_dir / "registry_metadata.json"
self.metadata = self._load_metadata()
logger.info(f"Model Registry initialized at {self.base_dir}")
def _ensure_directories(self):
"""Ensure all required directories exist."""
directories = [
self.saved_dir,
self.checkpoint_dir,
self.archive_dir
]
# Add model type directories
for model_dir in self.model_dirs.values():
directories.extend([
model_dir / "saved",
model_dir / "checkpoints",
model_dir / "archive"
])
for directory in directories:
directory.mkdir(parents=True, exist_ok=True)
def _load_metadata(self) -> Dict[str, Any]:
"""Load registry metadata."""
if self.metadata_file.exists():
try:
with open(self.metadata_file, 'r') as f:
return json.load(f)
except Exception as e:
logger.warning(f"Failed to load metadata: {e}")
return {'models': {}, 'last_updated': datetime.now().isoformat()}
def _save_metadata(self):
"""Save registry metadata."""
self.metadata['last_updated'] = datetime.now().isoformat()
try:
with open(self.metadata_file, 'w') as f:
json.dump(self.metadata, f, indent=2)
except Exception as e:
logger.error(f"Failed to save metadata: {e}")
def save_model(self, model: Any, model_name: str, model_type: str = 'cnn',
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Save a model to the unified storage.
Args:
model: The model to save
model_name: Name of the model
model_type: Type of model (cnn, dqn, transformer, hybrid)
metadata: Additional metadata to save
Returns:
bool: True if successful, False otherwise
"""
try:
model_dir = self.model_dirs.get(model_type, self.saved_dir)
save_dir = model_dir / "saved"
# Generate filename with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f"{model_name}_{timestamp}.pt"
filepath = save_dir / filename
# Also save as latest
latest_filepath = save_dir / f"{model_name}_latest.pt"
# Save model
save_dict = {
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
'model_class': model.__class__.__name__,
'model_type': model_type,
'timestamp': timestamp,
'metadata': metadata or {}
}
torch.save(save_dict, filepath)
torch.save(save_dict, latest_filepath)
# Update metadata
if model_name not in self.metadata['models']:
self.metadata['models'][model_name] = {}
self.metadata['models'][model_name].update({
'type': model_type,
'latest_path': str(latest_filepath),
'last_saved': timestamp,
'save_count': self.metadata['models'][model_name].get('save_count', 0) + 1
})
self._save_metadata()
logger.info(f"Model {model_name} saved to {filepath}")
return True
except Exception as e:
logger.error(f"Failed to save model {model_name}: {e}")
return False
def load_model(self, model_name: str, model_type: str = 'cnn',
model_class: Optional[Any] = None) -> Optional[Any]:
"""
Load a model from the unified storage.
Args:
model_name: Name of the model to load
model_type: Type of model (cnn, dqn, transformer, hybrid)
model_class: Model class to instantiate (if needed)
Returns:
The loaded model or None if failed
"""
try:
model_dir = self.model_dirs.get(model_type, self.saved_dir)
save_dir = model_dir / "saved"
latest_filepath = save_dir / f"{model_name}_latest.pt"
if not latest_filepath.exists():
logger.warning(f"Model {model_name} not found at {latest_filepath}")
return None
# Load checkpoint
checkpoint = torch.load(latest_filepath, map_location='cpu')
# Instantiate model if class provided
if model_class is not None:
model = model_class()
model.load_state_dict(checkpoint['model_state_dict'])
else:
# Try to reconstruct model from state_dict
model = type('LoadedModel', (), {})()
model.state_dict = lambda: checkpoint['model_state_dict']
model.load_state_dict = lambda state_dict: None
logger.info(f"Model {model_name} loaded from {latest_filepath}")
return model
except Exception as e:
logger.error(f"Failed to load model {model_name}: {e}")
return None
def save_checkpoint(self, model: Any, model_name: str, model_type: str = 'cnn',
performance_score: float = 0.0,
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Save a model checkpoint.
Args:
model: The model to checkpoint
model_name: Name of the model
model_type: Type of model
performance_score: Performance score for this checkpoint
metadata: Additional metadata
Returns:
bool: True if successful, False otherwise
"""
try:
model_dir = self.model_dirs.get(model_type, self.checkpoint_dir)
checkpoint_dir = model_dir / "checkpoints"
# Generate checkpoint ID
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_id = f"{model_name}_{timestamp}_{performance_score:.4f}"
filepath = checkpoint_dir / f"{checkpoint_id}.pt"
# Save checkpoint
checkpoint_data = {
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
'model_class': model.__class__.__name__,
'model_type': model_type,
'model_name': model_name,
'performance_score': performance_score,
'timestamp': timestamp,
'metadata': metadata or {}
}
torch.save(checkpoint_data, filepath)
# Update metadata
if model_name not in self.metadata['models']:
self.metadata['models'][model_name] = {}
if 'checkpoints' not in self.metadata['models'][model_name]:
self.metadata['models'][model_name]['checkpoints'] = []
checkpoint_info = {
'id': checkpoint_id,
'path': str(filepath),
'performance_score': performance_score,
'timestamp': timestamp
}
self.metadata['models'][model_name]['checkpoints'].append(checkpoint_info)
# Keep only top 5 checkpoints
checkpoints = self.metadata['models'][model_name]['checkpoints']
if len(checkpoints) > 5:
checkpoints.sort(key=lambda x: x['performance_score'], reverse=True)
checkpoints_to_remove = checkpoints[5:]
for checkpoint in checkpoints_to_remove:
try:
os.remove(checkpoint['path'])
except:
pass
self.metadata['models'][model_name]['checkpoints'] = checkpoints[:5]
self._save_metadata()
logger.info(f"Checkpoint {checkpoint_id} saved with score {performance_score}")
return True
except Exception as e:
logger.error(f"Failed to save checkpoint for {model_name}: {e}")
return False
def load_best_checkpoint(self, model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]:
"""
Load the best checkpoint for a model.
Args:
model_name: Name of the model
model_type: Type of model
Returns:
Tuple of (checkpoint_path, checkpoint_data) or None
"""
try:
if model_name not in self.metadata['models']:
logger.warning(f"No metadata found for model {model_name}")
return None
checkpoints = self.metadata['models'][model_name].get('checkpoints', [])
if not checkpoints:
logger.warning(f"No checkpoints found for model {model_name}")
return None
# Find best checkpoint by performance score
best_checkpoint = max(checkpoints, key=lambda x: x['performance_score'])
checkpoint_path = best_checkpoint['path']
if not os.path.exists(checkpoint_path):
logger.warning(f"Checkpoint file not found: {checkpoint_path}")
return None
checkpoint_data = torch.load(checkpoint_path, map_location='cpu')
logger.info(f"Best checkpoint loaded for {model_name}: {best_checkpoint['id']}")
return checkpoint_path, checkpoint_data
except Exception as e:
logger.error(f"Failed to load best checkpoint for {model_name}: {e}")
return None
def archive_model(self, model_name: str, model_type: str = 'cnn') -> bool:
"""
Archive a model by moving it to archive directory.
Args:
model_name: Name of the model to archive
model_type: Type of model
Returns:
bool: True if successful, False otherwise
"""
try:
model_dir = self.model_dirs.get(model_type, self.saved_dir)
save_dir = model_dir / "saved"
archive_dir = model_dir / "archive"
latest_filepath = save_dir / f"{model_name}_latest.pt"
if not latest_filepath.exists():
logger.warning(f"Model {model_name} not found to archive")
return False
# Move to archive with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
archive_filepath = archive_dir / f"{model_name}_archived_{timestamp}.pt"
os.rename(latest_filepath, archive_filepath)
logger.info(f"Model {model_name} archived to {archive_filepath}")
return True
except Exception as e:
logger.error(f"Failed to archive model {model_name}: {e}")
return False
def list_models(self, model_type: Optional[str] = None) -> Dict[str, Any]:
"""
List all models in the registry.
Args:
model_type: Filter by model type (optional)
Returns:
Dictionary of model information
"""
models_info = {}
for model_name, model_data in self.metadata['models'].items():
if model_type and model_data.get('type') != model_type:
continue
models_info[model_name] = {
'type': model_data.get('type'),
'last_saved': model_data.get('last_saved'),
'save_count': model_data.get('save_count', 0),
'checkpoint_count': len(model_data.get('checkpoints', [])),
'latest_path': model_data.get('latest_path')
}
return models_info
def cleanup_old_checkpoints(self, model_name: str, keep_count: int = 5) -> int:
"""
Clean up old checkpoints, keeping only the best ones.
Args:
model_name: Name of the model
keep_count: Number of checkpoints to keep
Returns:
Number of checkpoints removed
"""
if model_name not in self.metadata['models']:
return 0
checkpoints = self.metadata['models'][model_name].get('checkpoints', [])
if len(checkpoints) <= keep_count:
return 0
# Sort by performance score (descending)
checkpoints.sort(key=lambda x: x['performance_score'], reverse=True)
# Remove old checkpoints
removed_count = 0
for checkpoint in checkpoints[keep_count:]:
try:
os.remove(checkpoint['path'])
removed_count += 1
except:
pass
# Update metadata
self.metadata['models'][model_name]['checkpoints'] = checkpoints[:keep_count]
self._save_metadata()
logger.info(f"Cleaned up {removed_count} old checkpoints for {model_name}")
return removed_count
# Global registry instance
_registry_instance = None
def get_model_registry() -> ModelRegistry:
"""Get the global model registry instance."""
global _registry_instance
if _registry_instance is None:
_registry_instance = ModelRegistry()
return _registry_instance
def save_model(model: Any, model_name: str, model_type: str = 'cnn',
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Convenience function to save a model using the global registry.
"""
return get_model_registry().save_model(model, model_name, model_type, metadata)
def load_model(model_name: str, model_type: str = 'cnn',
model_class: Optional[Any] = None) -> Optional[Any]:
"""
Convenience function to load a model using the global registry.
"""
return get_model_registry().load_model(model_name, model_type, model_class)
def save_checkpoint(model: Any, model_name: str, model_type: str = 'cnn',
performance_score: float = 0.0,
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Convenience function to save a checkpoint using the global registry.
"""
return get_model_registry().save_checkpoint(model, model_name, model_type, performance_score, metadata)
def load_best_checkpoint(model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]:
"""
Convenience function to load the best checkpoint using the global registry.
"""
return get_model_registry().load_best_checkpoint(model_name, model_type)

View File

@@ -4710,39 +4710,64 @@ class CleanTradingDashboard:
stored_models = []
# Use unified model registry for saving
from utils.model_registry import save_model
# 1. Store DQN model
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
try:
if hasattr(self.orchestrator.rl_agent, 'save'):
save_path = self.orchestrator.rl_agent.save('models/saved/dqn_agent_session')
stored_models.append(('DQN', save_path))
logger.info(f"Stored DQN model: {save_path}")
success = save_model(
model=self.orchestrator.rl_agent.policy_net, # Save policy network
model_name='dqn_agent_session',
model_type='dqn',
metadata={'session_save': True, 'dashboard_save': True}
)
if success:
stored_models.append(('DQN', 'models/dqn/saved/dqn_agent_session_latest.pt'))
logger.info("Stored DQN model via unified registry")
else:
logger.warning("Failed to store DQN model via unified registry")
except Exception as e:
logger.warning(f"Failed to store DQN model: {e}")
# 2. Store CNN model
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
try:
if hasattr(self.orchestrator.cnn_model, 'save'):
save_path = self.orchestrator.cnn_model.save('models/saved/cnn_model_session')
stored_models.append(('CNN', save_path))
logger.info(f"Stored CNN model: {save_path}")
success = save_model(
model=self.orchestrator.cnn_model,
model_name='cnn_model_session',
model_type='cnn',
metadata={'session_save': True, 'dashboard_save': True}
)
if success:
stored_models.append(('CNN', 'models/cnn/saved/cnn_model_session_latest.pt'))
logger.info("Stored CNN model via unified registry")
else:
logger.warning("Failed to store CNN model via unified registry")
except Exception as e:
logger.warning(f"Failed to store CNN model: {e}")
# 3. Store Transformer model
if hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer:
try:
if hasattr(self.orchestrator.primary_transformer, 'save'):
save_path = self.orchestrator.primary_transformer.save('models/saved/transformer_model_session')
stored_models.append(('Transformer', save_path))
logger.info(f"Stored Transformer model: {save_path}")
success = save_model(
model=self.orchestrator.primary_transformer,
model_name='transformer_model_session',
model_type='transformer',
metadata={'session_save': True, 'dashboard_save': True}
)
if success:
stored_models.append(('Transformer', 'models/transformer/saved/transformer_model_session_latest.pt'))
logger.info("Stored Transformer model via unified registry")
else:
logger.warning("Failed to store Transformer model via unified registry")
except Exception as e:
logger.warning(f"Failed to store Transformer model: {e}")
# 4. Store COB RL model
# 4. Store COB RL model (if exists)
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
try:
# COB RL model might have different save method
if hasattr(self.orchestrator.cob_rl_agent, 'save'):
save_path = self.orchestrator.cob_rl_agent.save('models/saved/cob_rl_agent_session')
stored_models.append(('COB RL', save_path))
@@ -4750,13 +4775,20 @@ class CleanTradingDashboard:
except Exception as e:
logger.warning(f"Failed to store COB RL model: {e}")
# 5. Store Decision Fusion model
# 5. Store Decision model
if hasattr(self.orchestrator, 'decision_model') and self.orchestrator.decision_model:
try:
if hasattr(self.orchestrator.decision_model, 'save'):
save_path = self.orchestrator.decision_model.save('models/saved/decision_fusion_session')
stored_models.append(('Decision Fusion', save_path))
logger.info(f"Stored Decision Fusion model: {save_path}")
success = save_model(
model=self.orchestrator.decision_model,
model_name='decision_fusion_session',
model_type='hybrid',
metadata={'session_save': True, 'dashboard_save': True}
)
if success:
stored_models.append(('Decision Fusion', 'models/hybrid/saved/decision_fusion_session_latest.pt'))
logger.info("Stored Decision Fusion model via unified registry")
else:
logger.warning("Failed to store Decision Fusion model via unified registry")
except Exception as e:
logger.warning(f"Failed to store Decision Fusion model: {e}")
@@ -6706,6 +6738,32 @@ class CleanTradingDashboard:
except Exception as e:
logger.error(f"Error saving transformer checkpoint: {e}")
# Use unified registry for checkpoint
try:
from utils.model_registry import save_checkpoint as registry_save_checkpoint
checkpoint_data = torch.load(checkpoint_path, map_location='cpu') if 'checkpoint_path' in locals() else checkpoint_data
success = registry_save_checkpoint(
model=checkpoint_data,
model_name='transformer',
model_type='transformer',
performance_score=training_metrics['accuracy'],
metadata={
'training_samples': len(training_samples),
'loss': training_metrics['total_loss'],
'accuracy': training_metrics['accuracy'],
'checkpoint_source': 'dashboard_training'
}
)
if success:
logger.info("TRANSFORMER: Checkpoint saved via unified registry")
else:
logger.warning("TRANSFORMER: Failed to save checkpoint via unified registry")
except Exception as registry_error:
logger.warning(f"Unified registry save failed: {registry_error}")
# Fallback to direct save
try:
checkpoint_path = f"NN/models/saved/transformer_checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"