checkpoint manager

This commit is contained in:
Dobromir Popov
2025-07-23 21:40:04 +03:00
parent bab39fa68f
commit 45a62443a0
9 changed files with 1587 additions and 709 deletions

View File

@ -0,0 +1,276 @@
"""
CNN Dashboard Integration
This module integrates the EnhancedCNN model with the dashboard, providing real-time
training and visualization of model predictions.
"""
import logging
import threading
import time
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
import os
import json
from .enhanced_cnn_adapter import EnhancedCNNAdapter
from .data_models import BaseDataInput, ModelOutput, create_model_output
from utils.training_integration import get_training_integration
logger = logging.getLogger(__name__)
class CNNDashboardIntegration:
"""
Integrates the EnhancedCNN model with the dashboard
This class:
1. Loads and initializes the CNN model
2. Processes real-time data for model inference
3. Manages continuous training of the model
4. Provides visualization data for the dashboard
"""
def __init__(self, data_provider=None, checkpoint_dir: str = "models/enhanced_cnn"):
"""
Initialize the CNN dashboard integration
Args:
data_provider: Data provider instance
checkpoint_dir: Directory to save checkpoints to
"""
self.data_provider = data_provider
self.checkpoint_dir = checkpoint_dir
self.cnn_adapter = None
self.training_thread = None
self.training_active = False
self.training_interval = 60 # Train every 60 seconds
self.training_samples = []
self.max_training_samples = 1000
self.last_training_time = 0
self.last_predictions = {}
self.performance_metrics = {}
self.model_name = "enhanced_cnn_v1"
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
# Initialize CNN adapter
self._initialize_cnn_adapter()
logger.info(f"CNNDashboardIntegration initialized with checkpoint_dir: {checkpoint_dir}")
def _initialize_cnn_adapter(self):
"""Initialize the CNN adapter"""
try:
# Import here to avoid circular imports
from .enhanced_cnn_adapter import EnhancedCNNAdapter
# Create CNN adapter
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=self.checkpoint_dir)
# Load best checkpoint if available
self.cnn_adapter.load_best_checkpoint()
logger.info("CNN adapter initialized successfully")
except Exception as e:
logger.error(f"Error initializing CNN adapter: {e}")
self.cnn_adapter = None
def start_training_thread(self):
"""Start the training thread"""
if self.training_thread is not None and self.training_thread.is_alive():
logger.info("Training thread already running")
return
self.training_active = True
self.training_thread = threading.Thread(target=self._training_loop, daemon=True)
self.training_thread.start()
logger.info("CNN training thread started")
def stop_training_thread(self):
"""Stop the training thread"""
self.training_active = False
if self.training_thread is not None:
self.training_thread.join(timeout=5)
self.training_thread = None
logger.info("CNN training thread stopped")
def _training_loop(self):
"""Training loop for continuous model training"""
while self.training_active:
try:
# Check if it's time to train
current_time = time.time()
if current_time - self.last_training_time >= self.training_interval and len(self.training_samples) >= 10:
logger.info(f"Training CNN model with {len(self.training_samples)} samples")
# Train model
if self.cnn_adapter is not None:
metrics = self.cnn_adapter.train(epochs=1)
# Update performance metrics
self.performance_metrics = {
'loss': metrics.get('loss', 0.0),
'accuracy': metrics.get('accuracy', 0.0),
'samples': metrics.get('samples', 0),
'last_training': datetime.now().isoformat()
}
# Log training metrics
logger.info(f"CNN training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}")
# Update last training time
self.last_training_time = current_time
# Sleep to avoid high CPU usage
time.sleep(1)
except Exception as e:
logger.error(f"Error in CNN training loop: {e}")
time.sleep(5) # Sleep longer on error
def process_data(self, symbol: str, base_data: BaseDataInput) -> Optional[ModelOutput]:
"""
Process data for model inference and training
Args:
symbol: Trading symbol
base_data: Standardized input data
Returns:
Optional[ModelOutput]: Model output, or None if processing failed
"""
try:
if self.cnn_adapter is None:
logger.warning("CNN adapter not initialized")
return None
# Make prediction
model_output = self.cnn_adapter.predict(base_data)
# Store prediction
self.last_predictions[symbol] = model_output
# Store model output in data provider
if self.data_provider is not None:
self.data_provider.store_model_output(model_output)
return model_output
except Exception as e:
logger.error(f"Error processing data for CNN model: {e}")
return None
def add_training_sample(self, base_data: BaseDataInput, actual_action: str, reward: float):
"""
Add a training sample
Args:
base_data: Standardized input data
actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
reward: Reward received for the action
"""
try:
if self.cnn_adapter is None:
logger.warning("CNN adapter not initialized")
return
# Add training sample to CNN adapter
self.cnn_adapter.add_training_sample(base_data, actual_action, reward)
# Add to local training samples
self.training_samples.append((base_data.symbol, actual_action, reward))
# Limit training samples
if len(self.training_samples) > self.max_training_samples:
self.training_samples = self.training_samples[-self.max_training_samples:]
logger.debug(f"Added training sample for {base_data.symbol}, action: {actual_action}, reward: {reward:.4f}")
except Exception as e:
logger.error(f"Error adding training sample: {e}")
def get_performance_metrics(self) -> Dict[str, Any]:
"""
Get performance metrics
Returns:
Dict[str, Any]: Performance metrics
"""
metrics = self.performance_metrics.copy()
# Add additional metrics
metrics['training_samples'] = len(self.training_samples)
metrics['model_name'] = self.model_name
# Add last prediction metrics
if self.last_predictions:
for symbol, prediction in self.last_predictions.items():
metrics[f'{symbol}_last_action'] = prediction.predictions.get('action', 'UNKNOWN')
metrics[f'{symbol}_last_confidence'] = prediction.confidence
return metrics
def get_visualization_data(self, symbol: str) -> Dict[str, Any]:
"""
Get visualization data for the dashboard
Args:
symbol: Trading symbol
Returns:
Dict[str, Any]: Visualization data
"""
data = {
'model_name': self.model_name,
'symbol': symbol,
'timestamp': datetime.now().isoformat(),
'performance_metrics': self.get_performance_metrics()
}
# Add last prediction
if symbol in self.last_predictions:
prediction = self.last_predictions[symbol]
data['last_prediction'] = {
'action': prediction.predictions.get('action', 'UNKNOWN'),
'confidence': prediction.confidence,
'timestamp': prediction.timestamp.isoformat(),
'buy_probability': prediction.predictions.get('buy_probability', 0.0),
'sell_probability': prediction.predictions.get('sell_probability', 0.0),
'hold_probability': prediction.predictions.get('hold_probability', 0.0)
}
# Add training samples summary
symbol_samples = [s for s in self.training_samples if s[0] == symbol]
data['training_samples'] = {
'total': len(symbol_samples),
'buy': len([s for s in symbol_samples if s[1] == 'BUY']),
'sell': len([s for s in symbol_samples if s[1] == 'SELL']),
'hold': len([s for s in symbol_samples if s[1] == 'HOLD']),
'avg_reward': sum(s[2] for s in symbol_samples) / len(symbol_samples) if symbol_samples else 0.0
}
return data
# Global CNN dashboard integration instance
_cnn_dashboard_integration = None
def get_cnn_dashboard_integration(data_provider=None) -> CNNDashboardIntegration:
"""
Get the global CNN dashboard integration instance
Args:
data_provider: Data provider instance
Returns:
CNNDashboardIntegration: Global CNN dashboard integration instance
"""
global _cnn_dashboard_integration
if _cnn_dashboard_integration is None:
_cnn_dashboard_integration = CNNDashboardIntegration(data_provider=data_provider)
return _cnn_dashboard_integration

View File

@ -1467,12 +1467,10 @@ class DataProvider:
# Update COB data cache for distribution
binance_symbol = symbol.replace('/', '').upper()
if binance_symbol not in self.cob_data_cache or self.cob_data_cache[binance_symbol] is None:
from collections import deque
self.cob_data_cache[binance_symbol] = deque(maxlen=300)
# Ensure the deque is properly initialized
if not isinstance(self.cob_data_cache[binance_symbol], deque):
from collections import deque
self.cob_data_cache[binance_symbol] = deque(maxlen=300)
self.cob_data_cache[binance_symbol].append({

View File

@ -0,0 +1,430 @@
"""
Enhanced CNN Adapter for Standardized Input Format
This module provides an adapter for the EnhancedCNN model to work with the standardized
BaseDataInput format, enabling seamless integration with the multi-modal trading system.
"""
import torch
import numpy as np
import logging
import os
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Any, Union
from threading import Lock
from .data_models import BaseDataInput, ModelOutput, create_model_output
from NN.models.enhanced_cnn import EnhancedCNN
logger = logging.getLogger(__name__)
class EnhancedCNNAdapter:
"""
Adapter for EnhancedCNN model to work with standardized BaseDataInput format
This adapter:
1. Converts BaseDataInput to the format expected by EnhancedCNN
2. Processes model outputs to create standardized ModelOutput
3. Manages model training with collected data
4. Handles checkpoint management
"""
def __init__(self, model_path: str = None, checkpoint_dir: str = "models/enhanced_cnn"):
"""
Initialize the EnhancedCNN adapter
Args:
model_path: Path to load model from, if None a new model is created
checkpoint_dir: Directory to save checkpoints to
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = None
self.model_path = model_path
self.checkpoint_dir = checkpoint_dir
self.training_lock = Lock()
self.training_data = []
self.max_training_samples = 10000
self.batch_size = 32
self.learning_rate = 0.0001
self.model_name = "enhanced_cnn_v1"
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
# Initialize model
self._initialize_model()
logger.info(f"EnhancedCNNAdapter initialized with device: {self.device}")
def _initialize_model(self):
"""Initialize the EnhancedCNN model"""
try:
# Calculate input shape based on BaseDataInput structure
# OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features
# BTC OHLCV: 300 frames x 5 features = 1500 features
# COB: ±20 buckets x 4 metrics = 160 features
# MA: 4 timeframes x 10 buckets = 40 features
# Technical indicators: 100 features
# Last predictions: 50 features
# Total: 7850 features
input_shape = 7850
n_actions = 3 # BUY, SELL, HOLD
# Create model
self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
self.model.to(self.device)
# Load model if path is provided
if self.model_path:
success = self.model.load(self.model_path)
if success:
logger.info(f"Model loaded from {self.model_path}")
else:
logger.warning(f"Failed to load model from {self.model_path}, using new model")
else:
logger.info("No model path provided, using new model")
except Exception as e:
logger.error(f"Error initializing EnhancedCNN model: {e}")
raise
def _convert_base_data_to_features(self, base_data: BaseDataInput) -> torch.Tensor:
"""
Convert BaseDataInput to feature vector for EnhancedCNN
Args:
base_data: Standardized input data
Returns:
torch.Tensor: Feature vector for EnhancedCNN
"""
try:
# Use the get_feature_vector method from BaseDataInput
features = base_data.get_feature_vector()
# Convert to torch tensor
features_tensor = torch.tensor(features, dtype=torch.float32, device=self.device)
return features_tensor
except Exception as e:
logger.error(f"Error converting BaseDataInput to features: {e}")
# Return empty tensor with correct shape
return torch.zeros(7850, dtype=torch.float32, device=self.device)
def predict(self, base_data: BaseDataInput) -> ModelOutput:
"""
Make a prediction using the EnhancedCNN model
Args:
base_data: Standardized input data
Returns:
ModelOutput: Standardized model output
"""
try:
# Convert BaseDataInput to features
features = self._convert_base_data_to_features(base_data)
# Ensure features has batch dimension
if features.dim() == 1:
features = features.unsqueeze(0)
# Set model to evaluation mode
self.model.eval()
# Make prediction
with torch.no_grad():
q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.model(features)
# Get action and confidence
action_probs = torch.softmax(q_values, dim=1)
action_idx = torch.argmax(action_probs, dim=1).item()
confidence = float(action_probs[0, action_idx].item())
# Map action index to action string
actions = ['BUY', 'SELL', 'HOLD']
action = actions[action_idx]
# Create predictions dictionary
predictions = {
'action': action,
'buy_probability': float(action_probs[0, 0].item()),
'sell_probability': float(action_probs[0, 1].item()),
'hold_probability': float(action_probs[0, 2].item()),
'extrema': extrema_pred.squeeze(0).cpu().numpy().tolist(),
'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist()
}
# Create hidden states dictionary
hidden_states = {
'features': features_refined.squeeze(0).cpu().numpy().tolist()
}
# Create metadata dictionary
metadata = {
'model_version': '1.0',
'timestamp': datetime.now().isoformat(),
'input_shape': features.shape
}
# Create ModelOutput
model_output = ModelOutput(
model_type='cnn',
model_name=self.model_name,
symbol=base_data.symbol,
timestamp=datetime.now(),
confidence=confidence,
predictions=predictions,
hidden_states=hidden_states,
metadata=metadata
)
return model_output
except Exception as e:
logger.error(f"Error making prediction with EnhancedCNN: {e}")
# Return default ModelOutput
return create_model_output(
model_type='cnn',
model_name=self.model_name,
symbol=base_data.symbol,
action='HOLD',
confidence=0.0
)
def add_training_sample(self, base_data: BaseDataInput, actual_action: str, reward: float):
"""
Add a training sample to the training data
Args:
base_data: Standardized input data
actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
reward: Reward received for the action
"""
try:
# Convert BaseDataInput to features
features = self._convert_base_data_to_features(base_data)
# Convert action to index
actions = ['BUY', 'SELL', 'HOLD']
action_idx = actions.index(actual_action)
# Add to training data
with self.training_lock:
self.training_data.append((features, action_idx, reward))
# Limit training data size
if len(self.training_data) > self.max_training_samples:
# Sort by reward (highest first) and keep top samples
self.training_data.sort(key=lambda x: x[2], reverse=True)
self.training_data = self.training_data[:self.max_training_samples]
logger.debug(f"Added training sample for {base_data.symbol}, action: {actual_action}, reward: {reward:.4f}")
except Exception as e:
logger.error(f"Error adding training sample: {e}")
def train(self, epochs: int = 1) -> Dict[str, float]:
"""
Train the model with collected data
Args:
epochs: Number of epochs to train for
Returns:
Dict[str, float]: Training metrics
"""
try:
with self.training_lock:
# Check if we have enough data
if len(self.training_data) < self.batch_size:
logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
# Set model to training mode
self.model.train()
# Create optimizer
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
# Training metrics
total_loss = 0.0
correct_predictions = 0
total_predictions = 0
# Train for specified number of epochs
for epoch in range(epochs):
# Shuffle training data
np.random.shuffle(self.training_data)
# Process in batches
for i in range(0, len(self.training_data), self.batch_size):
batch = self.training_data[i:i+self.batch_size]
# Skip if batch is too small
if len(batch) < 2:
continue
# Prepare batch
features = torch.stack([sample[0] for sample in batch])
actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)
# Zero gradients
optimizer.zero_grad()
# Forward pass
q_values, _, _, _, _ = self.model(features)
# Calculate loss (CrossEntropyLoss with reward weighting)
# First, apply softmax to get probabilities
probs = torch.softmax(q_values, dim=1)
# Get probability of chosen action
chosen_probs = probs[torch.arange(len(actions)), actions]
# Calculate negative log likelihood loss
nll_loss = -torch.log(chosen_probs + 1e-10)
# Weight by reward (higher reward = higher weight)
# Normalize rewards to [0, 1] range
min_reward = rewards.min()
max_reward = rewards.max()
if max_reward > min_reward:
normalized_rewards = (rewards - min_reward) / (max_reward - min_reward)
else:
normalized_rewards = torch.ones_like(rewards)
# Apply reward weighting (higher reward = higher weight)
weighted_loss = nll_loss * (normalized_rewards + 0.1) # Add small constant to avoid zero weights
# Mean loss
loss = weighted_loss.mean()
# Backward pass
loss.backward()
# Update weights
optimizer.step()
# Update metrics
total_loss += loss.item()
# Calculate accuracy
predicted_actions = torch.argmax(q_values, dim=1)
correct_predictions += (predicted_actions == actions).sum().item()
total_predictions += len(actions)
# Calculate final metrics
avg_loss = total_loss / (len(self.training_data) / self.batch_size)
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
# Save checkpoint
self._save_checkpoint(avg_loss, accuracy)
logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}")
return {
'loss': avg_loss,
'accuracy': accuracy,
'samples': len(self.training_data)
}
except Exception as e:
logger.error(f"Error training model: {e}")
return {'loss': 0.0, 'accuracy': 0.0, 'samples': 0, 'error': str(e)}
def _save_checkpoint(self, loss: float, accuracy: float):
"""
Save model checkpoint
Args:
loss: Training loss
accuracy: Training accuracy
"""
try:
# Import checkpoint manager
from utils.checkpoint_manager import CheckpointManager
# Create checkpoint manager
checkpoint_manager = CheckpointManager(
checkpoint_dir=self.checkpoint_dir,
max_checkpoints=10,
metric_name="accuracy"
)
# Create temporary model file
temp_path = os.path.join(self.checkpoint_dir, f"{self.model_name}_temp")
self.model.save(temp_path)
# Create metrics
metrics = {
'loss': loss,
'accuracy': accuracy,
'samples': len(self.training_data)
}
# Create metadata
metadata = {
'timestamp': datetime.now().isoformat(),
'model_name': self.model_name,
'input_shape': self.model.input_shape,
'n_actions': self.model.n_actions
}
# Save checkpoint
checkpoint_path = checkpoint_manager.save_checkpoint(
model_name=self.model_name,
model_path=f"{temp_path}.pt",
metrics=metrics,
metadata=metadata
)
# Delete temporary model file
if os.path.exists(f"{temp_path}.pt"):
os.remove(f"{temp_path}.pt")
logger.info(f"Model checkpoint saved to {checkpoint_path}")
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
def load_best_checkpoint(self):
"""Load the best checkpoint based on accuracy"""
try:
# Import checkpoint manager
from utils.checkpoint_manager import CheckpointManager
# Create checkpoint manager
checkpoint_manager = CheckpointManager(
checkpoint_dir=self.checkpoint_dir,
max_checkpoints=10,
metric_name="accuracy"
)
# Load best checkpoint
best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name)
if not best_checkpoint_path:
logger.info("No checkpoints found")
return False
# Load model
success = self.model.load(best_checkpoint_path)
if success:
logger.info(f"Loaded best checkpoint from {best_checkpoint_path}")
# Log metrics
metrics = best_checkpoint_metadata.get('metrics', {})
logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}")
return True
else:
logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}")
return False
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return False

View File

@ -1,34 +1,31 @@
"""
Model Output Manager
This module provides extensible model output storage and management for the multi-modal trading system.
Supports CNN, RL, LSTM, Transformer, and future model types with cross-model feeding capabilities.
This module provides a centralized storage and management system for model outputs,
enabling cross-model feeding and evaluation.
"""
import logging
import os
import json
import pickle
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Union
from collections import deque, defaultdict
import logging
import time
from datetime import datetime
from typing import Dict, List, Optional, Any
from threading import Lock
from pathlib import Path
from .data_models import ModelOutput, create_model_output
from .data_models import ModelOutput
logger = logging.getLogger(__name__)
class ModelOutputManager:
"""
Extensible model output storage and management system
Centralized storage and management system for model outputs
Features:
- Standardized ModelOutput storage for all model types
- Cross-model feeding with hidden states
- Historical output tracking
- Metadata management
- Persistence and recovery
- Performance analytics
This class:
1. Stores model outputs for all models
2. Provides access to current and historical outputs
3. Handles persistence of outputs to disk
4. Supports evaluation of model performance
"""
def __init__(self, cache_dir: str = "cache/model_outputs", max_history: int = 1000):
@ -36,75 +33,66 @@ class ModelOutputManager:
Initialize the model output manager
Args:
cache_dir: Directory for persistent storage
max_history: Maximum number of outputs to keep in memory per model
cache_dir: Directory to store model outputs
max_history: Maximum number of historical outputs to keep per model
"""
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.cache_dir = cache_dir
self.max_history = max_history
self.outputs_lock = Lock()
# In-memory storage
self.current_outputs: Dict[str, Dict[str, ModelOutput]] = defaultdict(dict) # {symbol: {model_name: ModelOutput}}
self.output_history: Dict[str, Dict[str, deque]] = defaultdict(lambda: defaultdict(lambda: deque(maxlen=max_history))) # {symbol: {model_name: deque}}
self.cross_model_states: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: hidden_states}}
# Current outputs for each model and symbol
# {symbol: {model_name: ModelOutput}}
self.current_outputs: Dict[str, Dict[str, ModelOutput]] = {}
# Metadata tracking
self.model_metadata: Dict[str, Dict[str, Any]] = defaultdict(dict) # {model_name: metadata}
self.performance_stats: Dict[str, Dict[str, Any]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: stats}}
# Historical outputs for each model and symbol
# {symbol: {model_name: List[ModelOutput]}}
self.historical_outputs: Dict[str, Dict[str, List[ModelOutput]]] = {}
# Thread safety
self.storage_lock = Lock()
# Performance metrics for each model and symbol
# {symbol: {model_name: Dict[str, float]}}
self.performance_metrics: Dict[str, Dict[str, Dict[str, float]]] = {}
# Supported model types
self.supported_model_types = {
'cnn', 'rl', 'lstm', 'transformer', 'orchestrator',
'ensemble', 'hybrid', 'custom' # Extensible for future types
}
# Create cache directory if it doesn't exist
os.makedirs(cache_dir, exist_ok=True)
logger.info(f"ModelOutputManager initialized with cache dir: {self.cache_dir}")
logger.info(f"Supported model types: {self.supported_model_types}")
logger.info(f"ModelOutputManager initialized with cache_dir: {cache_dir}")
def store_output(self, model_output: ModelOutput) -> bool:
"""
Store model output with full extensibility support
Store a model output
Args:
model_output: ModelOutput from any model type
model_output: Model output to store
Returns:
bool: True if stored successfully, False otherwise
bool: True if successful, False otherwise
"""
try:
with self.storage_lock:
symbol = model_output.symbol
model_name = model_output.model_name
model_type = model_output.model_type
# Validate model type (extensible)
if model_type not in self.supported_model_types:
logger.warning(f"Unknown model type '{model_type}' - adding to supported types")
self.supported_model_types.add(model_type)
with self.outputs_lock:
# Initialize dictionaries if they don't exist
if symbol not in self.current_outputs:
self.current_outputs[symbol] = {}
if symbol not in self.historical_outputs:
self.historical_outputs[symbol] = {}
if model_name not in self.historical_outputs[symbol]:
self.historical_outputs[symbol][model_name] = []
# Store current output
self.current_outputs[symbol][model_name] = model_output
# Add to history
self.output_history[symbol][model_name].append(model_output)
# Add to historical outputs
self.historical_outputs[symbol][model_name].append(model_output)
# Store cross-model states if available
if model_output.hidden_states:
self.cross_model_states[symbol][model_name] = model_output.hidden_states
# Limit historical outputs
if len(self.historical_outputs[symbol][model_name]) > self.max_history:
self.historical_outputs[symbol][model_name] = self.historical_outputs[symbol][model_name][-self.max_history:]
# Update model metadata
self._update_model_metadata(model_name, model_type, model_output.metadata)
# Persist output to disk
self._persist_output(model_output)
# Update performance statistics
self._update_performance_stats(symbol, model_name, model_output)
# Persist to disk (async to avoid blocking)
self._persist_output_async(model_output)
logger.debug(f"Stored output from {model_name} ({model_type}) for {symbol}")
return True
except Exception as e:
@ -113,202 +101,158 @@ class ModelOutputManager:
def get_current_output(self, symbol: str, model_name: str) -> Optional[ModelOutput]:
"""
Get the current (latest) output from a specific model
Get the current output for a model and symbol
Args:
symbol: Trading symbol
model_name: Name of the model
symbol: Symbol to get output for
model_name: Model name to get output for
Returns:
ModelOutput: Latest output from the model, or None if not available
ModelOutput: Current output, or None if not available
"""
try:
return self.current_outputs.get(symbol, {}).get(model_name)
with self.outputs_lock:
if symbol in self.current_outputs and model_name in self.current_outputs[symbol]:
return self.current_outputs[symbol][model_name]
return None
except Exception as e:
logger.error(f"Error getting current output for {model_name}: {e}")
logger.error(f"Error getting current output: {e}")
return None
def get_all_current_outputs(self, symbol: str) -> Dict[str, ModelOutput]:
"""
Get all current outputs for a symbol (for cross-model feeding)
Get all current outputs for a symbol
Args:
symbol: Trading symbol
symbol: Symbol to get outputs for
Returns:
Dict[str, ModelOutput]: Dictionary of current outputs by model name
Dict[str, ModelOutput]: Dictionary of model name to output
"""
try:
return dict(self.current_outputs.get(symbol, {}))
except Exception as e:
logger.error(f"Error getting all current outputs for {symbol}: {e}")
with self.outputs_lock:
if symbol in self.current_outputs:
return self.current_outputs[symbol].copy()
return {}
def get_output_history(self, symbol: str, model_name: str, count: int = 10) -> List[ModelOutput]:
"""
Get historical outputs from a model
Args:
symbol: Trading symbol
model_name: Name of the model
count: Number of historical outputs to retrieve
Returns:
List[ModelOutput]: List of historical outputs (most recent first)
"""
try:
history = self.output_history.get(symbol, {}).get(model_name, deque())
return list(history)[-count:][::-1] # Most recent first
except Exception as e:
logger.error(f"Error getting output history for {model_name}: {e}")
return []
def get_cross_model_states(self, symbol: str, requesting_model: str) -> Dict[str, Dict[str, Any]]:
"""
Get hidden states from other models for cross-model feeding
Args:
symbol: Trading symbol
requesting_model: Name of the model requesting the states
Returns:
Dict[str, Dict[str, Any]]: Hidden states from other models
"""
try:
all_states = self.cross_model_states.get(symbol, {})
# Return states from all models except the requesting one
return {model_name: states for model_name, states in all_states.items()
if model_name != requesting_model}
except Exception as e:
logger.error(f"Error getting cross-model states for {requesting_model}: {e}")
logger.error(f"Error getting all current outputs: {e}")
return {}
def get_model_types_active(self, symbol: str) -> List[str]:
def get_historical_outputs(self, symbol: str, model_name: str, limit: int = None) -> List[ModelOutput]:
"""
Get list of active model types for a symbol
Get historical outputs for a model and symbol
Args:
symbol: Trading symbol
symbol: Symbol to get outputs for
model_name: Model name to get outputs for
limit: Maximum number of outputs to return, None for all
Returns:
List[str]: List of active model types
List[ModelOutput]: List of historical outputs
"""
try:
current_outputs = self.current_outputs.get(symbol, {})
return [output.model_type for output in current_outputs.values()]
except Exception as e:
logger.error(f"Error getting active model types for {symbol}: {e}")
with self.outputs_lock:
if symbol in self.historical_outputs and model_name in self.historical_outputs[symbol]:
outputs = self.historical_outputs[symbol][model_name]
if limit is not None:
outputs = outputs[-limit:]
return outputs.copy()
return []
def get_consensus_prediction(self, symbol: str, confidence_threshold: float = 0.5) -> Optional[Dict[str, Any]]:
except Exception as e:
logger.error(f"Error getting historical outputs: {e}")
return []
def evaluate_model_performance(self, symbol: str, model_name: str) -> Dict[str, float]:
"""
Get consensus prediction from all active models
Evaluate model performance based on historical outputs
Args:
symbol: Trading symbol
confidence_threshold: Minimum confidence threshold for inclusion
symbol: Symbol to evaluate
model_name: Model name to evaluate
Returns:
Dict containing consensus prediction or None
Dict[str, float]: Performance metrics
"""
try:
current_outputs = self.current_outputs.get(symbol, {})
if not current_outputs:
return None
# Get historical outputs
outputs = self.get_historical_outputs(symbol, model_name)
# Filter by confidence threshold
high_confidence_outputs = [
output for output in current_outputs.values()
if output.confidence >= confidence_threshold
]
if not outputs:
return {'accuracy': 0.0, 'confidence': 0.0, 'samples': 0}
if not high_confidence_outputs:
return None
# Calculate metrics
total_outputs = len(outputs)
total_confidence = sum(output.confidence for output in outputs)
avg_confidence = total_confidence / total_outputs if total_outputs > 0 else 0.0
# Calculate consensus
buy_votes = sum(1 for output in high_confidence_outputs
if output.predictions.get('action') == 'BUY')
sell_votes = sum(1 for output in high_confidence_outputs
if output.predictions.get('action') == 'SELL')
hold_votes = sum(1 for output in high_confidence_outputs
if output.predictions.get('action') == 'HOLD')
# For now, we don't have ground truth to calculate accuracy
# In the future, we can add this by comparing predictions to actual market movements
total_votes = len(high_confidence_outputs)
avg_confidence = sum(output.confidence for output in high_confidence_outputs) / total_votes
# Determine consensus action
if buy_votes > sell_votes and buy_votes > hold_votes:
consensus_action = 'BUY'
elif sell_votes > buy_votes and sell_votes > hold_votes:
consensus_action = 'SELL'
else:
consensus_action = 'HOLD'
return {
'action': consensus_action,
metrics = {
'confidence': avg_confidence,
'votes': {'BUY': buy_votes, 'SELL': sell_votes, 'HOLD': hold_votes},
'total_models': total_votes,
'model_types': [output.model_type for output in high_confidence_outputs]
'samples': total_outputs,
'last_update': datetime.now().isoformat()
}
except Exception as e:
logger.error(f"Error calculating consensus prediction for {symbol}: {e}")
return None
# Store metrics
with self.outputs_lock:
if symbol not in self.performance_metrics:
self.performance_metrics[symbol] = {}
self.performance_metrics[symbol][model_name] = metrics
def _update_model_metadata(self, model_name: str, model_type: str, metadata: Dict[str, Any]):
"""Update metadata for a model"""
try:
if model_name not in self.model_metadata:
self.model_metadata[model_name] = {
'model_type': model_type,
'first_seen': datetime.now(),
'total_predictions': 0,
'custom_metadata': {}
}
self.model_metadata[model_name]['total_predictions'] += 1
self.model_metadata[model_name]['last_seen'] = datetime.now()
# Merge custom metadata
if metadata:
self.model_metadata[model_name]['custom_metadata'].update(metadata)
return metrics
except Exception as e:
logger.error(f"Error updating model metadata: {e}")
logger.error(f"Error evaluating model performance: {e}")
return {'error': str(e)}
def _update_performance_stats(self, symbol: str, model_name: str, model_output: ModelOutput):
"""Update performance statistics for a model"""
def get_performance_metrics(self, symbol: str, model_name: str) -> Dict[str, float]:
"""
Get performance metrics for a model and symbol
Args:
symbol: Symbol to get metrics for
model_name: Model name to get metrics for
Returns:
Dict[str, float]: Performance metrics
"""
try:
stats = self.performance_stats[symbol][model_name]
with self.outputs_lock:
if symbol in self.performance_metrics and model_name in self.performance_metrics[symbol]:
return self.performance_metrics[symbol][model_name].copy()
if 'prediction_count' not in stats:
stats['prediction_count'] = 0
stats['confidence_sum'] = 0.0
stats['action_counts'] = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
stats['first_prediction'] = model_output.timestamp
stats['prediction_count'] += 1
stats['confidence_sum'] += model_output.confidence
stats['avg_confidence'] = stats['confidence_sum'] / stats['prediction_count']
stats['last_prediction'] = model_output.timestamp
action = model_output.predictions.get('action', 'HOLD')
if action in stats['action_counts']:
stats['action_counts'][action] += 1
# If no metrics are available, calculate them
return self.evaluate_model_performance(symbol, model_name)
except Exception as e:
logger.error(f"Error updating performance stats: {e}")
logger.error(f"Error getting performance metrics: {e}")
return {'error': str(e)}
def _persist_output_async(self, model_output: ModelOutput):
"""Persist model output to disk (simplified version)"""
def _persist_output(self, model_output: ModelOutput) -> bool:
"""
Persist a model output to disk
Args:
model_output: Model output to persist
Returns:
bool: True if successful, False otherwise
"""
try:
# Create filename based on model and timestamp
timestamp_str = model_output.timestamp.strftime("%Y%m%d_%H%M%S")
filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp_str}.json"
filepath = self.cache_dir / filename
# Create directory if it doesn't exist
symbol_dir = os.path.join(self.cache_dir, model_output.symbol.replace('/', '_'))
os.makedirs(symbol_dir, exist_ok=True)
# Convert to JSON-serializable format
# Create filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp}.json"
filepath = os.path.join(self.cache_dir, filename)
# Convert ModelOutput to dictionary
output_dict = {
'model_type': model_output.model_type,
'model_name': model_output.model_name,
@ -319,77 +263,120 @@ class ModelOutputManager:
'metadata': model_output.metadata
}
# Save to file (in a real implementation, this would be async)
# Don't store hidden states in file (too large)
# Write to file
with open(filepath, 'w') as f:
json.dump(output_dict, f, indent=2)
return True
except Exception as e:
logger.error(f"Error persisting model output: {e}")
return False
def get_performance_summary(self, symbol: str) -> Dict[str, Any]:
def load_outputs_from_disk(self, symbol: str = None, model_name: str = None) -> int:
"""
Get performance summary for all models for a symbol
Load model outputs from disk
Args:
symbol: Trading symbol
symbol: Symbol to load outputs for, None for all
model_name: Model name to load outputs for, None for all
Returns:
Dict containing performance summary
int: Number of outputs loaded
"""
try:
summary = {
'symbol': symbol,
'active_models': len(self.current_outputs.get(symbol, {})),
'model_stats': {}
}
# Find all output files
import glob
for model_name, stats in self.performance_stats.get(symbol, {}).items():
summary['model_stats'][model_name] = {
'predictions': stats.get('prediction_count', 0),
'avg_confidence': round(stats.get('avg_confidence', 0.0), 3),
'action_distribution': stats.get('action_counts', {}),
'model_type': self.model_metadata.get(model_name, {}).get('model_type', 'unknown')
}
if symbol and model_name:
pattern = os.path.join(self.cache_dir, f"{model_name}_{symbol.replace('/', '_')}*.json")
elif symbol:
pattern = os.path.join(self.cache_dir, f"*_{symbol.replace('/', '_')}*.json")
elif model_name:
pattern = os.path.join(self.cache_dir, f"{model_name}_*.json")
else:
pattern = os.path.join(self.cache_dir, "*.json")
return summary
output_files = glob.glob(pattern)
if not output_files:
logger.info(f"No output files found for pattern: {pattern}")
return 0
# Load each file
loaded_count = 0
for filepath in output_files:
try:
with open(filepath, 'r') as f:
output_dict = json.load(f)
# Create ModelOutput
model_output = ModelOutput(
model_type=output_dict['model_type'],
model_name=output_dict['model_name'],
symbol=output_dict['symbol'],
timestamp=datetime.fromisoformat(output_dict['timestamp']),
confidence=output_dict['confidence'],
predictions=output_dict['predictions'],
hidden_states={}, # Don't load hidden states from disk
metadata=output_dict.get('metadata', {})
)
# Store output
self.store_output(model_output)
loaded_count += 1
except Exception as e:
logger.error(f"Error getting performance summary: {e}")
return {'symbol': symbol, 'error': str(e)}
logger.error(f"Error loading output file {filepath}: {e}")
def cleanup_old_outputs(self, max_age_hours: int = 24):
logger.info(f"Loaded {loaded_count} model outputs from disk")
return loaded_count
except Exception as e:
logger.error(f"Error loading outputs from disk: {e}")
return 0
def cleanup_old_outputs(self, max_age_days: int = 30) -> int:
"""
Clean up old outputs to manage memory usage
Clean up old output files
Args:
max_age_hours: Maximum age of outputs to keep in hours
max_age_days: Maximum age of files to keep in days
Returns:
int: Number of files deleted
"""
try:
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
# Find all output files
import glob
output_files = glob.glob(os.path.join(self.cache_dir, "*.json"))
with self.storage_lock:
for symbol in self.output_history:
for model_name in self.output_history[symbol]:
history = self.output_history[symbol][model_name]
# Remove old outputs
while history and history[0].timestamp < cutoff_time:
history.popleft()
if not output_files:
return 0
logger.info(f"Cleaned up outputs older than {max_age_hours} hours")
# Calculate cutoff time
cutoff_time = time.time() - (max_age_days * 24 * 60 * 60)
# Delete old files
deleted_count = 0
for filepath in output_files:
try:
# Get file modification time
mtime = os.path.getmtime(filepath)
# Delete if older than cutoff
if mtime < cutoff_time:
os.remove(filepath)
deleted_count += 1
except Exception as e:
logger.error(f"Error deleting file {filepath}: {e}")
logger.info(f"Deleted {deleted_count} old model output files")
return deleted_count
except Exception as e:
logger.error(f"Error cleaning up old outputs: {e}")
def add_custom_model_type(self, model_type: str):
"""
Add support for a new custom model type
Args:
model_type: Name of the new model type
"""
self.supported_model_types.add(model_type)
logger.info(f"Added support for custom model type: {model_type}")
def get_supported_model_types(self) -> List[str]:
"""Get list of all supported model types"""
return list(self.supported_model_types)
return 0

View File

@ -0,0 +1,155 @@
"""
Test Continuous CNN Training
This script demonstrates how the CNN model can be trained with each new inference result
using collected data, implementing a continuous learning loop.
"""
import logging
import time
from datetime import datetime
import random
import os
from core.standardized_data_provider import StandardizedDataProvider
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
from core.data_models import create_model_output
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def simulate_market_feedback(action, symbol):
"""
Simulate market feedback for a given action
In a real system, this would be replaced with actual market performance data
Args:
action: Trading action ('BUY', 'SELL', 'HOLD')
symbol: Trading symbol
Returns:
tuple: (actual_action, reward)
"""
# Simulate market movement (random for demonstration)
market_direction = random.choice(['up', 'down', 'sideways'])
# Determine actual best action based on market direction
if market_direction == 'up':
best_action = 'BUY'
elif market_direction == 'down':
best_action = 'SELL'
else:
best_action = 'HOLD'
# Calculate reward based on whether the action matched the best action
if action == best_action:
reward = random.uniform(0.01, 0.1) # Positive reward for correct action
else:
reward = random.uniform(-0.1, -0.01) # Negative reward for incorrect action
logger.info(f"Market went {market_direction}, best action was {best_action}, model chose {action}, reward: {reward:.4f}")
return best_action, reward
def test_continuous_training():
"""Test continuous training of the CNN model with new inference results"""
try:
# Initialize data provider
symbols = ['ETH/USDT', 'BTC/USDT']
timeframes = ['1s', '1m', '1h', '1d']
data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
# Initialize CNN adapter
checkpoint_dir = "models/enhanced_cnn"
os.makedirs(checkpoint_dir, exist_ok=True)
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=checkpoint_dir)
# Load best checkpoint if available
cnn_adapter.load_best_checkpoint()
# Continuous learning loop
num_iterations = 10
training_frequency = 3 # Train every N iterations
samples_collected = 0
logger.info(f"Starting continuous learning loop with {num_iterations} iterations")
for i in range(num_iterations):
logger.info(f"\nIteration {i+1}/{num_iterations}")
# Get standardized input data
symbol = random.choice(symbols)
logger.info(f"Getting data for {symbol}...")
base_data = data_provider.get_base_data_input(symbol)
if base_data is None:
logger.warning(f"Failed to get base data input for {symbol}, skipping iteration")
continue
# Make prediction
logger.info(f"Making prediction for {symbol}...")
model_output = cnn_adapter.predict(base_data)
# Log prediction
action = model_output.predictions['action']
confidence = model_output.confidence
logger.info(f"Prediction: {action} with confidence {confidence:.4f}")
# Store model output
data_provider.store_model_output(model_output)
# Simulate market feedback
best_action, reward = simulate_market_feedback(action, symbol)
# Add training sample
logger.info(f"Adding training sample: action={best_action}, reward={reward:.4f}")
cnn_adapter.add_training_sample(base_data, best_action, reward)
samples_collected += 1
# Train model periodically
if (i + 1) % training_frequency == 0 and samples_collected >= 3:
logger.info(f"Training model with {samples_collected} samples...")
metrics = cnn_adapter.train(epochs=1)
# Log training metrics
logger.info(f"Training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}")
# Simulate time passing
time.sleep(1)
logger.info("\nContinuous learning loop completed")
# Final evaluation
logger.info("Performing final evaluation...")
# Get data for evaluation
symbol = 'ETH/USDT'
base_data = data_provider.get_base_data_input(symbol)
if base_data is not None:
# Make prediction
model_output = cnn_adapter.predict(base_data)
# Log prediction
action = model_output.predictions['action']
confidence = model_output.confidence
logger.info(f"Final prediction for {symbol}: {action} with confidence {confidence:.4f}")
# Get model output manager
output_manager = data_provider.get_model_output_manager()
# Evaluate model performance
metrics = output_manager.evaluate_model_performance(symbol, cnn_adapter.model_name)
logger.info(f"Performance metrics: {metrics}")
else:
logger.warning(f"Failed to get base data input for final evaluation")
logger.info("Test completed successfully")
except Exception as e:
logger.error(f"Error in test: {e}", exc_info=True)
if __name__ == "__main__":
test_continuous_training()

View File

@ -0,0 +1,87 @@
"""
Test Enhanced CNN Adapter
This script tests the EnhancedCNNAdapter with standardized input format.
"""
import logging
import time
from datetime import datetime
from core.standardized_data_provider import StandardizedDataProvider
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
from core.data_models import create_model_output
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_cnn_adapter():
"""Test the EnhancedCNNAdapter with standardized input format"""
try:
# Initialize data provider
symbols = ['ETH/USDT', 'BTC/USDT']
timeframes = ['1s', '1m', '1h', '1d']
data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes)
# Initialize CNN adapter
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
# Load best checkpoint if available
cnn_adapter.load_best_checkpoint()
# Get standardized input data
logger.info("Getting standardized input data...")
base_data = data_provider.get_base_data_input('ETH/USDT')
if base_data is None:
logger.error("Failed to get base data input")
return
# Make prediction
logger.info("Making prediction...")
model_output = cnn_adapter.predict(base_data)
# Log prediction
logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}")
# Store model output
data_provider.store_model_output(model_output)
# Add training sample (simulated)
logger.info("Adding training sample...")
cnn_adapter.add_training_sample(base_data, 'BUY', 0.05)
# Train model
logger.info("Training model...")
metrics = cnn_adapter.train(epochs=1)
# Log training metrics
logger.info(f"Training metrics: {metrics}")
# Make another prediction
logger.info("Making another prediction...")
model_output = cnn_adapter.predict(base_data)
# Log prediction
logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}")
# Test model output manager
logger.info("Testing model output manager...")
output_manager = data_provider.get_model_output_manager()
# Get current outputs
current_outputs = output_manager.get_all_current_outputs('ETH/USDT')
logger.info(f"Current outputs: {len(current_outputs)} models")
# Evaluate model performance
metrics = output_manager.evaluate_model_performance('ETH/USDT', 'enhanced_cnn_v1')
logger.info(f"Performance metrics: {metrics}")
logger.info("Test completed successfully")
except Exception as e:
logger.error(f"Error in test: {e}", exc_info=True)
if __name__ == "__main__":
test_cnn_adapter()

View File

@ -0,0 +1,3 @@
"""
Utils package for the multi-modal trading system
"""

View File

@ -1,466 +1,408 @@
#!/usr/bin/env python3
"""
Checkpoint Management System for W&B Training
"""
Checkpoint Manager
This module provides functionality for managing model checkpoints, including:
- Saving checkpoints with metadata
- Loading the best checkpoint based on performance metrics
- Cleaning up old or underperforming checkpoints
"""
import os
import json
import glob
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, asdict
from collections import defaultdict
import shutil
import torch
import random
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
logger = logging.getLogger(__name__)
@dataclass
# Global checkpoint manager instance
_checkpoint_manager_instance = None
def get_checkpoint_manager(checkpoint_dir: str = "models/checkpoints", max_checkpoints: int = 10, metric_name: str = "accuracy") -> 'CheckpointManager':
"""
Get the global checkpoint manager instance
Args:
checkpoint_dir: Directory to store checkpoints
max_checkpoints: Maximum number of checkpoints to keep
metric_name: Metric to use for ranking checkpoints
Returns:
CheckpointManager: Global checkpoint manager instance
"""
global _checkpoint_manager_instance
if _checkpoint_manager_instance is None:
_checkpoint_manager_instance = CheckpointManager(
checkpoint_dir=checkpoint_dir,
max_checkpoints=max_checkpoints,
metric_name=metric_name
)
return _checkpoint_manager_instance
def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Dict[str, Any] = None, checkpoint_dir: str = "models/checkpoints") -> Any:
"""
Save a checkpoint with metadata
Args:
model: The model to save
model_name: Name of the model
model_type: Type of the model ('cnn', 'rl', etc.)
performance_metrics: Performance metrics
training_metadata: Additional training metadata
checkpoint_dir: Directory to store checkpoints
Returns:
Any: Checkpoint metadata
"""
try:
# Create checkpoint directory
os.makedirs(checkpoint_dir, exist_ok=True)
# Create timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create checkpoint path
model_dir = os.path.join(checkpoint_dir, model_name)
os.makedirs(model_dir, exist_ok=True)
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp}")
# Save model
if hasattr(model, 'save'):
# Use model's save method if available
model.save(checkpoint_path)
else:
# Otherwise, save state_dict
torch_path = f"{checkpoint_path}.pt"
torch.save({
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else None,
'model_name': model_name,
'model_type': model_type,
'timestamp': timestamp
}, torch_path)
# Create metadata
checkpoint_metadata = {
'model_name': model_name,
'model_type': model_type,
'timestamp': timestamp,
'performance_metrics': performance_metrics,
'training_metadata': training_metadata or {},
'checkpoint_id': f"{model_name}_{timestamp}"
}
# Add performance score for sorting
primary_metric = 'accuracy' if 'accuracy' in performance_metrics else 'reward'
checkpoint_metadata['performance_score'] = performance_metrics.get(primary_metric, 0.0)
checkpoint_metadata['created_at'] = timestamp
# Save metadata
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
json.dump(checkpoint_metadata, f, indent=2)
# Get checkpoint manager and clean up old checkpoints
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
checkpoint_manager._cleanup_checkpoints(model_name)
# Return metadata as an object
class CheckpointMetadata:
checkpoint_id: str
model_name: str
model_type: str
file_path: str
created_at: datetime
file_size_mb: float
performance_score: float
accuracy: Optional[float] = None
loss: Optional[float] = None
val_accuracy: Optional[float] = None
val_loss: Optional[float] = None
reward: Optional[float] = None
pnl: Optional[float] = None
epoch: Optional[int] = None
training_time_hours: Optional[float] = None
total_parameters: Optional[int] = None
wandb_run_id: Optional[str] = None
wandb_artifact_name: Optional[str] = None
def __init__(self, metadata):
for key, value in metadata.items():
setattr(self, key, value)
def to_dict(self) -> Dict[str, Any]:
data = asdict(self)
data['created_at'] = self.created_at.isoformat()
return data
return CheckpointMetadata(checkpoint_metadata)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata':
data['created_at'] = datetime.fromisoformat(data['created_at'])
return cls(**data)
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
return None
def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoints") -> Optional[Tuple[str, Any]]:
"""
Load the best checkpoint based on performance metrics
Args:
model_name: Name of the model
checkpoint_dir: Directory to store checkpoints
Returns:
Optional[Tuple[str, Any]]: Path to the best checkpoint and its metadata, or None if not found
"""
try:
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
checkpoint_path, checkpoint_metadata = checkpoint_manager.load_best_checkpoint(model_name)
if not checkpoint_path:
return None
# Convert metadata to object
class CheckpointMetadata:
def __init__(self, metadata):
for key, value in metadata.items():
setattr(self, key, value)
# Add performance score if not present
if not hasattr(self, 'performance_score'):
metrics = getattr(self, 'metrics', {})
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
self.performance_score = metrics.get(primary_metric, 0.0)
# Add created_at if not present
if not hasattr(self, 'created_at'):
self.created_at = getattr(self, 'timestamp', 'unknown')
return f"{checkpoint_path}.pt", CheckpointMetadata(checkpoint_metadata)
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return None
class CheckpointManager:
def __init__(self,
base_checkpoint_dir: str = "NN/models/saved",
max_checkpoints_per_model: int = 5,
metadata_file: str = "checkpoint_metadata.json",
enable_wandb: bool = True):
self.base_dir = Path(base_checkpoint_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
"""
Manages model checkpoints with performance-based optimization
self.max_checkpoints = max_checkpoints_per_model
self.metadata_file = self.base_dir / metadata_file
self.enable_wandb = enable_wandb and WANDB_AVAILABLE
This class:
1. Saves checkpoints with metadata
2. Loads the best checkpoint based on performance metrics
3. Cleans up old or underperforming checkpoints
"""
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
self._load_metadata()
def __init__(self, checkpoint_dir: str, max_checkpoints: int = 10, metric_name: str = "accuracy"):
"""
Initialize the checkpoint manager
logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}")
Args:
checkpoint_dir: Directory to store checkpoints
max_checkpoints: Maximum number of checkpoints to keep
metric_name: Metric to use for ranking checkpoints
"""
self.checkpoint_dir = checkpoint_dir
self.max_checkpoints = max_checkpoints
self.metric_name = metric_name
def save_checkpoint(self, model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
logger.info(f"CheckpointManager initialized with checkpoint_dir: {checkpoint_dir}")
def save_checkpoint(self, model_name: str, model_path: str, metrics: Dict[str, float], metadata: Dict[str, Any] = None) -> str:
"""
Save a checkpoint with metadata
Args:
model_name: Name of the model
model_path: Path to the model file
metrics: Performance metrics
metadata: Additional metadata
Returns:
str: Path to the saved checkpoint
"""
try:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_id = f"{model_name}_{timestamp}"
# Create timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_dir = self.base_dir / model_name
model_dir.mkdir(exist_ok=True)
# Create checkpoint directory
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = model_dir / f"{checkpoint_id}.pt"
# Create checkpoint path
checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_{timestamp}")
performance_score = self._calculate_performance_score(performance_metrics)
# Copy model file to checkpoint path
shutil.copy2(model_path, f"{checkpoint_path}.pt")
if not force_save and not self._should_save_checkpoint(model_name, performance_score):
logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved")
return None
success = self._save_model_file(model, checkpoint_path, model_type)
if not success:
return None
file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024)
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name=model_name,
model_type=model_type,
file_path=str(checkpoint_path),
created_at=datetime.now(),
file_size_mb=file_size_mb,
performance_score=performance_score,
accuracy=performance_metrics.get('accuracy'),
loss=performance_metrics.get('loss'),
val_accuracy=performance_metrics.get('val_accuracy'),
val_loss=performance_metrics.get('val_loss'),
reward=performance_metrics.get('reward'),
pnl=performance_metrics.get('pnl'),
epoch=training_metadata.get('epoch') if training_metadata else None,
training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None,
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
)
if self.enable_wandb and wandb.run is not None:
artifact_name = self._upload_to_wandb(checkpoint_path, metadata)
metadata.wandb_run_id = wandb.run.id
metadata.wandb_artifact_name = artifact_name
self.checkpoints[model_name].append(metadata)
self._rotate_checkpoints(model_name)
self._save_metadata()
logger.debug(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})")
return metadata
except Exception as e:
logger.error(f"Error saving checkpoint for {model_name}: {e}")
return None
def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
try:
# First, try the standard checkpoint system
if model_name in self.checkpoints and self.checkpoints[model_name]:
# Filter out checkpoints with non-existent files
valid_checkpoints = [
cp for cp in self.checkpoints[model_name]
if Path(cp.file_path).exists()
]
if valid_checkpoints:
best_checkpoint = max(valid_checkpoints, key=lambda x: x.performance_score)
logger.debug(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}")
return best_checkpoint.file_path, best_checkpoint
else:
# Clean up invalid metadata entries
invalid_count = len(self.checkpoints[model_name])
logger.warning(f"Found {invalid_count} invalid checkpoint entries for {model_name}, cleaning up metadata")
self.checkpoints[model_name] = []
self._save_metadata()
# Fallback: Look for existing saved models in the legacy format
logger.debug(f"No valid checkpoints found for model: {model_name}, attempting to find legacy saved models")
legacy_model_path = self._find_legacy_model(model_name)
if legacy_model_path:
# Create checkpoint metadata for the legacy model using actual file data
legacy_metadata = self._create_legacy_metadata(model_name, legacy_model_path)
logger.debug(f"Found legacy model for {model_name}: {legacy_model_path}")
return str(legacy_model_path), legacy_metadata
logger.warning(f"No checkpoints or legacy models found for: {model_name}")
return None
except Exception as e:
logger.error(f"Error loading best checkpoint for {model_name}: {e}")
return None
def _calculate_performance_score(self, metrics: Dict[str, float]) -> float:
"""Calculate performance score with improved sensitivity for training models"""
score = 0.0
# Prioritize loss reduction for active training models
if 'loss' in metrics:
# Invert loss so lower loss = higher score, with better scaling
loss_value = metrics['loss']
if loss_value > 0:
score += max(0, 100 / (1 + loss_value)) # More sensitive to loss changes
else:
score += 100 # Perfect loss
# Add other metrics with appropriate weights
if 'accuracy' in metrics:
score += metrics['accuracy'] * 50 # Reduced weight to balance with loss
if 'val_accuracy' in metrics:
score += metrics['val_accuracy'] * 50
if 'val_loss' in metrics:
val_loss = metrics['val_loss']
if val_loss > 0:
score += max(0, 50 / (1 + val_loss))
if 'reward' in metrics:
score += metrics['reward'] * 10
if 'pnl' in metrics:
score += metrics['pnl'] * 5
if 'training_samples' in metrics:
# Bonus for processing more training samples
score += min(10, metrics['training_samples'] / 10)
# Return actual calculated score - NO SYNTHETIC MINIMUM
return score
def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool:
"""Improved checkpoint saving logic with more frequent saves during training"""
if model_name not in self.checkpoints or not self.checkpoints[model_name]:
return True # Always save first checkpoint
# Allow more checkpoints during active training
if len(self.checkpoints[model_name]) < self.max_checkpoints:
return True
# Get current best and worst scores
scores = [cp.performance_score for cp in self.checkpoints[model_name]]
best_score = max(scores)
worst_score = min(scores)
# Save if better than worst (more frequent saves)
if performance_score > worst_score:
return True
# For high-performing models (score > 100), be more sensitive to small improvements
if best_score > 100:
# Save if within 0.1% of best score (very sensitive for converged models)
if performance_score >= best_score * 0.999:
return True
else:
# Also save if we're within 10% of best score (capture near-optimal models)
if performance_score >= best_score * 0.9:
return True
# Save more frequently during active training (every 5th attempt instead of 10th)
if random.random() < 0.2: # 20% chance to save anyway
logger.debug(f"Saving checkpoint for {model_name} - periodic save during active training")
return True
return False
def _save_model_file(self, model, file_path: Path, model_type: str) -> bool:
try:
if hasattr(model, 'state_dict'):
torch.save({
'model_state_dict': model.state_dict(),
'model_type': model_type,
'saved_at': datetime.now().isoformat()
}, file_path)
else:
torch.save(model, file_path)
return True
except Exception as e:
logger.error(f"Error saving model file {file_path}: {e}")
return False
def _rotate_checkpoints(self, model_name: str):
checkpoint_list = self.checkpoints[model_name]
if len(checkpoint_list) <= self.max_checkpoints:
return
checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True)
to_remove = checkpoint_list[self.max_checkpoints:]
self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints]
for checkpoint in to_remove:
try:
file_path = Path(checkpoint.file_path)
if file_path.exists():
file_path.unlink()
logger.debug(f"Rotated out checkpoint: {checkpoint.checkpoint_id}")
except Exception as e:
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]:
try:
if not self.enable_wandb or wandb.run is None:
return None
artifact_name = f"{metadata.model_name}_checkpoint"
artifact = wandb.Artifact(artifact_name, type="model")
artifact.add_file(str(file_path))
wandb.log_artifact(artifact)
return artifact_name
except Exception as e:
logger.error(f"Error uploading to W&B: {e}")
return None
def _load_metadata(self):
try:
if self.metadata_file.exists():
with open(self.metadata_file, 'r') as f:
data = json.load(f)
for model_name, checkpoint_list in data.items():
self.checkpoints[model_name] = [
CheckpointMetadata.from_dict(cp_data)
for cp_data in checkpoint_list
]
logger.info(f"Loaded metadata for {len(self.checkpoints)} models")
except Exception as e:
logger.error(f"Error loading checkpoint metadata: {e}")
def _save_metadata(self):
try:
data = {}
for model_name, checkpoint_list in self.checkpoints.items():
data[model_name] = [cp.to_dict() for cp in checkpoint_list]
with open(self.metadata_file, 'w') as f:
json.dump(data, f, indent=2)
except Exception as e:
logger.error(f"Error saving checkpoint metadata: {e}")
def get_checkpoint_stats(self):
"""Get statistics about managed checkpoints"""
stats = {
'total_models': len(self.checkpoints),
'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()),
'total_size_mb': 0.0,
'models': {}
# Create metadata
checkpoint_metadata = {
'model_name': model_name,
'timestamp': timestamp,
'metrics': metrics,
'metadata': metadata or {}
}
for model_name, checkpoint_list in self.checkpoints.items():
if not checkpoint_list:
# Save metadata
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
json.dump(checkpoint_metadata, f, indent=2)
logger.info(f"Saved checkpoint to {checkpoint_path}")
# Clean up old checkpoints
self._cleanup_checkpoints(model_name)
return checkpoint_path
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
return ""
def load_best_checkpoint(self, model_name: str) -> Tuple[str, Dict[str, Any]]:
"""
Load the best checkpoint based on performance metrics
Args:
model_name: Name of the model
Returns:
Tuple[str, Dict[str, Any]]: Path to the best checkpoint and its metadata
"""
try:
# Find all checkpoint metadata files
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
if not metadata_files:
logger.info(f"No checkpoints found for {model_name}")
return "", {}
# Load metadata for each checkpoint
checkpoints = []
for metadata_file in metadata_files:
try:
with open(metadata_file, 'r') as f:
metadata = json.load(f)
# Get checkpoint path (remove _metadata.json)
checkpoint_path = metadata_file[:-14]
# Check if model file exists
if not os.path.exists(f"{checkpoint_path}.pt"):
logger.warning(f"Model file not found for checkpoint {checkpoint_path}")
continue
model_size = sum(cp.file_size_mb for cp in checkpoint_list)
best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score)
checkpoints.append((checkpoint_path, metadata))
stats['models'][model_name] = {
'checkpoint_count': len(checkpoint_list),
'total_size_mb': model_size,
'best_performance': best_checkpoint.performance_score,
'best_checkpoint_id': best_checkpoint.checkpoint_id,
'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id
}
stats['total_size_mb'] += model_size
return stats
def _find_legacy_model(self, model_name: str) -> Optional[Path]:
"""Find legacy saved models based on model name patterns"""
base_dir = Path(self.base_dir)
# Define model name mappings and patterns for legacy files
legacy_patterns = {
'dqn_agent': [
'dqn_agent_best_policy.pt',
'enhanced_dqn_best_policy.pt',
'improved_dqn_agent_best_policy.pt',
'dqn_agent_final_policy.pt'
],
'enhanced_cnn': [
'cnn_model_best.pt',
'optimized_short_term_model_best.pt',
'optimized_short_term_model_realtime_best.pt',
'optimized_short_term_model_ticks_best.pt'
],
'extrema_trainer': [
'supervised_model_best.pt'
],
'cob_rl': [
'best_rl_model.pth_policy.pt',
'rl_agent_best_policy.pt'
],
'decision': [
# Decision models might be in subdirectories, but let's check main dir too
'decision_best.pt',
'decision_model_best.pt',
# Check for transformer models which might be used as decision models
'enhanced_dqn_best_policy.pt',
'improved_dqn_agent_best_policy.pt'
]
}
# Get patterns for this model name
patterns = legacy_patterns.get(model_name, [])
# Also try generic patterns based on model name
patterns.extend([
f'{model_name}_best.pt',
f'{model_name}_best_policy.pt',
f'{model_name}_final.pt',
f'{model_name}_final_policy.pt'
])
# Search for the model files
for pattern in patterns:
candidate_path = base_dir / pattern
if candidate_path.exists():
logger.debug(f"Found legacy model file: {candidate_path}")
return candidate_path
# Also check subdirectories
for subdir in base_dir.iterdir():
if subdir.is_dir() and subdir.name == model_name:
for pattern in patterns:
candidate_path = subdir / pattern
if candidate_path.exists():
logger.debug(f"Found legacy model file in subdirectory: {candidate_path}")
return candidate_path
return None
def _create_legacy_metadata(self, model_name: str, file_path: Path) -> CheckpointMetadata:
"""Create metadata for legacy model files using only actual file information"""
try:
file_size_mb = file_path.stat().st_size / (1024 * 1024)
created_time = datetime.fromtimestamp(file_path.stat().st_mtime)
# NO SYNTHETIC DATA - use only actual file information
return CheckpointMetadata(
checkpoint_id=f"legacy_{model_name}_{int(created_time.timestamp())}",
model_name=model_name,
model_type=model_name,
file_path=str(file_path),
created_at=created_time,
file_size_mb=file_size_mb,
performance_score=0.0, # Unknown performance - use 0, not synthetic values
accuracy=None,
loss=None,
val_accuracy=None,
val_loss=None,
reward=None,
pnl=None,
epoch=None,
training_time_hours=None,
total_parameters=None,
wandb_run_id=None,
wandb_artifact_name=None
)
except Exception as e:
logger.error(f"Error creating legacy metadata for {model_name}: {e}")
# Return a basic metadata with minimal info - NO SYNTHETIC VALUES
return CheckpointMetadata(
checkpoint_id=f"legacy_{model_name}",
model_name=model_name,
model_type=model_name,
file_path=str(file_path),
created_at=datetime.now(),
file_size_mb=0.0,
performance_score=0.0 # Unknown - use 0, not synthetic
)
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
_checkpoint_manager = None
if not checkpoints:
logger.info(f"No valid checkpoints found for {model_name}")
return "", {}
def get_checkpoint_manager() -> CheckpointManager:
global _checkpoint_manager
if _checkpoint_manager is None:
_checkpoint_manager = CheckpointManager()
return _checkpoint_manager
# Sort by metric (highest first)
checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True)
def save_checkpoint(model, model_name: str, model_type: str,
performance_metrics: Dict[str, float],
training_metadata: Optional[Dict[str, Any]] = None,
force_save: bool = False) -> Optional[CheckpointMetadata]:
return get_checkpoint_manager().save_checkpoint(
model, model_name, model_type, performance_metrics, training_metadata, force_save
)
# Return best checkpoint
best_checkpoint_path = checkpoints[0][0]
best_checkpoint_metadata = checkpoints[0][1]
def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]:
return get_checkpoint_manager().load_best_checkpoint(model_name)
logger.info(f"Best checkpoint for {model_name}: {best_checkpoint_path}")
return best_checkpoint_path, best_checkpoint_metadata
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return "", {}
def _cleanup_checkpoints(self, model_name: str) -> int:
"""
Clean up old or underperforming checkpoints
Args:
model_name: Name of the model
Returns:
int: Number of checkpoints deleted
"""
try:
# Find all checkpoint metadata files
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
if not metadata_files or len(metadata_files) <= self.max_checkpoints:
return 0
# Load metadata for each checkpoint
checkpoints = []
for metadata_file in metadata_files:
try:
with open(metadata_file, 'r') as f:
metadata = json.load(f)
# Get checkpoint path (remove _metadata.json)
checkpoint_path = metadata_file[:-14]
checkpoints.append((checkpoint_path, metadata))
except Exception as e:
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
# Sort by metric (highest first)
checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True)
# Keep only the best checkpoints
checkpoints_to_delete = checkpoints[self.max_checkpoints:]
# Delete checkpoints
deleted_count = 0
for checkpoint_path, _ in checkpoints_to_delete:
try:
# Delete model file
if os.path.exists(f"{checkpoint_path}.pt"):
os.remove(f"{checkpoint_path}.pt")
# Delete metadata file
if os.path.exists(f"{checkpoint_path}_metadata.json"):
os.remove(f"{checkpoint_path}_metadata.json")
deleted_count += 1
except Exception as e:
logger.error(f"Error deleting checkpoint {checkpoint_path}: {e}")
logger.info(f"Deleted {deleted_count} old checkpoints for {model_name}")
return deleted_count
except Exception as e:
logger.error(f"Error cleaning up checkpoints: {e}")
return 0
def get_all_checkpoints(self, model_name: str) -> List[Tuple[str, Dict[str, Any]]]:
"""
Get all checkpoints for a model
Args:
model_name: Name of the model
Returns:
List[Tuple[str, Dict[str, Any]]]: List of checkpoint paths and metadata
"""
try:
# Find all checkpoint metadata files
checkpoint_dir = os.path.join(self.checkpoint_dir, model_name)
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
if not metadata_files:
return []
# Load metadata for each checkpoint
checkpoints = []
for metadata_file in metadata_files:
try:
with open(metadata_file, 'r') as f:
metadata = json.load(f)
# Get checkpoint path (remove _metadata.json)
checkpoint_path = metadata_file[:-14]
# Check if model file exists
if not os.path.exists(f"{checkpoint_path}.pt"):
logger.warning(f"Model file not found for checkpoint {checkpoint_path}")
continue
checkpoints.append((checkpoint_path, metadata))
except Exception as e:
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
# Sort by timestamp (newest first)
checkpoints.sort(key=lambda x: x[1].get('timestamp', ''), reverse=True)
return checkpoints
except Exception as e:
logger.error(f"Error getting all checkpoints: {e}")
return []

View File

@ -9,7 +9,7 @@ from datetime import datetime
from typing import Dict, Any, Optional
from pathlib import Path
from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_best_checkpoint
from .checkpoint_manager import get_checkpoint_manager, load_best_checkpoint
logger = logging.getLogger(__name__)
@ -78,7 +78,7 @@ class TrainingIntegration:
except Exception as e:
logger.warning(f"Error logging to W&B: {e}")
metadata = save_checkpoint(
metadata = self.checkpoint_manager.save_checkpoint(
model=cnn_model,
model_name=model_name,
model_type='cnn',
@ -137,7 +137,7 @@ class TrainingIntegration:
except Exception as e:
logger.warning(f"Error logging to W&B: {e}")
metadata = save_checkpoint(
metadata = self.checkpoint_manager.save_checkpoint(
model=rl_agent,
model_name=model_name,
model_type='rl',
@ -158,7 +158,7 @@ class TrainingIntegration:
def load_best_model(self, model_name: str, model_class=None):
try:
result = load_best_checkpoint(model_name)
result = self.checkpoint_manager.load_best_checkpoint(model_name)
if not result:
logger.warning(f"No checkpoint found for model: {model_name}")
return None