dedulicae model storage

This commit is contained in:
Dobromir Popov
2025-09-09 00:35:13 +03:00
parent c3a94600c8
commit 9671d0d363
11 changed files with 292 additions and 1846 deletions

View File

@@ -457,6 +457,72 @@ class ModelManager:
logger.error(f"Error getting storage stats: {e}")
return {'error': str(e)}
def get_checkpoint_stats(self) -> Dict[str, Any]:
"""Get statistics about managed checkpoints (compatible with old checkpoint_manager interface)"""
try:
stats = {
'total_models': 0,
'total_checkpoints': 0,
'total_size_mb': 0.0,
'models': {}
}
# Count files in different directories as "checkpoints"
checkpoint_dirs = [
self.checkpoints_dir / "cnn",
self.checkpoints_dir / "dqn",
self.checkpoints_dir / "rl",
self.checkpoints_dir / "transformer",
self.checkpoints_dir / "hybrid"
]
total_size = 0
total_files = 0
for checkpoint_dir in checkpoint_dirs:
if checkpoint_dir.exists():
model_files = list(checkpoint_dir.rglob('*.pt'))
if model_files:
model_name = checkpoint_dir.name
stats['total_models'] += 1
model_size = sum(f.stat().st_size for f in model_files)
stats['total_checkpoints'] += len(model_files)
stats['total_size_mb'] += model_size / (1024 * 1024)
total_size += model_size
total_files += len(model_files)
# Get the most recent file as "latest"
latest_file = max(model_files, key=lambda f: f.stat().st_mtime)
stats['models'][model_name] = {
'checkpoint_count': len(model_files),
'total_size_mb': model_size / (1024 * 1024),
'best_performance': 0.0, # Not tracked in unified system
'best_checkpoint_id': latest_file.name,
'latest_checkpoint': latest_file.name
}
# Also check saved models directory
if self.saved_dir.exists():
saved_files = list(self.saved_dir.rglob('*.pt'))
if saved_files:
stats['total_checkpoints'] += len(saved_files)
saved_size = sum(f.stat().st_size for f in saved_files)
stats['total_size_mb'] += saved_size / (1024 * 1024)
return stats
except Exception as e:
logger.error(f"Error getting checkpoint stats: {e}")
return {
'total_models': 0,
'total_checkpoints': 0,
'total_size_mb': 0.0,
'models': {},
'error': str(e)
}
def get_model_leaderboard(self) -> List[Dict[str, Any]]:
"""Get model performance leaderboard"""
try:

View File

@@ -1,560 +0,0 @@
#!/usr/bin/env python3
"""
Checkpoint Management System for W&B Training
"""
import os
import json
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, asdict
from collections import defaultdict
import torch
import random
WANDB_AVAILABLE = False
# Import model registry
from utils.model_registry import get_model_registry
logger = logging.getLogger(__name__)
@dataclass
class CheckpointMetadata:
checkpoint_id: str
model_name: str
model_type: str
file_path: str
created_at: datetime
file_size_mb: float
performance_score: float
accuracy: Optional[float] = None
loss: Optional[float] = None
val_accuracy: Optional[float] = None
val_loss: Optional[float] = None
reward: Optional[float] = None
pnl: Optional[float] = None
epoch: Optional[int] = None
training_time_hours: Optional[float] = None
total_parameters: Optional[int] = None
wandb_run_id: Optional[str] = None
wandb_artifact_name: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
data = asdict(self)
data['created_at'] = self.created_at.isoformat()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata':
data['created_at'] = datetime.fromisoformat(data['created_at'])
return cls(**data)
class CheckpointManager:
def __init__(self,
base_checkpoint_dir: str = "NN/models/saved",
max_checkpoints_per_model: int = 5,
metadata_file: str = "checkpoint_metadata.json",
enable_wandb: bool = False):
self.base_dir = Path(base_checkpoint_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
self.max_checkpoints = max_checkpoints_per_model
self.metadata_file = self.base_dir / metadata_file
self.enable_wandb = False
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
self._warned_models = set() # Track models we've warned about to reduce spam
self._load_metadata()
logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}")
def save_checkpoint(self, model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
"""Save a model checkpoint with improved error handling and validation using unified registry"""
try:
from NN.training.model_manager import save_checkpoint as registry_save_checkpoint
performance_score = self._calculate_performance_score(performance_metrics)
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved")
return None
# Use unified registry for checkpointing
success = registry_save_checkpoint(
model=model,
model_name=model_name,
model_type=model_type,
performance_score=performance_score,
metadata={
'performance_metrics': performance_metrics,
'training_metadata': training_metadata,
'checkpoint_manager': True
}
)
if not success:
return None
# Get checkpoint info from registry
registry = get_model_registry()
checkpoint_info = registry.metadata['models'][model_name]['checkpoints'][-1]
# Create CheckpointMetadata object
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_info['id'],
model_name=model_name,
model_type=model_type,
file_path=checkpoint_info['path'],
created_at=datetime.fromisoformat(checkpoint_info['timestamp']),
file_size_mb=0.0, # Will be calculated by registry
performance_score=performance_score,
accuracy=performance_metrics.get('accuracy'),
loss=performance_metrics.get('loss'),
val_accuracy=performance_metrics.get('val_accuracy'),
val_loss=performance_metrics.get('val_loss'),
reward=performance_metrics.get('reward'),
pnl=performance_metrics.get('pnl'),
epoch=training_metadata.get('epoch') if training_metadata else None,
training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None,
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
)
# Update local checkpoint tracking
self.checkpoints[model_name].append(metadata)
self._rotate_checkpoints(model_name)
self._save_metadata()
logger.debug(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})")
return metadata
except Exception as e:
logger.error(f"Error saving checkpoint for {model_name}: {e}")
return None
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
try:
from utils.model_registry import load_best_checkpoint as registry_load_checkpoint
# First, try the unified registry
registry_result = registry_load_checkpoint(model_name, 'cnn') # Try CNN type first
if registry_result is None:
registry_result = registry_load_checkpoint(model_name, 'dqn') # Try DQN type
if registry_result:
checkpoint_path, checkpoint_data = registry_result
# Create CheckpointMetadata from registry data
metadata = CheckpointMetadata(
checkpoint_id=f"{model_name}_registry",
model_name=model_name,
model_type=checkpoint_data.get('model_type', 'unknown'),
file_path=checkpoint_path,
created_at=datetime.fromisoformat(checkpoint_data.get('timestamp', datetime.now().isoformat())),
file_size_mb=0.0, # Will be calculated by registry
performance_score=checkpoint_data.get('performance_score', 0.0),
accuracy=checkpoint_data.get('accuracy'),
loss=checkpoint_data.get('loss'),
reward=checkpoint_data.get('reward'),
pnl=checkpoint_data.get('pnl')
)
logger.debug(f"Loading checkpoint from unified registry for {model_name}")
return checkpoint_path, metadata
# Fallback: Try the standard checkpoint system
if model_name in self.checkpoints and self.checkpoints[model_name]:
# Filter out checkpoints with non-existent files
valid_checkpoints = [
cp for cp in self.checkpoints[model_name]
if Path(cp.file_path).exists()
]
if valid_checkpoints:
best_checkpoint = max(valid_checkpoints, key=lambda x: x.performance_score)
logger.debug(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}")
return best_checkpoint.file_path, best_checkpoint
else:
# Clean up invalid metadata entries
invalid_count = len(self.checkpoints[model_name])
logger.warning(f"Found {invalid_count} invalid checkpoint entries for {model_name}, cleaning up metadata")
self.checkpoints[model_name] = []
self._save_metadata()
# Fallback: Look for existing saved models in the legacy format
logger.debug(f"No valid checkpoints found for model: {model_name}, attempting to find legacy saved models")
legacy_model_path = self._find_legacy_model(model_name)
if legacy_model_path:
# Create checkpoint metadata for the legacy model using actual file data
legacy_metadata = self._create_legacy_metadata(model_name, legacy_model_path)
logger.debug(f"Found legacy model for {model_name}: {legacy_model_path}")
return str(legacy_model_path), legacy_metadata
# Only warn once per model to avoid spam
if model_name not in self._warned_models:
logger.info(f"No checkpoints found for {model_name}, starting fresh")
self._warned_models.add(model_name)
return None
except Exception as e:
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
return None
def _calculate_performance_score(self, metrics: Dict[str, float]) -> float:
"""Calculate performance score with improved sensitivity for training models"""
score = 0.0
# Prioritize loss reduction for active training models
if 'loss' in metrics:
# Invert loss so lower loss = higher score, with better scaling
loss_value = metrics['loss']
if loss_value > 0:
score += max(0, 100 / (1 + loss_value)) # More sensitive to loss changes
else:
score += 100 # Perfect loss
# Add other metrics with appropriate weights
if 'accuracy' in metrics:
score += metrics['accuracy'] * 50 # Reduced weight to balance with loss
if 'val_accuracy' in metrics:
score += metrics['val_accuracy'] * 50
if 'val_loss' in metrics:
val_loss = metrics['val_loss']
if val_loss > 0:
score += max(0, 50 / (1 + val_loss))
if 'reward' in metrics:
score += metrics['reward'] * 10
if 'pnl' in metrics:
score += metrics['pnl'] * 5
if 'training_samples' in metrics:
# Bonus for processing more training samples
score += min(10, metrics['training_samples'] / 10)
# Return actual calculated score - NO SYNTHETIC MINIMUM
return score
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
"""Improved checkpoint saving logic with more frequent saves during training"""
if model_name not in self.checkpoints or not self.checkpoints[model_name]:
return True # Always save first checkpoint
# Allow more checkpoints during active training
if len(self.checkpoints[model_name]) < self.max_checkpoints:
return True
# Get current best and worst scores
scores = [cp.performance_score for cp in self.checkpoints[model_name]]
best_score = max(scores)
worst_score = min(scores)
# Save if better than worst (more frequent saves)
if performance_score > worst_score:
return True
# For high-performing models (score > 100), be more sensitive to small improvements
if best_score > 100:
# Save if within 0.1% of best score (very sensitive for converged models)
if performance_score >= best_score * 0.999:
return True
else:
# Also save if we're within 10% of best score (capture near-optimal models)
if performance_score >= best_score * 0.9:
return True
# Save more frequently during active training (every 5th attempt instead of 10th)
if random.random() < 0.2: # 20% chance to save anyway
logger.debug(f"Saving checkpoint for {model_name} - periodic save during active training")
return True
return False
def _save_model_file(self, model, file_path: Path, model_type: str) -> bool:
try:
if hasattr(model, 'state_dict'):
torch.save({
'model_state_dict': model.state_dict(),
'model_type': model_type,
'saved_at': datetime.now().isoformat()
}, file_path)
else:
torch.save(model, file_path)
return True
except Exception as e:
logger.error(f"Error saving model file {file_path}: {e}")
return False
def _rotate_checkpoints(self, model_name: str):
checkpoint_list = self.checkpoints[model_name]
if len(checkpoint_list) <= self.max_checkpoints:
return
checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True)
to_remove = checkpoint_list[self.max_checkpoints:]
self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints]
for checkpoint in to_remove:
try:
file_path = Path(checkpoint.file_path)
if file_path.exists():
file_path.unlink()
logger.debug(f"Rotated out checkpoint: {checkpoint.checkpoint_id}")
except Exception as e:
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]:
return None
def _load_metadata(self):
try:
if self.metadata_file.exists():
with open(self.metadata_file, 'r') as f:
data = json.load(f)
for model_name, checkpoint_list in data.items():
self.checkpoints[model_name] = [
CheckpointMetadata.from_dict(cp_data)
for cp_data in checkpoint_list
]
logger.info(f"Loaded metadata for {len(self.checkpoints)} models")
except Exception as e:
logger.error(f"Error loading checkpoint metadata: {e}")
def _save_metadata(self):
try:
data = {}
for model_name, checkpoint_list in self.checkpoints.items():
data[model_name] = [cp.to_dict() for cp in checkpoint_list]
with open(self.metadata_file, 'w') as f:
json.dump(data, f, indent=2)
except Exception as e:
logger.error(f"Error saving checkpoint metadata: {e}")
def get_checkpoint_stats(self):
"""Get statistics about managed checkpoints"""
stats = {
'total_models': len(self.checkpoints),
'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()),
'total_size_mb': 0.0,
'models': {}
}
for model_name, checkpoint_list in self.checkpoints.items():
if not checkpoint_list:
continue
model_size = sum(cp.file_size_mb for cp in checkpoint_list)
best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score)
stats['models'][model_name] = {
'checkpoint_count': len(checkpoint_list),
'total_size_mb': model_size,
'best_performance': best_checkpoint.performance_score,
'best_checkpoint_id': best_checkpoint.checkpoint_id,
'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id
}
stats['total_size_mb'] += model_size
return stats
def _find_legacy_model(self, model_name: str) -> Optional[Path]:
"""Find legacy saved models based on model name patterns"""
base_dir = Path(self.base_dir)
# Additional search locations
search_dirs = [
base_dir,
Path("models/saved"),
Path("NN/models/saved"),
Path("models"),
Path("models/archive"),
Path("models/backtest")
]
# Define model name mappings and patterns for legacy files
legacy_patterns = {
'dqn_agent': [
'dqn_agent_session_policy.pt',
'dqn_agent_session_agent_state.pt',
'dqn_agent_best_policy.pt',
'enhanced_dqn_best_policy.pt',
'improved_dqn_agent_best_policy.pt',
'dqn_agent_final_policy.pt',
'trading_agent_best_pnl.pt'
],
'enhanced_cnn': [
'cnn_model_session.pt',
'cnn_model_best.pt',
'optimized_short_term_model_best.pt',
'optimized_short_term_model_realtime_best.pt',
'optimized_short_term_model_ticks_best.pt'
],
'extrema_trainer': [
'supervised_model_best.pt'
],
'cob_rl': [
'best_rl_model.pth_policy.pt',
'rl_agent_best_policy.pt'
],
'decision': [
# Decision models might be in subdirectories, but let's check main dir too
'decision_best.pt',
'decision_model_best.pt',
# Check for transformer models which might be used as decision models
'enhanced_dqn_best_policy.pt',
'improved_dqn_agent_best_policy.pt'
]
}
# Get patterns for this model name
patterns = legacy_patterns.get(model_name, [])
# Also try generic patterns based on model name
patterns.extend([
f'{model_name}_best.pt',
f'{model_name}_best_policy.pt',
f'{model_name}_final.pt',
f'{model_name}_final_policy.pt'
])
# Search for the model files in all search directories
for search_dir in search_dirs:
if not search_dir.exists():
continue
for pattern in patterns:
candidate_path = search_dir / pattern
if candidate_path.exists():
logger.info(f"Found legacy model file: {candidate_path}")
return candidate_path
# Also check subdirectories
for subdir in base_dir.iterdir():
if subdir.is_dir() and subdir.name == model_name:
for pattern in patterns:
candidate_path = subdir / pattern
if candidate_path.exists():
logger.debug(f"Found legacy model file in subdirectory: {candidate_path}")
return candidate_path
# Extended search: scan common project model directories for best checkpoints
try:
# Attempt to infer project root from base_dir (NN/models/saved -> root)
project_root = base_dir.resolve().parent.parent.parent
except Exception:
project_root = Path(".").resolve()
additional_dirs = [
project_root / "models",
project_root / "models" / "archive",
project_root / "models" / "backtest",
]
def _match_legacy_name(candidate: Path, model: str) -> bool:
name = candidate.name.lower()
model_keys = {
'dqn_agent': ['dqn', 'agent', 'policy'],
'enhanced_cnn': ['cnn', 'optimized_short_term'],
'extrema_trainer': ['supervised', 'extrema'],
'cob_rl': ['cob', 'rl', 'policy'],
'decision': ['decision', 'transformer']
}.get(model, [model])
return any(k in name for k in model_keys)
candidates: List[Path] = []
for adir in additional_dirs:
if not adir.exists():
continue
try:
for pt in adir.rglob('*.pt'):
# Prefer files that indicate "best" and match model hints
lname = pt.name.lower()
if 'best' in lname and _match_legacy_name(pt, model_name):
candidates.append(pt)
# Do not add generic fallbacks to avoid mismatched model types
except Exception:
# Ignore directory traversal issues
pass
if candidates:
# Pick the most recently modified candidate
try:
best = max(candidates, key=lambda p: p.stat().st_mtime)
logger.debug(f"Found legacy model file in project models dir: {best}")
return best
except Exception:
# If stat fails, just return the first one deterministically
candidates.sort()
logger.debug(f"Found legacy model file in project models dir: {candidates[0]}")
return candidates[0]
return None
def _create_legacy_metadata(self, model_name: str, file_path: Path) -> CheckpointMetadata:
"""Create metadata for legacy model files using only actual file information"""
try:
file_size_mb = file_path.stat().st_size / (1024 * 1024)
created_time = datetime.fromtimestamp(file_path.stat().st_mtime)
# NO SYNTHETIC DATA - use only actual file information
return CheckpointMetadata(
checkpoint_id=f"legacy_{model_name}_{int(created_time.timestamp())}",
model_name=model_name,
model_type=model_name,
file_path=str(file_path),
created_at=created_time,
file_size_mb=file_size_mb,
performance_score=0.0, # Unknown performance - use 0, not synthetic values
accuracy=None,
loss=None,
val_accuracy=None,
val_loss=None,
reward=None,
pnl=None,
epoch=None,
training_time_hours=None,
total_parameters=None,
wandb_run_id=None,
wandb_artifact_name=None
)
except Exception as e:
logger.error(f"Error creating legacy metadata for {model_name}: {e}")
# Return a basic metadata with minimal info - NO SYNTHETIC VALUES
return CheckpointMetadata(
checkpoint_id=f"legacy_{model_name}",
model_name=model_name,
model_type=model_name,
file_path=str(file_path),
created_at=datetime.now(),
file_size_mb=0.0,
performance_score=0.0 # Unknown - use 0, not synthetic
)
_checkpoint_manager = None
def get_checkpoint_manager() -> CheckpointManager:
global _checkpoint_manager
if _checkpoint_manager is None:
_checkpoint_manager = CheckpointManager()
return _checkpoint_manager
def save_checkpoint(model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
return get_checkpoint_manager().save_checkpoint(
model, model_name, model_type, performance_metrics, training_metadata, force_save
)
def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
return get_checkpoint_manager().load_best_checkpoint(model_name)

View File

@@ -1,361 +0,0 @@
#!/usr/bin/env python3
"""
Improved Model Saver
A comprehensive model saving utility that handles various model types
and ensures reliable checkpointing with validation.
"""
import logging
import torch
import os
import json
from pathlib import Path
from datetime import datetime
from typing import Dict, Any, Optional, Union
import shutil
logger = logging.getLogger(__name__)
class ImprovedModelSaver:
"""Enhanced model saving with validation and backup strategies"""
def __init__(self, base_dir: str = "models/saved"):
self.base_dir = Path(base_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
def save_model_safely(self,
model: Any,
model_name: str,
model_type: str = "unknown",
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Save a model with multiple fallback strategies
Args:
model: The model to save
model_name: Name identifier for the model
model_type: Type of model (dqn, cnn, rl, etc.)
metadata: Additional metadata to save
Returns:
bool: True if successful, False otherwise
"""
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
model_dir = self.base_dir / model_name
model_dir.mkdir(parents=True, exist_ok=True)
# Create backup file names
main_path = model_dir / f"{model_name}_latest.pt"
backup_path = model_dir / f"{model_name}_{timestamp}.pt"
try:
# Strategy 1: Try to save using robust_save if available
if hasattr(model, '__dict__') and hasattr(torch, 'save'):
success = self._save_pytorch_model(model, main_path, backup_path)
if success:
self._save_metadata(model_dir, model_name, model_type, metadata)
logger.info(f"Successfully saved {model_name} using PyTorch save")
return True
# Strategy 2: Try state_dict saving for PyTorch models
if hasattr(model, 'state_dict'):
success = self._save_state_dict(model, main_path, backup_path)
if success:
self._save_metadata(model_dir, model_name, model_type, metadata)
logger.info(f"Successfully saved {model_name} using state_dict")
return True
# Strategy 3: Try component-based saving for complex models
if hasattr(model, 'policy_net') or hasattr(model, 'target_net'):
success = self._save_rl_agent_components(model, model_dir, model_name)
if success:
self._save_metadata(model_dir, model_name, model_type, metadata)
logger.info(f"Successfully saved {model_name} using component-based saving")
return True
# Strategy 4: Fallback - try pickle
success = self._save_with_pickle(model, main_path, backup_path)
if success:
self._save_metadata(model_dir, model_name, model_type, metadata)
logger.info(f"Successfully saved {model_name} using pickle fallback")
return True
logger.error(f"All save strategies failed for {model_name}")
return False
except Exception as e:
logger.error(f"Critical error saving {model_name}: {e}")
return False
def _save_pytorch_model(self, model, main_path: Path, backup_path: Path) -> bool:
"""Save using standard PyTorch torch.save"""
try:
# Create checkpoint data
if hasattr(model, 'state_dict'):
checkpoint = {
'model_state_dict': model.state_dict(),
'model_class': model.__class__.__name__,
'timestamp': datetime.now().isoformat()
}
# Add additional attributes
for attr in ['epsilon', 'total_steps', 'current_reward', 'optimizer']:
if hasattr(model, attr):
try:
value = getattr(model, attr)
if attr == 'optimizer' and value is not None:
checkpoint['optimizer_state_dict'] = value.state_dict()
else:
checkpoint[attr] = value
except Exception:
pass # Skip problematic attributes
else:
checkpoint = {
'model': model,
'timestamp': datetime.now().isoformat()
}
# Save to backup location first
torch.save(checkpoint, backup_path)
# Verify backup was saved correctly
torch.load(backup_path, map_location='cpu')
# Copy to main location
shutil.copy2(backup_path, main_path)
return True
except Exception as e:
logger.warning(f"PyTorch save failed: {e}")
return False
def _save_state_dict(self, model, main_path: Path, backup_path: Path) -> bool:
"""Save using state_dict only"""
try:
state_dict = model.state_dict()
checkpoint = {
'state_dict': state_dict,
'model_class': model.__class__.__name__,
'timestamp': datetime.now().isoformat()
}
torch.save(checkpoint, backup_path)
torch.load(backup_path, map_location='cpu') # Verify
shutil.copy2(backup_path, main_path)
return True
except Exception as e:
logger.warning(f"State dict save failed: {e}")
return False
def _save_rl_agent_components(self, model, model_dir: Path, model_name: str) -> bool:
"""Save RL agent components separately"""
try:
components_saved = 0
# Save policy network
if hasattr(model, 'policy_net') and model.policy_net is not None:
policy_path = model_dir / f"{model_name}_policy.pt"
torch.save(model.policy_net.state_dict(), policy_path)
components_saved += 1
# Save target network
if hasattr(model, 'target_net') and model.target_net is not None:
target_path = model_dir / f"{model_name}_target.pt"
torch.save(model.target_net.state_dict(), target_path)
components_saved += 1
# Save agent state
agent_state = {}
for attr in ['epsilon', 'total_steps', 'current_reward', 'memory']:
if hasattr(model, attr):
try:
value = getattr(model, attr)
if attr == 'memory' and hasattr(value, '__len__'):
# Don't save large replay buffers
agent_state[attr + '_size'] = len(value)
else:
agent_state[attr] = value
except Exception:
pass
if agent_state:
state_path = model_dir / f"{model_name}_agent_state.pt"
torch.save(agent_state, state_path)
components_saved += 1
return components_saved > 0
except Exception as e:
logger.warning(f"Component-based save failed: {e}")
return False
def _save_with_pickle(self, model, main_path: Path, backup_path: Path) -> bool:
"""Fallback: save using pickle"""
try:
import pickle
with open(backup_path.with_suffix('.pkl'), 'wb') as f:
pickle.dump(model, f)
# Verify
with open(backup_path.with_suffix('.pkl'), 'rb') as f:
pickle.load(f)
shutil.copy2(backup_path.with_suffix('.pkl'), main_path.with_suffix('.pkl'))
return True
except Exception as e:
logger.warning(f"Pickle save failed: {e}")
return False
def _save_metadata(self, model_dir: Path, model_name: str, model_type: str, metadata: Optional[Dict[str, Any]]):
"""Save model metadata"""
try:
meta_data = {
'model_name': model_name,
'model_type': model_type,
'saved_at': datetime.now().isoformat(),
'save_method': 'improved_model_saver'
}
if metadata:
meta_data.update(metadata)
meta_path = model_dir / f"{model_name}_metadata.json"
with open(meta_path, 'w') as f:
json.dump(meta_data, f, indent=2, default=str)
except Exception as e:
logger.warning(f"Failed to save metadata: {e}")
def load_model_safely(self, model_name: str, model_class=None):
"""
Load a model with multiple strategies
Args:
model_name: Name of the model to load
model_class: Class to instantiate if needed
Returns:
Loaded model or None
"""
model_dir = self.base_dir / model_name
if not model_dir.exists():
logger.warning(f"Model directory not found: {model_dir}")
return None
# Try different loading strategies
loaders = [
self._load_pytorch_checkpoint,
self._load_state_dict_only,
self._load_rl_components,
self._load_pickle_fallback
]
for loader in loaders:
try:
result = loader(model_dir, model_name, model_class)
if result is not None:
logger.info(f"Successfully loaded {model_name} using {loader.__name__}")
return result
except Exception as e:
logger.debug(f"{loader.__name__} failed: {e}")
continue
logger.error(f"All load strategies failed for {model_name}")
return None
def _load_pytorch_checkpoint(self, model_dir: Path, model_name: str, model_class):
"""Load PyTorch checkpoint"""
main_path = model_dir / f"{model_name}_latest.pt"
if main_path.exists():
checkpoint = torch.load(main_path, map_location='cpu')
if model_class and 'model_state_dict' in checkpoint:
model = model_class()
model.load_state_dict(checkpoint['model_state_dict'])
# Restore other attributes
for key, value in checkpoint.items():
if key not in ['model_state_dict', 'optimizer_state_dict', 'timestamp', 'model_class']:
if hasattr(model, key):
setattr(model, key, value)
return model
return checkpoint.get('model', checkpoint)
return None
def _load_state_dict_only(self, model_dir: Path, model_name: str, model_class):
"""Load state dict only"""
main_path = model_dir / f"{model_name}_latest.pt"
if main_path.exists() and model_class:
checkpoint = torch.load(main_path, map_location='cpu')
if 'state_dict' in checkpoint:
model = model_class()
model.load_state_dict(checkpoint['state_dict'])
return model
return None
def _load_rl_components(self, model_dir: Path, model_name: str, model_class):
"""Load RL agent from components"""
policy_path = model_dir / f"{model_name}_policy.pt"
target_path = model_dir / f"{model_name}_target.pt"
state_path = model_dir / f"{model_name}_agent_state.pt"
if policy_path.exists() and model_class:
model = model_class()
# Load policy network
if hasattr(model, 'policy_net'):
model.policy_net.load_state_dict(torch.load(policy_path, map_location='cpu'))
# Load target network
if target_path.exists() and hasattr(model, 'target_net'):
model.target_net.load_state_dict(torch.load(target_path, map_location='cpu'))
# Load agent state
if state_path.exists():
agent_state = torch.load(state_path, map_location='cpu')
for key, value in agent_state.items():
if hasattr(model, key):
setattr(model, key, value)
return model
return None
def _load_pickle_fallback(self, model_dir: Path, model_name: str, model_class):
"""Load from pickle"""
pickle_path = model_dir / f"{model_name}_latest.pkl"
if pickle_path.exists():
import pickle
with open(pickle_path, 'rb') as f:
return pickle.load(f)
return None
# Global instance for easy access
_improved_model_saver = None
def get_improved_model_saver() -> ImprovedModelSaver:
"""Get or create the global improved model saver instance"""
global _improved_model_saver
if _improved_model_saver is None:
_improved_model_saver = ImprovedModelSaver()
return _improved_model_saver

View File

@@ -1,246 +0,0 @@
#!/usr/bin/env python3
"""
Model Checkpoint Saver
Utility to ensure all models can save checkpoints properly.
This will make them show as LOADED instead of FRESH.
"""
import logging
import os
from datetime import datetime
from typing import Dict, Any, Optional
from pathlib import Path
logger = logging.getLogger(__name__)
class ModelCheckpointSaver:
"""Utility to save checkpoints for all models to fix FRESH status"""
def __init__(self, orchestrator):
self.orchestrator = orchestrator
def save_all_model_checkpoints(self, force: bool = True) -> Dict[str, bool]:
"""Save checkpoints for all initialized models"""
results = {}
# Save DQN Agent
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
results['dqn_agent'] = self._save_dqn_checkpoint(force)
# Save CNN Model
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
results['enhanced_cnn'] = self._save_cnn_checkpoint(force)
# Save Extrema Trainer
if hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
results['extrema_trainer'] = self._save_extrema_checkpoint(force)
# COB RL model removed - see COB_MODEL_ARCHITECTURE_DOCUMENTATION.md
# Will recreate when COB data quality is improved
# Save Transformer
if hasattr(self.orchestrator, 'transformer_trainer') and self.orchestrator.transformer_trainer:
results['transformer'] = self._save_transformer_checkpoint(force)
# Save Decision Model
if hasattr(self.orchestrator, 'decision_model') and self.orchestrator.decision_model:
results['decision'] = self._save_decision_checkpoint(force)
return results
def _save_dqn_checkpoint(self, force: bool = True) -> bool:
"""Save DQN agent checkpoint"""
try:
if hasattr(self.orchestrator.rl_agent, 'save_checkpoint'):
success = self.orchestrator.rl_agent.save_checkpoint(force_save=force)
if success:
self.orchestrator.model_states['dqn']['checkpoint_loaded'] = True
self.orchestrator.model_states['dqn']['checkpoint_filename'] = f"dqn_agent_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
logger.info("DQN checkpoint saved successfully")
return True
# Fallback: use improved model saver
from improved_model_saver import get_improved_model_saver
saver = get_improved_model_saver()
success = saver.save_model_safely(
self.orchestrator.rl_agent,
"dqn_agent",
"dqn",
metadata={"saved_by": "checkpoint_saver", "timestamp": datetime.now().isoformat()}
)
if success:
self.orchestrator.model_states['dqn']['checkpoint_loaded'] = True
self.orchestrator.model_states['dqn']['checkpoint_filename'] = "dqn_agent_latest"
logger.info("DQN checkpoint saved using fallback method")
return True
return False
except Exception as e:
logger.error(f"Failed to save DQN checkpoint: {e}")
return False
def _save_cnn_checkpoint(self, force: bool = True) -> bool:
"""Save CNN model checkpoint"""
try:
if hasattr(self.orchestrator.cnn_model, 'save_checkpoint'):
success = self.orchestrator.cnn_model.save_checkpoint(force_save=force)
if success:
self.orchestrator.model_states['cnn']['checkpoint_loaded'] = True
self.orchestrator.model_states['cnn']['checkpoint_filename'] = f"enhanced_cnn_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
logger.info("CNN checkpoint saved successfully")
return True
# Fallback: use improved model saver
from improved_model_saver import get_improved_model_saver
saver = get_improved_model_saver()
success = saver.save_model_safely(
self.orchestrator.cnn_model,
"enhanced_cnn",
"cnn",
metadata={"saved_by": "checkpoint_saver", "timestamp": datetime.now().isoformat()}
)
if success:
self.orchestrator.model_states['cnn']['checkpoint_loaded'] = True
self.orchestrator.model_states['cnn']['checkpoint_filename'] = "enhanced_cnn_latest"
logger.info("CNN checkpoint saved using fallback method")
return True
return False
except Exception as e:
logger.error(f"Failed to save CNN checkpoint: {e}")
return False
def _save_extrema_checkpoint(self, force: bool = True) -> bool:
"""Save Extrema Trainer checkpoint"""
try:
if hasattr(self.orchestrator.extrema_trainer, 'save_checkpoint'):
self.orchestrator.extrema_trainer.save_checkpoint(force_save=force)
self.orchestrator.model_states['extrema_trainer']['checkpoint_loaded'] = True
self.orchestrator.model_states['extrema_trainer']['checkpoint_filename'] = f"extrema_trainer_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
logger.info("Extrema Trainer checkpoint saved successfully")
return True
return False
except Exception as e:
logger.error(f"Failed to save Extrema Trainer checkpoint: {e}")
return False
def _save_cob_rl_checkpoint(self, force: bool = True) -> bool:
"""Save COB RL agent checkpoint"""
try:
# COB RL may have a different saving mechanism
from improved_model_saver import get_improved_model_saver
saver = get_improved_model_saver()
success = saver.save_model_safely(
self.orchestrator.cob_rl_agent,
"cob_rl",
"cob_rl",
metadata={"saved_by": "checkpoint_saver", "timestamp": datetime.now().isoformat()}
)
if success:
self.orchestrator.model_states['cob_rl']['checkpoint_loaded'] = True
self.orchestrator.model_states['cob_rl']['checkpoint_filename'] = "cob_rl_latest"
logger.info("COB RL checkpoint saved successfully")
return True
return False
except Exception as e:
logger.error(f"Failed to save COB RL checkpoint: {e}")
return False
def _save_transformer_checkpoint(self, force: bool = True) -> bool:
"""Save Transformer model checkpoint"""
try:
if hasattr(self.orchestrator.transformer_trainer, 'save_model'):
# Create a checkpoint file path
checkpoint_dir = Path("models/saved/transformer")
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = checkpoint_dir / f"transformer_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"
self.orchestrator.transformer_trainer.save_model(str(checkpoint_path))
self.orchestrator.model_states['transformer']['checkpoint_loaded'] = True
self.orchestrator.model_states['transformer']['checkpoint_filename'] = checkpoint_path.name
logger.info("Transformer checkpoint saved successfully")
return True
return False
except Exception as e:
logger.error(f"Failed to save Transformer checkpoint: {e}")
return False
def _save_decision_checkpoint(self, force: bool = True) -> bool:
"""Save Decision model checkpoint"""
try:
from improved_model_saver import get_improved_model_saver
saver = get_improved_model_saver()
success = saver.save_model_safely(
self.orchestrator.decision_model,
"decision",
"decision",
metadata={"saved_by": "checkpoint_saver", "timestamp": datetime.now().isoformat()}
)
if success:
self.orchestrator.model_states['decision']['checkpoint_loaded'] = True
self.orchestrator.model_states['decision']['checkpoint_filename'] = "decision_latest"
logger.info("Decision model checkpoint saved successfully")
return True
return False
except Exception as e:
logger.error(f"Failed to save Decision model checkpoint: {e}")
return False
def update_model_status_to_loaded(self, model_name: str):
"""Manually update a model's status to LOADED"""
if model_name in self.orchestrator.model_states:
self.orchestrator.model_states[model_name]['checkpoint_loaded'] = True
if not self.orchestrator.model_states[model_name].get('checkpoint_filename'):
self.orchestrator.model_states[model_name]['checkpoint_filename'] = f"{model_name}_manual_loaded"
logger.info(f"Updated {model_name} status to LOADED")
def force_all_models_to_loaded(self):
"""Force all existing models to show as LOADED"""
models_updated = []
for model_name in self.orchestrator.model_states.keys():
# Check if model actually exists
model_exists = False
if model_name == 'dqn' and hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
model_exists = True
elif model_name == 'cnn' and hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
model_exists = True
elif model_name == 'extrema_trainer' and hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
model_exists = True
# COB RL model removed - focusing on COB data quality first
elif model_name == 'transformer' and hasattr(self.orchestrator, 'transformer_model') and self.orchestrator.transformer_model:
model_exists = True
elif model_name == 'decision' and hasattr(self.orchestrator, 'decision_model') and self.orchestrator.decision_model:
model_exists = True
if model_exists:
self.update_model_status_to_loaded(model_name)
models_updated.append(model_name)
logger.info(f"Force-updated {len(models_updated)} models to LOADED status: {models_updated}")
return models_updated
def save_all_checkpoints_now(orchestrator):
"""Convenience function to save all checkpoints"""
saver = ModelCheckpointSaver(orchestrator)
results = saver.save_all_model_checkpoints(force=True)
print("Checkpoint saving results:")
for model_name, success in results.items():
status = "✅ SUCCESS" if success else "❌ FAILED"
print(f" {model_name}: {status}")
return results

View File

@@ -1,446 +0,0 @@
#!/usr/bin/env python3
"""
Unified Model Registry for Centralized Model Management
This module provides a unified interface for saving, loading, and managing
all machine learning models in the trading system. It consolidates model
storage from multiple locations into a single, organized structure.
"""
import os
import json
import torch
import logging
import pickle
from pathlib import Path
from typing import Dict, Any, Optional, Tuple, List
from datetime import datetime
import hashlib
logger = logging.getLogger(__name__)
class ModelRegistry:
"""
Unified model registry for centralized model management.
Handles saving, loading, and organization of all ML models.
"""
def __init__(self, base_dir: str = "models"):
"""
Initialize the model registry.
Args:
base_dir: Base directory for model storage
"""
self.base_dir = Path(base_dir)
self.saved_dir = self.base_dir / "saved"
self.checkpoint_dir = self.base_dir / "checkpoints"
self.archive_dir = self.base_dir / "archive"
# Model type directories
self.model_dirs = {
'cnn': self.base_dir / "cnn",
'dqn': self.base_dir / "dqn",
'transformer': self.base_dir / "transformer",
'hybrid': self.base_dir / "hybrid"
}
# Ensure all directories exist
self._ensure_directories()
# Metadata tracking
self.metadata_file = self.base_dir / "registry_metadata.json"
self.metadata = self._load_metadata()
logger.info(f"Model Registry initialized at {self.base_dir}")
def _ensure_directories(self):
"""Ensure all required directories exist."""
directories = [
self.saved_dir,
self.checkpoint_dir,
self.archive_dir
]
# Add model type directories
for model_dir in self.model_dirs.values():
directories.extend([
model_dir / "saved",
model_dir / "checkpoints",
model_dir / "archive"
])
for directory in directories:
directory.mkdir(parents=True, exist_ok=True)
def _load_metadata(self) -> Dict[str, Any]:
"""Load registry metadata."""
if self.metadata_file.exists():
try:
with open(self.metadata_file, 'r') as f:
return json.load(f)
except Exception as e:
logger.warning(f"Failed to load metadata: {e}")
return {'models': {}, 'last_updated': datetime.now().isoformat()}
def _save_metadata(self):
"""Save registry metadata."""
self.metadata['last_updated'] = datetime.now().isoformat()
try:
with open(self.metadata_file, 'w') as f:
json.dump(self.metadata, f, indent=2)
except Exception as e:
logger.error(f"Failed to save metadata: {e}")
def save_model(self, model: Any, model_name: str, model_type: str = 'cnn',
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Save a model to the unified storage.
Args:
model: The model to save
model_name: Name of the model
model_type: Type of model (cnn, dqn, transformer, hybrid)
metadata: Additional metadata to save
Returns:
bool: True if successful, False otherwise
"""
try:
model_dir = self.model_dirs.get(model_type, self.saved_dir)
save_dir = model_dir / "saved"
# Generate filename with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f"{model_name}_{timestamp}.pt"
filepath = save_dir / filename
# Also save as latest
latest_filepath = save_dir / f"{model_name}_latest.pt"
# Save model
save_dict = {
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
'model_class': model.__class__.__name__,
'model_type': model_type,
'timestamp': timestamp,
'metadata': metadata or {}
}
torch.save(save_dict, filepath)
torch.save(save_dict, latest_filepath)
# Update metadata
if model_name not in self.metadata['models']:
self.metadata['models'][model_name] = {}
self.metadata['models'][model_name].update({
'type': model_type,
'latest_path': str(latest_filepath),
'last_saved': timestamp,
'save_count': self.metadata['models'][model_name].get('save_count', 0) + 1
})
self._save_metadata()
logger.info(f"Model {model_name} saved to {filepath}")
return True
except Exception as e:
logger.error(f"Failed to save model {model_name}: {e}")
return False
def load_model(self, model_name: str, model_type: str = 'cnn',
model_class: Optional[Any] = None) -> Optional[Any]:
"""
Load a model from the unified storage.
Args:
model_name: Name of the model to load
model_type: Type of model (cnn, dqn, transformer, hybrid)
model_class: Model class to instantiate (if needed)
Returns:
The loaded model or None if failed
"""
try:
model_dir = self.model_dirs.get(model_type, self.saved_dir)
save_dir = model_dir / "saved"
latest_filepath = save_dir / f"{model_name}_latest.pt"
if not latest_filepath.exists():
logger.warning(f"Model {model_name} not found at {latest_filepath}")
return None
# Load checkpoint
checkpoint = torch.load(latest_filepath, map_location='cpu')
# Instantiate model if class provided
if model_class is not None:
model = model_class()
model.load_state_dict(checkpoint['model_state_dict'])
else:
# Try to reconstruct model from state_dict
model = type('LoadedModel', (), {})()
model.state_dict = lambda: checkpoint['model_state_dict']
model.load_state_dict = lambda state_dict: None
logger.info(f"Model {model_name} loaded from {latest_filepath}")
return model
except Exception as e:
logger.error(f"Failed to load model {model_name}: {e}")
return None
def save_checkpoint(self, model: Any, model_name: str, model_type: str = 'cnn',
performance_score: float = 0.0,
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Save a model checkpoint.
Args:
model: The model to checkpoint
model_name: Name of the model
model_type: Type of model
performance_score: Performance score for this checkpoint
metadata: Additional metadata
Returns:
bool: True if successful, False otherwise
"""
try:
model_dir = self.model_dirs.get(model_type, self.checkpoint_dir)
checkpoint_dir = model_dir / "checkpoints"
# Generate checkpoint ID
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_id = f"{model_name}_{timestamp}_{performance_score:.4f}"
filepath = checkpoint_dir / f"{checkpoint_id}.pt"
# Save checkpoint
checkpoint_data = {
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else {},
'model_class': model.__class__.__name__,
'model_type': model_type,
'model_name': model_name,
'performance_score': performance_score,
'timestamp': timestamp,
'metadata': metadata or {}
}
torch.save(checkpoint_data, filepath)
# Update metadata
if model_name not in self.metadata['models']:
self.metadata['models'][model_name] = {}
if 'checkpoints' not in self.metadata['models'][model_name]:
self.metadata['models'][model_name]['checkpoints'] = []
checkpoint_info = {
'id': checkpoint_id,
'path': str(filepath),
'performance_score': performance_score,
'timestamp': timestamp
}
self.metadata['models'][model_name]['checkpoints'].append(checkpoint_info)
# Keep only top 5 checkpoints
checkpoints = self.metadata['models'][model_name]['checkpoints']
if len(checkpoints) > 5:
checkpoints.sort(key=lambda x: x['performance_score'], reverse=True)
checkpoints_to_remove = checkpoints[5:]
for checkpoint in checkpoints_to_remove:
try:
os.remove(checkpoint['path'])
except:
pass
self.metadata['models'][model_name]['checkpoints'] = checkpoints[:5]
self._save_metadata()
logger.info(f"Checkpoint {checkpoint_id} saved with score {performance_score}")
return True
except Exception as e:
logger.error(f"Failed to save checkpoint for {model_name}: {e}")
return False
def load_best_checkpoint(self, model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]:
"""
Load the best checkpoint for a model.
Args:
model_name: Name of the model
model_type: Type of model
Returns:
Tuple of (checkpoint_path, checkpoint_data) or None
"""
try:
if model_name not in self.metadata['models']:
logger.warning(f"No metadata found for model {model_name}")
return None
checkpoints = self.metadata['models'][model_name].get('checkpoints', [])
if not checkpoints:
logger.warning(f"No checkpoints found for model {model_name}")
return None
# Find best checkpoint by performance score
best_checkpoint = max(checkpoints, key=lambda x: x['performance_score'])
checkpoint_path = best_checkpoint['path']
if not os.path.exists(checkpoint_path):
logger.warning(f"Checkpoint file not found: {checkpoint_path}")
return None
checkpoint_data = torch.load(checkpoint_path, map_location='cpu')
logger.info(f"Best checkpoint loaded for {model_name}: {best_checkpoint['id']}")
return checkpoint_path, checkpoint_data
except Exception as e:
logger.error(f"Failed to load best checkpoint for {model_name}: {e}")
return None
def archive_model(self, model_name: str, model_type: str = 'cnn') -> bool:
"""
Archive a model by moving it to archive directory.
Args:
model_name: Name of the model to archive
model_type: Type of model
Returns:
bool: True if successful, False otherwise
"""
try:
model_dir = self.model_dirs.get(model_type, self.saved_dir)
save_dir = model_dir / "saved"
archive_dir = model_dir / "archive"
latest_filepath = save_dir / f"{model_name}_latest.pt"
if not latest_filepath.exists():
logger.warning(f"Model {model_name} not found to archive")
return False
# Move to archive with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
archive_filepath = archive_dir / f"{model_name}_archived_{timestamp}.pt"
os.rename(latest_filepath, archive_filepath)
logger.info(f"Model {model_name} archived to {archive_filepath}")
return True
except Exception as e:
logger.error(f"Failed to archive model {model_name}: {e}")
return False
def list_models(self, model_type: Optional[str] = None) -> Dict[str, Any]:
"""
List all models in the registry.
Args:
model_type: Filter by model type (optional)
Returns:
Dictionary of model information
"""
models_info = {}
for model_name, model_data in self.metadata['models'].items():
if model_type and model_data.get('type') != model_type:
continue
models_info[model_name] = {
'type': model_data.get('type'),
'last_saved': model_data.get('last_saved'),
'save_count': model_data.get('save_count', 0),
'checkpoint_count': len(model_data.get('checkpoints', [])),
'latest_path': model_data.get('latest_path')
}
return models_info
def cleanup_old_checkpoints(self, model_name: str, keep_count: int = 5) -> int:
"""
Clean up old checkpoints, keeping only the best ones.
Args:
model_name: Name of the model
keep_count: Number of checkpoints to keep
Returns:
Number of checkpoints removed
"""
if model_name not in self.metadata['models']:
return 0
checkpoints = self.metadata['models'][model_name].get('checkpoints', [])
if len(checkpoints) <= keep_count:
return 0
# Sort by performance score (descending)
checkpoints.sort(key=lambda x: x['performance_score'], reverse=True)
# Remove old checkpoints
removed_count = 0
for checkpoint in checkpoints[keep_count:]:
try:
os.remove(checkpoint['path'])
removed_count += 1
except:
pass
# Update metadata
self.metadata['models'][model_name]['checkpoints'] = checkpoints[:keep_count]
self._save_metadata()
logger.info(f"Cleaned up {removed_count} old checkpoints for {model_name}")
return removed_count
# Global registry instance
_registry_instance = None
def get_model_registry() -> ModelRegistry:
"""Get the global model registry instance."""
global _registry_instance
if _registry_instance is None:
_registry_instance = ModelRegistry()
return _registry_instance
def save_model(model: Any, model_name: str, model_type: str = 'cnn',
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Convenience function to save a model using the global registry.
"""
return get_model_registry().save_model(model, model_name, model_type, metadata)
def load_model(model_name: str, model_type: str = 'cnn',
model_class: Optional[Any] = None) -> Optional[Any]:
"""
Convenience function to load a model using the global registry.
"""
return get_model_registry().load_model(model_name, model_type, model_class)
def save_checkpoint(model: Any, model_name: str, model_type: str = 'cnn',
performance_score: float = 0.0,
metadata: Optional[Dict[str, Any]] = None) -> bool:
"""
Convenience function to save a checkpoint using the global registry.
"""
return get_model_registry().save_checkpoint(model, model_name, model_type, performance_score, metadata)
def load_best_checkpoint(model_name: str, model_type: str = 'cnn') -> Optional[Tuple[str, Any]]:
"""
Convenience function to load the best checkpoint using the global registry.
"""
return get_model_registry().load_best_checkpoint(model_name, model_type)

View File

@@ -1,109 +0,0 @@
"""
Models Module
Provides model registry and interfaces for the trading system.
This module acts as a bridge between the core system and the NN models.
"""
import logging
from typing import Dict, Any, Optional, List
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface
logger = logging.getLogger(__name__)
class ModelRegistry:
"""Registry for managing trading models"""
def __init__(self):
self.models: Dict[str, ModelInterface] = {}
self.model_performance: Dict[str, Dict[str, Any]] = {}
def register_model(self, model: ModelInterface):
"""Register a model in the registry"""
name = model.name
self.models[name] = model
self.model_performance[name] = {
'correct': 0,
'total': 0,
'accuracy': 0.0,
'last_used': None
}
logger.info(f"Registered model: {name}")
return True
def get_model(self, name: str) -> Optional[ModelInterface]:
"""Get a model by name"""
return self.models.get(name)
def get_all_models(self) -> Dict[str, ModelInterface]:
"""Get all registered models"""
return self.models.copy()
def update_performance(self, name: str, correct: bool):
"""Update model performance metrics"""
if name in self.model_performance:
self.model_performance[name]['total'] += 1
if correct:
self.model_performance[name]['correct'] += 1
self.model_performance[name]['accuracy'] = (
self.model_performance[name]['correct'] /
self.model_performance[name]['total']
)
def get_best_model(self, model_type: str = None) -> Optional[str]:
"""Get the best performing model"""
if not self.model_performance:
return None
best_model = None
best_accuracy = -1.0
for name, perf in self.model_performance.items():
if model_type and not name.lower().startswith(model_type.lower()):
continue
if perf['accuracy'] > best_accuracy:
best_accuracy = perf['accuracy']
best_model = name
return best_model
def unregister_model(self, name: str) -> bool:
"""Unregister a model from the registry"""
if name in self.models:
del self.models[name]
if name in self.model_performance:
del self.model_performance[name]
logger.info(f"Unregistered model: {name}")
return True
# Global model registry instance
_model_registry = ModelRegistry()
def get_model_registry() -> ModelRegistry:
"""Get the global model registry instance"""
return _model_registry
def register_model(model: ModelInterface):
"""Register a model in the global registry"""
return _model_registry.register_model(model)
def get_model(name: str) -> Optional[ModelInterface]:
"""Get a model from the global registry"""
return _model_registry.get_model(name)
def get_all_models() -> Dict[str, ModelInterface]:
"""Get all models from the global registry"""
return _model_registry.get_all_models()
# Export the interfaces
__all__ = [
'ModelRegistry',
'get_model_registry',
'register_model',
'get_model',
'get_all_models',
'ModelInterface',
'CNNModelInterface',
'RLAgentInterface',
'ExtremaTrainerInterface'
]

View File

@@ -346,11 +346,58 @@ class TradingOrchestrator:
logger.warning("Extrema trainer not available")
self.extrema_trainer = None
# COB RL Model REMOVED - See COB_MODEL_ARCHITECTURE_DOCUMENTATION.md
# Reason: Need quality COB data first before evaluating massive parameter benefit
# Will recreate improved version when COB data pipeline is fixed
logger.info("COB RL model removed - focusing on COB data quality first")
self.cob_rl_agent = None
# Initialize COB RL Model - UNIFIED with ModelManager
try:
from NN.models.cob_rl_model import COBRLModelInterface
# Initialize COB RL model using unified approach
self.cob_rl_agent = COBRLModelInterface(
model_checkpoint_dir="@checkpoints/cob_rl",
device='cuda' if torch.cuda.is_available() else 'cpu'
)
# Add COB RL to model states tracking
self.model_states['cob_rl'] = {
'initial_loss': None,
'current_loss': None,
'best_loss': None,
'checkpoint_loaded': False
}
# Load best checkpoint using unified ModelManager
checkpoint_loaded = False
try:
from NN.training.model_manager import load_best_checkpoint
result = load_best_checkpoint("cob_rl_agent")
if result:
file_path, metadata = result
self.model_states['cob_rl']['initial_loss'] = metadata.loss
self.model_states['cob_rl']['current_loss'] = metadata.loss
self.model_states['cob_rl']['best_loss'] = metadata.loss
self.model_states['cob_rl']['checkpoint_loaded'] = True
self.model_states['cob_rl']['checkpoint_filename'] = metadata.checkpoint_id
checkpoint_loaded = True
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "N/A"
logger.info(f"COB RL checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
except Exception as e:
logger.warning(f"Error loading COB RL checkpoint: {e}")
if not checkpoint_loaded:
# New model - no synthetic data, start fresh
self.model_states['cob_rl']['initial_loss'] = None
self.model_states['cob_rl']['current_loss'] = None
self.model_states['cob_rl']['best_loss'] = None
self.model_states['cob_rl']['checkpoint_filename'] = 'none (fresh start)'
logger.info("COB RL starting fresh - no checkpoint found")
logger.info("COB RL Agent initialized and integrated with unified ModelManager")
logger.info(" - Uses @checkpoints/ directory structure")
logger.info(" - Follows same load/save/checkpoint flow as other models")
logger.info(" - Integrated with enhanced real-time training system")
except ImportError as e:
logger.warning(f"COB RL Model not available: {e}")
self.cob_rl_agent = None
# Initialize TRANSFORMER Model
try:

View File

@@ -34,7 +34,8 @@ import os
# Local imports
from .cob_integration import COBIntegration
from .trading_executor import TradingExecutor
from NN.models.cob_rl_model import MassiveRLNetwork, COBRLModelInterface
# UNIFIED: Import only the interface, models come from orchestrator
from NN.models.cob_rl_model import COBRLModelInterface
logger = logging.getLogger(__name__)
@@ -98,51 +99,44 @@ class RealtimeRLCOBTrader:
Real-time RL trader using COB data with comprehensive subscriber system
"""
def __init__(self,
def __init__(self,
symbols: Optional[List[str]] = None,
trading_executor: Optional[TradingExecutor] = None,
model_checkpoint_dir: str = "models/realtime_rl_cob",
orchestrator: Any = None, # UNIFIED: Use orchestrator's models
inference_interval_ms: int = 200,
min_confidence_threshold: float = 0.35, # Lowered from 0.7 for more aggressive trading
required_confident_predictions: int = 3,
checkpoint_manager: Any = None):
required_confident_predictions: int = 3):
self.symbols = symbols or ['BTC/USDT', 'ETH/USDT']
self.trading_executor = trading_executor
self.model_checkpoint_dir = model_checkpoint_dir
self.orchestrator = orchestrator # UNIFIED: Use orchestrator's models
self.inference_interval_ms = inference_interval_ms
self.min_confidence_threshold = min_confidence_threshold
self.required_confident_predictions = required_confident_predictions
# Initialize ModelManager (either provided or get global instance)
if checkpoint_manager is None:
from NN.training.model_manager import create_model_manager
self.checkpoint_manager = create_model_manager()
# UNIFIED: Use orchestrator's ModelManager instead of creating our own
if self.orchestrator and hasattr(self.orchestrator, 'model_manager'):
self.model_manager = self.orchestrator.model_manager
else:
self.checkpoint_manager = checkpoint_manager
from NN.training.model_manager import create_model_manager
self.model_manager = create_model_manager()
# Track start time for training duration calculation
self.start_time = datetime.now() # Initialize start_time
# Setup device
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {self.device}")
# Initialize models for each symbol
self.models: Dict[str, MassiveRLNetwork] = {}
self.optimizers: Dict[str, optim.AdamW] = {}
self.scalers: Dict[str, torch.cuda.amp.GradScaler] = {}
for symbol in self.symbols:
model = MassiveRLNetwork().to(self.device)
self.models[symbol] = model
self.optimizers[symbol] = optim.AdamW(
model.parameters(),
lr=1e-5, # Low learning rate for stability
weight_decay=1e-6,
betas=(0.9, 0.999)
)
self.scalers[symbol] = torch.cuda.amp.GradScaler()
self.start_time = datetime.now()
# UNIFIED: Use orchestrator's COB RL model
if not self.orchestrator or not hasattr(self.orchestrator, 'cob_rl_agent') or not self.orchestrator.cob_rl_agent:
raise ValueError("RealtimeRLCOBTrader requires orchestrator with COB RL model. Please initialize TradingOrchestrator first.")
# Use orchestrator's unified COB RL model
self.cob_rl_model = self.orchestrator.cob_rl_agent
self.device = self.orchestrator.cob_rl_agent.device if hasattr(self.orchestrator.cob_rl_agent, 'device') else torch.device('cpu')
logger.info(f"Using orchestrator's unified COB RL model on device: {self.device}")
# Create unified model references for all symbols
self.models = {symbol: self.cob_rl_model.model for symbol in self.symbols}
self.optimizers = {symbol: self.cob_rl_model.optimizer for symbol in self.symbols}
self.scalers = {symbol: self.cob_rl_model.scaler for symbol in self.symbols}
# Subscriber system for real-time events
self.prediction_subscribers: List[Callable[[PredictionResult], None]] = []
@@ -906,56 +900,67 @@ class RealtimeRLCOBTrader:
return reward
async def _train_batch(self, symbol: str, predictions: List[PredictionResult]) -> float:
"""Train model on a batch of predictions"""
"""Train model on a batch of predictions using unified approach"""
try:
model = self.models[symbol]
optimizer = self.optimizers[symbol]
scaler = self.scalers[symbol]
# UNIFIED: Always use orchestrator's COB RL model
return self._train_batch_unified(predictions)
except Exception as e:
logger.error(f"Error training batch for {symbol}: {e}")
return 0.0
def _train_batch_unified(self, predictions: List[PredictionResult]) -> float:
"""Train using unified COB RL model from orchestrator"""
try:
model = self.cob_rl_model.model
optimizer = self.cob_rl_model.optimizer
scaler = self.cob_rl_model.scaler
model.train()
optimizer.zero_grad()
# Prepare batch data
features = torch.stack([
torch.from_numpy(p.features) for p in predictions
]).to(self.device)
# Targets
direction_targets = torch.tensor([
p.actual_direction for p in predictions
], dtype=torch.long).to(self.device)
value_targets = torch.tensor([
p.reward for p in predictions
], dtype=torch.float32).to(self.device)
# Forward pass with mixed precision
with torch.cuda.amp.autocast():
outputs = model(features)
# Calculate losses
direction_loss = nn.CrossEntropyLoss()(outputs['price_logits'], direction_targets)
value_loss = nn.MSELoss()(outputs['value'].squeeze(), value_targets)
# Confidence loss (encourage high confidence for correct predictions)
correct_predictions = (torch.argmax(outputs['price_logits'], dim=1) == direction_targets).float()
confidence_loss = nn.BCELoss()(outputs['confidence'].squeeze(), correct_predictions)
# Combined loss
total_loss = direction_loss + 0.5 * value_loss + 0.3 * confidence_loss
# Backward pass with gradient scaling
scaler.scale(total_loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
return total_loss.item()
except Exception as e:
logger.error(f"Error training batch for {symbol}: {e}")
logger.error(f"Error in unified training batch: {e}")
return 0.0
async def _train_on_trade_execution(self, symbol: str, signals: List[PredictionResult],
action: str, price: float):
@@ -1015,68 +1020,99 @@ class RealtimeRLCOBTrader:
await asyncio.sleep(60)
def _save_models(self):
"""Save all models to disk using CheckpointManager"""
"""Save models using unified ModelManager approach"""
try:
for symbol in self.symbols:
model_name = f"cob_rl_{symbol.replace('/', '_').lower()}" # Standardize model name for CheckpointManager
# Prepare performance metrics for CheckpointManager
if self.cob_rl_model:
# UNIFIED: Use orchestrator's COB RL model with ModelManager
performance_metrics = {
'loss': self.training_stats[symbol].get('average_loss', 0.0),
'reward': self.training_stats[symbol].get('average_reward', 0.0), # Assuming average_reward is tracked
'accuracy': self.training_stats[symbol].get('average_accuracy', 0.0), # Assuming average_accuracy is tracked
'loss': self._get_average_loss(),
'reward': self._get_average_reward(),
'accuracy': self._get_average_accuracy(),
}
if self.trading_executor: # Add check for trading_executor
daily_stats = self.trading_executor.get_daily_stats()
performance_metrics['pnl'] = daily_stats.get('total_pnl', 0.0) # Example, get actual pnl
performance_metrics['training_samples'] = self.training_stats[symbol].get('total_training_steps', 0)
# Prepare training metadata for CheckpointManager
# Add P&L if trading executor is available
if self.trading_executor and hasattr(self.trading_executor, 'get_daily_stats'):
try:
daily_stats = self.trading_executor.get_daily_stats()
performance_metrics['pnl'] = daily_stats.get('total_pnl', 0.0)
except Exception:
performance_metrics['pnl'] = 0.0
performance_metrics['training_samples'] = sum(
stats.get('total_training_steps', 0) for stats in self.training_stats.values()
)
# Prepare training metadata
training_metadata = {
'total_parameters': sum(p.numel() for p in self.models[symbol].parameters()),
'epoch': self.training_stats[symbol].get('total_training_steps', 0), # Using total_training_steps as pseudo-epoch
'total_parameters': sum(p.numel() for p in self.cob_rl_model.model.parameters()),
'epoch': max(stats.get('total_training_steps', 0) for stats in self.training_stats.values()),
'training_time_hours': (datetime.now() - self.start_time).total_seconds() / 3600
}
self.checkpoint_manager.save_checkpoint(
model=self.models[symbol],
model_name=model_name,
model_type='COB_RL', # Specify model type
# Save using unified ModelManager
self.model_manager.save_checkpoint(
model=self.cob_rl_model.model,
model_name="cob_rl_agent",
model_type='COB_RL',
performance_metrics=performance_metrics,
training_metadata=training_metadata
)
logger.debug(f"Saved model for {symbol}")
logger.info("COB RL model saved using unified ModelManager")
else:
# This should not happen with proper initialization
logger.error("Unified COB RL model not available - check orchestrator initialization")
except Exception as e:
logger.error(f"Error saving models: {e}")
def _load_models(self):
"""Load existing models from disk using CheckpointManager"""
"""Load models using unified ModelManager approach"""
try:
for symbol in self.symbols:
model_name = f"cob_rl_{symbol.replace('/', '_').lower()}" # Standardize model name for CheckpointManager
loaded_checkpoint = self.checkpoint_manager.load_best_checkpoint(model_name)
if self.cob_rl_model:
# UNIFIED: Load using ModelManager
loaded_checkpoint = self.model_manager.load_best_checkpoint("cob_rl_agent")
if loaded_checkpoint:
model_path, metadata = loaded_checkpoint
checkpoint = torch.load(model_path, map_location=self.device)
self.models[symbol].load_state_dict(checkpoint['model_state_dict'])
self.optimizers[symbol].load_state_dict(checkpoint['optimizer_state_dict'])
if 'training_stats' in checkpoint:
self.training_stats[symbol].update(checkpoint['training_stats'])
if 'inference_stats' in checkpoint:
self.inference_stats[symbol].update(checkpoint['inference_stats'])
logger.info(f"Loaded existing model for {symbol} from checkpoint: {metadata.checkpoint_id}")
self.cob_rl_model.model.load_state_dict(checkpoint['model_state_dict'])
self.cob_rl_model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Update training stats for all symbols with loaded data
for symbol in self.symbols:
if 'training_stats' in checkpoint:
self.training_stats[symbol].update(checkpoint['training_stats'])
if 'inference_stats' in checkpoint:
self.inference_stats[symbol].update(checkpoint['inference_stats'])
logger.info(f"Loaded unified COB RL model from checkpoint: {metadata.checkpoint_id}")
else:
logger.info(f"No existing model found for {symbol} via CheckpointManager, starting fresh.")
logger.info("No existing COB RL model found via ModelManager, starting fresh.")
else:
# This should not happen with proper initialization
logger.error("Unified COB RL model not available - check orchestrator initialization")
except Exception as e:
logger.error(f"Error loading models: {e}")
def _get_average_loss(self) -> float:
"""Get average loss across all symbols"""
losses = [stats.get('average_loss', 0.0) for stats in self.training_stats.values() if stats.get('average_loss') is not None]
return sum(losses) / len(losses) if losses else 0.0
def _get_average_reward(self) -> float:
"""Get average reward across all symbols"""
rewards = [stats.get('average_reward', 0.0) for stats in self.training_stats.values() if stats.get('average_reward') is not None]
return sum(rewards) / len(rewards) if rewards else 0.0
def _get_average_accuracy(self) -> float:
"""Get average accuracy across all symbols"""
accuracies = [stats.get('average_accuracy', 0.0) for stats in self.training_stats.values() if stats.get('average_accuracy') is not None]
return sum(accuracies) / len(accuracies) if accuracies else 0.0
def get_performance_stats(self) -> Dict[str, Any]:
"""Get comprehensive performance statistics"""
@@ -1119,36 +1155,49 @@ class RealtimeRLCOBTrader:
# Example usage
async def main():
"""Example usage of RealtimeRLCOBTrader"""
"""Example usage of unified RealtimeRLCOBTrader"""
from ..core.orchestrator import TradingOrchestrator
from ..core.trading_executor import TradingExecutor
# Initialize orchestrator (which now includes unified COB RL model)
orchestrator = TradingOrchestrator()
# Initialize trading executor (simulation mode)
trading_executor = TradingExecutor()
# Initialize real-time RL trader
# Initialize real-time RL trader with unified orchestrator
trader = RealtimeRLCOBTrader(
symbols=['BTC/USDT', 'ETH/USDT'],
trading_executor=trading_executor,
orchestrator=orchestrator, # UNIFIED: Use orchestrator's models
inference_interval_ms=200,
min_confidence_threshold=0.7,
required_confident_predictions=3
)
try:
# Start the trader
# Start the orchestrator first (initializes all models)
await orchestrator.start()
# Start the trader (uses orchestrator's unified COB RL model)
await trader.start()
# Run for demonstration
logger.info("Real-time RL COB Trader running...")
logger.info("Real-time RL COB Trader running with unified orchestrator...")
await asyncio.sleep(300) # Run for 5 minutes
# Print performance stats
stats = trader.get_performance_stats()
logger.info(f"Performance stats: {json.dumps(stats, indent=2, default=str)}")
# Print performance stats from both systems
orchestrator_stats = orchestrator.get_model_stats()
trader_stats = trader.get_performance_stats()
logger.info("=== ORCHESTRATOR STATS ===")
logger.info(f"Model stats: {json.dumps(orchestrator_stats, indent=2, default=str)}")
logger.info("=== TRADER STATS ===")
logger.info(f"Performance stats: {json.dumps(trader_stats, indent=2, default=str)}")
finally:
# Stop the trader
# Stop both systems
await trader.stop()
await orchestrator.stop()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

View File

@@ -168,8 +168,8 @@ def start_web_ui(port=8051):
except ImportError:
model_registry = {}
# Initialize checkpoint management for dashboard
dashboard_checkpoint_manager = get_checkpoint_manager()
# Initialize unified model management for dashboard
dashboard_checkpoint_manager = create_model_manager()
dashboard_training_integration = get_training_integration()
# Create unified orchestrator for the dashboard
@@ -206,8 +206,8 @@ async def start_training_loop(orchestrator, trading_executor):
logger.info("STARTING ENHANCED TRAINING LOOP WITH COB INTEGRATION")
logger.info("=" * 70)
# Initialize checkpoint management for training loop
checkpoint_manager = get_checkpoint_manager()
# Initialize unified model management for training loop
checkpoint_manager = create_model_manager()
training_integration = get_training_integration()
# Training statistics for checkpoint management

View File

@@ -6261,7 +6261,7 @@ class CleanTradingDashboard:
# Save checkpoint after training
if loss_count > 0:
try:
from utils.checkpoint_manager import save_checkpoint
from NN.training.model_manager import save_checkpoint
avg_loss = total_loss / loss_count
# Prepare checkpoint data
@@ -6390,7 +6390,7 @@ class CleanTradingDashboard:
# Save checkpoint after training
if loss_count > 0:
try:
from utils.checkpoint_manager import save_checkpoint
from NN.training.model_manager import save_checkpoint
avg_loss = total_loss / loss_count
# Prepare checkpoint data
@@ -6878,7 +6878,7 @@ class CleanTradingDashboard:
# Save checkpoint after training
if training_samples > 0:
try:
from utils.checkpoint_manager import save_checkpoint
from NN.training.model_manager import save_checkpoint
avg_loss = total_loss / loss_count if loss_count > 0 else 0.356
# Prepare checkpoint data

View File

@@ -443,14 +443,20 @@ class DashboardComponentManager:
ask_levels = [center_bucket + i * bucket_size for i in range(1, num_levels + 1)]
bid_levels = [center_bucket - i * bucket_size for i in range(num_levels)]
# Debug: Log how many orders we have to work with
print(f"DEBUG COB: {symbol} - Processing {len(bids)} bids, {len(asks)} asks")
print(f"DEBUG COB: Mid price: ${mid_price:.2f}, Bucket size: ${bucket_size}")
print(f"DEBUG COB: Bid buckets: {len(bid_buckets)}, Ask buckets: {len(ask_buckets)}")
if bid_buckets:
print(f"DEBUG COB: Bid price range: ${min(bid_buckets.keys()):.2f} - ${max(bid_buckets.keys()):.2f}")
if ask_buckets:
print(f"DEBUG COB: Ask price range: ${min(ask_buckets.keys()):.2f} - ${max(ask_buckets.keys()):.2f}")
# Debug: Combined log for COB ladder panel
print(
f"DEBUG COB: {symbol} - {len(bids)} bids, {len(asks)} asks | "
f"Mid price: ${mid_price:.2f}, ${bucket_size} buckets | "
f"Bid buckets: {len(bid_buckets)}, Ask buckets: {len(ask_buckets)}"
+ (
f" | Bid range: ${min(bid_buckets.keys()):.2f} - ${max(bid_buckets.keys()):.2f}"
if bid_buckets else ""
)
+ (
f" | Ask range: ${min(ask_buckets.keys()):.2f} - ${max(ask_buckets.keys()):.2f}"
if ask_buckets else ""
)
)
def create_bookmap_row(price, bid_data, ask_data, max_vol):
"""Create a Bookmap-style row with horizontal bars extending from center"""