dedulicae model storage
This commit is contained in:
@@ -457,6 +457,72 @@ class ModelManager:
|
||||
logger.error(f"Error getting storage stats: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def get_checkpoint_stats(self) -> Dict[str, Any]:
|
||||
"""Get statistics about managed checkpoints (compatible with old checkpoint_manager interface)"""
|
||||
try:
|
||||
stats = {
|
||||
'total_models': 0,
|
||||
'total_checkpoints': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'models': {}
|
||||
}
|
||||
|
||||
# Count files in different directories as "checkpoints"
|
||||
checkpoint_dirs = [
|
||||
self.checkpoints_dir / "cnn",
|
||||
self.checkpoints_dir / "dqn",
|
||||
self.checkpoints_dir / "rl",
|
||||
self.checkpoints_dir / "transformer",
|
||||
self.checkpoints_dir / "hybrid"
|
||||
]
|
||||
|
||||
total_size = 0
|
||||
total_files = 0
|
||||
|
||||
for checkpoint_dir in checkpoint_dirs:
|
||||
if checkpoint_dir.exists():
|
||||
model_files = list(checkpoint_dir.rglob('*.pt'))
|
||||
if model_files:
|
||||
model_name = checkpoint_dir.name
|
||||
stats['total_models'] += 1
|
||||
|
||||
model_size = sum(f.stat().st_size for f in model_files)
|
||||
stats['total_checkpoints'] += len(model_files)
|
||||
stats['total_size_mb'] += model_size / (1024 * 1024)
|
||||
total_size += model_size
|
||||
total_files += len(model_files)
|
||||
|
||||
# Get the most recent file as "latest"
|
||||
latest_file = max(model_files, key=lambda f: f.stat().st_mtime)
|
||||
|
||||
stats['models'][model_name] = {
|
||||
'checkpoint_count': len(model_files),
|
||||
'total_size_mb': model_size / (1024 * 1024),
|
||||
'best_performance': 0.0, # Not tracked in unified system
|
||||
'best_checkpoint_id': latest_file.name,
|
||||
'latest_checkpoint': latest_file.name
|
||||
}
|
||||
|
||||
# Also check saved models directory
|
||||
if self.saved_dir.exists():
|
||||
saved_files = list(self.saved_dir.rglob('*.pt'))
|
||||
if saved_files:
|
||||
stats['total_checkpoints'] += len(saved_files)
|
||||
saved_size = sum(f.stat().st_size for f in saved_files)
|
||||
stats['total_size_mb'] += saved_size / (1024 * 1024)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint stats: {e}")
|
||||
return {
|
||||
'total_models': 0,
|
||||
'total_checkpoints': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'models': {},
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
def get_model_leaderboard(self) -> List[Dict[str, Any]]:
|
||||
"""Get model performance leaderboard"""
|
||||
try:
|
||||
|
@@ -1,560 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Checkpoint Management System for W&B Training
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
import random
|
||||
|
||||
WANDB_AVAILABLE = False
|
||||
|
||||
# Import model registry
|
||||
from utils.model_registry import get_model_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class CheckpointMetadata:
|
||||
checkpoint_id: str
|
||||
model_name: str
|
||||
model_type: str
|
||||
file_path: str
|
||||
created_at: datetime
|
||||
file_size_mb: float
|
||||
performance_score: float
|
||||
accuracy: Optional[float] = None
|
||||
loss: Optional[float] = None
|
||||
val_accuracy: Optional[float] = None
|
||||
val_loss: Optional[float] = None
|
||||
reward: Optional[float] = None
|
||||
pnl: Optional[float] = None
|
||||
epoch: Optional[int] = None
|
||||
training_time_hours: Optional[float] = None
|
||||
total_parameters: Optional[int] = None
|
||||
wandb_run_id: Optional[str] = None
|
||||
wandb_artifact_name: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
data = asdict(self)
|
||||
data['created_at'] = self.created_at.isoformat()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata':
|
||||
data['created_at'] = datetime.fromisoformat(data['created_at'])
|
||||
return cls(**data)
|
||||
|
||||
class CheckpointManager:
|
||||
def __init__(self,
|
||||
base_checkpoint_dir: str = "NN/models/saved",
|
||||
max_checkpoints_per_model: int = 5,
|
||||
metadata_file: str = "checkpoint_metadata.json",
|
||||
enable_wandb: bool = False):
|
||||
self.base_dir = Path(base_checkpoint_dir)
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.max_checkpoints = max_checkpoints_per_model
|
||||
self.metadata_file = self.base_dir / metadata_file
|
||||
self.enable_wandb = False
|
||||
|
||||
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
|
||||
self._warned_models = set() # Track models we've warned about to reduce spam
|
||||
self._load_metadata()
|
||||
|
||||
logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}")
|
||||
|
||||
def save_checkpoint(self, model, model_name: str, model_type: str,
|
||||
performance_metrics: Dict[str, float],
|
||||
training_metadata: Optional[Dict[str, Any]] = None,
|
||||
force_save: bool = False) -> Optional[CheckpointMetadata]:
|
||||
"""Save a model checkpoint with improved error handling and validation using unified registry"""
|
||||
try:
|
||||
from NN.training.model_manager import save_checkpoint as registry_save_checkpoint
|
||||
|
||||
performance_score = self._calculate_performance_score(performance_metrics)
|
||||
|
||||
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
|
||||
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved")
|
||||
return None
|
||||
|
||||
# Use unified registry for checkpointing
|
||||
success = registry_save_checkpoint(
|
||||
model=model,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
performance_score=performance_score,
|
||||
metadata={
|
||||
'performance_metrics': performance_metrics,
|
||||
'training_metadata': training_metadata,
|
||||
'checkpoint_manager': True
|
||||
}
|
||||
)
|
||||
|
||||
if not success:
|
||||
return None
|
||||
|
||||
# Get checkpoint info from registry
|
||||
registry = get_model_registry()
|
||||
checkpoint_info = registry.metadata['models'][model_name]['checkpoints'][-1]
|
||||
|
||||
# Create CheckpointMetadata object
|
||||
metadata = CheckpointMetadata(
|
||||
checkpoint_id=checkpoint_info['id'],
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
file_path=checkpoint_info['path'],
|
||||
created_at=datetime.fromisoformat(checkpoint_info['timestamp']),
|
||||
file_size_mb=0.0, # Will be calculated by registry
|
||||
performance_score=performance_score,
|
||||
accuracy=performance_metrics.get('accuracy'),
|
||||
loss=performance_metrics.get('loss'),
|
||||
val_accuracy=performance_metrics.get('val_accuracy'),
|
||||
val_loss=performance_metrics.get('val_loss'),
|
||||
reward=performance_metrics.get('reward'),
|
||||
pnl=performance_metrics.get('pnl'),
|
||||
epoch=training_metadata.get('epoch') if training_metadata else None,
|
||||
training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None,
|
||||
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
|
||||
)
|
||||
|
||||
# Update local checkpoint tracking
|
||||
self.checkpoints[model_name].append(metadata)
|
||||
self._rotate_checkpoints(model_name)
|
||||
self._save_metadata()
|
||||
|
||||
logger.debug(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})")
|
||||
return metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
||||
try:
|
||||
from utils.model_registry import load_best_checkpoint as registry_load_checkpoint
|
||||
|
||||
# First, try the unified registry
|
||||
registry_result = registry_load_checkpoint(model_name, 'cnn') # Try CNN type first
|
||||
if registry_result is None:
|
||||
registry_result = registry_load_checkpoint(model_name, 'dqn') # Try DQN type
|
||||
|
||||
if registry_result:
|
||||
checkpoint_path, checkpoint_data = registry_result
|
||||
|
||||
# Create CheckpointMetadata from registry data
|
||||
metadata = CheckpointMetadata(
|
||||
checkpoint_id=f"{model_name}_registry",
|
||||
model_name=model_name,
|
||||
model_type=checkpoint_data.get('model_type', 'unknown'),
|
||||
file_path=checkpoint_path,
|
||||
created_at=datetime.fromisoformat(checkpoint_data.get('timestamp', datetime.now().isoformat())),
|
||||
file_size_mb=0.0, # Will be calculated by registry
|
||||
performance_score=checkpoint_data.get('performance_score', 0.0),
|
||||
accuracy=checkpoint_data.get('accuracy'),
|
||||
loss=checkpoint_data.get('loss'),
|
||||
reward=checkpoint_data.get('reward'),
|
||||
pnl=checkpoint_data.get('pnl')
|
||||
)
|
||||
|
||||
logger.debug(f"Loading checkpoint from unified registry for {model_name}")
|
||||
return checkpoint_path, metadata
|
||||
|
||||
# Fallback: Try the standard checkpoint system
|
||||
if model_name in self.checkpoints and self.checkpoints[model_name]:
|
||||
# Filter out checkpoints with non-existent files
|
||||
valid_checkpoints = [
|
||||
cp for cp in self.checkpoints[model_name]
|
||||
if Path(cp.file_path).exists()
|
||||
]
|
||||
|
||||
if valid_checkpoints:
|
||||
best_checkpoint = max(valid_checkpoints, key=lambda x: x.performance_score)
|
||||
logger.debug(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}")
|
||||
return best_checkpoint.file_path, best_checkpoint
|
||||
else:
|
||||
# Clean up invalid metadata entries
|
||||
invalid_count = len(self.checkpoints[model_name])
|
||||
logger.warning(f"Found {invalid_count} invalid checkpoint entries for {model_name}, cleaning up metadata")
|
||||
self.checkpoints[model_name] = []
|
||||
self._save_metadata()
|
||||
|
||||
# Fallback: Look for existing saved models in the legacy format
|
||||
logger.debug(f"No valid checkpoints found for model: {model_name}, attempting to find legacy saved models")
|
||||
legacy_model_path = self._find_legacy_model(model_name)
|
||||
|
||||
if legacy_model_path:
|
||||
# Create checkpoint metadata for the legacy model using actual file data
|
||||
legacy_metadata = self._create_legacy_metadata(model_name, legacy_model_path)
|
||||
logger.debug(f"Found legacy model for {model_name}: {legacy_model_path}")
|
||||
return str(legacy_model_path), legacy_metadata
|
||||
|
||||
# Only warn once per model to avoid spam
|
||||
if model_name not in self._warned_models:
|
||||
logger.info(f"No checkpoints found for {model_name}, starting fresh")
|
||||
self._warned_models.add(model_name)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def _calculate_performance_score(self, metrics: Dict[str, float]) -> float:
|
||||
"""Calculate performance score with improved sensitivity for training models"""
|
||||
score = 0.0
|
||||
|
||||
# Prioritize loss reduction for active training models
|
||||
if 'loss' in metrics:
|
||||
# Invert loss so lower loss = higher score, with better scaling
|
||||
loss_value = metrics['loss']
|
||||
if loss_value > 0:
|
||||
score += max(0, 100 / (1 + loss_value)) # More sensitive to loss changes
|
||||
else:
|
||||
score += 100 # Perfect loss
|
||||
|
||||
# Add other metrics with appropriate weights
|
||||
if 'accuracy' in metrics:
|
||||
score += metrics['accuracy'] * 50 # Reduced weight to balance with loss
|
||||
if 'val_accuracy' in metrics:
|
||||
score += metrics['val_accuracy'] * 50
|
||||
if 'val_loss' in metrics:
|
||||
val_loss = metrics['val_loss']
|
||||
if val_loss > 0:
|
||||
score += max(0, 50 / (1 + val_loss))
|
||||
if 'reward' in metrics:
|
||||
score += metrics['reward'] * 10
|
||||
if 'pnl' in metrics:
|
||||
score += metrics['pnl'] * 5
|
||||
if 'training_samples' in metrics:
|
||||
# Bonus for processing more training samples
|
||||
score += min(10, metrics['training_samples'] / 10)
|
||||
|
||||
# Return actual calculated score - NO SYNTHETIC MINIMUM
|
||||
return score
|
||||
|
||||
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
|
||||
"""Improved checkpoint saving logic with more frequent saves during training"""
|
||||
if model_name not in self.checkpoints or not self.checkpoints[model_name]:
|
||||
return True # Always save first checkpoint
|
||||
|
||||
# Allow more checkpoints during active training
|
||||
if len(self.checkpoints[model_name]) < self.max_checkpoints:
|
||||
return True
|
||||
|
||||
# Get current best and worst scores
|
||||
scores = [cp.performance_score for cp in self.checkpoints[model_name]]
|
||||
best_score = max(scores)
|
||||
worst_score = min(scores)
|
||||
|
||||
# Save if better than worst (more frequent saves)
|
||||
if performance_score > worst_score:
|
||||
return True
|
||||
|
||||
# For high-performing models (score > 100), be more sensitive to small improvements
|
||||
if best_score > 100:
|
||||
# Save if within 0.1% of best score (very sensitive for converged models)
|
||||
if performance_score >= best_score * 0.999:
|
||||
return True
|
||||
else:
|
||||
# Also save if we're within 10% of best score (capture near-optimal models)
|
||||
if performance_score >= best_score * 0.9:
|
||||
return True
|
||||
|
||||
# Save more frequently during active training (every 5th attempt instead of 10th)
|
||||
if random.random() < 0.2: # 20% chance to save anyway
|
||||
logger.debug(f"Saving checkpoint for {model_name} - periodic save during active training")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _save_model_file(self, model, file_path: Path, model_type: str) -> bool:
|
||||
try:
|
||||
if hasattr(model, 'state_dict'):
|
||||
torch.save({
|
||||
'model_state_dict': model.state_dict(),
|
||||
'model_type': model_type,
|
||||
'saved_at': datetime.now().isoformat()
|
||||
}, file_path)
|
||||
else:
|
||||
torch.save(model, file_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model file {file_path}: {e}")
|
||||
return False
|
||||
|
||||
def _rotate_checkpoints(self, model_name: str):
|
||||
checkpoint_list = self.checkpoints[model_name]
|
||||
|
||||
if len(checkpoint_list) <= self.max_checkpoints:
|
||||
return
|
||||
|
||||
checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True)
|
||||
|
||||
to_remove = checkpoint_list[self.max_checkpoints:]
|
||||
self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints]
|
||||
|
||||
for checkpoint in to_remove:
|
||||
try:
|
||||
file_path = Path(checkpoint.file_path)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
logger.debug(f"Rotated out checkpoint: {checkpoint.checkpoint_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
|
||||
|
||||
def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]:
|
||||
return None
|
||||
|
||||
def _load_metadata(self):
|
||||
try:
|
||||
if self.metadata_file.exists():
|
||||
with open(self.metadata_file, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
for model_name, checkpoint_list in data.items():
|
||||
self.checkpoints[model_name] = [
|
||||
CheckpointMetadata.from_dict(cp_data)
|
||||
for cp_data in checkpoint_list
|
||||
]
|
||||
|
||||
logger.info(f"Loaded metadata for {len(self.checkpoints)} models")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checkpoint metadata: {e}")
|
||||
|
||||
def _save_metadata(self):
|
||||
try:
|
||||
data = {}
|
||||
for model_name, checkpoint_list in self.checkpoints.items():
|
||||
data[model_name] = [cp.to_dict() for cp in checkpoint_list]
|
||||
|
||||
with open(self.metadata_file, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint metadata: {e}")
|
||||
|
||||
def get_checkpoint_stats(self):
|
||||
"""Get statistics about managed checkpoints"""
|
||||
stats = {
|
||||
'total_models': len(self.checkpoints),
|
||||
'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()),
|
||||
'total_size_mb': 0.0,
|
||||
'models': {}
|
||||
}
|
||||
|
||||
for model_name, checkpoint_list in self.checkpoints.items():
|
||||
if not checkpoint_list:
|
||||
continue
|
||||
|
||||
model_size = sum(cp.file_size_mb for cp in checkpoint_list)
|
||||
best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score)
|
||||
|
||||
stats['models'][model_name] = {
|
||||
'checkpoint_count': len(checkpoint_list),
|
||||
'total_size_mb': model_size,
|
||||
'best_performance': best_checkpoint.performance_score,
|
||||
'best_checkpoint_id': best_checkpoint.checkpoint_id,
|
||||
'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id
|
||||
}
|
||||
|
||||
stats['total_size_mb'] += model_size
|
||||
|
||||
return stats
|
||||
|
||||
def _find_legacy_model(self, model_name: str) -> Optional[Path]:
|
||||
"""Find legacy saved models based on model name patterns"""
|
||||
base_dir = Path(self.base_dir)
|
||||
|
||||
# Additional search locations
|
||||
search_dirs = [
|
||||
base_dir,
|
||||
Path("models/saved"),
|
||||
Path("NN/models/saved"),
|
||||
Path("models"),
|
||||
Path("models/archive"),
|
||||
Path("models/backtest")
|
||||
]
|
||||
|
||||
# Define model name mappings and patterns for legacy files
|
||||
legacy_patterns = {
|
||||
'dqn_agent': [
|
||||
'dqn_agent_session_policy.pt',
|
||||
'dqn_agent_session_agent_state.pt',
|
||||
'dqn_agent_best_policy.pt',
|
||||
'enhanced_dqn_best_policy.pt',
|
||||
'improved_dqn_agent_best_policy.pt',
|
||||
'dqn_agent_final_policy.pt',
|
||||
'trading_agent_best_pnl.pt'
|
||||
],
|
||||
'enhanced_cnn': [
|
||||
'cnn_model_session.pt',
|
||||
'cnn_model_best.pt',
|
||||
'optimized_short_term_model_best.pt',
|
||||
'optimized_short_term_model_realtime_best.pt',
|
||||
'optimized_short_term_model_ticks_best.pt'
|
||||
],
|
||||
'extrema_trainer': [
|
||||
'supervised_model_best.pt'
|
||||
],
|
||||
'cob_rl': [
|
||||
'best_rl_model.pth_policy.pt',
|
||||
'rl_agent_best_policy.pt'
|
||||
],
|
||||
'decision': [
|
||||
# Decision models might be in subdirectories, but let's check main dir too
|
||||
'decision_best.pt',
|
||||
'decision_model_best.pt',
|
||||
# Check for transformer models which might be used as decision models
|
||||
'enhanced_dqn_best_policy.pt',
|
||||
'improved_dqn_agent_best_policy.pt'
|
||||
]
|
||||
}
|
||||
|
||||
# Get patterns for this model name
|
||||
patterns = legacy_patterns.get(model_name, [])
|
||||
|
||||
# Also try generic patterns based on model name
|
||||
patterns.extend([
|
||||
f'{model_name}_best.pt',
|
||||
f'{model_name}_best_policy.pt',
|
||||
f'{model_name}_final.pt',
|
||||
f'{model_name}_final_policy.pt'
|
||||
])
|
||||
|
||||
# Search for the model files in all search directories
|
||||
for search_dir in search_dirs:
|
||||
if not search_dir.exists():
|
||||
continue
|
||||
|
||||
for pattern in patterns:
|
||||
candidate_path = search_dir / pattern
|
||||
if candidate_path.exists():
|
||||
logger.info(f"Found legacy model file: {candidate_path}")
|
||||
return candidate_path
|
||||
|
||||
# Also check subdirectories
|
||||
for subdir in base_dir.iterdir():
|
||||
if subdir.is_dir() and subdir.name == model_name:
|
||||
for pattern in patterns:
|
||||
candidate_path = subdir / pattern
|
||||
if candidate_path.exists():
|
||||
logger.debug(f"Found legacy model file in subdirectory: {candidate_path}")
|
||||
return candidate_path
|
||||
|
||||
# Extended search: scan common project model directories for best checkpoints
|
||||
try:
|
||||
# Attempt to infer project root from base_dir (NN/models/saved -> root)
|
||||
project_root = base_dir.resolve().parent.parent.parent
|
||||
except Exception:
|
||||
project_root = Path(".").resolve()
|
||||
additional_dirs = [
|
||||
project_root / "models",
|
||||
project_root / "models" / "archive",
|
||||
project_root / "models" / "backtest",
|
||||
]
|
||||
|
||||
def _match_legacy_name(candidate: Path, model: str) -> bool:
|
||||
name = candidate.name.lower()
|
||||
model_keys = {
|
||||
'dqn_agent': ['dqn', 'agent', 'policy'],
|
||||
'enhanced_cnn': ['cnn', 'optimized_short_term'],
|
||||
'extrema_trainer': ['supervised', 'extrema'],
|
||||
'cob_rl': ['cob', 'rl', 'policy'],
|
||||
'decision': ['decision', 'transformer']
|
||||
}.get(model, [model])
|
||||
return any(k in name for k in model_keys)
|
||||
|
||||
candidates: List[Path] = []
|
||||
for adir in additional_dirs:
|
||||
if not adir.exists():
|
||||
continue
|
||||
try:
|
||||
for pt in adir.rglob('*.pt'):
|
||||
# Prefer files that indicate "best" and match model hints
|
||||
lname = pt.name.lower()
|
||||
if 'best' in lname and _match_legacy_name(pt, model_name):
|
||||
candidates.append(pt)
|
||||
# Do not add generic fallbacks to avoid mismatched model types
|
||||
except Exception:
|
||||
# Ignore directory traversal issues
|
||||
pass
|
||||
|
||||
if candidates:
|
||||
# Pick the most recently modified candidate
|
||||
try:
|
||||
best = max(candidates, key=lambda p: p.stat().st_mtime)
|
||||
logger.debug(f"Found legacy model file in project models dir: {best}")
|
||||
return best
|
||||
except Exception:
|
||||
# If stat fails, just return the first one deterministically
|
||||
candidates.sort()
|
||||
logger.debug(f"Found legacy model file in project models dir: {candidates[0]}")
|
||||
return candidates[0]
|
||||
|
||||
return None
|
||||
|
||||
def _create_legacy_metadata(self, model_name: str, file_path: Path) -> CheckpointMetadata:
|
||||
"""Create metadata for legacy model files using only actual file information"""
|
||||
try:
|
||||
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
created_time = datetime.fromtimestamp(file_path.stat().st_mtime)
|
||||
|
||||
# NO SYNTHETIC DATA - use only actual file information
|
||||
return CheckpointMetadata(
|
||||
checkpoint_id=f"legacy_{model_name}_{int(created_time.timestamp())}",
|
||||
model_name=model_name,
|
||||
model_type=model_name,
|
||||
file_path=str(file_path),
|
||||
created_at=created_time,
|
||||
file_size_mb=file_size_mb,
|
||||
performance_score=0.0, # Unknown performance - use 0, not synthetic values
|
||||
accuracy=None,
|
||||
loss=None,
|
||||
val_accuracy=None,
|
||||
val_loss=None,
|
||||
reward=None,
|
||||
pnl=None,
|
||||
epoch=None,
|
||||
training_time_hours=None,
|
||||
total_parameters=None,
|
||||
wandb_run_id=None,
|
||||
wandb_artifact_name=None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating legacy metadata for {model_name}: {e}")
|
||||
# Return a basic metadata with minimal info - NO SYNTHETIC VALUES
|
||||
return CheckpointMetadata(
|
||||
checkpoint_id=f"legacy_{model_name}",
|
||||
model_name=model_name,
|
||||
model_type=model_name,
|
||||
file_path=str(file_path),
|
||||
created_at=datetime.now(),
|
||||
file_size_mb=0.0,
|
||||
performance_score=0.0 # Unknown - use 0, not synthetic
|
||||
)
|
||||
|
||||
_checkpoint_manager = None
|
||||
|
||||
def get_checkpoint_manager() -> CheckpointManager:
|
||||
global _checkpoint_manager
|
||||
if _checkpoint_manager is None:
|
||||
_checkpoint_manager = CheckpointManager()
|
||||
return _checkpoint_manager
|
||||
|
||||
def save_checkpoint(model, model_name: str, model_type: str,
|
||||
performance_metrics: Dict[str, float],
|
||||
training_metadata: Optional[Dict[str, Any]] = None,
|
||||
force_save: bool = False) -> Optional[CheckpointMetadata]:
|
||||
return get_checkpoint_manager().save_checkpoint(
|
||||
model, model_name, model_type, performance_metrics, training_metadata, force_save
|
||||
)
|
||||
|
||||
def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
|
||||
return get_checkpoint_manager().load_best_checkpoint(model_name)
|
@@ -1,361 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Improved Model Saver
|
||||
|
||||
A comprehensive model saving utility that handles various model types
|
||||
and ensures reliable checkpointing with validation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, Union
|
||||
import shutil
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ImprovedModelSaver:
|
||||
"""Enhanced model saving with validation and backup strategies"""
|
||||
|
||||
def __init__(self, base_dir: str = "models/saved"):
|
||||
self.base_dir = Path(base_dir)
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def save_model_safely(self,
|
||||
model: Any,
|
||||
model_name: str,
|
||||
model_type: str = "unknown",
|
||||
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Save a model with multiple fallback strategies
|
||||
|
||||
Args:
|
||||
model: The model to save
|
||||
model_name: Name identifier for the model
|
||||
model_type: Type of model (dqn, cnn, rl, etc.)
|
||||
metadata: Additional metadata to save
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
model_dir = self.base_dir / model_name
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create backup file names
|
||||
main_path = model_dir / f"{model_name}_latest.pt"
|
||||
backup_path = model_dir / f"{model_name}_{timestamp}.pt"
|
||||
|
||||
try:
|
||||
# Strategy 1: Try to save using robust_save if available
|
||||
if hasattr(model, '__dict__') and hasattr(torch, 'save'):
|
||||
success = self._save_pytorch_model(model, main_path, backup_path)
|
||||
if success:
|
||||
self._save_metadata(model_dir, model_name, model_type, metadata)
|
||||
logger.info(f"Successfully saved {model_name} using PyTorch save")
|
||||
return True
|
||||
|
||||
# Strategy 2: Try state_dict saving for PyTorch models
|
||||
if hasattr(model, 'state_dict'):
|
||||
success = self._save_state_dict(model, main_path, backup_path)
|
||||
if success:
|
||||
self._save_metadata(model_dir, model_name, model_type, metadata)
|
||||
logger.info(f"Successfully saved {model_name} using state_dict")
|
||||
return True
|
||||
|
||||
# Strategy 3: Try component-based saving for complex models
|
||||
if hasattr(model, 'policy_net') or hasattr(model, 'target_net'):
|
||||
success = self._save_rl_agent_components(model, model_dir, model_name)
|
||||
if success:
|
||||
self._save_metadata(model_dir, model_name, model_type, metadata)
|
||||
logger.info(f"Successfully saved {model_name} using component-based saving")
|
||||
return True
|
||||
|
||||
# Strategy 4: Fallback - try pickle
|
||||
success = self._save_with_pickle(model, main_path, backup_path)
|
||||
if success:
|
||||
self._save_metadata(model_dir, model_name, model_type, metadata)
|
||||
logger.info(f"Successfully saved {model_name} using pickle fallback")
|
||||
return True
|
||||
|
||||
logger.error(f"All save strategies failed for {model_name}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Critical error saving {model_name}: {e}")
|
||||
return False
|
||||
|
||||
def _save_pytorch_model(self, model, main_path: Path, backup_path: Path) -> bool:
|
||||
"""Save using standard PyTorch torch.save"""
|
||||
try:
|
||||
# Create checkpoint data
|
||||
if hasattr(model, 'state_dict'):
|
||||
checkpoint = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'model_class': model.__class__.__name__,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Add additional attributes
|
||||
for attr in ['epsilon', 'total_steps', 'current_reward', 'optimizer']:
|
||||
if hasattr(model, attr):
|
||||
try:
|
||||
value = getattr(model, attr)
|
||||
if attr == 'optimizer' and value is not None:
|
||||
checkpoint['optimizer_state_dict'] = value.state_dict()
|
||||
else:
|
||||
checkpoint[attr] = value
|
||||
except Exception:
|
||||
pass # Skip problematic attributes
|
||||
else:
|
||||
checkpoint = {
|
||||
'model': model,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Save to backup location first
|
||||
torch.save(checkpoint, backup_path)
|
||||
|
||||
# Verify backup was saved correctly
|
||||
torch.load(backup_path, map_location='cpu')
|
||||
|
||||
# Copy to main location
|
||||
shutil.copy2(backup_path, main_path)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"PyTorch save failed: {e}")
|
||||
return False
|
||||
|
||||
def _save_state_dict(self, model, main_path: Path, backup_path: Path) -> bool:
|
||||
"""Save using state_dict only"""
|
||||
try:
|
||||
state_dict = model.state_dict()
|
||||
|
||||
checkpoint = {
|
||||
'state_dict': state_dict,
|
||||
'model_class': model.__class__.__name__,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
torch.save(checkpoint, backup_path)
|
||||
torch.load(backup_path, map_location='cpu') # Verify
|
||||
shutil.copy2(backup_path, main_path)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"State dict save failed: {e}")
|
||||
return False
|
||||
|
||||
def _save_rl_agent_components(self, model, model_dir: Path, model_name: str) -> bool:
|
||||
"""Save RL agent components separately"""
|
||||
try:
|
||||
components_saved = 0
|
||||
|
||||
# Save policy network
|
||||
if hasattr(model, 'policy_net') and model.policy_net is not None:
|
||||
policy_path = model_dir / f"{model_name}_policy.pt"
|
||||
torch.save(model.policy_net.state_dict(), policy_path)
|
||||
components_saved += 1
|
||||
|
||||
# Save target network
|
||||
if hasattr(model, 'target_net') and model.target_net is not None:
|
||||
target_path = model_dir / f"{model_name}_target.pt"
|
||||
torch.save(model.target_net.state_dict(), target_path)
|
||||
components_saved += 1
|
||||
|
||||
# Save agent state
|
||||
agent_state = {}
|
||||
for attr in ['epsilon', 'total_steps', 'current_reward', 'memory']:
|
||||
if hasattr(model, attr):
|
||||
try:
|
||||
value = getattr(model, attr)
|
||||
if attr == 'memory' and hasattr(value, '__len__'):
|
||||
# Don't save large replay buffers
|
||||
agent_state[attr + '_size'] = len(value)
|
||||
else:
|
||||
agent_state[attr] = value
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if agent_state:
|
||||
state_path = model_dir / f"{model_name}_agent_state.pt"
|
||||
torch.save(agent_state, state_path)
|
||||
components_saved += 1
|
||||
|
||||
return components_saved > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Component-based save failed: {e}")
|
||||
return False
|
||||
|
||||
def _save_with_pickle(self, model, main_path: Path, backup_path: Path) -> bool:
|
||||
"""Fallback: save using pickle"""
|
||||
try:
|
||||
import pickle
|
||||
|
||||
with open(backup_path.with_suffix('.pkl'), 'wb') as f:
|
||||
pickle.dump(model, f)
|
||||
|
||||
# Verify
|
||||
with open(backup_path.with_suffix('.pkl'), 'rb') as f:
|
||||
pickle.load(f)
|
||||
|
||||
shutil.copy2(backup_path.with_suffix('.pkl'), main_path.with_suffix('.pkl'))
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Pickle save failed: {e}")
|
||||
return False
|
||||
|
||||
def _save_metadata(self, model_dir: Path, model_name: str, model_type: str, metadata: Optional[Dict[str, Any]]):
|
||||
"""Save model metadata"""
|
||||
try:
|
||||
meta_data = {
|
||||
'model_name': model_name,
|
||||
'model_type': model_type,
|
||||
'saved_at': datetime.now().isoformat(),
|
||||
'save_method': 'improved_model_saver'
|
||||
}
|
||||
|
||||
if metadata:
|
||||
meta_data.update(metadata)
|
||||
|
||||
meta_path = model_dir / f"{model_name}_metadata.json"
|
||||
with open(meta_path, 'w') as f:
|
||||
json.dump(meta_data, f, indent=2, default=str)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save metadata: {e}")
|
||||
|
||||
def load_model_safely(self, model_name: str, model_class=None):
|
||||
"""
|
||||
Load a model with multiple strategies
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to load
|
||||
model_class: Class to instantiate if needed
|
||||
|
||||
Returns:
|
||||
Loaded model or None
|
||||
"""
|
||||
model_dir = self.base_dir / model_name
|
||||
|
||||
if not model_dir.exists():
|
||||
logger.warning(f"Model directory not found: {model_dir}")
|
||||
return None
|
||||
|
||||
# Try different loading strategies
|
||||
loaders = [
|
||||
self._load_pytorch_checkpoint,
|
||||
self._load_state_dict_only,
|
||||
self._load_rl_components,
|
||||
self._load_pickle_fallback
|
||||
]
|
||||
|
||||
for loader in loaders:
|
||||
try:
|
||||
result = loader(model_dir, model_name, model_class)
|
||||
if result is not None:
|
||||
logger.info(f"Successfully loaded {model_name} using {loader.__name__}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.debug(f"{loader.__name__} failed: {e}")
|
||||
continue
|
||||
|
||||
logger.error(f"All load strategies failed for {model_name}")
|
||||
return None
|
||||
|
||||
def _load_pytorch_checkpoint(self, model_dir: Path, model_name: str, model_class):
|
||||
"""Load PyTorch checkpoint"""
|
||||
main_path = model_dir / f"{model_name}_latest.pt"
|
||||
|
||||
if main_path.exists():
|
||||
checkpoint = torch.load(main_path, map_location='cpu')
|
||||
|
||||
if model_class and 'model_state_dict' in checkpoint:
|
||||
model = model_class()
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
# Restore other attributes
|
||||
for key, value in checkpoint.items():
|
||||
if key not in ['model_state_dict', 'optimizer_state_dict', 'timestamp', 'model_class']:
|
||||
if hasattr(model, key):
|
||||
setattr(model, key, value)
|
||||
|
||||
return model
|
||||
|
||||
return checkpoint.get('model', checkpoint)
|
||||
|
||||
return None
|
||||
|
||||
def _load_state_dict_only(self, model_dir: Path, model_name: str, model_class):
|
||||
"""Load state dict only"""
|
||||
main_path = model_dir / f"{model_name}_latest.pt"
|
||||
|
||||
if main_path.exists() and model_class:
|
||||
checkpoint = torch.load(main_path, map_location='cpu')
|
||||
|
||||
if 'state_dict' in checkpoint:
|
||||
model = model_class()
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
return model
|
||||
|
||||
return None
|
||||
|
||||
def _load_rl_components(self, model_dir: Path, model_name: str, model_class):
|
||||
"""Load RL agent from components"""
|
||||
policy_path = model_dir / f"{model_name}_policy.pt"
|
||||
target_path = model_dir / f"{model_name}_target.pt"
|
||||
state_path = model_dir / f"{model_name}_agent_state.pt"
|
||||
|
||||
if policy_path.exists() and model_class:
|
||||
model = model_class()
|
||||
|
||||
# Load policy network
|
||||
if hasattr(model, 'policy_net'):
|
||||
model.policy_net.load_state_dict(torch.load(policy_path, map_location='cpu'))
|
||||
|
||||
# Load target network
|
||||
if target_path.exists() and hasattr(model, 'target_net'):
|
||||
model.target_net.load_state_dict(torch.load(target_path, map_location='cpu'))
|
||||
|
||||
# Load agent state
|
||||
if state_path.exists():
|
||||
agent_state = torch.load(state_path, map_location='cpu')
|
||||
for key, value in agent_state.items():
|
||||
if hasattr(model, key):
|
||||
setattr(model, key, value)
|
||||
|
||||
return model
|
||||
|
||||
return None
|
||||
|
||||
def _load_pickle_fallback(self, model_dir: Path, model_name: str, model_class):
|
||||
"""Load from pickle"""
|
||||
pickle_path = model_dir / f"{model_name}_latest.pkl"
|
||||
|
||||
if pickle_path.exists():
|
||||
import pickle
|
||||
with open(pickle_path, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Global instance for easy access
|
||||
_improved_model_saver = None
|
||||
|
||||
def get_improved_model_saver() -> ImprovedModelSaver:
|
||||
"""Get or create the global improved model saver instance"""
|
||||
global _improved_model_saver
|
||||
if _improved_model_saver is None:
|
||||
_improved_model_saver = ImprovedModelSaver()
|
||||
return _improved_model_saver
|
@@ -1,246 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Model Checkpoint Saver
|
||||
|
||||
Utility to ensure all models can save checkpoints properly.
|
||||
This will make them show as LOADED instead of FRESH.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelCheckpointSaver:
|
||||
"""Utility to save checkpoints for all models to fix FRESH status"""
|
||||
|
||||
def __init__(self, orchestrator):
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
def save_all_model_checkpoints(self, force: bool = True) -> Dict[str, bool]:
|
||||
"""Save checkpoints for all initialized models"""
|
||||
results = {}
|
||||
|
||||
# Save DQN Agent
|
||||
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
results['dqn_agent'] = self._save_dqn_checkpoint(force)
|
||||
|
||||
# Save CNN Model
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
results['enhanced_cnn'] = self._save_cnn_checkpoint(force)
|
||||
|
||||
# Save Extrema Trainer
|
||||
if hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
|
||||
results['extrema_trainer'] = self._save_extrema_checkpoint(force)
|
||||
|
||||
# COB RL model removed - see COB_MODEL_ARCHITECTURE_DOCUMENTATION.md
|
||||
# Will recreate when COB data quality is improved
|
||||
|
||||
# Save Transformer
|
||||
if hasattr(self.orchestrator, 'transformer_trainer') and self.orchestrator.transformer_trainer:
|
||||
results['transformer'] = self._save_transformer_checkpoint(force)
|
||||
|
||||
# Save Decision Model
|
||||
if hasattr(self.orchestrator, 'decision_model') and self.orchestrator.decision_model:
|
||||
results['decision'] = self._save_decision_checkpoint(force)
|
||||
|
||||
return results
|
||||
|
||||
def _save_dqn_checkpoint(self, force: bool = True) -> bool:
|
||||
"""Save DQN agent checkpoint"""
|
||||
try:
|
||||
if hasattr(self.orchestrator.rl_agent, 'save_checkpoint'):
|
||||
success = self.orchestrator.rl_agent.save_checkpoint(force_save=force)
|
||||
if success:
|
||||
self.orchestrator.model_states['dqn']['checkpoint_loaded'] = True
|
||||
self.orchestrator.model_states['dqn']['checkpoint_filename'] = f"dqn_agent_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
logger.info("DQN checkpoint saved successfully")
|
||||
return True
|
||||
|
||||
# Fallback: use improved model saver
|
||||
from improved_model_saver import get_improved_model_saver
|
||||
saver = get_improved_model_saver()
|
||||
success = saver.save_model_safely(
|
||||
self.orchestrator.rl_agent,
|
||||
"dqn_agent",
|
||||
"dqn",
|
||||
metadata={"saved_by": "checkpoint_saver", "timestamp": datetime.now().isoformat()}
|
||||
)
|
||||
if success:
|
||||
self.orchestrator.model_states['dqn']['checkpoint_loaded'] = True
|
||||
self.orchestrator.model_states['dqn']['checkpoint_filename'] = "dqn_agent_latest"
|
||||
logger.info("DQN checkpoint saved using fallback method")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save DQN checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def _save_cnn_checkpoint(self, force: bool = True) -> bool:
|
||||
"""Save CNN model checkpoint"""
|
||||
try:
|
||||
if hasattr(self.orchestrator.cnn_model, 'save_checkpoint'):
|
||||
success = self.orchestrator.cnn_model.save_checkpoint(force_save=force)
|
||||
if success:
|
||||
self.orchestrator.model_states['cnn']['checkpoint_loaded'] = True
|
||||
self.orchestrator.model_states['cnn']['checkpoint_filename'] = f"enhanced_cnn_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
logger.info("CNN checkpoint saved successfully")
|
||||
return True
|
||||
|
||||
# Fallback: use improved model saver
|
||||
from improved_model_saver import get_improved_model_saver
|
||||
saver = get_improved_model_saver()
|
||||
success = saver.save_model_safely(
|
||||
self.orchestrator.cnn_model,
|
||||
"enhanced_cnn",
|
||||
"cnn",
|
||||
metadata={"saved_by": "checkpoint_saver", "timestamp": datetime.now().isoformat()}
|
||||
)
|
||||
if success:
|
||||
self.orchestrator.model_states['cnn']['checkpoint_loaded'] = True
|
||||
self.orchestrator.model_states['cnn']['checkpoint_filename'] = "enhanced_cnn_latest"
|
||||
logger.info("CNN checkpoint saved using fallback method")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save CNN checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def _save_extrema_checkpoint(self, force: bool = True) -> bool:
|
||||
"""Save Extrema Trainer checkpoint"""
|
||||
try:
|
||||
if hasattr(self.orchestrator.extrema_trainer, 'save_checkpoint'):
|
||||
self.orchestrator.extrema_trainer.save_checkpoint(force_save=force)
|
||||
self.orchestrator.model_states['extrema_trainer']['checkpoint_loaded'] = True
|
||||
self.orchestrator.model_states['extrema_trainer']['checkpoint_filename'] = f"extrema_trainer_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
logger.info("Extrema Trainer checkpoint saved successfully")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save Extrema Trainer checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def _save_cob_rl_checkpoint(self, force: bool = True) -> bool:
|
||||
"""Save COB RL agent checkpoint"""
|
||||
try:
|
||||
# COB RL may have a different saving mechanism
|
||||
from improved_model_saver import get_improved_model_saver
|
||||
saver = get_improved_model_saver()
|
||||
success = saver.save_model_safely(
|
||||
self.orchestrator.cob_rl_agent,
|
||||
"cob_rl",
|
||||
"cob_rl",
|
||||
metadata={"saved_by": "checkpoint_saver", "timestamp": datetime.now().isoformat()}
|
||||
)
|
||||
if success:
|
||||
self.orchestrator.model_states['cob_rl']['checkpoint_loaded'] = True
|
||||
self.orchestrator.model_states['cob_rl']['checkpoint_filename'] = "cob_rl_latest"
|
||||
logger.info("COB RL checkpoint saved successfully")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save COB RL checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def _save_transformer_checkpoint(self, force: bool = True) -> bool:
|
||||
"""Save Transformer model checkpoint"""
|
||||
try:
|
||||
if hasattr(self.orchestrator.transformer_trainer, 'save_model'):
|
||||
# Create a checkpoint file path
|
||||
checkpoint_dir = Path("models/saved/transformer")
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
checkpoint_path = checkpoint_dir / f"transformer_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
|
||||
|
||||
self.orchestrator.transformer_trainer.save_model(str(checkpoint_path))
|
||||
self.orchestrator.model_states['transformer']['checkpoint_loaded'] = True
|
||||
self.orchestrator.model_states['transformer']['checkpoint_filename'] = checkpoint_path.name
|
||||
logger.info("Transformer checkpoint saved successfully")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save Transformer checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def _save_decision_checkpoint(self, force: bool = True) -> bool:
|
||||
"""Save Decision model checkpoint"""
|
||||
try:
|
||||
from improved_model_saver import get_improved_model_saver
|
||||
saver = get_improved_model_saver()
|
||||
success = saver.save_model_safely(
|
||||
self.orchestrator.decision_model,
|
||||
"decision",
|
||||
"decision",
|
||||
metadata={"saved_by": "checkpoint_saver", "timestamp": datetime.now().isoformat()}
|
||||
)
|
||||
if success:
|
||||
self.orchestrator.model_states['decision']['checkpoint_loaded'] = True
|
||||
self.orchestrator.model_states['decision']['checkpoint_filename'] = "decision_latest"
|
||||
logger.info("Decision model checkpoint saved successfully")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save Decision model checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def update_model_status_to_loaded(self, model_name: str):
|
||||
"""Manually update a model's status to LOADED"""
|
||||
if model_name in self.orchestrator.model_states:
|
||||
self.orchestrator.model_states[model_name]['checkpoint_loaded'] = True
|
||||
if not self.orchestrator.model_states[model_name].get('checkpoint_filename'):
|
||||
self.orchestrator.model_states[model_name]['checkpoint_filename'] = f"{model_name}_manual_loaded"
|
||||
logger.info(f"Updated {model_name} status to LOADED")
|
||||
|
||||
def force_all_models_to_loaded(self):
|
||||
"""Force all existing models to show as LOADED"""
|
||||
models_updated = []
|
||||
|
||||
for model_name in self.orchestrator.model_states.keys():
|
||||
# Check if model actually exists
|
||||
model_exists = False
|
||||
|
||||
if model_name == 'dqn' and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
|
||||
model_exists = True
|
||||
elif model_name == 'cnn' and hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
model_exists = True
|
||||
elif model_name == 'extrema_trainer' and hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
|
||||
model_exists = True
|
||||
# COB RL model removed - focusing on COB data quality first
|
||||
elif model_name == 'transformer' and hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model:
|
||||
model_exists = True
|
||||
elif model_name == 'decision' and hasattr(self.orchestrator, 'decision_model') and self.orchestrator.decision_model:
|
||||
model_exists = True
|
||||
|
||||
if model_exists:
|
||||
self.update_model_status_to_loaded(model_name)
|
||||
models_updated.append(model_name)
|
||||
|
||||
logger.info(f"Force-updated {len(models_updated)} models to LOADED status: {models_updated}")
|
||||
return models_updated
|
||||
|
||||
|
||||
def save_all_checkpoints_now(orchestrator):
|
||||
"""Convenience function to save all checkpoints"""
|
||||
saver = ModelCheckpointSaver(orchestrator)
|
||||
results = saver.save_all_model_checkpoints(force=True)
|
||||
|
||||
print("Checkpoint saving results:")
|
||||
for model_name, success in results.items():
|
||||
status = "✅ SUCCESS" if success else "❌ FAILED"
|
||||
print(f" {model_name}: {status}")
|
||||
|
||||
return results
|
@@ -1,446 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unified Model Registry for Centralized Model Management
|
||||
|
||||
This module provides a unified interface for saving, loading, and managing
|
||||
all machine learning models in the trading system. It consolidates model
|
||||
storage from multiple locations into a single, organized structure.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import logging
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Tuple, List
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelRegistry:
|
||||
"""
|
||||
Unified model registry for centralized model management.
|
||||
Handles saving, loading, and organization of all ML models.
|
||||
"""
|
||||
|
||||
def __init__(self, base_dir: str = "models"):
|
||||
"""
|
||||
Initialize the model registry.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for model storage
|
||||
"""
|
||||
self.base_dir = Path(base_dir)
|
||||
self.saved_dir = self.base_dir / "saved"
|
||||
self.checkpoint_dir = self.base_dir / "checkpoints"
|
||||
self.archive_dir = self.base_dir / "archive"
|
||||
|
||||
# Model type directories
|
||||
self.model_dirs = {
|
||||
'cnn': self.base_dir / "cnn",
|
||||
'dqn': self.base_dir / "dqn",
|
||||
'transformer': self.base_dir / "transformer",
|
||||
'hybrid': self.base_dir / "hybrid"
|
||||
}
|
||||
|
||||
# Ensure all directories exist
|
||||
self._ensure_directories()
|
||||
|
||||
# Metadata tracking
|
||||
self.metadata_file = self.base_dir / "registry_metadata.json"
|
||||
self.metadata = self._load_metadata()
|
||||
|
||||
logger.info(f"Model Registry initialized at {self.base_dir}")
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""Ensure all required directories exist."""
|
||||
directories = [
|
||||
self.saved_dir,
|
||||
self.checkpoint_dir,
|
||||
self.archive_dir
|
||||
]
|
||||
|
||||
# Add model type directories
|
||||
for model_dir in self.model_dirs.values():
|
||||
directories.extend([
|
||||
model_dir / "saved",
|
||||
model_dir / "checkpoints",
|
||||
model_dir / "archive"
|
||||
])
|
||||
|
||||
for directory in directories:
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _load_metadata(self) -> Dict[str, Any]:
|
||||
"""Load registry metadata."""
|
||||
if self.metadata_file.exists():
|
||||
try:
|
||||
with open(self.metadata_file, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load metadata: {e}")
|
||||
return {'models': {}, 'last_updated': datetime.now().isoformat()}
|
||||
|
||||
def _save_metadata(self):
|
||||
"""Save registry metadata."""
|
||||
self.metadata['last_updated'] = datetime.now().isoformat()
|
||||
try:
|
||||
with open(self.metadata_file, 'w') as f:
|
||||
json.dump(self.metadata, f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save metadata: {e}")
|
||||
|
||||
def save_model(self, model: Any, model_name: str, model_type: str = 'cnn',
|
||||
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Save a model to the unified storage.
|
||||
|
||||
Args:
|
||||
model: The model to save
|
||||
model_name: Name of the model
|
||||
model_type: Type of model (cnn, dqn, transformer, hybrid)
|
||||
metadata: Additional metadata to save
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
model_dir = self.model_dirs.get(model_type, self.saved_dir)
|
||||
save_dir = model_dir / "saved"
|
||||
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
filename = f"{model_name}_{timestamp}.pt"
|
||||
filepath = save_dir / filename
|
||||
|
||||
# Also save as latest
|
||||
latest_filepath = save_dir / f"{model_name}_latest.pt"
|
||||
|
||||
# Save model
|
||||
save_dict = {
|
||||
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
|
||||
'model_class': model.__class__.__name__,
|
||||
'model_type': model_type,
|
||||
'timestamp': timestamp,
|
||||
'metadata': metadata or {}
|
||||
}
|
||||
|
||||
torch.save(save_dict, filepath)
|
||||
torch.save(save_dict, latest_filepath)
|
||||
|
||||
# Update metadata
|
||||
if model_name not in self.metadata['models']:
|
||||
self.metadata['models'][model_name] = {}
|
||||
|
||||
self.metadata['models'][model_name].update({
|
||||
'type': model_type,
|
||||
'latest_path': str(latest_filepath),
|
||||
'last_saved': timestamp,
|
||||
'save_count': self.metadata['models'][model_name].get('save_count', 0) + 1
|
||||
})
|
||||
|
||||
self._save_metadata()
|
||||
|
||||
logger.info(f"Model {model_name} saved to {filepath}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save model {model_name}: {e}")
|
||||
return False
|
||||
|
||||
def load_model(self, model_name: str, model_type: str = 'cnn',
|
||||
model_class: Optional[Any] = None) -> Optional[Any]:
|
||||
"""
|
||||
Load a model from the unified storage.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to load
|
||||
model_type: Type of model (cnn, dqn, transformer, hybrid)
|
||||
model_class: Model class to instantiate (if needed)
|
||||
|
||||
Returns:
|
||||
The loaded model or None if failed
|
||||
"""
|
||||
try:
|
||||
model_dir = self.model_dirs.get(model_type, self.saved_dir)
|
||||
save_dir = model_dir / "saved"
|
||||
latest_filepath = save_dir / f"{model_name}_latest.pt"
|
||||
|
||||
if not latest_filepath.exists():
|
||||
logger.warning(f"Model {model_name} not found at {latest_filepath}")
|
||||
return None
|
||||
|
||||
# Load checkpoint
|
||||
checkpoint = torch.load(latest_filepath, map_location='cpu')
|
||||
|
||||
# Instantiate model if class provided
|
||||
if model_class is not None:
|
||||
model = model_class()
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
else:
|
||||
# Try to reconstruct model from state_dict
|
||||
model = type('LoadedModel', (), {})()
|
||||
model.state_dict = lambda: checkpoint['model_state_dict']
|
||||
model.load_state_dict = lambda state_dict: None
|
||||
|
||||
logger.info(f"Model {model_name} loaded from {latest_filepath}")
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def save_checkpoint(self, model: Any, model_name: str, model_type: str = 'cnn',
|
||||
performance_score: float = 0.0,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Save a model checkpoint.
|
||||
|
||||
Args:
|
||||
model: The model to checkpoint
|
||||
model_name: Name of the model
|
||||
model_type: Type of model
|
||||
performance_score: Performance score for this checkpoint
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
model_dir = self.model_dirs.get(model_type, self.checkpoint_dir)
|
||||
checkpoint_dir = model_dir / "checkpoints"
|
||||
|
||||
# Generate checkpoint ID
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
checkpoint_id = f"{model_name}_{timestamp}_{performance_score:.4f}"
|
||||
|
||||
filepath = checkpoint_dir / f"{checkpoint_id}.pt"
|
||||
|
||||
# Save checkpoint
|
||||
checkpoint_data = {
|
||||
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
|
||||
'model_class': model.__class__.__name__,
|
||||
'model_type': model_type,
|
||||
'model_name': model_name,
|
||||
'performance_score': performance_score,
|
||||
'timestamp': timestamp,
|
||||
'metadata': metadata or {}
|
||||
}
|
||||
|
||||
torch.save(checkpoint_data, filepath)
|
||||
|
||||
# Update metadata
|
||||
if model_name not in self.metadata['models']:
|
||||
self.metadata['models'][model_name] = {}
|
||||
|
||||
if 'checkpoints' not in self.metadata['models'][model_name]:
|
||||
self.metadata['models'][model_name]['checkpoints'] = []
|
||||
|
||||
checkpoint_info = {
|
||||
'id': checkpoint_id,
|
||||
'path': str(filepath),
|
||||
'performance_score': performance_score,
|
||||
'timestamp': timestamp
|
||||
}
|
||||
|
||||
self.metadata['models'][model_name]['checkpoints'].append(checkpoint_info)
|
||||
|
||||
# Keep only top 5 checkpoints
|
||||
checkpoints = self.metadata['models'][model_name]['checkpoints']
|
||||
if len(checkpoints) > 5:
|
||||
checkpoints.sort(key=lambda x: x['performance_score'], reverse=True)
|
||||
checkpoints_to_remove = checkpoints[5:]
|
||||
|
||||
for checkpoint in checkpoints_to_remove:
|
||||
try:
|
||||
os.remove(checkpoint['path'])
|
||||
except:
|
||||
pass
|
||||
|
||||
self.metadata['models'][model_name]['checkpoints'] = checkpoints[:5]
|
||||
|
||||
self._save_metadata()
|
||||
|
||||
logger.info(f"Checkpoint {checkpoint_id} saved with score {performance_score}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save checkpoint for {model_name}: {e}")
|
||||
return False
|
||||
|
||||
def load_best_checkpoint(self, model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]:
|
||||
"""
|
||||
Load the best checkpoint for a model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
model_type: Type of model
|
||||
|
||||
Returns:
|
||||
Tuple of (checkpoint_path, checkpoint_data) or None
|
||||
"""
|
||||
try:
|
||||
if model_name not in self.metadata['models']:
|
||||
logger.warning(f"No metadata found for model {model_name}")
|
||||
return None
|
||||
|
||||
checkpoints = self.metadata['models'][model_name].get('checkpoints', [])
|
||||
if not checkpoints:
|
||||
logger.warning(f"No checkpoints found for model {model_name}")
|
||||
return None
|
||||
|
||||
# Find best checkpoint by performance score
|
||||
best_checkpoint = max(checkpoints, key=lambda x: x['performance_score'])
|
||||
checkpoint_path = best_checkpoint['path']
|
||||
|
||||
if not os.path.exists(checkpoint_path):
|
||||
logger.warning(f"Checkpoint file not found: {checkpoint_path}")
|
||||
return None
|
||||
|
||||
checkpoint_data = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
||||
logger.info(f"Best checkpoint loaded for {model_name}: {best_checkpoint['id']}")
|
||||
return checkpoint_path, checkpoint_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load best checkpoint for {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def archive_model(self, model_name: str, model_type: str = 'cnn') -> bool:
|
||||
"""
|
||||
Archive a model by moving it to archive directory.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to archive
|
||||
model_type: Type of model
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
model_dir = self.model_dirs.get(model_type, self.saved_dir)
|
||||
save_dir = model_dir / "saved"
|
||||
archive_dir = model_dir / "archive"
|
||||
|
||||
latest_filepath = save_dir / f"{model_name}_latest.pt"
|
||||
|
||||
if not latest_filepath.exists():
|
||||
logger.warning(f"Model {model_name} not found to archive")
|
||||
return False
|
||||
|
||||
# Move to archive with timestamp
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
archive_filepath = archive_dir / f"{model_name}_archived_{timestamp}.pt"
|
||||
|
||||
os.rename(latest_filepath, archive_filepath)
|
||||
|
||||
logger.info(f"Model {model_name} archived to {archive_filepath}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to archive model {model_name}: {e}")
|
||||
return False
|
||||
|
||||
def list_models(self, model_type: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
List all models in the registry.
|
||||
|
||||
Args:
|
||||
model_type: Filter by model type (optional)
|
||||
|
||||
Returns:
|
||||
Dictionary of model information
|
||||
"""
|
||||
models_info = {}
|
||||
|
||||
for model_name, model_data in self.metadata['models'].items():
|
||||
if model_type and model_data.get('type') != model_type:
|
||||
continue
|
||||
|
||||
models_info[model_name] = {
|
||||
'type': model_data.get('type'),
|
||||
'last_saved': model_data.get('last_saved'),
|
||||
'save_count': model_data.get('save_count', 0),
|
||||
'checkpoint_count': len(model_data.get('checkpoints', [])),
|
||||
'latest_path': model_data.get('latest_path')
|
||||
}
|
||||
|
||||
return models_info
|
||||
|
||||
def cleanup_old_checkpoints(self, model_name: str, keep_count: int = 5) -> int:
|
||||
"""
|
||||
Clean up old checkpoints, keeping only the best ones.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
keep_count: Number of checkpoints to keep
|
||||
|
||||
Returns:
|
||||
Number of checkpoints removed
|
||||
"""
|
||||
if model_name not in self.metadata['models']:
|
||||
return 0
|
||||
|
||||
checkpoints = self.metadata['models'][model_name].get('checkpoints', [])
|
||||
if len(checkpoints) <= keep_count:
|
||||
return 0
|
||||
|
||||
# Sort by performance score (descending)
|
||||
checkpoints.sort(key=lambda x: x['performance_score'], reverse=True)
|
||||
|
||||
# Remove old checkpoints
|
||||
removed_count = 0
|
||||
for checkpoint in checkpoints[keep_count:]:
|
||||
try:
|
||||
os.remove(checkpoint['path'])
|
||||
removed_count += 1
|
||||
except:
|
||||
pass
|
||||
|
||||
# Update metadata
|
||||
self.metadata['models'][model_name]['checkpoints'] = checkpoints[:keep_count]
|
||||
self._save_metadata()
|
||||
|
||||
logger.info(f"Cleaned up {removed_count} old checkpoints for {model_name}")
|
||||
return removed_count
|
||||
|
||||
|
||||
# Global registry instance
|
||||
_registry_instance = None
|
||||
|
||||
def get_model_registry() -> ModelRegistry:
|
||||
"""Get the global model registry instance."""
|
||||
global _registry_instance
|
||||
if _registry_instance is None:
|
||||
_registry_instance = ModelRegistry()
|
||||
return _registry_instance
|
||||
|
||||
def save_model(model: Any, model_name: str, model_type: str = 'cnn',
|
||||
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Convenience function to save a model using the global registry.
|
||||
"""
|
||||
return get_model_registry().save_model(model, model_name, model_type, metadata)
|
||||
|
||||
def load_model(model_name: str, model_type: str = 'cnn',
|
||||
model_class: Optional[Any] = None) -> Optional[Any]:
|
||||
"""
|
||||
Convenience function to load a model using the global registry.
|
||||
"""
|
||||
return get_model_registry().load_model(model_name, model_type, model_class)
|
||||
|
||||
def save_checkpoint(model: Any, model_name: str, model_type: str = 'cnn',
|
||||
performance_score: float = 0.0,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Convenience function to save a checkpoint using the global registry.
|
||||
"""
|
||||
return get_model_registry().save_checkpoint(model, model_name, model_type, performance_score, metadata)
|
||||
|
||||
def load_best_checkpoint(model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]:
|
||||
"""
|
||||
Convenience function to load the best checkpoint using the global registry.
|
||||
"""
|
||||
return get_model_registry().load_best_checkpoint(model_name, model_type)
|
@@ -1,109 +0,0 @@
|
||||
"""
|
||||
Models Module
|
||||
|
||||
Provides model registry and interfaces for the trading system.
|
||||
This module acts as a bridge between the core system and the NN models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelRegistry:
|
||||
"""Registry for managing trading models"""
|
||||
|
||||
def __init__(self):
|
||||
self.models: Dict[str, ModelInterface] = {}
|
||||
self.model_performance: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def register_model(self, model: ModelInterface):
|
||||
"""Register a model in the registry"""
|
||||
name = model.name
|
||||
self.models[name] = model
|
||||
self.model_performance[name] = {
|
||||
'correct': 0,
|
||||
'total': 0,
|
||||
'accuracy': 0.0,
|
||||
'last_used': None
|
||||
}
|
||||
logger.info(f"Registered model: {name}")
|
||||
return True
|
||||
|
||||
def get_model(self, name: str) -> Optional[ModelInterface]:
|
||||
"""Get a model by name"""
|
||||
return self.models.get(name)
|
||||
|
||||
def get_all_models(self) -> Dict[str, ModelInterface]:
|
||||
"""Get all registered models"""
|
||||
return self.models.copy()
|
||||
|
||||
def update_performance(self, name: str, correct: bool):
|
||||
"""Update model performance metrics"""
|
||||
if name in self.model_performance:
|
||||
self.model_performance[name]['total'] += 1
|
||||
if correct:
|
||||
self.model_performance[name]['correct'] += 1
|
||||
self.model_performance[name]['accuracy'] = (
|
||||
self.model_performance[name]['correct'] /
|
||||
self.model_performance[name]['total']
|
||||
)
|
||||
|
||||
def get_best_model(self, model_type: str = None) -> Optional[str]:
|
||||
"""Get the best performing model"""
|
||||
if not self.model_performance:
|
||||
return None
|
||||
|
||||
best_model = None
|
||||
best_accuracy = -1.0
|
||||
|
||||
for name, perf in self.model_performance.items():
|
||||
if model_type and not name.lower().startswith(model_type.lower()):
|
||||
continue
|
||||
if perf['accuracy'] > best_accuracy:
|
||||
best_accuracy = perf['accuracy']
|
||||
best_model = name
|
||||
|
||||
return best_model
|
||||
|
||||
def unregister_model(self, name: str) -> bool:
|
||||
"""Unregister a model from the registry"""
|
||||
if name in self.models:
|
||||
del self.models[name]
|
||||
if name in self.model_performance:
|
||||
del self.model_performance[name]
|
||||
logger.info(f"Unregistered model: {name}")
|
||||
return True
|
||||
|
||||
# Global model registry instance
|
||||
_model_registry = ModelRegistry()
|
||||
|
||||
def get_model_registry() -> ModelRegistry:
|
||||
"""Get the global model registry instance"""
|
||||
return _model_registry
|
||||
|
||||
def register_model(model: ModelInterface):
|
||||
"""Register a model in the global registry"""
|
||||
return _model_registry.register_model(model)
|
||||
|
||||
def get_model(name: str) -> Optional[ModelInterface]:
|
||||
"""Get a model from the global registry"""
|
||||
return _model_registry.get_model(name)
|
||||
|
||||
def get_all_models() -> Dict[str, ModelInterface]:
|
||||
"""Get all models from the global registry"""
|
||||
return _model_registry.get_all_models()
|
||||
|
||||
# Export the interfaces
|
||||
__all__ = [
|
||||
'ModelRegistry',
|
||||
'get_model_registry',
|
||||
'register_model',
|
||||
'get_model',
|
||||
'get_all_models',
|
||||
'ModelInterface',
|
||||
'CNNModelInterface',
|
||||
'RLAgentInterface',
|
||||
'ExtremaTrainerInterface'
|
||||
]
|
@@ -346,11 +346,58 @@ class TradingOrchestrator:
|
||||
logger.warning("Extrema trainer not available")
|
||||
self.extrema_trainer = None
|
||||
|
||||
# COB RL Model REMOVED - See COB_MODEL_ARCHITECTURE_DOCUMENTATION.md
|
||||
# Reason: Need quality COB data first before evaluating massive parameter benefit
|
||||
# Will recreate improved version when COB data pipeline is fixed
|
||||
logger.info("COB RL model removed - focusing on COB data quality first")
|
||||
self.cob_rl_agent = None
|
||||
# Initialize COB RL Model - UNIFIED with ModelManager
|
||||
try:
|
||||
from NN.models.cob_rl_model import COBRLModelInterface
|
||||
|
||||
# Initialize COB RL model using unified approach
|
||||
self.cob_rl_agent = COBRLModelInterface(
|
||||
model_checkpoint_dir="@checkpoints/cob_rl",
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
|
||||
# Add COB RL to model states tracking
|
||||
self.model_states['cob_rl'] = {
|
||||
'initial_loss': None,
|
||||
'current_loss': None,
|
||||
'best_loss': None,
|
||||
'checkpoint_loaded': False
|
||||
}
|
||||
|
||||
# Load best checkpoint using unified ModelManager
|
||||
checkpoint_loaded = False
|
||||
try:
|
||||
from NN.training.model_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("cob_rl_agent")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
self.model_states['cob_rl']['initial_loss'] = metadata.loss
|
||||
self.model_states['cob_rl']['current_loss'] = metadata.loss
|
||||
self.model_states['cob_rl']['best_loss'] = metadata.loss
|
||||
self.model_states['cob_rl']['checkpoint_loaded'] = True
|
||||
self.model_states['cob_rl']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "N/A"
|
||||
logger.info(f"COB RL checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading COB RL checkpoint: {e}")
|
||||
|
||||
if not checkpoint_loaded:
|
||||
# New model - no synthetic data, start fresh
|
||||
self.model_states['cob_rl']['initial_loss'] = None
|
||||
self.model_states['cob_rl']['current_loss'] = None
|
||||
self.model_states['cob_rl']['best_loss'] = None
|
||||
self.model_states['cob_rl']['checkpoint_filename'] = 'none (fresh start)'
|
||||
logger.info("COB RL starting fresh - no checkpoint found")
|
||||
|
||||
logger.info("COB RL Agent initialized and integrated with unified ModelManager")
|
||||
logger.info(" - Uses @checkpoints/ directory structure")
|
||||
logger.info(" - Follows same load/save/checkpoint flow as other models")
|
||||
logger.info(" - Integrated with enhanced real-time training system")
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"COB RL Model not available: {e}")
|
||||
self.cob_rl_agent = None
|
||||
|
||||
# Initialize TRANSFORMER Model
|
||||
try:
|
||||
|
@@ -34,7 +34,8 @@ import os
|
||||
# Local imports
|
||||
from .cob_integration import COBIntegration
|
||||
from .trading_executor import TradingExecutor
|
||||
from NN.models.cob_rl_model import MassiveRLNetwork, COBRLModelInterface
|
||||
# UNIFIED: Import only the interface, models come from orchestrator
|
||||
from NN.models.cob_rl_model import COBRLModelInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -98,51 +99,44 @@ class RealtimeRLCOBTrader:
|
||||
Real-time RL trader using COB data with comprehensive subscriber system
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(self,
|
||||
symbols: Optional[List[str]] = None,
|
||||
trading_executor: Optional[TradingExecutor] = None,
|
||||
model_checkpoint_dir: str = "models/realtime_rl_cob",
|
||||
orchestrator: Any = None, # UNIFIED: Use orchestrator's models
|
||||
inference_interval_ms: int = 200,
|
||||
min_confidence_threshold: float = 0.35, # Lowered from 0.7 for more aggressive trading
|
||||
required_confident_predictions: int = 3,
|
||||
checkpoint_manager: Any = None):
|
||||
required_confident_predictions: int = 3):
|
||||
|
||||
self.symbols = symbols or ['BTC/USDT', 'ETH/USDT']
|
||||
self.trading_executor = trading_executor
|
||||
self.model_checkpoint_dir = model_checkpoint_dir
|
||||
self.orchestrator = orchestrator # UNIFIED: Use orchestrator's models
|
||||
self.inference_interval_ms = inference_interval_ms
|
||||
self.min_confidence_threshold = min_confidence_threshold
|
||||
self.required_confident_predictions = required_confident_predictions
|
||||
|
||||
# Initialize ModelManager (either provided or get global instance)
|
||||
if checkpoint_manager is None:
|
||||
from NN.training.model_manager import create_model_manager
|
||||
self.checkpoint_manager = create_model_manager()
|
||||
|
||||
# UNIFIED: Use orchestrator's ModelManager instead of creating our own
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'model_manager'):
|
||||
self.model_manager = self.orchestrator.model_manager
|
||||
else:
|
||||
self.checkpoint_manager = checkpoint_manager
|
||||
|
||||
from NN.training.model_manager import create_model_manager
|
||||
self.model_manager = create_model_manager()
|
||||
|
||||
# Track start time for training duration calculation
|
||||
self.start_time = datetime.now() # Initialize start_time
|
||||
|
||||
# Setup device
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logger.info(f"Using device: {self.device}")
|
||||
|
||||
# Initialize models for each symbol
|
||||
self.models: Dict[str, MassiveRLNetwork] = {}
|
||||
self.optimizers: Dict[str, optim.AdamW] = {}
|
||||
self.scalers: Dict[str, torch.cuda.amp.GradScaler] = {}
|
||||
|
||||
for symbol in self.symbols:
|
||||
model = MassiveRLNetwork().to(self.device)
|
||||
self.models[symbol] = model
|
||||
self.optimizers[symbol] = optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=1e-5, # Low learning rate for stability
|
||||
weight_decay=1e-6,
|
||||
betas=(0.9, 0.999)
|
||||
)
|
||||
self.scalers[symbol] = torch.cuda.amp.GradScaler()
|
||||
self.start_time = datetime.now()
|
||||
|
||||
# UNIFIED: Use orchestrator's COB RL model
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'cob_rl_agent') or not self.orchestrator.cob_rl_agent:
|
||||
raise ValueError("RealtimeRLCOBTrader requires orchestrator with COB RL model. Please initialize TradingOrchestrator first.")
|
||||
|
||||
# Use orchestrator's unified COB RL model
|
||||
self.cob_rl_model = self.orchestrator.cob_rl_agent
|
||||
self.device = self.orchestrator.cob_rl_agent.device if hasattr(self.orchestrator.cob_rl_agent, 'device') else torch.device('cpu')
|
||||
logger.info(f"Using orchestrator's unified COB RL model on device: {self.device}")
|
||||
|
||||
# Create unified model references for all symbols
|
||||
self.models = {symbol: self.cob_rl_model.model for symbol in self.symbols}
|
||||
self.optimizers = {symbol: self.cob_rl_model.optimizer for symbol in self.symbols}
|
||||
self.scalers = {symbol: self.cob_rl_model.scaler for symbol in self.symbols}
|
||||
|
||||
# Subscriber system for real-time events
|
||||
self.prediction_subscribers: List[Callable[[PredictionResult], None]] = []
|
||||
@@ -906,56 +900,67 @@ class RealtimeRLCOBTrader:
|
||||
return reward
|
||||
|
||||
async def _train_batch(self, symbol: str, predictions: List[PredictionResult]) -> float:
|
||||
"""Train model on a batch of predictions"""
|
||||
"""Train model on a batch of predictions using unified approach"""
|
||||
try:
|
||||
model = self.models[symbol]
|
||||
optimizer = self.optimizers[symbol]
|
||||
scaler = self.scalers[symbol]
|
||||
|
||||
# UNIFIED: Always use orchestrator's COB RL model
|
||||
return self._train_batch_unified(predictions)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training batch for {symbol}: {e}")
|
||||
return 0.0
|
||||
|
||||
def _train_batch_unified(self, predictions: List[PredictionResult]) -> float:
|
||||
"""Train using unified COB RL model from orchestrator"""
|
||||
try:
|
||||
model = self.cob_rl_model.model
|
||||
optimizer = self.cob_rl_model.optimizer
|
||||
scaler = self.cob_rl_model.scaler
|
||||
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
# Prepare batch data
|
||||
features = torch.stack([
|
||||
torch.from_numpy(p.features) for p in predictions
|
||||
]).to(self.device)
|
||||
|
||||
|
||||
# Targets
|
||||
direction_targets = torch.tensor([
|
||||
p.actual_direction for p in predictions
|
||||
], dtype=torch.long).to(self.device)
|
||||
|
||||
|
||||
value_targets = torch.tensor([
|
||||
p.reward for p in predictions
|
||||
], dtype=torch.float32).to(self.device)
|
||||
|
||||
|
||||
# Forward pass with mixed precision
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs = model(features)
|
||||
|
||||
|
||||
# Calculate losses
|
||||
direction_loss = nn.CrossEntropyLoss()(outputs['price_logits'], direction_targets)
|
||||
value_loss = nn.MSELoss()(outputs['value'].squeeze(), value_targets)
|
||||
|
||||
|
||||
# Confidence loss (encourage high confidence for correct predictions)
|
||||
correct_predictions = (torch.argmax(outputs['price_logits'], dim=1) == direction_targets).float()
|
||||
confidence_loss = nn.BCELoss()(outputs['confidence'].squeeze(), correct_predictions)
|
||||
|
||||
|
||||
# Combined loss
|
||||
total_loss = direction_loss + 0.5 * value_loss + 0.3 * confidence_loss
|
||||
|
||||
|
||||
# Backward pass with gradient scaling
|
||||
scaler.scale(total_loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
|
||||
return total_loss.item()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training batch for {symbol}: {e}")
|
||||
logger.error(f"Error in unified training batch: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
async def _train_on_trade_execution(self, symbol: str, signals: List[PredictionResult],
|
||||
action: str, price: float):
|
||||
@@ -1015,68 +1020,99 @@ class RealtimeRLCOBTrader:
|
||||
await asyncio.sleep(60)
|
||||
|
||||
def _save_models(self):
|
||||
"""Save all models to disk using CheckpointManager"""
|
||||
"""Save models using unified ModelManager approach"""
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
model_name = f"cob_rl_{symbol.replace('/', '_').lower()}" # Standardize model name for CheckpointManager
|
||||
|
||||
# Prepare performance metrics for CheckpointManager
|
||||
if self.cob_rl_model:
|
||||
# UNIFIED: Use orchestrator's COB RL model with ModelManager
|
||||
performance_metrics = {
|
||||
'loss': self.training_stats[symbol].get('average_loss', 0.0),
|
||||
'reward': self.training_stats[symbol].get('average_reward', 0.0), # Assuming average_reward is tracked
|
||||
'accuracy': self.training_stats[symbol].get('average_accuracy', 0.0), # Assuming average_accuracy is tracked
|
||||
'loss': self._get_average_loss(),
|
||||
'reward': self._get_average_reward(),
|
||||
'accuracy': self._get_average_accuracy(),
|
||||
}
|
||||
if self.trading_executor: # Add check for trading_executor
|
||||
daily_stats = self.trading_executor.get_daily_stats()
|
||||
performance_metrics['pnl'] = daily_stats.get('total_pnl', 0.0) # Example, get actual pnl
|
||||
performance_metrics['training_samples'] = self.training_stats[symbol].get('total_training_steps', 0)
|
||||
|
||||
# Prepare training metadata for CheckpointManager
|
||||
# Add P&L if trading executor is available
|
||||
if self.trading_executor and hasattr(self.trading_executor, 'get_daily_stats'):
|
||||
try:
|
||||
daily_stats = self.trading_executor.get_daily_stats()
|
||||
performance_metrics['pnl'] = daily_stats.get('total_pnl', 0.0)
|
||||
except Exception:
|
||||
performance_metrics['pnl'] = 0.0
|
||||
|
||||
performance_metrics['training_samples'] = sum(
|
||||
stats.get('total_training_steps', 0) for stats in self.training_stats.values()
|
||||
)
|
||||
|
||||
# Prepare training metadata
|
||||
training_metadata = {
|
||||
'total_parameters': sum(p.numel() for p in self.models[symbol].parameters()),
|
||||
'epoch': self.training_stats[symbol].get('total_training_steps', 0), # Using total_training_steps as pseudo-epoch
|
||||
'total_parameters': sum(p.numel() for p in self.cob_rl_model.model.parameters()),
|
||||
'epoch': max(stats.get('total_training_steps', 0) for stats in self.training_stats.values()),
|
||||
'training_time_hours': (datetime.now() - self.start_time).total_seconds() / 3600
|
||||
}
|
||||
|
||||
self.checkpoint_manager.save_checkpoint(
|
||||
model=self.models[symbol],
|
||||
model_name=model_name,
|
||||
model_type='COB_RL', # Specify model type
|
||||
# Save using unified ModelManager
|
||||
self.model_manager.save_checkpoint(
|
||||
model=self.cob_rl_model.model,
|
||||
model_name="cob_rl_agent",
|
||||
model_type='COB_RL',
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
logger.debug(f"Saved model for {symbol}")
|
||||
|
||||
|
||||
logger.info("COB RL model saved using unified ModelManager")
|
||||
else:
|
||||
# This should not happen with proper initialization
|
||||
logger.error("Unified COB RL model not available - check orchestrator initialization")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving models: {e}")
|
||||
|
||||
|
||||
def _load_models(self):
|
||||
"""Load existing models from disk using CheckpointManager"""
|
||||
"""Load models using unified ModelManager approach"""
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
model_name = f"cob_rl_{symbol.replace('/', '_').lower()}" # Standardize model name for CheckpointManager
|
||||
|
||||
loaded_checkpoint = self.checkpoint_manager.load_best_checkpoint(model_name)
|
||||
|
||||
if self.cob_rl_model:
|
||||
# UNIFIED: Load using ModelManager
|
||||
loaded_checkpoint = self.model_manager.load_best_checkpoint("cob_rl_agent")
|
||||
|
||||
if loaded_checkpoint:
|
||||
model_path, metadata = loaded_checkpoint
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
|
||||
self.models[symbol].load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizers[symbol].load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
if 'training_stats' in checkpoint:
|
||||
self.training_stats[symbol].update(checkpoint['training_stats'])
|
||||
if 'inference_stats' in checkpoint:
|
||||
self.inference_stats[symbol].update(checkpoint['inference_stats'])
|
||||
|
||||
logger.info(f"Loaded existing model for {symbol} from checkpoint: {metadata.checkpoint_id}")
|
||||
|
||||
self.cob_rl_model.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.cob_rl_model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
# Update training stats for all symbols with loaded data
|
||||
for symbol in self.symbols:
|
||||
if 'training_stats' in checkpoint:
|
||||
self.training_stats[symbol].update(checkpoint['training_stats'])
|
||||
if 'inference_stats' in checkpoint:
|
||||
self.inference_stats[symbol].update(checkpoint['inference_stats'])
|
||||
|
||||
logger.info(f"Loaded unified COB RL model from checkpoint: {metadata.checkpoint_id}")
|
||||
else:
|
||||
logger.info(f"No existing model found for {symbol} via CheckpointManager, starting fresh.")
|
||||
|
||||
logger.info("No existing COB RL model found via ModelManager, starting fresh.")
|
||||
else:
|
||||
# This should not happen with proper initialization
|
||||
logger.error("Unified COB RL model not available - check orchestrator initialization")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading models: {e}")
|
||||
|
||||
|
||||
def _get_average_loss(self) -> float:
|
||||
"""Get average loss across all symbols"""
|
||||
losses = [stats.get('average_loss', 0.0) for stats in self.training_stats.values() if stats.get('average_loss') is not None]
|
||||
return sum(losses) / len(losses) if losses else 0.0
|
||||
|
||||
def _get_average_reward(self) -> float:
|
||||
"""Get average reward across all symbols"""
|
||||
rewards = [stats.get('average_reward', 0.0) for stats in self.training_stats.values() if stats.get('average_reward') is not None]
|
||||
return sum(rewards) / len(rewards) if rewards else 0.0
|
||||
|
||||
def _get_average_accuracy(self) -> float:
|
||||
"""Get average accuracy across all symbols"""
|
||||
accuracies = [stats.get('average_accuracy', 0.0) for stats in self.training_stats.values() if stats.get('average_accuracy') is not None]
|
||||
return sum(accuracies) / len(accuracies) if accuracies else 0.0
|
||||
|
||||
def get_performance_stats(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive performance statistics"""
|
||||
@@ -1119,36 +1155,49 @@ class RealtimeRLCOBTrader:
|
||||
|
||||
# Example usage
|
||||
async def main():
|
||||
"""Example usage of RealtimeRLCOBTrader"""
|
||||
"""Example usage of unified RealtimeRLCOBTrader"""
|
||||
from ..core.orchestrator import TradingOrchestrator
|
||||
from ..core.trading_executor import TradingExecutor
|
||||
|
||||
|
||||
# Initialize orchestrator (which now includes unified COB RL model)
|
||||
orchestrator = TradingOrchestrator()
|
||||
|
||||
# Initialize trading executor (simulation mode)
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Initialize real-time RL trader
|
||||
|
||||
# Initialize real-time RL trader with unified orchestrator
|
||||
trader = RealtimeRLCOBTrader(
|
||||
symbols=['BTC/USDT', 'ETH/USDT'],
|
||||
trading_executor=trading_executor,
|
||||
orchestrator=orchestrator, # UNIFIED: Use orchestrator's models
|
||||
inference_interval_ms=200,
|
||||
min_confidence_threshold=0.7,
|
||||
required_confident_predictions=3
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# Start the trader
|
||||
# Start the orchestrator first (initializes all models)
|
||||
await orchestrator.start()
|
||||
|
||||
# Start the trader (uses orchestrator's unified COB RL model)
|
||||
await trader.start()
|
||||
|
||||
|
||||
# Run for demonstration
|
||||
logger.info("Real-time RL COB Trader running...")
|
||||
logger.info("Real-time RL COB Trader running with unified orchestrator...")
|
||||
await asyncio.sleep(300) # Run for 5 minutes
|
||||
|
||||
# Print performance stats
|
||||
stats = trader.get_performance_stats()
|
||||
logger.info(f"Performance stats: {json.dumps(stats, indent=2, default=str)}")
|
||||
|
||||
|
||||
# Print performance stats from both systems
|
||||
orchestrator_stats = orchestrator.get_model_stats()
|
||||
trader_stats = trader.get_performance_stats()
|
||||
logger.info("=== ORCHESTRATOR STATS ===")
|
||||
logger.info(f"Model stats: {json.dumps(orchestrator_stats, indent=2, default=str)}")
|
||||
logger.info("=== TRADER STATS ===")
|
||||
logger.info(f"Performance stats: {json.dumps(trader_stats, indent=2, default=str)}")
|
||||
|
||||
finally:
|
||||
# Stop the trader
|
||||
# Stop both systems
|
||||
await trader.stop()
|
||||
await orchestrator.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
8
main.py
8
main.py
@@ -168,8 +168,8 @@ def start_web_ui(port=8051):
|
||||
except ImportError:
|
||||
model_registry = {}
|
||||
|
||||
# Initialize checkpoint management for dashboard
|
||||
dashboard_checkpoint_manager = get_checkpoint_manager()
|
||||
# Initialize unified model management for dashboard
|
||||
dashboard_checkpoint_manager = create_model_manager()
|
||||
dashboard_training_integration = get_training_integration()
|
||||
|
||||
# Create unified orchestrator for the dashboard
|
||||
@@ -206,8 +206,8 @@ async def start_training_loop(orchestrator, trading_executor):
|
||||
logger.info("STARTING ENHANCED TRAINING LOOP WITH COB INTEGRATION")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Initialize checkpoint management for training loop
|
||||
checkpoint_manager = get_checkpoint_manager()
|
||||
# Initialize unified model management for training loop
|
||||
checkpoint_manager = create_model_manager()
|
||||
training_integration = get_training_integration()
|
||||
|
||||
# Training statistics for checkpoint management
|
||||
|
@@ -6261,7 +6261,7 @@ class CleanTradingDashboard:
|
||||
# Save checkpoint after training
|
||||
if loss_count > 0:
|
||||
try:
|
||||
from utils.checkpoint_manager import save_checkpoint
|
||||
from NN.training.model_manager import save_checkpoint
|
||||
avg_loss = total_loss / loss_count
|
||||
|
||||
# Prepare checkpoint data
|
||||
@@ -6390,7 +6390,7 @@ class CleanTradingDashboard:
|
||||
# Save checkpoint after training
|
||||
if loss_count > 0:
|
||||
try:
|
||||
from utils.checkpoint_manager import save_checkpoint
|
||||
from NN.training.model_manager import save_checkpoint
|
||||
avg_loss = total_loss / loss_count
|
||||
|
||||
# Prepare checkpoint data
|
||||
@@ -6878,7 +6878,7 @@ class CleanTradingDashboard:
|
||||
# Save checkpoint after training
|
||||
if training_samples > 0:
|
||||
try:
|
||||
from utils.checkpoint_manager import save_checkpoint
|
||||
from NN.training.model_manager import save_checkpoint
|
||||
avg_loss = total_loss / loss_count if loss_count > 0 else 0.356
|
||||
|
||||
# Prepare checkpoint data
|
||||
|
@@ -443,14 +443,20 @@ class DashboardComponentManager:
|
||||
ask_levels = [center_bucket + i * bucket_size for i in range(1, num_levels + 1)]
|
||||
bid_levels = [center_bucket - i * bucket_size for i in range(num_levels)]
|
||||
|
||||
# Debug: Log how many orders we have to work with
|
||||
print(f"DEBUG COB: {symbol} - Processing {len(bids)} bids, {len(asks)} asks")
|
||||
print(f"DEBUG COB: Mid price: ${mid_price:.2f}, Bucket size: ${bucket_size}")
|
||||
print(f"DEBUG COB: Bid buckets: {len(bid_buckets)}, Ask buckets: {len(ask_buckets)}")
|
||||
if bid_buckets:
|
||||
print(f"DEBUG COB: Bid price range: ${min(bid_buckets.keys()):.2f} - ${max(bid_buckets.keys()):.2f}")
|
||||
if ask_buckets:
|
||||
print(f"DEBUG COB: Ask price range: ${min(ask_buckets.keys()):.2f} - ${max(ask_buckets.keys()):.2f}")
|
||||
# Debug: Combined log for COB ladder panel
|
||||
print(
|
||||
f"DEBUG COB: {symbol} - {len(bids)} bids, {len(asks)} asks | "
|
||||
f"Mid price: ${mid_price:.2f}, ${bucket_size} buckets | "
|
||||
f"Bid buckets: {len(bid_buckets)}, Ask buckets: {len(ask_buckets)}"
|
||||
+ (
|
||||
f" | Bid range: ${min(bid_buckets.keys()):.2f} - ${max(bid_buckets.keys()):.2f}"
|
||||
if bid_buckets else ""
|
||||
)
|
||||
+ (
|
||||
f" | Ask range: ${min(ask_buckets.keys()):.2f} - ${max(ask_buckets.keys()):.2f}"
|
||||
if ask_buckets else ""
|
||||
)
|
||||
)
|
||||
|
||||
def create_bookmap_row(price, bid_data, ask_data, max_vol):
|
||||
"""Create a Bookmap-style row with horizontal bars extending from center"""
|
||||
|
Reference in New Issue
Block a user