checkpoint manager

This commit is contained in:
Dobromir Popov
2025-07-23 21:40:04 +03:00
parent bab39fa68f
commit 45a62443a0
9 changed files with 1587 additions and 709 deletions

View File

@ -0,0 +1,276 @@
"""
CNN Dashboard Integration
This module integrates the EnhancedCNN model with the dashboard, providing real-time
training and visualization of model predictions.
"""
import logging
import threading
import time
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
import os
import json
from .enhanced_cnn_adapter import EnhancedCNNAdapter
from .data_models import BaseDataInput, ModelOutput, create_model_output
from utils.training_integration import get_training_integration
logger = logging.getLogger(__name__)
class CNNDashboardIntegration:
"""
Integrates the EnhancedCNN model with the dashboard
This class:
1. Loads and initializes the CNN model
2. Processes real-time data for model inference
3. Manages continuous training of the model
4. Provides visualization data for the dashboard
"""
def __init__(self, data_provider=None, checkpoint_dir: str = "models/enhanced_cnn"):
"""
Initialize the CNN dashboard integration
Args:
data_provider: Data provider instance
checkpoint_dir: Directory to save checkpoints to
"""
self.data_provider = data_provider
self.checkpoint_dir = checkpoint_dir
self.cnn_adapter = None
self.training_thread = None
self.training_active = False
self.training_interval = 60 # Train every 60 seconds
self.training_samples = []
self.max_training_samples = 1000
self.last_training_time = 0
self.last_predictions = {}
self.performance_metrics = {}
self.model_name = "enhanced_cnn_v1"
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
# Initialize CNN adapter
self._initialize_cnn_adapter()
logger.info(f"CNNDashboardIntegration initialized with checkpoint_dir: {checkpoint_dir}")
def _initialize_cnn_adapter(self):
"""Initialize the CNN adapter"""
try:
# Import here to avoid circular imports
from .enhanced_cnn_adapter import EnhancedCNNAdapter
# Create CNN adapter
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=self.checkpoint_dir)
# Load best checkpoint if available
self.cnn_adapter.load_best_checkpoint()
logger.info("CNN adapter initialized successfully")
except Exception as e:
logger.error(f"Error initializing CNN adapter: {e}")
self.cnn_adapter = None
def start_training_thread(self):
"""Start the training thread"""
if self.training_thread is not None and self.training_thread.is_alive():
logger.info("Training thread already running")
return
self.training_active = True
self.training_thread = threading.Thread(target=self._training_loop, daemon=True)
self.training_thread.start()
logger.info("CNN training thread started")
def stop_training_thread(self):
"""Stop the training thread"""
self.training_active = False
if self.training_thread is not None:
self.training_thread.join(timeout=5)
self.training_thread = None
logger.info("CNN training thread stopped")
def _training_loop(self):
"""Training loop for continuous model training"""
while self.training_active:
try:
# Check if it's time to train
current_time = time.time()
if current_time - self.last_training_time >= self.training_interval and len(self.training_samples) >= 10:
logger.info(f"Training CNN model with {len(self.training_samples)} samples")
# Train model
if self.cnn_adapter is not None:
metrics = self.cnn_adapter.train(epochs=1)
# Update performance metrics
self.performance_metrics = {
'loss': metrics.get('loss', 0.0),
'accuracy': metrics.get('accuracy', 0.0),
'samples': metrics.get('samples', 0),
'last_training': datetime.now().isoformat()
}
# Log training metrics
logger.info(f"CNN training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}")
# Update last training time
self.last_training_time = current_time
# Sleep to avoid high CPU usage
time.sleep(1)
except Exception as e:
logger.error(f"Error in CNN training loop: {e}")
time.sleep(5) # Sleep longer on error
def process_data(self, symbol: str, base_data: BaseDataInput) -> Optional[ModelOutput]:
"""
Process data for model inference and training
Args:
symbol: Trading symbol
base_data: Standardized input data
Returns:
Optional[ModelOutput]: Model output, or None if processing failed
"""
try:
if self.cnn_adapter is None:
logger.warning("CNN adapter not initialized")
return None
# Make prediction
model_output = self.cnn_adapter.predict(base_data)
# Store prediction
self.last_predictions[symbol] = model_output
# Store model output in data provider
if self.data_provider is not None:
self.data_provider.store_model_output(model_output)
return model_output
except Exception as e:
logger.error(f"Error processing data for CNN model: {e}")
return None
def add_training_sample(self, base_data: BaseDataInput, actual_action: str, reward: float):
"""
Add a training sample
Args:
base_data: Standardized input data
actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
reward: Reward received for the action
"""
try:
if self.cnn_adapter is None:
logger.warning("CNN adapter not initialized")
return
# Add training sample to CNN adapter
self.cnn_adapter.add_training_sample(base_data, actual_action, reward)
# Add to local training samples
self.training_samples.append((base_data.symbol, actual_action, reward))
# Limit training samples
if len(self.training_samples) > self.max_training_samples:
self.training_samples = self.training_samples[-self.max_training_samples:]
logger.debug(f"Added training sample for {base_data.symbol}, action: {actual_action}, reward: {reward:.4f}")
except Exception as e:
logger.error(f"Error adding training sample: {e}")
def get_performance_metrics(self) -> Dict[str, Any]:
"""
Get performance metrics
Returns:
Dict[str, Any]: Performance metrics
"""
metrics = self.performance_metrics.copy()
# Add additional metrics
metrics['training_samples'] = len(self.training_samples)
metrics['model_name'] = self.model_name
# Add last prediction metrics
if self.last_predictions:
for symbol, prediction in self.last_predictions.items():
metrics[f'{symbol}_last_action'] = prediction.predictions.get('action', 'UNKNOWN')
metrics[f'{symbol}_last_confidence'] = prediction.confidence
return metrics
def get_visualization_data(self, symbol: str) -> Dict[str, Any]:
"""
Get visualization data for the dashboard
Args:
symbol: Trading symbol
Returns:
Dict[str, Any]: Visualization data
"""
data = {
'model_name': self.model_name,
'symbol': symbol,
'timestamp': datetime.now().isoformat(),
'performance_metrics': self.get_performance_metrics()
}
# Add last prediction
if symbol in self.last_predictions:
prediction = self.last_predictions[symbol]
data['last_prediction'] = {
'action': prediction.predictions.get('action', 'UNKNOWN'),
'confidence': prediction.confidence,
'timestamp': prediction.timestamp.isoformat(),
'buy_probability': prediction.predictions.get('buy_probability', 0.0),
'sell_probability': prediction.predictions.get('sell_probability', 0.0),
'hold_probability': prediction.predictions.get('hold_probability', 0.0)
}
# Add training samples summary
symbol_samples = [s for s in self.training_samples if s[0] == symbol]
data['training_samples'] = {
'total': len(symbol_samples),
'buy': len([s for s in symbol_samples if s[1] == 'BUY']),
'sell': len([s for s in symbol_samples if s[1] == 'SELL']),
'hold': len([s for s in symbol_samples if s[1] == 'HOLD']),
'avg_reward': sum(s[2] for s in symbol_samples) / len(symbol_samples) if symbol_samples else 0.0
}
return data
# Global CNN dashboard integration instance
_cnn_dashboard_integration = None
def get_cnn_dashboard_integration(data_provider=None) -> CNNDashboardIntegration:
"""
Get the global CNN dashboard integration instance
Args:
data_provider: Data provider instance
Returns:
CNNDashboardIntegration: Global CNN dashboard integration instance
"""
global _cnn_dashboard_integration
if _cnn_dashboard_integration is None:
_cnn_dashboard_integration = CNNDashboardIntegration(data_provider=data_provider)
return _cnn_dashboard_integration

View File

@ -1467,12 +1467,10 @@ class DataProvider:
# Update COB data cache for distribution # Update COB data cache for distribution
binance_symbol = symbol.replace('/', '').upper() binance_symbol = symbol.replace('/', '').upper()
if binance_symbol not in self.cob_data_cache or self.cob_data_cache[binance_symbol] is None: if binance_symbol not in self.cob_data_cache or self.cob_data_cache[binance_symbol] is None:
from collections import deque
self.cob_data_cache[binance_symbol] = deque(maxlen=300) self.cob_data_cache[binance_symbol] = deque(maxlen=300)
# Ensure the deque is properly initialized # Ensure the deque is properly initialized
if not isinstance(self.cob_data_cache[binance_symbol], deque): if not isinstance(self.cob_data_cache[binance_symbol], deque):
from collections import deque
self.cob_data_cache[binance_symbol] = deque(maxlen=300) self.cob_data_cache[binance_symbol] = deque(maxlen=300)
self.cob_data_cache[binance_symbol].append({ self.cob_data_cache[binance_symbol].append({

View File

@ -0,0 +1,430 @@
"""
Enhanced CNN Adapter for Standardized Input Format
This module provides an adapter for the EnhancedCNN model to work with the standardized
BaseDataInput format, enabling seamless integration with the multi-modal trading system.
"""
import torch
import numpy as np
import logging
import os
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Any, Union
from threading import Lock
from .data_models import BaseDataInput, ModelOutput, create_model_output
from NN.models.enhanced_cnn import EnhancedCNN
logger = logging.getLogger(__name__)
class EnhancedCNNAdapter:
"""
Adapter for EnhancedCNN model to work with standardized BaseDataInput format
This adapter:
1. Converts BaseDataInput to the format expected by EnhancedCNN
2. Processes model outputs to create standardized ModelOutput
3. Manages model training with collected data
4. Handles checkpoint management
"""
def __init__(self, model_path: str = None, checkpoint_dir: str = "models/enhanced_cnn"):
"""
Initialize the EnhancedCNN adapter
Args:
model_path: Path to load model from, if None a new model is created
checkpoint_dir: Directory to save checkpoints to
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = None
self.model_path = model_path
self.checkpoint_dir = checkpoint_dir
self.training_lock = Lock()
self.training_data = []
self.max_training_samples = 10000
self.batch_size = 32
self.learning_rate = 0.0001
self.model_name = "enhanced_cnn_v1"
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
# Initialize model
self._initialize_model()
logger.info(f"EnhancedCNNAdapter initialized with device: {self.device}")
def _initialize_model(self):
"""Initialize the EnhancedCNN model"""
try:
# Calculate input shape based on BaseDataInput structure
# OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features
# BTC OHLCV: 300 frames x 5 features = 1500 features
# COB: ±20 buckets x 4 metrics = 160 features
# MA: 4 timeframes x 10 buckets = 40 features
# Technical indicators: 100 features
# Last predictions: 50 features
# Total: 7850 features
input_shape = 7850
n_actions = 3 # BUY, SELL, HOLD
# Create model
self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
self.model.to(self.device)
# Load model if path is provided
if self.model_path:
success = self.model.load(self.model_path)
if success:
logger.info(f"Model loaded from {self.model_path}")
else:
logger.warning(f"Failed to load model from {self.model_path}, using new model")
else:
logger.info("No model path provided, using new model")
except Exception as e:
logger.error(f"Error initializing EnhancedCNN model: {e}")
raise
def _convert_base_data_to_features(self, base_data: BaseDataInput) -> torch.Tensor:
"""
Convert BaseDataInput to feature vector for EnhancedCNN
Args:
base_data: Standardized input data
Returns:
torch.Tensor: Feature vector for EnhancedCNN
"""
try:
# Use the get_feature_vector method from BaseDataInput
features = base_data.get_feature_vector()
# Convert to torch tensor
features_tensor = torch.tensor(features, dtype=torch.float32, device=self.device)
return features_tensor
except Exception as e:
logger.error(f"Error converting BaseDataInput to features: {e}")
# Return empty tensor with correct shape
return torch.zeros(7850, dtype=torch.float32, device=self.device)
def predict(self, base_data: BaseDataInput) -> ModelOutput:
"""
Make a prediction using the EnhancedCNN model
Args:
base_data: Standardized input data
Returns:
ModelOutput: Standardized model output
"""
try:
# Convert BaseDataInput to features
features = self._convert_base_data_to_features(base_data)
# Ensure features has batch dimension
if features.dim() == 1:
features = features.unsqueeze(0)
# Set model to evaluation mode
self.model.eval()
# Make prediction
with torch.no_grad():
q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.model(features)
# Get action and confidence
action_probs = torch.softmax(q_values, dim=1)
action_idx = torch.argmax(action_probs, dim=1).item()
confidence = float(action_probs[0, action_idx].item())
# Map action index to action string
actions = ['BUY', 'SELL', 'HOLD']
action = actions[action_idx]
# Create predictions dictionary
predictions = {
'action': action,
'buy_probability': float(action_probs[0, 0].item()),
'sell_probability': float(action_probs[0, 1].item()),
'hold_probability': float(action_probs[0, 2].item()),
'extrema': extrema_pred.squeeze(0).cpu().numpy().tolist(),
'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist()
}
# Create hidden states dictionary
hidden_states = {
'features': features_refined.squeeze(0).cpu().numpy().tolist()
}
# Create metadata dictionary
metadata = {
'model_version': '1.0',
'timestamp': datetime.now().isoformat(),
'input_shape': features.shape
}
# Create ModelOutput
model_output = ModelOutput(
model_type='cnn',
model_name=self.model_name,
symbol=base_data.symbol,
timestamp=datetime.now(),
confidence=confidence,
predictions=predictions,
hidden_states=hidden_states,
metadata=metadata
)
return model_output
except Exception as e:
logger.error(f"Error making prediction with EnhancedCNN: {e}")
# Return default ModelOutput
return create_model_output(
model_type='cnn',
model_name=self.model_name,
symbol=base_data.symbol,
action='HOLD',
confidence=0.0
)
def add_training_sample(self, base_data: BaseDataInput, actual_action: str, reward: float):
"""
Add a training sample to the training data
Args:
base_data: Standardized input data
actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
reward: Reward received for the action
"""
try:
# Convert BaseDataInput to features
features = self._convert_base_data_to_features(base_data)
# Convert action to index
actions = ['BUY', 'SELL', 'HOLD']
action_idx = actions.index(actual_action)
# Add to training data
with self.training_lock:
self.training_data.append((features, action_idx, reward))
# Limit training data size
if len(self.training_data) > self.max_training_samples:
# Sort by reward (highest first) and keep top samples
self.training_data.sort(key=lambda x: x[2], reverse=True)
self.training_data = self.training_data[:self.max_training_samples]
logger.debug(f"Added training sample for {base_data.symbol}, action: {actual_action}, reward: {reward:.4f}")
except Exception as e:
logger.error(f"Error adding training sample: {e}")
def train(self, epochs: int = 1) -> Dict[str, float]:
"""
Train the model with collected data
Args:
epochs: Number of epochs to train for
Returns:
Dict[str, float]: Training metrics
"""
try:
with self.training_lock:
# Check if we have enough data
if len(self.training_data) < self.batch_size:
logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
# Set model to training mode
self.model.train()
# Create optimizer
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
# Training metrics
total_loss = 0.0
correct_predictions = 0
total_predictions = 0
# Train for specified number of epochs
for epoch in range(epochs):
# Shuffle training data
np.random.shuffle(self.training_data)
# Process in batches
for i in range(0, len(self.training_data), self.batch_size):
batch = self.training_data[i:i+self.batch_size]
# Skip if batch is too small
if len(batch) < 2:
continue
# Prepare batch
features = torch.stack([sample[0] for sample in batch])
actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
q_values, _, _, _, _ = self.model(features)
# Calculate loss (CrossEntropyLoss with reward weighting)
# First, apply softmax to get probabilities
probs = torch.softmax(q_values, dim=1)
# Get probability of chosen action
chosen_probs = probs[torch.arange(len(actions)), actions]
# Calculate negative log likelihood loss
nll_loss = -torch.log(chosen_probs + 1e-10)
# Weight by reward (higher reward = higher weight)
# Normalize rewards to [0, 1] range
min_reward = rewards.min()
max_reward = rewards.max()
if max_reward > min_reward:
normalized_rewards = (rewards - min_reward) / (max_reward - min_reward)
else:
normalized_rewards = torch.ones_like(rewards)
# Apply reward weighting (higher reward = higher weight)
weighted_loss = nll_loss * (normalized_rewards + 0.1) # Add small constant to avoid zero weights
# Mean loss
loss = weighted_loss.mean()
# Backward pass
loss.backward()
# Update weights
optimizer.step()
# Update metrics
total_loss += loss.item()
# Calculate accuracy
predicted_actions = torch.argmax(q_values, dim=1)
correct_predictions += (predicted_actions == actions).sum().item()
total_predictions += len(actions)
# Calculate final metrics
avg_loss = total_loss / (len(self.training_data) / self.batch_size)
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
# Save checkpoint
self._save_checkpoint(avg_loss, accuracy)
logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}")
return {
'loss': avg_loss,
'accuracy': accuracy,
'samples': len(self.training_data)
}
except Exception as e:
logger.error(f"Error training model: {e}")
return {'loss': 0.0, 'accuracy': 0.0, 'samples': 0, 'error': str(e)}
def _save_checkpoint(self, loss: float, accuracy: float):
"""
Save model checkpoint
Args:
loss: Training loss
accuracy: Training accuracy
"""
try:
# Import checkpoint manager
from utils.checkpoint_manager import CheckpointManager
# Create checkpoint manager
checkpoint_manager = CheckpointManager(
checkpoint_dir=self.checkpoint_dir,
max_checkpoints=10,
metric_name="accuracy"
)
# Create temporary model file
temp_path = os.path.join(self.checkpoint_dir, f"{self.model_name}_temp")
self.model.save(temp_path)
# Create metrics
metrics = {
'loss': loss,
'accuracy': accuracy,
'samples': len(self.training_data)
}
# Create metadata
metadata = {
'timestamp': datetime.now().isoformat(),
'model_name': self.model_name,
'input_shape': self.model.input_shape,
'n_actions': self.model.n_actions
}
# Save checkpoint
checkpoint_path = checkpoint_manager.save_checkpoint(
model_name=self.model_name,
model_path=f"{temp_path}.pt",
metrics=metrics,
metadata=metadata
)
# Delete temporary model file
if os.path.exists(f"{temp_path}.pt"):
os.remove(f"{temp_path}.pt")
logger.info(f"Model checkpoint saved to {checkpoint_path}")
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
def load_best_checkpoint(self):
"""Load the best checkpoint based on accuracy"""
try:
# Import checkpoint manager
from utils.checkpoint_manager import CheckpointManager
# Create checkpoint manager
checkpoint_manager = CheckpointManager(
checkpoint_dir=self.checkpoint_dir,
max_checkpoints=10,
metric_name="accuracy"
)
# Load best checkpoint
best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name)
if not best_checkpoint_path:
logger.info("No checkpoints found")
return False
# Load model
success = self.model.load(best_checkpoint_path)
if success:
logger.info(f"Loaded best checkpoint from {best_checkpoint_path}")
# Log metrics
metrics = best_checkpoint_metadata.get('metrics', {})
logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}")
return True
else:
logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}")
return False
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return False

View File

@ -1,34 +1,31 @@
""" """
Model Output Manager Model Output Manager
This module provides extensible model output storage and management for the multi-modal trading system. This module provides a centralized storage and management system for model outputs,
Supports CNN, RL, LSTM, Transformer, and future model types with cross-model feeding capabilities. enabling cross-model feeding and evaluation.
""" """
import logging import os
import json import json
import pickle import logging
from datetime import datetime, timedelta import time
from typing import Dict, List, Optional, Any, Union from datetime import datetime
from collections import deque, defaultdict from typing import Dict, List, Optional, Any
from threading import Lock from threading import Lock
from pathlib import Path
from .data_models import ModelOutput, create_model_output from .data_models import ModelOutput
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ModelOutputManager: class ModelOutputManager:
""" """
Extensible model output storage and management system Centralized storage and management system for model outputs
Features: This class:
- Standardized ModelOutput storage for all model types 1. Stores model outputs for all models
- Cross-model feeding with hidden states 2. Provides access to current and historical outputs
- Historical output tracking 3. Handles persistence of outputs to disk
- Metadata management 4. Supports evaluation of model performance
- Persistence and recovery
- Performance analytics
""" """
def __init__(self, cache_dir: str = "cache/model_outputs", max_history: int = 1000): def __init__(self, cache_dir: str = "cache/model_outputs", max_history: int = 1000):
@ -36,279 +33,226 @@ class ModelOutputManager:
Initialize the model output manager Initialize the model output manager
Args: Args:
cache_dir: Directory for persistent storage cache_dir: Directory to store model outputs
max_history: Maximum number of outputs to keep in memory per model max_history: Maximum number of historical outputs to keep per model
""" """
self.cache_dir = Path(cache_dir) self.cache_dir = cache_dir
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.max_history = max_history self.max_history = max_history
self.outputs_lock = Lock()
# In-memory storage # Current outputs for each model and symbol
self.current_outputs: Dict[str, Dict[str, ModelOutput]] = defaultdict(dict) # {symbol: {model_name: ModelOutput}} # {symbol: {model_name: ModelOutput}}
self.output_history: Dict[str, Dict[str, deque]] = defaultdict(lambda: defaultdict(lambda: deque(maxlen=max_history))) # {symbol: {model_name: deque}} self.current_outputs: Dict[str, Dict[str, ModelOutput]] = {}
self.cross_model_states: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: hidden_states}}
# Metadata tracking # Historical outputs for each model and symbol
self.model_metadata: Dict[str, Dict[str, Any]] = defaultdict(dict) # {model_name: metadata} # {symbol: {model_name: List[ModelOutput]}}
self.performance_stats: Dict[str, Dict[str, Any]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: stats}} self.historical_outputs: Dict[str, Dict[str, List[ModelOutput]]] = {}
# Thread safety # Performance metrics for each model and symbol
self.storage_lock = Lock() # {symbol: {model_name: Dict[str, float]}}
self.performance_metrics: Dict[str, Dict[str, Dict[str, float]]] = {}
# Supported model types # Create cache directory if it doesn't exist
self.supported_model_types = { os.makedirs(cache_dir, exist_ok=True)
'cnn', 'rl', 'lstm', 'transformer', 'orchestrator',
'ensemble', 'hybrid', 'custom' # Extensible for future types
}
logger.info(f"ModelOutputManager initialized with cache dir: {self.cache_dir}") logger.info(f"ModelOutputManager initialized with cache_dir: {cache_dir}")
logger.info(f"Supported model types: {self.supported_model_types}")
def store_output(self, model_output: ModelOutput) -> bool: def store_output(self, model_output: ModelOutput) -> bool:
""" """
Store model output with full extensibility support Store a model output
Args: Args:
model_output: ModelOutput from any model type model_output: Model output to store
Returns: Returns:
bool: True if stored successfully, False otherwise bool: True if successful, False otherwise
""" """
try: try:
with self.storage_lock: symbol = model_output.symbol
symbol = model_output.symbol model_name = model_output.model_name
model_name = model_output.model_name
model_type = model_output.model_type with self.outputs_lock:
# Initialize dictionaries if they don't exist
# Validate model type (extensible) if symbol not in self.current_outputs:
if model_type not in self.supported_model_types: self.current_outputs[symbol] = {}
logger.warning(f"Unknown model type '{model_type}' - adding to supported types") if symbol not in self.historical_outputs:
self.supported_model_types.add(model_type) self.historical_outputs[symbol] = {}
if model_name not in self.historical_outputs[symbol]:
self.historical_outputs[symbol][model_name] = []
# Store current output # Store current output
self.current_outputs[symbol][model_name] = model_output self.current_outputs[symbol][model_name] = model_output
# Add to history # Add to historical outputs
self.output_history[symbol][model_name].append(model_output) self.historical_outputs[symbol][model_name].append(model_output)
# Store cross-model states if available
if model_output.hidden_states:
self.cross_model_states[symbol][model_name] = model_output.hidden_states
# Update model metadata
self._update_model_metadata(model_name, model_type, model_output.metadata)
# Update performance statistics
self._update_performance_stats(symbol, model_name, model_output)
# Persist to disk (async to avoid blocking)
self._persist_output_async(model_output)
logger.debug(f"Stored output from {model_name} ({model_type}) for {symbol}")
return True
# Limit historical outputs
if len(self.historical_outputs[symbol][model_name]) > self.max_history:
self.historical_outputs[symbol][model_name] = self.historical_outputs[symbol][model_name][-self.max_history:]
# Persist output to disk
self._persist_output(model_output)
return True
except Exception as e: except Exception as e:
logger.error(f"Error storing model output: {e}") logger.error(f"Error storing model output: {e}")
return False return False
def get_current_output(self, symbol: str, model_name: str) -> Optional[ModelOutput]: def get_current_output(self, symbol: str, model_name: str) -> Optional[ModelOutput]:
""" """
Get the current (latest) output from a specific model Get the current output for a model and symbol
Args: Args:
symbol: Trading symbol symbol: Symbol to get output for
model_name: Name of the model model_name: Model name to get output for
Returns: Returns:
ModelOutput: Latest output from the model, or None if not available ModelOutput: Current output, or None if not available
""" """
try: try:
return self.current_outputs.get(symbol, {}).get(model_name) with self.outputs_lock:
if symbol in self.current_outputs and model_name in self.current_outputs[symbol]:
return self.current_outputs[symbol][model_name]
return None
except Exception as e: except Exception as e:
logger.error(f"Error getting current output for {model_name}: {e}") logger.error(f"Error getting current output: {e}")
return None return None
def get_all_current_outputs(self, symbol: str) -> Dict[str, ModelOutput]: def get_all_current_outputs(self, symbol: str) -> Dict[str, ModelOutput]:
""" """
Get all current outputs for a symbol (for cross-model feeding) Get all current outputs for a symbol
Args: Args:
symbol: Trading symbol symbol: Symbol to get outputs for
Returns: Returns:
Dict[str, ModelOutput]: Dictionary of current outputs by model name Dict[str, ModelOutput]: Dictionary of model name to output
""" """
try: try:
return dict(self.current_outputs.get(symbol, {})) with self.outputs_lock:
if symbol in self.current_outputs:
return self.current_outputs[symbol].copy()
return {}
except Exception as e: except Exception as e:
logger.error(f"Error getting all current outputs for {symbol}: {e}") logger.error(f"Error getting all current outputs: {e}")
return {} return {}
def get_output_history(self, symbol: str, model_name: str, count: int = 10) -> List[ModelOutput]: def get_historical_outputs(self, symbol: str, model_name: str, limit: int = None) -> List[ModelOutput]:
""" """
Get historical outputs from a model Get historical outputs for a model and symbol
Args: Args:
symbol: Trading symbol symbol: Symbol to get outputs for
model_name: Name of the model model_name: Model name to get outputs for
count: Number of historical outputs to retrieve limit: Maximum number of outputs to return, None for all
Returns: Returns:
List[ModelOutput]: List of historical outputs (most recent first) List[ModelOutput]: List of historical outputs
""" """
try: try:
history = self.output_history.get(symbol, {}).get(model_name, deque()) with self.outputs_lock:
return list(history)[-count:][::-1] # Most recent first if symbol in self.historical_outputs and model_name in self.historical_outputs[symbol]:
outputs = self.historical_outputs[symbol][model_name]
if limit is not None:
outputs = outputs[-limit:]
return outputs.copy()
return []
except Exception as e: except Exception as e:
logger.error(f"Error getting output history for {model_name}: {e}") logger.error(f"Error getting historical outputs: {e}")
return [] return []
def get_cross_model_states(self, symbol: str, requesting_model: str) -> Dict[str, Dict[str, Any]]: def evaluate_model_performance(self, symbol: str, model_name: str) -> Dict[str, float]:
""" """
Get hidden states from other models for cross-model feeding Evaluate model performance based on historical outputs
Args: Args:
symbol: Trading symbol symbol: Symbol to evaluate
requesting_model: Name of the model requesting the states model_name: Model name to evaluate
Returns: Returns:
Dict[str, Dict[str, Any]]: Hidden states from other models Dict[str, float]: Performance metrics
""" """
try: try:
all_states = self.cross_model_states.get(symbol, {}) # Get historical outputs
# Return states from all models except the requesting one outputs = self.get_historical_outputs(symbol, model_name)
return {model_name: states for model_name, states in all_states.items()
if model_name != requesting_model}
except Exception as e:
logger.error(f"Error getting cross-model states for {requesting_model}: {e}")
return {}
def get_model_types_active(self, symbol: str) -> List[str]:
"""
Get list of active model types for a symbol
Args:
symbol: Trading symbol
Returns:
List[str]: List of active model types
"""
try:
current_outputs = self.current_outputs.get(symbol, {})
return [output.model_type for output in current_outputs.values()]
except Exception as e:
logger.error(f"Error getting active model types for {symbol}: {e}")
return []
def get_consensus_prediction(self, symbol: str, confidence_threshold: float = 0.5) -> Optional[Dict[str, Any]]:
"""
Get consensus prediction from all active models
Args:
symbol: Trading symbol
confidence_threshold: Minimum confidence threshold for inclusion
Returns:
Dict containing consensus prediction or None
"""
try:
current_outputs = self.current_outputs.get(symbol, {})
if not current_outputs:
return None
# Filter by confidence threshold if not outputs:
high_confidence_outputs = [ return {'accuracy': 0.0, 'confidence': 0.0, 'samples': 0}
output for output in current_outputs.values()
if output.confidence >= confidence_threshold
]
if not high_confidence_outputs: # Calculate metrics
return None total_outputs = len(outputs)
total_confidence = sum(output.confidence for output in outputs)
avg_confidence = total_confidence / total_outputs if total_outputs > 0 else 0.0
# Calculate consensus # For now, we don't have ground truth to calculate accuracy
buy_votes = sum(1 for output in high_confidence_outputs # In the future, we can add this by comparing predictions to actual market movements
if output.predictions.get('action') == 'BUY')
sell_votes = sum(1 for output in high_confidence_outputs
if output.predictions.get('action') == 'SELL')
hold_votes = sum(1 for output in high_confidence_outputs
if output.predictions.get('action') == 'HOLD')
total_votes = len(high_confidence_outputs) metrics = {
avg_confidence = sum(output.confidence for output in high_confidence_outputs) / total_votes
# Determine consensus action
if buy_votes > sell_votes and buy_votes > hold_votes:
consensus_action = 'BUY'
elif sell_votes > buy_votes and sell_votes > hold_votes:
consensus_action = 'SELL'
else:
consensus_action = 'HOLD'
return {
'action': consensus_action,
'confidence': avg_confidence, 'confidence': avg_confidence,
'votes': {'BUY': buy_votes, 'SELL': sell_votes, 'HOLD': hold_votes}, 'samples': total_outputs,
'total_models': total_votes, 'last_update': datetime.now().isoformat()
'model_types': [output.model_type for output in high_confidence_outputs]
} }
except Exception as e: # Store metrics
logger.error(f"Error calculating consensus prediction for {symbol}: {e}") with self.outputs_lock:
return None if symbol not in self.performance_metrics:
self.performance_metrics[symbol] = {}
def _update_model_metadata(self, model_name: str, model_type: str, metadata: Dict[str, Any]): self.performance_metrics[symbol][model_name] = metrics
"""Update metadata for a model"""
try:
if model_name not in self.model_metadata:
self.model_metadata[model_name] = {
'model_type': model_type,
'first_seen': datetime.now(),
'total_predictions': 0,
'custom_metadata': {}
}
self.model_metadata[model_name]['total_predictions'] += 1 return metrics
self.model_metadata[model_name]['last_seen'] = datetime.now()
# Merge custom metadata
if metadata:
self.model_metadata[model_name]['custom_metadata'].update(metadata)
except Exception as e:
logger.error(f"Error updating model metadata: {e}")
def _update_performance_stats(self, symbol: str, model_name: str, model_output: ModelOutput):
"""Update performance statistics for a model"""
try:
stats = self.performance_stats[symbol][model_name]
if 'prediction_count' not in stats:
stats['prediction_count'] = 0
stats['confidence_sum'] = 0.0
stats['action_counts'] = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
stats['first_prediction'] = model_output.timestamp
stats['prediction_count'] += 1
stats['confidence_sum'] += model_output.confidence
stats['avg_confidence'] = stats['confidence_sum'] / stats['prediction_count']
stats['last_prediction'] = model_output.timestamp
action = model_output.predictions.get('action', 'HOLD')
if action in stats['action_counts']:
stats['action_counts'][action] += 1
except Exception as e: except Exception as e:
logger.error(f"Error updating performance stats: {e}") logger.error(f"Error evaluating model performance: {e}")
return {'error': str(e)}
def _persist_output_async(self, model_output: ModelOutput): def get_performance_metrics(self, symbol: str, model_name: str) -> Dict[str, float]:
"""Persist model output to disk (simplified version)""" """
Get performance metrics for a model and symbol
Args:
symbol: Symbol to get metrics for
model_name: Model name to get metrics for
Returns:
Dict[str, float]: Performance metrics
"""
try: try:
# Create filename based on model and timestamp with self.outputs_lock:
timestamp_str = model_output.timestamp.strftime("%Y%m%d_%H%M%S") if symbol in self.performance_metrics and model_name in self.performance_metrics[symbol]:
filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp_str}.json" return self.performance_metrics[symbol][model_name].copy()
filepath = self.cache_dir / filename
# Convert to JSON-serializable format # If no metrics are available, calculate them
return self.evaluate_model_performance(symbol, model_name)
except Exception as e:
logger.error(f"Error getting performance metrics: {e}")
return {'error': str(e)}
def _persist_output(self, model_output: ModelOutput) -> bool:
"""
Persist a model output to disk
Args:
model_output: Model output to persist
Returns:
bool: True if successful, False otherwise
"""
try:
# Create directory if it doesn't exist
symbol_dir = os.path.join(self.cache_dir, model_output.symbol.replace('/', '_'))
os.makedirs(symbol_dir, exist_ok=True)
# Create filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp}.json"
filepath = os.path.join(self.cache_dir, filename)
# Convert ModelOutput to dictionary
output_dict = { output_dict = {
'model_type': model_output.model_type, 'model_type': model_output.model_type,
'model_name': model_output.model_name, 'model_name': model_output.model_name,
@ -319,77 +263,120 @@ class ModelOutputManager:
'metadata': model_output.metadata 'metadata': model_output.metadata
} }
# Save to file (in a real implementation, this would be async) # Don't store hidden states in file (too large)
# Write to file
with open(filepath, 'w') as f: with open(filepath, 'w') as f:
json.dump(output_dict, f, indent=2) json.dump(output_dict, f, indent=2)
return True
except Exception as e: except Exception as e:
logger.error(f"Error persisting model output: {e}") logger.error(f"Error persisting model output: {e}")
return False
def get_performance_summary(self, symbol: str) -> Dict[str, Any]: def load_outputs_from_disk(self, symbol: str = None, model_name: str = None) -> int:
""" """
Get performance summary for all models for a symbol Load model outputs from disk
Args: Args:
symbol: Trading symbol symbol: Symbol to load outputs for, None for all
model_name: Model name to load outputs for, None for all
Returns: Returns:
Dict containing performance summary int: Number of outputs loaded
""" """
try: try:
summary = { # Find all output files
'symbol': symbol, import glob
'active_models': len(self.current_outputs.get(symbol, {})),
'model_stats': {}
}
for model_name, stats in self.performance_stats.get(symbol, {}).items(): if symbol and model_name:
summary['model_stats'][model_name] = { pattern = os.path.join(self.cache_dir, f"{model_name}_{symbol.replace('/', '_')}*.json")
'predictions': stats.get('prediction_count', 0), elif symbol:
'avg_confidence': round(stats.get('avg_confidence', 0.0), 3), pattern = os.path.join(self.cache_dir, f"*_{symbol.replace('/', '_')}*.json")
'action_distribution': stats.get('action_counts', {}), elif model_name:
'model_type': self.model_metadata.get(model_name, {}).get('model_type', 'unknown') pattern = os.path.join(self.cache_dir, f"{model_name}_*.json")
} else:
pattern = os.path.join(self.cache_dir, "*.json")
return summary output_files = glob.glob(pattern)
if not output_files:
logger.info(f"No output files found for pattern: {pattern}")
return 0
# Load each file
loaded_count = 0
for filepath in output_files:
try:
with open(filepath, 'r') as f:
output_dict = json.load(f)
# Create ModelOutput
model_output = ModelOutput(
model_type=output_dict['model_type'],
model_name=output_dict['model_name'],
symbol=output_dict['symbol'],
timestamp=datetime.fromisoformat(output_dict['timestamp']),
confidence=output_dict['confidence'],
predictions=output_dict['predictions'],
hidden_states={}, # Don't load hidden states from disk
metadata=output_dict.get('metadata', {})
)
# Store output
self.store_output(model_output)
loaded_count += 1
except Exception as e:
logger.error(f"Error loading output file {filepath}: {e}")
logger.info(f"Loaded {loaded_count} model outputs from disk")
return loaded_count
except Exception as e: except Exception as e:
logger.error(f"Error getting performance summary: {e}") logger.error(f"Error loading outputs from disk: {e}")
return {'symbol': symbol, 'error': str(e)} return 0
def cleanup_old_outputs(self, max_age_hours: int = 24): def cleanup_old_outputs(self, max_age_days: int = 30) -> int:
""" """
Clean up old outputs to manage memory usage Clean up old output files
Args: Args:
max_age_hours: Maximum age of outputs to keep in hours max_age_days: Maximum age of files to keep in days
Returns:
int: Number of files deleted
""" """
try: try:
cutoff_time = datetime.now() - timedelta(hours=max_age_hours) # Find all output files
import glob
output_files = glob.glob(os.path.join(self.cache_dir, "*.json"))
with self.storage_lock: if not output_files:
for symbol in self.output_history: return 0
for model_name in self.output_history[symbol]:
history = self.output_history[symbol][model_name]
# Remove old outputs
while history and history[0].timestamp < cutoff_time:
history.popleft()
logger.info(f"Cleaned up outputs older than {max_age_hours} hours") # Calculate cutoff time
cutoff_time = time.time() - (max_age_days * 24 * 60 * 60)
# Delete old files
deleted_count = 0
for filepath in output_files:
try:
# Get file modification time
mtime = os.path.getmtime(filepath)
# Delete if older than cutoff
if mtime < cutoff_time:
os.remove(filepath)
deleted_count += 1
except Exception as e:
logger.error(f"Error deleting file {filepath}: {e}")
logger.info(f"Deleted {deleted_count} old model output files")
return deleted_count
except Exception as e: except Exception as e:
logger.error(f"Error cleaning up old outputs: {e}") logger.error(f"Error cleaning up old outputs: {e}")
return 0
def add_custom_model_type(self, model_type: str):
"""
Add support for a new custom model type
Args:
model_type: Name of the new model type
"""
self.supported_model_types.add(model_type)
logger.info(f"Added support for custom model type: {model_type}")
def get_supported_model_types(self) -> List[str]:
"""Get list of all supported model types"""
return list(self.supported_model_types)

View File

@ -0,0 +1,155 @@
"""
Test Continuous CNN Training
This script demonstrates how the CNN model can be trained with each new inference result
using collected data, implementing a continuous learning loop.
"""
import logging
import time
from datetime import datetime
import random
import os
from core.standardized_data_provider import StandardizedDataProvider
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
from core.data_models import create_model_output
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def simulate_market_feedback(action, symbol):
"""
Simulate market feedback for a given action
In a real system, this would be replaced with actual market performance data
Args:
action: Trading action ('BUY', 'SELL', 'HOLD')
symbol: Trading symbol
Returns:
tuple: (actual_action, reward)
"""
# Simulate market movement (random for demonstration)
market_direction = random.choice(['up', 'down', 'sideways'])
# Determine actual best action based on market direction
if market_direction == 'up':
best_action = 'BUY'
elif market_direction == 'down':
best_action = 'SELL'
else:
best_action = 'HOLD'
# Calculate reward based on whether the action matched the best action
if action == best_action:
reward = random.uniform(0.01, 0.1) # Positive reward for correct action
else:
reward = random.uniform(-0.1, -0.01) # Negative reward for incorrect action
logger.info(f"Market went {market_direction}, best action was {best_action}, model chose {action}, reward: {reward:.4f}")
return best_action, reward
def test_continuous_training():
"""Test continuous training of the CNN model with new inference results"""
try:
# Initialize data provider
symbols = ['ETH/USDT', 'BTC/USDT']
timeframes = ['1s', '1m', '1h', '1d']
data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
# Initialize CNN adapter
checkpoint_dir = "models/enhanced_cnn"
os.makedirs(checkpoint_dir, exist_ok=True)
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=checkpoint_dir)
# Load best checkpoint if available
cnn_adapter.load_best_checkpoint()
# Continuous learning loop
num_iterations = 10
training_frequency = 3 # Train every N iterations
samples_collected = 0
logger.info(f"Starting continuous learning loop with {num_iterations} iterations")
for i in range(num_iterations):
logger.info(f"\nIteration {i+1}/{num_iterations}")
# Get standardized input data
symbol = random.choice(symbols)
logger.info(f"Getting data for {symbol}...")
base_data = data_provider.get_base_data_input(symbol)
if base_data is None:
logger.warning(f"Failed to get base data input for {symbol}, skipping iteration")
continue
# Make prediction
logger.info(f"Making prediction for {symbol}...")
model_output = cnn_adapter.predict(base_data)
# Log prediction
action = model_output.predictions['action']
confidence = model_output.confidence
logger.info(f"Prediction: {action} with confidence {confidence:.4f}")
# Store model output
data_provider.store_model_output(model_output)
# Simulate market feedback
best_action, reward = simulate_market_feedback(action, symbol)
# Add training sample
logger.info(f"Adding training sample: action={best_action}, reward={reward:.4f}")
cnn_adapter.add_training_sample(base_data, best_action, reward)
samples_collected += 1
# Train model periodically
if (i + 1) % training_frequency == 0 and samples_collected >= 3:
logger.info(f"Training model with {samples_collected} samples...")
metrics = cnn_adapter.train(epochs=1)
# Log training metrics
logger.info(f"Training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}")
# Simulate time passing
time.sleep(1)
logger.info("\nContinuous learning loop completed")
# Final evaluation
logger.info("Performing final evaluation...")
# Get data for evaluation
symbol = 'ETH/USDT'
base_data = data_provider.get_base_data_input(symbol)
if base_data is not None:
# Make prediction
model_output = cnn_adapter.predict(base_data)
# Log prediction
action = model_output.predictions['action']
confidence = model_output.confidence
logger.info(f"Final prediction for {symbol}: {action} with confidence {confidence:.4f}")
# Get model output manager
output_manager = data_provider.get_model_output_manager()
# Evaluate model performance
metrics = output_manager.evaluate_model_performance(symbol, cnn_adapter.model_name)
logger.info(f"Performance metrics: {metrics}")
else:
logger.warning(f"Failed to get base data input for final evaluation")
logger.info("Test completed successfully")
except Exception as e:
logger.error(f"Error in test: {e}", exc_info=True)
if __name__ == "__main__":
test_continuous_training()

View File

@ -0,0 +1,87 @@
"""
Test Enhanced CNN Adapter
This script tests the EnhancedCNNAdapter with standardized input format.
"""
import logging
import time
from datetime import datetime
from core.standardized_data_provider import StandardizedDataProvider
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
from core.data_models import create_model_output
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_cnn_adapter():
"""Test the EnhancedCNNAdapter with standardized input format"""
try:
# Initialize data provider
symbols = ['ETH/USDT', 'BTC/USDT']
timeframes = ['1s', '1m', '1h', '1d']
data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
# Initialize CNN adapter
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
# Load best checkpoint if available
cnn_adapter.load_best_checkpoint()
# Get standardized input data
logger.info("Getting standardized input data...")
base_data = data_provider.get_base_data_input('ETH/USDT')
if base_data is None:
logger.error("Failed to get base data input")
return
# Make prediction
logger.info("Making prediction...")
model_output = cnn_adapter.predict(base_data)
# Log prediction
logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}")
# Store model output
data_provider.store_model_output(model_output)
# Add training sample (simulated)
logger.info("Adding training sample...")
cnn_adapter.add_training_sample(base_data, 'BUY', 0.05)
# Train model
logger.info("Training model...")
metrics = cnn_adapter.train(epochs=1)
# Log training metrics
logger.info(f"Training metrics: {metrics}")
# Make another prediction
logger.info("Making another prediction...")
model_output = cnn_adapter.predict(base_data)
# Log prediction
logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}")
# Test model output manager
logger.info("Testing model output manager...")
output_manager = data_provider.get_model_output_manager()
# Get current outputs
current_outputs = output_manager.get_all_current_outputs('ETH/USDT')
logger.info(f"Current outputs: {len(current_outputs)} models")
# Evaluate model performance
metrics = output_manager.evaluate_model_performance('ETH/USDT', 'enhanced_cnn_v1')
logger.info(f"Performance metrics: {metrics}")
logger.info("Test completed successfully")
except Exception as e:
logger.error(f"Error in test: {e}", exc_info=True)
if __name__ == "__main__":
test_cnn_adapter()

View File

@ -0,0 +1,3 @@
"""
Utils package for the multi-modal trading system
"""

View File

@ -1,466 +1,408 @@
#!/usr/bin/env python3 """
""" Checkpoint Manager
Checkpoint Management System for W&B Training
This module provides functionality for managing model checkpoints, including:
- Saving checkpoints with metadata
- Loading the best checkpoint based on performance metrics
- Cleaning up old or underperforming checkpoints
""" """
import os import os
import json import json
import glob
import logging import logging
from datetime import datetime, timedelta import shutil
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, asdict
from collections import defaultdict
import torch import torch
import random from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass # Global checkpoint manager instance
class CheckpointMetadata: _checkpoint_manager_instance = None
checkpoint_id: str
model_name: str def get_checkpoint_manager(checkpoint_dir: str = "models/checkpoints", max_checkpoints: int = 10, metric_name: str = "accuracy") -> 'CheckpointManager':
model_type: str """
file_path: str Get the global checkpoint manager instance
created_at: datetime
file_size_mb: float
performance_score: float
accuracy: Optional[float] = None
loss: Optional[float] = None
val_accuracy: Optional[float] = None
val_loss: Optional[float] = None
reward: Optional[float] = None
pnl: Optional[float] = None
epoch: Optional[int] = None
training_time_hours: Optional[float] = None
total_parameters: Optional[int] = None
wandb_run_id: Optional[str] = None
wandb_artifact_name: Optional[str] = None
def to_dict(self) -> Dict[str, Any]: Args:
data = asdict(self) checkpoint_dir: Directory to store checkpoints
data['created_at'] = self.created_at.isoformat() max_checkpoints: Maximum number of checkpoints to keep
return data metric_name: Metric to use for ranking checkpoints
@classmethod Returns:
def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata': CheckpointManager: Global checkpoint manager instance
data['created_at'] = datetime.fromisoformat(data['created_at']) """
return cls(**data) global _checkpoint_manager_instance
if _checkpoint_manager_instance is None:
_checkpoint_manager_instance = CheckpointManager(
checkpoint_dir=checkpoint_dir,
max_checkpoints=max_checkpoints,
metric_name=metric_name
)
return _checkpoint_manager_instance
def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Dict[str, Any] = None, checkpoint_dir: str = "models/checkpoints") -> Any:
"""
Save a checkpoint with metadata
Args:
model: The model to save
model_name: Name of the model
model_type: Type of the model ('cnn', 'rl', etc.)
performance_metrics: Performance metrics
training_metadata: Additional training metadata
checkpoint_dir: Directory to store checkpoints
Returns:
Any: Checkpoint metadata
"""
try:
# Create checkpoint directory
os.makedirs(checkpoint_dir, exist_ok=True)
# Create timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create checkpoint path
model_dir = os.path.join(checkpoint_dir, model_name)
os.makedirs(model_dir, exist_ok=True)
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp}")
# Save model
if hasattr(model, 'save'):
# Use model's save method if available
model.save(checkpoint_path)
else:
# Otherwise, save state_dict
torch_path = f"{checkpoint_path}.pt"
torch.save({
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else None,
'model_name': model_name,
'model_type': model_type,
'timestamp': timestamp
}, torch_path)
# Create metadata
checkpoint_metadata = {
'model_name': model_name,
'model_type': model_type,
'timestamp': timestamp,
'performance_metrics': performance_metrics,
'training_metadata': training_metadata or {},
'checkpoint_id': f"{model_name}_{timestamp}"
}
# Add performance score for sorting
primary_metric = 'accuracy' if 'accuracy' in performance_metrics else 'reward'
checkpoint_metadata['performance_score'] = performance_metrics.get(primary_metric, 0.0)
checkpoint_metadata['created_at'] = timestamp
# Save metadata
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
json.dump(checkpoint_metadata, f, indent=2)
# Get checkpoint manager and clean up old checkpoints
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
checkpoint_manager._cleanup_checkpoints(model_name)
# Return metadata as an object
class CheckpointMetadata:
def __init__(self, metadata):
for key, value in metadata.items():
setattr(self, key, value)
return CheckpointMetadata(checkpoint_metadata)
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
return None
def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoints") -> Optional[Tuple[str, Any]]:
"""
Load the best checkpoint based on performance metrics
Args:
model_name: Name of the model
checkpoint_dir: Directory to store checkpoints
Returns:
Optional[Tuple[str, Any]]: Path to the best checkpoint and its metadata, or None if not found
"""
try:
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
checkpoint_path, checkpoint_metadata = checkpoint_manager.load_best_checkpoint(model_name)
if not checkpoint_path:
return None
# Convert metadata to object
class CheckpointMetadata:
def __init__(self, metadata):
for key, value in metadata.items():
setattr(self, key, value)
# Add performance score if not present
if not hasattr(self, 'performance_score'):
metrics = getattr(self, 'metrics', {})
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
self.performance_score = metrics.get(primary_metric, 0.0)
# Add created_at if not present
if not hasattr(self, 'created_at'):
self.created_at = getattr(self, 'timestamp', 'unknown')
return f"{checkpoint_path}.pt", CheckpointMetadata(checkpoint_metadata)
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return None
class CheckpointManager: class CheckpointManager:
def __init__(self, """
base_checkpoint_dir: str = "NN/models/saved", Manages model checkpoints with performance-based optimization
max_checkpoints_per_model: int = 5,
metadata_file: str = "checkpoint_metadata.json",
enable_wandb: bool = True):
self.base_dir = Path(base_checkpoint_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
self.max_checkpoints = max_checkpoints_per_model
self.metadata_file = self.base_dir / metadata_file
self.enable_wandb = enable_wandb and WANDB_AVAILABLE
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
self._load_metadata()
logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}")
def save_checkpoint(self, model, model_name: str, model_type: str, This class:
performance_metrics: Dict[str, float], 1. Saves checkpoints with metadata
training_metadata: Optional[Dict[str, Any]] = None, 2. Loads the best checkpoint based on performance metrics
force_save: bool = False) -> Optional[CheckpointMetadata]: 3. Cleans up old or underperforming checkpoints
"""
def __init__(self, checkpoint_dir: str, max_checkpoints: int = 10, metric_name: str = "accuracy"):
"""
Initialize the checkpoint manager
Args:
checkpoint_dir: Directory to store checkpoints
max_checkpoints: Maximum number of checkpoints to keep
metric_name: Metric to use for ranking checkpoints
"""
self.checkpoint_dir = checkpoint_dir
self.max_checkpoints = max_checkpoints
self.metric_name = metric_name
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
logger.info(f"CheckpointManager initialized with checkpoint_dir: {checkpoint_dir}")
def save_checkpoint(self, model_name: str, model_path: str, metrics: Dict[str, float], metadata: Dict[str, Any] = None) -> str:
"""
Save a checkpoint with metadata
Args:
model_name: Name of the model
model_path: Path to the model file
metrics: Performance metrics
metadata: Additional metadata
Returns:
str: Path to the saved checkpoint
"""
try: try:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') # Create timestamp
checkpoint_id = f"{model_name}_{timestamp}" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir = self.base_dir / model_name # Create checkpoint directory
model_dir.mkdir(exist_ok=True) checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = model_dir / f"{checkpoint_id}.pt" # Create checkpoint path
checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_{timestamp}")
performance_score = self._calculate_performance_score(performance_metrics) # Copy model file to checkpoint path
shutil.copy2(model_path, f"{checkpoint_path}.pt")
if not force_save and not self._should_save_checkpoint(model_name, performance_score): # Create metadata
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved") checkpoint_metadata = {
return None 'model_name': model_name,
'timestamp': timestamp,
success = self._save_model_file(model, checkpoint_path, model_type) 'metrics': metrics,
if not success: 'metadata': metadata or {}
return None
file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024)
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name=model_name,
model_type=model_type,
file_path=str(checkpoint_path),
created_at=datetime.now(),
file_size_mb=file_size_mb,
performance_score=performance_score,
accuracy=performance_metrics.get('accuracy'),
loss=performance_metrics.get('loss'),
val_accuracy=performance_metrics.get('val_accuracy'),
val_loss=performance_metrics.get('val_loss'),
reward=performance_metrics.get('reward'),
pnl=performance_metrics.get('pnl'),
epoch=training_metadata.get('epoch') if training_metadata else None,
training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None,
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
)
if self.enable_wandb and wandb.run is not None:
artifact_name = self._upload_to_wandb(checkpoint_path, metadata)
metadata.wandb_run_id = wandb.run.id
metadata.wandb_artifact_name = artifact_name
self.checkpoints[model_name].append(metadata)
self._rotate_checkpoints(model_name)
self._save_metadata()
logger.debug(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})")
return metadata
except Exception as e:
logger.error(f"Error saving checkpoint for {model_name}: {e}")
return None
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
try:
# First, try the standard checkpoint system
if model_name in self.checkpoints and self.checkpoints[model_name]:
# Filter out checkpoints with non-existent files
valid_checkpoints = [
cp for cp in self.checkpoints[model_name]
if Path(cp.file_path).exists()
]
if valid_checkpoints:
best_checkpoint = max(valid_checkpoints, key=lambda x: x.performance_score)
logger.debug(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}")
return best_checkpoint.file_path, best_checkpoint
else:
# Clean up invalid metadata entries
invalid_count = len(self.checkpoints[model_name])
logger.warning(f"Found {invalid_count} invalid checkpoint entries for {model_name}, cleaning up metadata")
self.checkpoints[model_name] = []
self._save_metadata()
# Fallback: Look for existing saved models in the legacy format
logger.debug(f"No valid checkpoints found for model: {model_name}, attempting to find legacy saved models")
legacy_model_path = self._find_legacy_model(model_name)
if legacy_model_path:
# Create checkpoint metadata for the legacy model using actual file data
legacy_metadata = self._create_legacy_metadata(model_name, legacy_model_path)
logger.debug(f"Found legacy model for {model_name}: {legacy_model_path}")
return str(legacy_model_path), legacy_metadata
logger.warning(f"No checkpoints or legacy models found for: {model_name}")
return None
except Exception as e:
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
return None
def _calculate_performance_score(self, metrics: Dict[str, float]) -> float:
"""Calculate performance score with improved sensitivity for training models"""
score = 0.0
# Prioritize loss reduction for active training models
if 'loss' in metrics:
# Invert loss so lower loss = higher score, with better scaling
loss_value = metrics['loss']
if loss_value > 0:
score += max(0, 100 / (1 + loss_value)) # More sensitive to loss changes
else:
score += 100 # Perfect loss
# Add other metrics with appropriate weights
if 'accuracy' in metrics:
score += metrics['accuracy'] * 50 # Reduced weight to balance with loss
if 'val_accuracy' in metrics:
score += metrics['val_accuracy'] * 50
if 'val_loss' in metrics:
val_loss = metrics['val_loss']
if val_loss > 0:
score += max(0, 50 / (1 + val_loss))
if 'reward' in metrics:
score += metrics['reward'] * 10
if 'pnl' in metrics:
score += metrics['pnl'] * 5
if 'training_samples' in metrics:
# Bonus for processing more training samples
score += min(10, metrics['training_samples'] / 10)
# Return actual calculated score - NO SYNTHETIC MINIMUM
return score
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
"""Improved checkpoint saving logic with more frequent saves during training"""
if model_name not in self.checkpoints or not self.checkpoints[model_name]:
return True # Always save first checkpoint
# Allow more checkpoints during active training
if len(self.checkpoints[model_name]) < self.max_checkpoints:
return True
# Get current best and worst scores
scores = [cp.performance_score for cp in self.checkpoints[model_name]]
best_score = max(scores)
worst_score = min(scores)
# Save if better than worst (more frequent saves)
if performance_score > worst_score:
return True
# For high-performing models (score > 100), be more sensitive to small improvements
if best_score > 100:
# Save if within 0.1% of best score (very sensitive for converged models)
if performance_score >= best_score * 0.999:
return True
else:
# Also save if we're within 10% of best score (capture near-optimal models)
if performance_score >= best_score * 0.9:
return True
# Save more frequently during active training (every 5th attempt instead of 10th)
if random.random() < 0.2: # 20% chance to save anyway
logger.debug(f"Saving checkpoint for {model_name} - periodic save during active training")
return True
return False
def _save_model_file(self, model, file_path: Path, model_type: str) -> bool:
try:
if hasattr(model, 'state_dict'):
torch.save({
'model_state_dict': model.state_dict(),
'model_type': model_type,
'saved_at': datetime.now().isoformat()
}, file_path)
else:
torch.save(model, file_path)
return True
except Exception as e:
logger.error(f"Error saving model file {file_path}: {e}")
return False
def _rotate_checkpoints(self, model_name: str):
checkpoint_list = self.checkpoints[model_name]
if len(checkpoint_list) <= self.max_checkpoints:
return
checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True)
to_remove = checkpoint_list[self.max_checkpoints:]
self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints]
for checkpoint in to_remove:
try:
file_path = Path(checkpoint.file_path)
if file_path.exists():
file_path.unlink()
logger.debug(f"Rotated out checkpoint: {checkpoint.checkpoint_id}")
except Exception as e:
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]:
try:
if not self.enable_wandb or wandb.run is None:
return None
artifact_name = f"{metadata.model_name}_checkpoint"
artifact = wandb.Artifact(artifact_name, type="model")
artifact.add_file(str(file_path))
wandb.log_artifact(artifact)
return artifact_name
except Exception as e:
logger.error(f"Error uploading to W&B: {e}")
return None
def _load_metadata(self):
try:
if self.metadata_file.exists():
with open(self.metadata_file, 'r') as f:
data = json.load(f)
for model_name, checkpoint_list in data.items():
self.checkpoints[model_name] = [
CheckpointMetadata.from_dict(cp_data)
for cp_data in checkpoint_list
]
logger.info(f"Loaded metadata for {len(self.checkpoints)} models")
except Exception as e:
logger.error(f"Error loading checkpoint metadata: {e}")
def _save_metadata(self):
try:
data = {}
for model_name, checkpoint_list in self.checkpoints.items():
data[model_name] = [cp.to_dict() for cp in checkpoint_list]
with open(self.metadata_file, 'w') as f:
json.dump(data, f, indent=2)
except Exception as e:
logger.error(f"Error saving checkpoint metadata: {e}")
def get_checkpoint_stats(self):
"""Get statistics about managed checkpoints"""
stats = {
'total_models': len(self.checkpoints),
'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()),
'total_size_mb': 0.0,
'models': {}
}
for model_name, checkpoint_list in self.checkpoints.items():
if not checkpoint_list:
continue
model_size = sum(cp.file_size_mb for cp in checkpoint_list)
best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score)
stats['models'][model_name] = {
'checkpoint_count': len(checkpoint_list),
'total_size_mb': model_size,
'best_performance': best_checkpoint.performance_score,
'best_checkpoint_id': best_checkpoint.checkpoint_id,
'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id
} }
stats['total_size_mb'] += model_size # Save metadata
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
return stats json.dump(checkpoint_metadata, f, indent=2)
def _find_legacy_model(self, model_name: str) -> Optional[Path]: logger.info(f"Saved checkpoint to {checkpoint_path}")
"""Find legacy saved models based on model name patterns"""
base_dir = Path(self.base_dir) # Clean up old checkpoints
self._cleanup_checkpoints(model_name)
# Define model name mappings and patterns for legacy files
legacy_patterns = { return checkpoint_path
'dqn_agent': [
'dqn_agent_best_policy.pt',
'enhanced_dqn_best_policy.pt',
'improved_dqn_agent_best_policy.pt',
'dqn_agent_final_policy.pt'
],
'enhanced_cnn': [
'cnn_model_best.pt',
'optimized_short_term_model_best.pt',
'optimized_short_term_model_realtime_best.pt',
'optimized_short_term_model_ticks_best.pt'
],
'extrema_trainer': [
'supervised_model_best.pt'
],
'cob_rl': [
'best_rl_model.pth_policy.pt',
'rl_agent_best_policy.pt'
],
'decision': [
# Decision models might be in subdirectories, but let's check main dir too
'decision_best.pt',
'decision_model_best.pt',
# Check for transformer models which might be used as decision models
'enhanced_dqn_best_policy.pt',
'improved_dqn_agent_best_policy.pt'
]
}
# Get patterns for this model name
patterns = legacy_patterns.get(model_name, [])
# Also try generic patterns based on model name
patterns.extend([
f'{model_name}_best.pt',
f'{model_name}_best_policy.pt',
f'{model_name}_final.pt',
f'{model_name}_final_policy.pt'
])
# Search for the model files
for pattern in patterns:
candidate_path = base_dir / pattern
if candidate_path.exists():
logger.debug(f"Found legacy model file: {candidate_path}")
return candidate_path
# Also check subdirectories
for subdir in base_dir.iterdir():
if subdir.is_dir() and subdir.name == model_name:
for pattern in patterns:
candidate_path = subdir / pattern
if candidate_path.exists():
logger.debug(f"Found legacy model file in subdirectory: {candidate_path}")
return candidate_path
return None
def _create_legacy_metadata(self, model_name: str, file_path: Path) -> CheckpointMetadata:
"""Create metadata for legacy model files using only actual file information"""
try:
file_size_mb = file_path.stat().st_size / (1024 * 1024)
created_time = datetime.fromtimestamp(file_path.stat().st_mtime)
# NO SYNTHETIC DATA - use only actual file information
return CheckpointMetadata(
checkpoint_id=f"legacy_{model_name}_{int(created_time.timestamp())}",
model_name=model_name,
model_type=model_name,
file_path=str(file_path),
created_at=created_time,
file_size_mb=file_size_mb,
performance_score=0.0, # Unknown performance - use 0, not synthetic values
accuracy=None,
loss=None,
val_accuracy=None,
val_loss=None,
reward=None,
pnl=None,
epoch=None,
training_time_hours=None,
total_parameters=None,
wandb_run_id=None,
wandb_artifact_name=None
)
except Exception as e: except Exception as e:
logger.error(f"Error creating legacy metadata for {model_name}: {e}") logger.error(f"Error saving checkpoint: {e}")
# Return a basic metadata with minimal info - NO SYNTHETIC VALUES return ""
return CheckpointMetadata(
checkpoint_id=f"legacy_{model_name}", def load_best_checkpoint(self, model_name: str) -> Tuple[str, Dict[str, Any]]:
model_name=model_name, """
model_type=model_name, Load the best checkpoint based on performance metrics
file_path=str(file_path),
created_at=datetime.now(), Args:
file_size_mb=0.0, model_name: Name of the model
performance_score=0.0 # Unknown - use 0, not synthetic
) Returns:
Tuple[str, Dict[str, Any]]: Path to the best checkpoint and its metadata
_checkpoint_manager = None """
try:
def get_checkpoint_manager() -> CheckpointManager: # Find all checkpoint metadata files
global _checkpoint_manager checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
if _checkpoint_manager is None: metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
_checkpoint_manager = CheckpointManager()
return _checkpoint_manager if not metadata_files:
logger.info(f"No checkpoints found for {model_name}")
def save_checkpoint(model, model_name: str, model_type: str, return "", {}
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None, # Load metadata for each checkpoint
force_save: bool = False) -> Optional[CheckpointMetadata]: checkpoints = []
return get_checkpoint_manager().save_checkpoint( for metadata_file in metadata_files:
model, model_name, model_type, performance_metrics, training_metadata, force_save try:
) with open(metadata_file, 'r') as f:
metadata = json.load(f)
def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
return get_checkpoint_manager().load_best_checkpoint(model_name) # Get checkpoint path (remove _metadata.json)
checkpoint_path = metadata_file[:-14]
# Check if model file exists
if not os.path.exists(f"{checkpoint_path}.pt"):
logger.warning(f"Model file not found for checkpoint {checkpoint_path}")
continue
checkpoints.append((checkpoint_path, metadata))
except Exception as e:
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
if not checkpoints:
logger.info(f"No valid checkpoints found for {model_name}")
return "", {}
# Sort by metric (highest first)
checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True)
# Return best checkpoint
best_checkpoint_path = checkpoints[0][0]
best_checkpoint_metadata = checkpoints[0][1]
logger.info(f"Best checkpoint for {model_name}: {best_checkpoint_path}")
return best_checkpoint_path, best_checkpoint_metadata
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return "", {}
def _cleanup_checkpoints(self, model_name: str) -> int:
"""
Clean up old or underperforming checkpoints
Args:
model_name: Name of the model
Returns:
int: Number of checkpoints deleted
"""
try:
# Find all checkpoint metadata files
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
if not metadata_files or len(metadata_files) <= self.max_checkpoints:
return 0
# Load metadata for each checkpoint
checkpoints = []
for metadata_file in metadata_files:
try:
with open(metadata_file, 'r') as f:
metadata = json.load(f)
# Get checkpoint path (remove _metadata.json)
checkpoint_path = metadata_file[:-14]
checkpoints.append((checkpoint_path, metadata))
except Exception as e:
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
# Sort by metric (highest first)
checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True)
# Keep only the best checkpoints
checkpoints_to_delete = checkpoints[self.max_checkpoints:]
# Delete checkpoints
deleted_count = 0
for checkpoint_path, _ in checkpoints_to_delete:
try:
# Delete model file
if os.path.exists(f"{checkpoint_path}.pt"):
os.remove(f"{checkpoint_path}.pt")
# Delete metadata file
if os.path.exists(f"{checkpoint_path}_metadata.json"):
os.remove(f"{checkpoint_path}_metadata.json")
deleted_count += 1
except Exception as e:
logger.error(f"Error deleting checkpoint {checkpoint_path}: {e}")
logger.info(f"Deleted {deleted_count} old checkpoints for {model_name}")
return deleted_count
except Exception as e:
logger.error(f"Error cleaning up checkpoints: {e}")
return 0
def get_all_checkpoints(self, model_name: str) -> List[Tuple[str, Dict[str, Any]]]:
"""
Get all checkpoints for a model
Args:
model_name: Name of the model
Returns:
List[Tuple[str, Dict[str, Any]]]: List of checkpoint paths and metadata
"""
try:
# Find all checkpoint metadata files
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
if not metadata_files:
return []
# Load metadata for each checkpoint
checkpoints = []
for metadata_file in metadata_files:
try:
with open(metadata_file, 'r') as f:
metadata = json.load(f)
# Get checkpoint path (remove _metadata.json)
checkpoint_path = metadata_file[:-14]
# Check if model file exists
if not os.path.exists(f"{checkpoint_path}.pt"):
logger.warning(f"Model file not found for checkpoint {checkpoint_path}")
continue
checkpoints.append((checkpoint_path, metadata))
except Exception as e:
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
# Sort by timestamp (newest first)
checkpoints.sort(key=lambda x: x[1].get('timestamp', ''), reverse=True)
return checkpoints
except Exception as e:
logger.error(f"Error getting all checkpoints: {e}")
return []

View File

@ -9,7 +9,7 @@ from datetime import datetime
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
from pathlib import Path from pathlib import Path
from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_best_checkpoint from .checkpoint_manager import get_checkpoint_manager, load_best_checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -78,7 +78,7 @@ class TrainingIntegration:
except Exception as e: except Exception as e:
logger.warning(f"Error logging to W&B: {e}") logger.warning(f"Error logging to W&B: {e}")
metadata = save_checkpoint( metadata = self.checkpoint_manager.save_checkpoint(
model=cnn_model, model=cnn_model,
model_name=model_name, model_name=model_name,
model_type='cnn', model_type='cnn',
@ -137,7 +137,7 @@ class TrainingIntegration:
except Exception as e: except Exception as e:
logger.warning(f"Error logging to W&B: {e}") logger.warning(f"Error logging to W&B: {e}")
metadata = save_checkpoint( metadata = self.checkpoint_manager.save_checkpoint(
model=rl_agent, model=rl_agent,
model_name=model_name, model_name=model_name,
model_type='rl', model_type='rl',
@ -158,7 +158,7 @@ class TrainingIntegration:
def load_best_model(self, model_name: str, model_class=None): def load_best_model(self, model_name: str, model_class=None):
try: try:
result = load_best_checkpoint(model_name) result = self.checkpoint_manager.load_best_checkpoint(model_name)
if not result: if not result:
logger.warning(f"No checkpoint found for model: {model_name}") logger.warning(f"No checkpoint found for model: {model_name}")
return None return None