model checkpoint manager
This commit is contained in:
@@ -21,6 +21,7 @@ from typing import Dict, Any, Optional, Tuple
|
|||||||
|
|
||||||
# Import checkpoint management
|
# Import checkpoint management
|
||||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||||
|
from utils.model_registry import get_model_registry
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -774,42 +775,107 @@ class CNNModelTrainer:
|
|||||||
# Return realistic loss values based on random baseline performance
|
# Return realistic loss values based on random baseline performance
|
||||||
return {'main_loss': 0.693, 'total_loss': 0.693, 'accuracy': 0.5} # ln(2) for binary cross-entropy at random chance
|
return {'main_loss': 0.693, 'total_loss': 0.693, 'accuracy': 0.5} # ln(2) for binary cross-entropy at random chance
|
||||||
|
|
||||||
def save_model(self, filepath: str, metadata: Optional[Dict] = None):
|
def save_model(self, filepath: str = None, metadata: Optional[Dict] = None):
|
||||||
"""Save model with metadata"""
|
"""Save model with metadata using unified registry"""
|
||||||
save_dict = {
|
try:
|
||||||
'model_state_dict': self.model.state_dict(),
|
from utils.model_registry import save_model
|
||||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
||||||
'scheduler_state_dict': self.scheduler.state_dict(),
|
# Prepare model data
|
||||||
'training_history': self.training_history,
|
model_data = {
|
||||||
'model_config': {
|
'model_state_dict': self.model.state_dict(),
|
||||||
'input_size': self.model.input_size,
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||||
'feature_dim': self.model.feature_dim,
|
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||||
'output_size': self.model.output_size,
|
'training_history': self.training_history,
|
||||||
'base_channels': self.model.base_channels
|
'model_config': {
|
||||||
|
'input_size': self.model.input_size,
|
||||||
|
'feature_dim': self.model.feature_dim,
|
||||||
|
'output_size': self.model.output_size,
|
||||||
|
'base_channels': self.model.base_channels
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
save_dict['metadata'] = metadata
|
model_data['metadata'] = metadata
|
||||||
|
|
||||||
torch.save(save_dict, filepath)
|
# Use unified registry if no filepath specified
|
||||||
logger.info(f"Enhanced CNN model saved to {filepath}")
|
if filepath is None or filepath.startswith('models/'):
|
||||||
|
# Extract model name from filepath or use default
|
||||||
|
model_name = "enhanced_cnn"
|
||||||
|
if filepath:
|
||||||
|
model_name = filepath.split('/')[-1].replace('_latest.pt', '').replace('.pt', '')
|
||||||
|
|
||||||
def load_model(self, filepath: str) -> Dict:
|
success = save_model(
|
||||||
"""Load model from file"""
|
model=self.model,
|
||||||
checkpoint = torch.load(filepath, map_location=self.device)
|
model_name=model_name,
|
||||||
|
model_type='cnn',
|
||||||
|
metadata={'full_checkpoint': model_data}
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
logger.info(f"Enhanced CNN model saved to unified registry: {model_name}")
|
||||||
|
return success
|
||||||
|
else:
|
||||||
|
# Legacy direct file save
|
||||||
|
torch.save(model_data, filepath)
|
||||||
|
logger.info(f"Enhanced CNN model saved to {filepath} (legacy mode)")
|
||||||
|
return True
|
||||||
|
|
||||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
except Exception as e:
|
||||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
logger.error(f"Failed to save CNN model: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
if 'scheduler_state_dict' in checkpoint:
|
def load_model(self, filepath: str = None) -> Dict:
|
||||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
"""Load model from unified registry or file"""
|
||||||
|
try:
|
||||||
|
from utils.model_registry import load_model
|
||||||
|
|
||||||
if 'training_history' in checkpoint:
|
# Use unified registry if no filepath or if it's a models/ path
|
||||||
self.training_history = checkpoint['training_history']
|
if filepath is None or filepath.startswith('models/'):
|
||||||
|
model_name = "enhanced_cnn"
|
||||||
|
if filepath:
|
||||||
|
model_name = filepath.split('/')[-1].replace('_latest.pt', '').replace('.pt', '')
|
||||||
|
|
||||||
logger.info(f"Enhanced CNN model loaded from {filepath}")
|
model = load_model(model_name, 'cnn')
|
||||||
return checkpoint.get('metadata', {})
|
if model is None:
|
||||||
|
logger.warning(f"Could not load model {model_name} from unified registry")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Load full checkpoint data from metadata
|
||||||
|
registry = get_model_registry()
|
||||||
|
if model_name in registry.metadata['models']:
|
||||||
|
model_data = registry.metadata['models'][model_name]
|
||||||
|
if 'full_checkpoint' in model_data:
|
||||||
|
checkpoint = model_data['full_checkpoint']
|
||||||
|
|
||||||
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
if 'scheduler_state_dict' in checkpoint:
|
||||||
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||||
|
if 'training_history' in checkpoint:
|
||||||
|
self.training_history = checkpoint['training_history']
|
||||||
|
|
||||||
|
logger.info(f"Enhanced CNN model loaded from unified registry: {model_name}")
|
||||||
|
return checkpoint.get('metadata', {})
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Legacy direct file load
|
||||||
|
checkpoint = torch.load(filepath, map_location=self.device)
|
||||||
|
|
||||||
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
|
||||||
|
if 'scheduler_state_dict' in checkpoint:
|
||||||
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||||
|
|
||||||
|
if 'training_history' in checkpoint:
|
||||||
|
self.training_history = checkpoint['training_history']
|
||||||
|
|
||||||
|
logger.info(f"Enhanced CNN model loaded from {filepath} (legacy mode)")
|
||||||
|
return checkpoint.get('metadata', {})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load CNN model: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
def create_enhanced_cnn_model(input_size: int = 60,
|
def create_enhanced_cnn_model(input_size: int = 60,
|
||||||
feature_dim: int = 50,
|
feature_dim: int = 50,
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(
|
|||||||
|
|
||||||
# Import checkpoint management
|
# Import checkpoint management
|
||||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||||
|
from utils.model_registry import get_model_registry
|
||||||
|
|
||||||
# Configure logger
|
# Configure logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -1329,54 +1330,140 @@ class DQNAgent:
|
|||||||
|
|
||||||
return False # No improvement
|
return False # No improvement
|
||||||
|
|
||||||
def save(self, path: str):
|
def save(self, path: str = None):
|
||||||
"""Save model and agent state"""
|
"""Save model and agent state using unified registry"""
|
||||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
||||||
|
|
||||||
# Save policy network
|
|
||||||
self.policy_net.save(f"{path}_policy")
|
|
||||||
|
|
||||||
# Save target network
|
|
||||||
self.target_net.save(f"{path}_target")
|
|
||||||
|
|
||||||
# Save agent state
|
|
||||||
state = {
|
|
||||||
'epsilon': self.epsilon,
|
|
||||||
'update_count': self.update_count,
|
|
||||||
'losses': self.losses,
|
|
||||||
'optimizer_state': self.optimizer.state_dict(),
|
|
||||||
'best_reward': self.best_reward,
|
|
||||||
'avg_reward': self.avg_reward
|
|
||||||
}
|
|
||||||
|
|
||||||
torch.save(state, f"{path}_agent_state.pt")
|
|
||||||
logger.info(f"Agent state saved to {path}_agent_state.pt")
|
|
||||||
|
|
||||||
def load(self, path: str):
|
|
||||||
"""Load model and agent state"""
|
|
||||||
# Load policy network
|
|
||||||
self.policy_net.load(f"{path}_policy")
|
|
||||||
|
|
||||||
# Load target network
|
|
||||||
self.target_net.load(f"{path}_target")
|
|
||||||
|
|
||||||
# Load agent state
|
|
||||||
try:
|
try:
|
||||||
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device, weights_only=False)
|
from utils.model_registry import save_model
|
||||||
self.epsilon = agent_state['epsilon']
|
|
||||||
self.update_count = agent_state['update_count']
|
|
||||||
self.losses = agent_state['losses']
|
|
||||||
self.optimizer.load_state_dict(agent_state['optimizer_state'])
|
|
||||||
|
|
||||||
# Load additional metrics if they exist
|
# Use unified registry if no path or if it's a models/ path
|
||||||
if 'best_reward' in agent_state:
|
if path is None or path.startswith('models/'):
|
||||||
self.best_reward = agent_state['best_reward']
|
model_name = "dqn_agent"
|
||||||
if 'avg_reward' in agent_state:
|
if path:
|
||||||
self.avg_reward = agent_state['avg_reward']
|
model_name = path.split('/')[-1].replace('_agent_state', '').replace('.pt', '')
|
||||||
|
|
||||||
logger.info(f"Agent state loaded from {path}_agent_state.pt")
|
# Prepare full agent state
|
||||||
except FileNotFoundError:
|
agent_state = {
|
||||||
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
|
'epsilon': self.epsilon,
|
||||||
|
'update_count': self.update_count,
|
||||||
|
'losses': self.losses,
|
||||||
|
'optimizer_state': self.optimizer.state_dict(),
|
||||||
|
'best_reward': self.best_reward,
|
||||||
|
'avg_reward': self.avg_reward,
|
||||||
|
'policy_net_state': self.policy_net.state_dict(),
|
||||||
|
'target_net_state': self.target_net.state_dict()
|
||||||
|
}
|
||||||
|
|
||||||
|
success = save_model(
|
||||||
|
model=self.policy_net, # Save policy net as main model
|
||||||
|
model_name=model_name,
|
||||||
|
model_type='dqn',
|
||||||
|
metadata={'full_agent_state': agent_state}
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info(f"DQN agent saved to unified registry: {model_name}")
|
||||||
|
return
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Legacy direct file save
|
||||||
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
|
||||||
|
# Save policy network
|
||||||
|
self.policy_net.save(f"{path}_policy")
|
||||||
|
|
||||||
|
# Save target network
|
||||||
|
self.target_net.save(f"{path}_target")
|
||||||
|
|
||||||
|
# Save agent state
|
||||||
|
state = {
|
||||||
|
'epsilon': self.epsilon,
|
||||||
|
'update_count': self.update_count,
|
||||||
|
'losses': self.losses,
|
||||||
|
'optimizer_state': self.optimizer.state_dict(),
|
||||||
|
'best_reward': self.best_reward,
|
||||||
|
'avg_reward': self.avg_reward
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(state, f"{path}_agent_state.pt")
|
||||||
|
logger.info(f"Agent state saved to {path}_agent_state.pt (legacy mode)")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save DQN agent: {e}")
|
||||||
|
|
||||||
|
def load(self, path: str = None):
|
||||||
|
"""Load model and agent state from unified registry or file"""
|
||||||
|
try:
|
||||||
|
from utils.model_registry import load_model
|
||||||
|
|
||||||
|
# Use unified registry if no path or if it's a models/ path
|
||||||
|
if path is None or path.startswith('models/'):
|
||||||
|
model_name = "dqn_agent"
|
||||||
|
if path:
|
||||||
|
model_name = path.split('/')[-1].replace('_agent_state', '').replace('.pt', '')
|
||||||
|
|
||||||
|
model = load_model(model_name, 'dqn')
|
||||||
|
if model is None:
|
||||||
|
logger.warning(f"Could not load DQN agent {model_name} from unified registry")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load full agent state from metadata
|
||||||
|
registry = get_model_registry()
|
||||||
|
if model_name in registry.metadata['models']:
|
||||||
|
model_data = registry.metadata['models'][model_name]
|
||||||
|
if 'full_agent_state' in model_data:
|
||||||
|
agent_state = model_data['full_agent_state']
|
||||||
|
|
||||||
|
# Restore agent state
|
||||||
|
self.epsilon = agent_state['epsilon']
|
||||||
|
self.update_count = agent_state['update_count']
|
||||||
|
self.losses = agent_state['losses']
|
||||||
|
self.optimizer.load_state_dict(agent_state['optimizer_state'])
|
||||||
|
|
||||||
|
# Load additional metrics if they exist
|
||||||
|
if 'best_reward' in agent_state:
|
||||||
|
self.best_reward = agent_state['best_reward']
|
||||||
|
if 'avg_reward' in agent_state:
|
||||||
|
self.avg_reward = agent_state['avg_reward']
|
||||||
|
|
||||||
|
# Load network states
|
||||||
|
if 'policy_net_state' in agent_state:
|
||||||
|
self.policy_net.load_state_dict(agent_state['policy_net_state'])
|
||||||
|
if 'target_net_state' in agent_state:
|
||||||
|
self.target_net.load_state_dict(agent_state['target_net_state'])
|
||||||
|
|
||||||
|
logger.info(f"DQN agent loaded from unified registry: {model_name}")
|
||||||
|
return
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Legacy direct file load
|
||||||
|
# Load policy network
|
||||||
|
self.policy_net.load(f"{path}_policy")
|
||||||
|
|
||||||
|
# Load target network
|
||||||
|
self.target_net.load(f"{path}_target")
|
||||||
|
|
||||||
|
# Load agent state
|
||||||
|
try:
|
||||||
|
agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device, weights_only=False)
|
||||||
|
self.epsilon = agent_state['epsilon']
|
||||||
|
self.update_count = agent_state['update_count']
|
||||||
|
self.losses = agent_state['losses']
|
||||||
|
self.optimizer.load_state_dict(agent_state['optimizer_state'])
|
||||||
|
|
||||||
|
# Load additional metrics if they exist
|
||||||
|
if 'best_reward' in agent_state:
|
||||||
|
self.best_reward = agent_state['best_reward']
|
||||||
|
if 'avg_reward' in agent_state:
|
||||||
|
self.avg_reward = agent_state['avg_reward']
|
||||||
|
|
||||||
|
logger.info(f"Agent state loaded from {path}_agent_state.pt (legacy mode)")
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load DQN agent: {e}")
|
||||||
|
|
||||||
def get_position_info(self):
|
def get_position_info(self):
|
||||||
"""Get current position information"""
|
"""Get current position information"""
|
||||||
|
|||||||
BIN
mcp_servers/browser-tools-mcp/BrowserTools-1.2.0-extension.zip
Normal file
BIN
mcp_servers/browser-tools-mcp/BrowserTools-1.2.0-extension.zip
Normal file
Binary file not shown.
@@ -16,6 +16,9 @@ import random
|
|||||||
|
|
||||||
WANDB_AVAILABLE = False
|
WANDB_AVAILABLE = False
|
||||||
|
|
||||||
|
# Import model registry
|
||||||
|
from utils.model_registry import get_model_registry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -72,15 +75,9 @@ class CheckpointManager:
|
|||||||
performance_metrics: Dict[str, float],
|
performance_metrics: Dict[str, float],
|
||||||
training_metadata: Optional[Dict[str, Any]] = None,
|
training_metadata: Optional[Dict[str, Any]] = None,
|
||||||
force_save: bool = False) -> Optional[CheckpointMetadata]:
|
force_save: bool = False) -> Optional[CheckpointMetadata]:
|
||||||
"""Save a model checkpoint with improved error handling and validation"""
|
"""Save a model checkpoint with improved error handling and validation using unified registry"""
|
||||||
try:
|
try:
|
||||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
from utils.model_registry import save_checkpoint as registry_save_checkpoint
|
||||||
checkpoint_id = f"{model_name}_{timestamp}"
|
|
||||||
|
|
||||||
model_dir = self.base_dir / model_name
|
|
||||||
model_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
checkpoint_path = model_dir / f"{checkpoint_id}.pt"
|
|
||||||
|
|
||||||
performance_score = self._calculate_performance_score(performance_metrics)
|
performance_score = self._calculate_performance_score(performance_metrics)
|
||||||
|
|
||||||
@@ -88,19 +85,34 @@ class CheckpointManager:
|
|||||||
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved")
|
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
success = self._save_model_file(model, checkpoint_path, model_type)
|
# Use unified registry for checkpointing
|
||||||
|
success = registry_save_checkpoint(
|
||||||
|
model=model,
|
||||||
|
model_name=model_name,
|
||||||
|
model_type=model_type,
|
||||||
|
performance_score=performance_score,
|
||||||
|
metadata={
|
||||||
|
'performance_metrics': performance_metrics,
|
||||||
|
'training_metadata': training_metadata,
|
||||||
|
'checkpoint_manager': True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024)
|
# Get checkpoint info from registry
|
||||||
|
registry = get_model_registry()
|
||||||
|
checkpoint_info = registry.metadata['models'][model_name]['checkpoints'][-1]
|
||||||
|
|
||||||
|
# Create CheckpointMetadata object
|
||||||
metadata = CheckpointMetadata(
|
metadata = CheckpointMetadata(
|
||||||
checkpoint_id=checkpoint_id,
|
checkpoint_id=checkpoint_info['id'],
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
file_path=str(checkpoint_path),
|
file_path=checkpoint_info['path'],
|
||||||
created_at=datetime.now(),
|
created_at=datetime.fromisoformat(checkpoint_info['timestamp']),
|
||||||
file_size_mb=file_size_mb,
|
file_size_mb=0.0, # Will be calculated by registry
|
||||||
performance_score=performance_score,
|
performance_score=performance_score,
|
||||||
accuracy=performance_metrics.get('accuracy'),
|
accuracy=performance_metrics.get('accuracy'),
|
||||||
loss=performance_metrics.get('loss'),
|
loss=performance_metrics.get('loss'),
|
||||||
@@ -113,8 +125,7 @@ class CheckpointManager:
|
|||||||
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
|
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# W&B disabled
|
# Update local checkpoint tracking
|
||||||
|
|
||||||
self.checkpoints[model_name].append(metadata)
|
self.checkpoints[model_name].append(metadata)
|
||||||
self._rotate_checkpoints(model_name)
|
self._rotate_checkpoints(model_name)
|
||||||
self._save_metadata()
|
self._save_metadata()
|
||||||
@@ -128,7 +139,35 @@ class CheckpointManager:
|
|||||||
|
|
||||||
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
||||||
try:
|
try:
|
||||||
# First, try the standard checkpoint system
|
from utils.model_registry import load_best_checkpoint as registry_load_checkpoint
|
||||||
|
|
||||||
|
# First, try the unified registry
|
||||||
|
registry_result = registry_load_checkpoint(model_name, 'cnn') # Try CNN type first
|
||||||
|
if registry_result is None:
|
||||||
|
registry_result = registry_load_checkpoint(model_name, 'dqn') # Try DQN type
|
||||||
|
|
||||||
|
if registry_result:
|
||||||
|
checkpoint_path, checkpoint_data = registry_result
|
||||||
|
|
||||||
|
# Create CheckpointMetadata from registry data
|
||||||
|
metadata = CheckpointMetadata(
|
||||||
|
checkpoint_id=f"{model_name}_registry",
|
||||||
|
model_name=model_name,
|
||||||
|
model_type=checkpoint_data.get('model_type', 'unknown'),
|
||||||
|
file_path=checkpoint_path,
|
||||||
|
created_at=datetime.fromisoformat(checkpoint_data.get('timestamp', datetime.now().isoformat())),
|
||||||
|
file_size_mb=0.0, # Will be calculated by registry
|
||||||
|
performance_score=checkpoint_data.get('performance_score', 0.0),
|
||||||
|
accuracy=checkpoint_data.get('accuracy'),
|
||||||
|
loss=checkpoint_data.get('loss'),
|
||||||
|
reward=checkpoint_data.get('reward'),
|
||||||
|
pnl=checkpoint_data.get('pnl')
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Loading checkpoint from unified registry for {model_name}")
|
||||||
|
return checkpoint_path, metadata
|
||||||
|
|
||||||
|
# Fallback: Try the standard checkpoint system
|
||||||
if model_name in self.checkpoints and self.checkpoints[model_name]:
|
if model_name in self.checkpoints and self.checkpoints[model_name]:
|
||||||
# Filter out checkpoints with non-existent files
|
# Filter out checkpoints with non-existent files
|
||||||
valid_checkpoints = [
|
valid_checkpoints = [
|
||||||
|
|||||||
446
utils/model_registry.py
Normal file
446
utils/model_registry.py
Normal file
@@ -0,0 +1,446 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Unified Model Registry for Centralized Model Management
|
||||||
|
|
||||||
|
This module provides a unified interface for saving, loading, and managing
|
||||||
|
all machine learning models in the trading system. It consolidates model
|
||||||
|
storage from multiple locations into a single, organized structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, Optional, Tuple, List
|
||||||
|
from datetime import datetime
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ModelRegistry:
|
||||||
|
"""
|
||||||
|
Unified model registry for centralized model management.
|
||||||
|
Handles saving, loading, and organization of all ML models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_dir: str = "models"):
|
||||||
|
"""
|
||||||
|
Initialize the model registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_dir: Base directory for model storage
|
||||||
|
"""
|
||||||
|
self.base_dir = Path(base_dir)
|
||||||
|
self.saved_dir = self.base_dir / "saved"
|
||||||
|
self.checkpoint_dir = self.base_dir / "checkpoints"
|
||||||
|
self.archive_dir = self.base_dir / "archive"
|
||||||
|
|
||||||
|
# Model type directories
|
||||||
|
self.model_dirs = {
|
||||||
|
'cnn': self.base_dir / "cnn",
|
||||||
|
'dqn': self.base_dir / "dqn",
|
||||||
|
'transformer': self.base_dir / "transformer",
|
||||||
|
'hybrid': self.base_dir / "hybrid"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Ensure all directories exist
|
||||||
|
self._ensure_directories()
|
||||||
|
|
||||||
|
# Metadata tracking
|
||||||
|
self.metadata_file = self.base_dir / "registry_metadata.json"
|
||||||
|
self.metadata = self._load_metadata()
|
||||||
|
|
||||||
|
logger.info(f"Model Registry initialized at {self.base_dir}")
|
||||||
|
|
||||||
|
def _ensure_directories(self):
|
||||||
|
"""Ensure all required directories exist."""
|
||||||
|
directories = [
|
||||||
|
self.saved_dir,
|
||||||
|
self.checkpoint_dir,
|
||||||
|
self.archive_dir
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add model type directories
|
||||||
|
for model_dir in self.model_dirs.values():
|
||||||
|
directories.extend([
|
||||||
|
model_dir / "saved",
|
||||||
|
model_dir / "checkpoints",
|
||||||
|
model_dir / "archive"
|
||||||
|
])
|
||||||
|
|
||||||
|
for directory in directories:
|
||||||
|
directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def _load_metadata(self) -> Dict[str, Any]:
|
||||||
|
"""Load registry metadata."""
|
||||||
|
if self.metadata_file.exists():
|
||||||
|
try:
|
||||||
|
with open(self.metadata_file, 'r') as f:
|
||||||
|
return json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load metadata: {e}")
|
||||||
|
return {'models': {}, 'last_updated': datetime.now().isoformat()}
|
||||||
|
|
||||||
|
def _save_metadata(self):
|
||||||
|
"""Save registry metadata."""
|
||||||
|
self.metadata['last_updated'] = datetime.now().isoformat()
|
||||||
|
try:
|
||||||
|
with open(self.metadata_file, 'w') as f:
|
||||||
|
json.dump(self.metadata, f, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save metadata: {e}")
|
||||||
|
|
||||||
|
def save_model(self, model: Any, model_name: str, model_type: str = 'cnn',
|
||||||
|
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Save a model to the unified storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to save
|
||||||
|
model_name: Name of the model
|
||||||
|
model_type: Type of model (cnn, dqn, transformer, hybrid)
|
||||||
|
metadata: Additional metadata to save
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_dir = self.model_dirs.get(model_type, self.saved_dir)
|
||||||
|
save_dir = model_dir / "saved"
|
||||||
|
|
||||||
|
# Generate filename with timestamp
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
filename = f"{model_name}_{timestamp}.pt"
|
||||||
|
filepath = save_dir / filename
|
||||||
|
|
||||||
|
# Also save as latest
|
||||||
|
latest_filepath = save_dir / f"{model_name}_latest.pt"
|
||||||
|
|
||||||
|
# Save model
|
||||||
|
save_dict = {
|
||||||
|
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
|
||||||
|
'model_class': model.__class__.__name__,
|
||||||
|
'model_type': model_type,
|
||||||
|
'timestamp': timestamp,
|
||||||
|
'metadata': metadata or {}
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(save_dict, filepath)
|
||||||
|
torch.save(save_dict, latest_filepath)
|
||||||
|
|
||||||
|
# Update metadata
|
||||||
|
if model_name not in self.metadata['models']:
|
||||||
|
self.metadata['models'][model_name] = {}
|
||||||
|
|
||||||
|
self.metadata['models'][model_name].update({
|
||||||
|
'type': model_type,
|
||||||
|
'latest_path': str(latest_filepath),
|
||||||
|
'last_saved': timestamp,
|
||||||
|
'save_count': self.metadata['models'][model_name].get('save_count', 0) + 1
|
||||||
|
})
|
||||||
|
|
||||||
|
self._save_metadata()
|
||||||
|
|
||||||
|
logger.info(f"Model {model_name} saved to {filepath}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save model {model_name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def load_model(self, model_name: str, model_type: str = 'cnn',
|
||||||
|
model_class: Optional[Any] = None) -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
Load a model from the unified storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model to load
|
||||||
|
model_type: Type of model (cnn, dqn, transformer, hybrid)
|
||||||
|
model_class: Model class to instantiate (if needed)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The loaded model or None if failed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_dir = self.model_dirs.get(model_type, self.saved_dir)
|
||||||
|
save_dir = model_dir / "saved"
|
||||||
|
latest_filepath = save_dir / f"{model_name}_latest.pt"
|
||||||
|
|
||||||
|
if not latest_filepath.exists():
|
||||||
|
logger.warning(f"Model {model_name} not found at {latest_filepath}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Load checkpoint
|
||||||
|
checkpoint = torch.load(latest_filepath, map_location='cpu')
|
||||||
|
|
||||||
|
# Instantiate model if class provided
|
||||||
|
if model_class is not None:
|
||||||
|
model = model_class()
|
||||||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
else:
|
||||||
|
# Try to reconstruct model from state_dict
|
||||||
|
model = type('LoadedModel', (), {})()
|
||||||
|
model.state_dict = lambda: checkpoint['model_state_dict']
|
||||||
|
model.load_state_dict = lambda state_dict: None
|
||||||
|
|
||||||
|
logger.info(f"Model {model_name} loaded from {latest_filepath}")
|
||||||
|
return model
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model {model_name}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def save_checkpoint(self, model: Any, model_name: str, model_type: str = 'cnn',
|
||||||
|
performance_score: float = 0.0,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Save a model checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to checkpoint
|
||||||
|
model_name: Name of the model
|
||||||
|
model_type: Type of model
|
||||||
|
performance_score: Performance score for this checkpoint
|
||||||
|
metadata: Additional metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_dir = self.model_dirs.get(model_type, self.checkpoint_dir)
|
||||||
|
checkpoint_dir = model_dir / "checkpoints"
|
||||||
|
|
||||||
|
# Generate checkpoint ID
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
checkpoint_id = f"{model_name}_{timestamp}_{performance_score:.4f}"
|
||||||
|
|
||||||
|
filepath = checkpoint_dir / f"{checkpoint_id}.pt"
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
checkpoint_data = {
|
||||||
|
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
|
||||||
|
'model_class': model.__class__.__name__,
|
||||||
|
'model_type': model_type,
|
||||||
|
'model_name': model_name,
|
||||||
|
'performance_score': performance_score,
|
||||||
|
'timestamp': timestamp,
|
||||||
|
'metadata': metadata or {}
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(checkpoint_data, filepath)
|
||||||
|
|
||||||
|
# Update metadata
|
||||||
|
if model_name not in self.metadata['models']:
|
||||||
|
self.metadata['models'][model_name] = {}
|
||||||
|
|
||||||
|
if 'checkpoints' not in self.metadata['models'][model_name]:
|
||||||
|
self.metadata['models'][model_name]['checkpoints'] = []
|
||||||
|
|
||||||
|
checkpoint_info = {
|
||||||
|
'id': checkpoint_id,
|
||||||
|
'path': str(filepath),
|
||||||
|
'performance_score': performance_score,
|
||||||
|
'timestamp': timestamp
|
||||||
|
}
|
||||||
|
|
||||||
|
self.metadata['models'][model_name]['checkpoints'].append(checkpoint_info)
|
||||||
|
|
||||||
|
# Keep only top 5 checkpoints
|
||||||
|
checkpoints = self.metadata['models'][model_name]['checkpoints']
|
||||||
|
if len(checkpoints) > 5:
|
||||||
|
checkpoints.sort(key=lambda x: x['performance_score'], reverse=True)
|
||||||
|
checkpoints_to_remove = checkpoints[5:]
|
||||||
|
|
||||||
|
for checkpoint in checkpoints_to_remove:
|
||||||
|
try:
|
||||||
|
os.remove(checkpoint['path'])
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.metadata['models'][model_name]['checkpoints'] = checkpoints[:5]
|
||||||
|
|
||||||
|
self._save_metadata()
|
||||||
|
|
||||||
|
logger.info(f"Checkpoint {checkpoint_id} saved with score {performance_score}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save checkpoint for {model_name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def load_best_checkpoint(self, model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]:
|
||||||
|
"""
|
||||||
|
Load the best checkpoint for a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model
|
||||||
|
model_type: Type of model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (checkpoint_path, checkpoint_data) or None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if model_name not in self.metadata['models']:
|
||||||
|
logger.warning(f"No metadata found for model {model_name}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
checkpoints = self.metadata['models'][model_name].get('checkpoints', [])
|
||||||
|
if not checkpoints:
|
||||||
|
logger.warning(f"No checkpoints found for model {model_name}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find best checkpoint by performance score
|
||||||
|
best_checkpoint = max(checkpoints, key=lambda x: x['performance_score'])
|
||||||
|
checkpoint_path = best_checkpoint['path']
|
||||||
|
|
||||||
|
if not os.path.exists(checkpoint_path):
|
||||||
|
logger.warning(f"Checkpoint file not found: {checkpoint_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
checkpoint_data = torch.load(checkpoint_path, map_location='cpu')
|
||||||
|
|
||||||
|
logger.info(f"Best checkpoint loaded for {model_name}: {best_checkpoint['id']}")
|
||||||
|
return checkpoint_path, checkpoint_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load best checkpoint for {model_name}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def archive_model(self, model_name: str, model_type: str = 'cnn') -> bool:
|
||||||
|
"""
|
||||||
|
Archive a model by moving it to archive directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model to archive
|
||||||
|
model_type: Type of model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_dir = self.model_dirs.get(model_type, self.saved_dir)
|
||||||
|
save_dir = model_dir / "saved"
|
||||||
|
archive_dir = model_dir / "archive"
|
||||||
|
|
||||||
|
latest_filepath = save_dir / f"{model_name}_latest.pt"
|
||||||
|
|
||||||
|
if not latest_filepath.exists():
|
||||||
|
logger.warning(f"Model {model_name} not found to archive")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Move to archive with timestamp
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
archive_filepath = archive_dir / f"{model_name}_archived_{timestamp}.pt"
|
||||||
|
|
||||||
|
os.rename(latest_filepath, archive_filepath)
|
||||||
|
|
||||||
|
logger.info(f"Model {model_name} archived to {archive_filepath}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to archive model {model_name}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_models(self, model_type: Optional[str] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
List all models in the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: Filter by model type (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of model information
|
||||||
|
"""
|
||||||
|
models_info = {}
|
||||||
|
|
||||||
|
for model_name, model_data in self.metadata['models'].items():
|
||||||
|
if model_type and model_data.get('type') != model_type:
|
||||||
|
continue
|
||||||
|
|
||||||
|
models_info[model_name] = {
|
||||||
|
'type': model_data.get('type'),
|
||||||
|
'last_saved': model_data.get('last_saved'),
|
||||||
|
'save_count': model_data.get('save_count', 0),
|
||||||
|
'checkpoint_count': len(model_data.get('checkpoints', [])),
|
||||||
|
'latest_path': model_data.get('latest_path')
|
||||||
|
}
|
||||||
|
|
||||||
|
return models_info
|
||||||
|
|
||||||
|
def cleanup_old_checkpoints(self, model_name: str, keep_count: int = 5) -> int:
|
||||||
|
"""
|
||||||
|
Clean up old checkpoints, keeping only the best ones.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model
|
||||||
|
keep_count: Number of checkpoints to keep
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of checkpoints removed
|
||||||
|
"""
|
||||||
|
if model_name not in self.metadata['models']:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
checkpoints = self.metadata['models'][model_name].get('checkpoints', [])
|
||||||
|
if len(checkpoints) <= keep_count:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Sort by performance score (descending)
|
||||||
|
checkpoints.sort(key=lambda x: x['performance_score'], reverse=True)
|
||||||
|
|
||||||
|
# Remove old checkpoints
|
||||||
|
removed_count = 0
|
||||||
|
for checkpoint in checkpoints[keep_count:]:
|
||||||
|
try:
|
||||||
|
os.remove(checkpoint['path'])
|
||||||
|
removed_count += 1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Update metadata
|
||||||
|
self.metadata['models'][model_name]['checkpoints'] = checkpoints[:keep_count]
|
||||||
|
self._save_metadata()
|
||||||
|
|
||||||
|
logger.info(f"Cleaned up {removed_count} old checkpoints for {model_name}")
|
||||||
|
return removed_count
|
||||||
|
|
||||||
|
|
||||||
|
# Global registry instance
|
||||||
|
_registry_instance = None
|
||||||
|
|
||||||
|
def get_model_registry() -> ModelRegistry:
|
||||||
|
"""Get the global model registry instance."""
|
||||||
|
global _registry_instance
|
||||||
|
if _registry_instance is None:
|
||||||
|
_registry_instance = ModelRegistry()
|
||||||
|
return _registry_instance
|
||||||
|
|
||||||
|
def save_model(model: Any, model_name: str, model_type: str = 'cnn',
|
||||||
|
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Convenience function to save a model using the global registry.
|
||||||
|
"""
|
||||||
|
return get_model_registry().save_model(model, model_name, model_type, metadata)
|
||||||
|
|
||||||
|
def load_model(model_name: str, model_type: str = 'cnn',
|
||||||
|
model_class: Optional[Any] = None) -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
Convenience function to load a model using the global registry.
|
||||||
|
"""
|
||||||
|
return get_model_registry().load_model(model_name, model_type, model_class)
|
||||||
|
|
||||||
|
def save_checkpoint(model: Any, model_name: str, model_type: str = 'cnn',
|
||||||
|
performance_score: float = 0.0,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Convenience function to save a checkpoint using the global registry.
|
||||||
|
"""
|
||||||
|
return get_model_registry().save_checkpoint(model, model_name, model_type, performance_score, metadata)
|
||||||
|
|
||||||
|
def load_best_checkpoint(model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]:
|
||||||
|
"""
|
||||||
|
Convenience function to load the best checkpoint using the global registry.
|
||||||
|
"""
|
||||||
|
return get_model_registry().load_best_checkpoint(model_name, model_type)
|
||||||
@@ -4710,39 +4710,64 @@ class CleanTradingDashboard:
|
|||||||
|
|
||||||
stored_models = []
|
stored_models = []
|
||||||
|
|
||||||
|
# Use unified model registry for saving
|
||||||
|
from utils.model_registry import save_model
|
||||||
|
|
||||||
# 1. Store DQN model
|
# 1. Store DQN model
|
||||||
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||||
try:
|
try:
|
||||||
if hasattr(self.orchestrator.rl_agent, 'save'):
|
success = save_model(
|
||||||
save_path = self.orchestrator.rl_agent.save('models/saved/dqn_agent_session')
|
model=self.orchestrator.rl_agent.policy_net, # Save policy network
|
||||||
stored_models.append(('DQN', save_path))
|
model_name='dqn_agent_session',
|
||||||
logger.info(f"Stored DQN model: {save_path}")
|
model_type='dqn',
|
||||||
|
metadata={'session_save': True, 'dashboard_save': True}
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
stored_models.append(('DQN', 'models/dqn/saved/dqn_agent_session_latest.pt'))
|
||||||
|
logger.info("Stored DQN model via unified registry")
|
||||||
|
else:
|
||||||
|
logger.warning("Failed to store DQN model via unified registry")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to store DQN model: {e}")
|
logger.warning(f"Failed to store DQN model: {e}")
|
||||||
|
|
||||||
# 2. Store CNN model
|
# 2. Store CNN model
|
||||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||||
try:
|
try:
|
||||||
if hasattr(self.orchestrator.cnn_model, 'save'):
|
success = save_model(
|
||||||
save_path = self.orchestrator.cnn_model.save('models/saved/cnn_model_session')
|
model=self.orchestrator.cnn_model,
|
||||||
stored_models.append(('CNN', save_path))
|
model_name='cnn_model_session',
|
||||||
logger.info(f"Stored CNN model: {save_path}")
|
model_type='cnn',
|
||||||
|
metadata={'session_save': True, 'dashboard_save': True}
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
stored_models.append(('CNN', 'models/cnn/saved/cnn_model_session_latest.pt'))
|
||||||
|
logger.info("Stored CNN model via unified registry")
|
||||||
|
else:
|
||||||
|
logger.warning("Failed to store CNN model via unified registry")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to store CNN model: {e}")
|
logger.warning(f"Failed to store CNN model: {e}")
|
||||||
|
|
||||||
# 3. Store Transformer model
|
# 3. Store Transformer model
|
||||||
if hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer:
|
if hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer:
|
||||||
try:
|
try:
|
||||||
if hasattr(self.orchestrator.primary_transformer, 'save'):
|
success = save_model(
|
||||||
save_path = self.orchestrator.primary_transformer.save('models/saved/transformer_model_session')
|
model=self.orchestrator.primary_transformer,
|
||||||
stored_models.append(('Transformer', save_path))
|
model_name='transformer_model_session',
|
||||||
logger.info(f"Stored Transformer model: {save_path}")
|
model_type='transformer',
|
||||||
|
metadata={'session_save': True, 'dashboard_save': True}
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
stored_models.append(('Transformer', 'models/transformer/saved/transformer_model_session_latest.pt'))
|
||||||
|
logger.info("Stored Transformer model via unified registry")
|
||||||
|
else:
|
||||||
|
logger.warning("Failed to store Transformer model via unified registry")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to store Transformer model: {e}")
|
logger.warning(f"Failed to store Transformer model: {e}")
|
||||||
|
|
||||||
# 4. Store COB RL model
|
# 4. Store COB RL model (if exists)
|
||||||
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
||||||
try:
|
try:
|
||||||
|
# COB RL model might have different save method
|
||||||
if hasattr(self.orchestrator.cob_rl_agent, 'save'):
|
if hasattr(self.orchestrator.cob_rl_agent, 'save'):
|
||||||
save_path = self.orchestrator.cob_rl_agent.save('models/saved/cob_rl_agent_session')
|
save_path = self.orchestrator.cob_rl_agent.save('models/saved/cob_rl_agent_session')
|
||||||
stored_models.append(('COB RL', save_path))
|
stored_models.append(('COB RL', save_path))
|
||||||
@@ -4750,13 +4775,20 @@ class CleanTradingDashboard:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to store COB RL model: {e}")
|
logger.warning(f"Failed to store COB RL model: {e}")
|
||||||
|
|
||||||
# 5. Store Decision Fusion model
|
# 5. Store Decision model
|
||||||
if hasattr(self.orchestrator, 'decision_model') and self.orchestrator.decision_model:
|
if hasattr(self.orchestrator, 'decision_model') and self.orchestrator.decision_model:
|
||||||
try:
|
try:
|
||||||
if hasattr(self.orchestrator.decision_model, 'save'):
|
success = save_model(
|
||||||
save_path = self.orchestrator.decision_model.save('models/saved/decision_fusion_session')
|
model=self.orchestrator.decision_model,
|
||||||
stored_models.append(('Decision Fusion', save_path))
|
model_name='decision_fusion_session',
|
||||||
logger.info(f"Stored Decision Fusion model: {save_path}")
|
model_type='hybrid',
|
||||||
|
metadata={'session_save': True, 'dashboard_save': True}
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
stored_models.append(('Decision Fusion', 'models/hybrid/saved/decision_fusion_session_latest.pt'))
|
||||||
|
logger.info("Stored Decision Fusion model via unified registry")
|
||||||
|
else:
|
||||||
|
logger.warning("Failed to store Decision Fusion model via unified registry")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to store Decision Fusion model: {e}")
|
logger.warning(f"Failed to store Decision Fusion model: {e}")
|
||||||
|
|
||||||
@@ -6706,13 +6738,39 @@ class CleanTradingDashboard:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error saving transformer checkpoint: {e}")
|
logger.error(f"Error saving transformer checkpoint: {e}")
|
||||||
# Fallback to direct save
|
# Use unified registry for checkpoint
|
||||||
try:
|
try:
|
||||||
checkpoint_path = f"NN/models/saved/transformer_checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
|
from utils.model_registry import save_checkpoint as registry_save_checkpoint
|
||||||
transformer_trainer.save_model(checkpoint_path)
|
|
||||||
logger.info(f"TRANSFORMER: Fallback checkpoint saved: {checkpoint_path}")
|
checkpoint_data = torch.load(checkpoint_path, map_location='cpu') if 'checkpoint_path' in locals() else checkpoint_data
|
||||||
except Exception as fallback_error:
|
|
||||||
logger.error(f"Fallback checkpoint save also failed: {fallback_error}")
|
success = registry_save_checkpoint(
|
||||||
|
model=checkpoint_data,
|
||||||
|
model_name='transformer',
|
||||||
|
model_type='transformer',
|
||||||
|
performance_score=training_metrics['accuracy'],
|
||||||
|
metadata={
|
||||||
|
'training_samples': len(training_samples),
|
||||||
|
'loss': training_metrics['total_loss'],
|
||||||
|
'accuracy': training_metrics['accuracy'],
|
||||||
|
'checkpoint_source': 'dashboard_training'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info("TRANSFORMER: Checkpoint saved via unified registry")
|
||||||
|
else:
|
||||||
|
logger.warning("TRANSFORMER: Failed to save checkpoint via unified registry")
|
||||||
|
|
||||||
|
except Exception as registry_error:
|
||||||
|
logger.warning(f"Unified registry save failed: {registry_error}")
|
||||||
|
# Fallback to direct save
|
||||||
|
try:
|
||||||
|
checkpoint_path = f"NN/models/saved/transformer_checkpoint_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
|
||||||
|
transformer_trainer.save_model(checkpoint_path)
|
||||||
|
logger.info(f"TRANSFORMER: Fallback checkpoint saved: {checkpoint_path}")
|
||||||
|
except Exception as fallback_error:
|
||||||
|
logger.error(f"Fallback checkpoint save also failed: {fallback_error}")
|
||||||
|
|
||||||
logger.info(f"TRANSFORMER: Trained on {len(training_samples)} samples, loss: {training_metrics['total_loss']:.4f}, accuracy: {training_metrics['accuracy']:.4f}")
|
logger.info(f"TRANSFORMER: Trained on {len(training_samples)} samples, loss: {training_metrics['total_loss']:.4f}, accuracy: {training_metrics['accuracy']:.4f}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user