From 45a62443a064e77b92a2a5b6f310841fae54a6e0 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Wed, 23 Jul 2025 21:40:04 +0300 Subject: [PATCH] checkpoint manager --- core/cnn_dashboard_integration.py | 276 ++++++++++ core/data_provider.py | 2 - core/enhanced_cnn_adapter.py | 430 ++++++++++++++++ core/model_output_manager.py | 507 +++++++++--------- test_continuous_cnn_training.py | 155 ++++++ test_enhanced_cnn_adapter.py | 87 ++++ utils/__init__.py | 3 + utils/checkpoint_manager.py | 828 ++++++++++++++---------------- utils/training_integration.py | 8 +- 9 files changed, 1587 insertions(+), 709 deletions(-) create mode 100644 core/cnn_dashboard_integration.py create mode 100644 core/enhanced_cnn_adapter.py create mode 100644 test_continuous_cnn_training.py create mode 100644 test_enhanced_cnn_adapter.py diff --git a/core/cnn_dashboard_integration.py b/core/cnn_dashboard_integration.py new file mode 100644 index 0000000..ccd30aa --- /dev/null +++ b/core/cnn_dashboard_integration.py @@ -0,0 +1,276 @@ +""" +CNN Dashboard Integration + +This module integrates the EnhancedCNN model with the dashboard, providing real-time +training and visualization of model predictions. +""" + +import logging +import threading +import time +from datetime import datetime +from typing import Dict, List, Optional, Any, Tuple +import os +import json + +from .enhanced_cnn_adapter import EnhancedCNNAdapter +from .data_models import BaseDataInput, ModelOutput, create_model_output +from utils.training_integration import get_training_integration + +logger = logging.getLogger(__name__) + +class CNNDashboardIntegration: + """ + Integrates the EnhancedCNN model with the dashboard + + This class: + 1. Loads and initializes the CNN model + 2. Processes real-time data for model inference + 3. Manages continuous training of the model + 4. Provides visualization data for the dashboard + """ + + def __init__(self, data_provider=None, checkpoint_dir: str = "models/enhanced_cnn"): + """ + Initialize the CNN dashboard integration + + Args: + data_provider: Data provider instance + checkpoint_dir: Directory to save checkpoints to + """ + self.data_provider = data_provider + self.checkpoint_dir = checkpoint_dir + self.cnn_adapter = None + self.training_thread = None + self.training_active = False + self.training_interval = 60 # Train every 60 seconds + self.training_samples = [] + self.max_training_samples = 1000 + self.last_training_time = 0 + self.last_predictions = {} + self.performance_metrics = {} + self.model_name = "enhanced_cnn_v1" + + # Create checkpoint directory if it doesn't exist + os.makedirs(checkpoint_dir, exist_ok=True) + + # Initialize CNN adapter + self._initialize_cnn_adapter() + + logger.info(f"CNNDashboardIntegration initialized with checkpoint_dir: {checkpoint_dir}") + + def _initialize_cnn_adapter(self): + """Initialize the CNN adapter""" + try: + # Import here to avoid circular imports + from .enhanced_cnn_adapter import EnhancedCNNAdapter + + # Create CNN adapter + self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=self.checkpoint_dir) + + # Load best checkpoint if available + self.cnn_adapter.load_best_checkpoint() + + logger.info("CNN adapter initialized successfully") + + except Exception as e: + logger.error(f"Error initializing CNN adapter: {e}") + self.cnn_adapter = None + + def start_training_thread(self): + """Start the training thread""" + if self.training_thread is not None and self.training_thread.is_alive(): + logger.info("Training thread already running") + return + + self.training_active = True + self.training_thread = threading.Thread(target=self._training_loop, daemon=True) + self.training_thread.start() + + logger.info("CNN training thread started") + + def stop_training_thread(self): + """Stop the training thread""" + self.training_active = False + if self.training_thread is not None: + self.training_thread.join(timeout=5) + self.training_thread = None + + logger.info("CNN training thread stopped") + + def _training_loop(self): + """Training loop for continuous model training""" + while self.training_active: + try: + # Check if it's time to train + current_time = time.time() + if current_time - self.last_training_time >= self.training_interval and len(self.training_samples) >= 10: + logger.info(f"Training CNN model with {len(self.training_samples)} samples") + + # Train model + if self.cnn_adapter is not None: + metrics = self.cnn_adapter.train(epochs=1) + + # Update performance metrics + self.performance_metrics = { + 'loss': metrics.get('loss', 0.0), + 'accuracy': metrics.get('accuracy', 0.0), + 'samples': metrics.get('samples', 0), + 'last_training': datetime.now().isoformat() + } + + # Log training metrics + logger.info(f"CNN training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}") + + # Update last training time + self.last_training_time = current_time + + # Sleep to avoid high CPU usage + time.sleep(1) + + except Exception as e: + logger.error(f"Error in CNN training loop: {e}") + time.sleep(5) # Sleep longer on error + + def process_data(self, symbol: str, base_data: BaseDataInput) -> Optional[ModelOutput]: + """ + Process data for model inference and training + + Args: + symbol: Trading symbol + base_data: Standardized input data + + Returns: + Optional[ModelOutput]: Model output, or None if processing failed + """ + try: + if self.cnn_adapter is None: + logger.warning("CNN adapter not initialized") + return None + + # Make prediction + model_output = self.cnn_adapter.predict(base_data) + + # Store prediction + self.last_predictions[symbol] = model_output + + # Store model output in data provider + if self.data_provider is not None: + self.data_provider.store_model_output(model_output) + + return model_output + + except Exception as e: + logger.error(f"Error processing data for CNN model: {e}") + return None + + def add_training_sample(self, base_data: BaseDataInput, actual_action: str, reward: float): + """ + Add a training sample + + Args: + base_data: Standardized input data + actual_action: Actual action taken ('BUY', 'SELL', 'HOLD') + reward: Reward received for the action + """ + try: + if self.cnn_adapter is None: + logger.warning("CNN adapter not initialized") + return + + # Add training sample to CNN adapter + self.cnn_adapter.add_training_sample(base_data, actual_action, reward) + + # Add to local training samples + self.training_samples.append((base_data.symbol, actual_action, reward)) + + # Limit training samples + if len(self.training_samples) > self.max_training_samples: + self.training_samples = self.training_samples[-self.max_training_samples:] + + logger.debug(f"Added training sample for {base_data.symbol}, action: {actual_action}, reward: {reward:.4f}") + + except Exception as e: + logger.error(f"Error adding training sample: {e}") + + def get_performance_metrics(self) -> Dict[str, Any]: + """ + Get performance metrics + + Returns: + Dict[str, Any]: Performance metrics + """ + metrics = self.performance_metrics.copy() + + # Add additional metrics + metrics['training_samples'] = len(self.training_samples) + metrics['model_name'] = self.model_name + + # Add last prediction metrics + if self.last_predictions: + for symbol, prediction in self.last_predictions.items(): + metrics[f'{symbol}_last_action'] = prediction.predictions.get('action', 'UNKNOWN') + metrics[f'{symbol}_last_confidence'] = prediction.confidence + + return metrics + + def get_visualization_data(self, symbol: str) -> Dict[str, Any]: + """ + Get visualization data for the dashboard + + Args: + symbol: Trading symbol + + Returns: + Dict[str, Any]: Visualization data + """ + data = { + 'model_name': self.model_name, + 'symbol': symbol, + 'timestamp': datetime.now().isoformat(), + 'performance_metrics': self.get_performance_metrics() + } + + # Add last prediction + if symbol in self.last_predictions: + prediction = self.last_predictions[symbol] + data['last_prediction'] = { + 'action': prediction.predictions.get('action', 'UNKNOWN'), + 'confidence': prediction.confidence, + 'timestamp': prediction.timestamp.isoformat(), + 'buy_probability': prediction.predictions.get('buy_probability', 0.0), + 'sell_probability': prediction.predictions.get('sell_probability', 0.0), + 'hold_probability': prediction.predictions.get('hold_probability', 0.0) + } + + # Add training samples summary + symbol_samples = [s for s in self.training_samples if s[0] == symbol] + data['training_samples'] = { + 'total': len(symbol_samples), + 'buy': len([s for s in symbol_samples if s[1] == 'BUY']), + 'sell': len([s for s in symbol_samples if s[1] == 'SELL']), + 'hold': len([s for s in symbol_samples if s[1] == 'HOLD']), + 'avg_reward': sum(s[2] for s in symbol_samples) / len(symbol_samples) if symbol_samples else 0.0 + } + + return data + +# Global CNN dashboard integration instance +_cnn_dashboard_integration = None + +def get_cnn_dashboard_integration(data_provider=None) -> CNNDashboardIntegration: + """ + Get the global CNN dashboard integration instance + + Args: + data_provider: Data provider instance + + Returns: + CNNDashboardIntegration: Global CNN dashboard integration instance + """ + global _cnn_dashboard_integration + + if _cnn_dashboard_integration is None: + _cnn_dashboard_integration = CNNDashboardIntegration(data_provider=data_provider) + + return _cnn_dashboard_integration \ No newline at end of file diff --git a/core/data_provider.py b/core/data_provider.py index c6f40cf..a3634bc 100644 --- a/core/data_provider.py +++ b/core/data_provider.py @@ -1467,12 +1467,10 @@ class DataProvider: # Update COB data cache for distribution binance_symbol = symbol.replace('/', '').upper() if binance_symbol not in self.cob_data_cache or self.cob_data_cache[binance_symbol] is None: - from collections import deque self.cob_data_cache[binance_symbol] = deque(maxlen=300) # Ensure the deque is properly initialized if not isinstance(self.cob_data_cache[binance_symbol], deque): - from collections import deque self.cob_data_cache[binance_symbol] = deque(maxlen=300) self.cob_data_cache[binance_symbol].append({ diff --git a/core/enhanced_cnn_adapter.py b/core/enhanced_cnn_adapter.py new file mode 100644 index 0000000..586ab38 --- /dev/null +++ b/core/enhanced_cnn_adapter.py @@ -0,0 +1,430 @@ +""" +Enhanced CNN Adapter for Standardized Input Format + +This module provides an adapter for the EnhancedCNN model to work with the standardized +BaseDataInput format, enabling seamless integration with the multi-modal trading system. +""" + +import torch +import numpy as np +import logging +import os +from datetime import datetime +from typing import Dict, List, Optional, Tuple, Any, Union +from threading import Lock + +from .data_models import BaseDataInput, ModelOutput, create_model_output +from NN.models.enhanced_cnn import EnhancedCNN + +logger = logging.getLogger(__name__) + +class EnhancedCNNAdapter: + """ + Adapter for EnhancedCNN model to work with standardized BaseDataInput format + + This adapter: + 1. Converts BaseDataInput to the format expected by EnhancedCNN + 2. Processes model outputs to create standardized ModelOutput + 3. Manages model training with collected data + 4. Handles checkpoint management + """ + + def __init__(self, model_path: str = None, checkpoint_dir: str = "models/enhanced_cnn"): + """ + Initialize the EnhancedCNN adapter + + Args: + model_path: Path to load model from, if None a new model is created + checkpoint_dir: Directory to save checkpoints to + """ + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.model = None + self.model_path = model_path + self.checkpoint_dir = checkpoint_dir + self.training_lock = Lock() + self.training_data = [] + self.max_training_samples = 10000 + self.batch_size = 32 + self.learning_rate = 0.0001 + self.model_name = "enhanced_cnn_v1" + + # Create checkpoint directory if it doesn't exist + os.makedirs(checkpoint_dir, exist_ok=True) + + # Initialize model + self._initialize_model() + + logger.info(f"EnhancedCNNAdapter initialized with device: {self.device}") + + def _initialize_model(self): + """Initialize the EnhancedCNN model""" + try: + # Calculate input shape based on BaseDataInput structure + # OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features + # BTC OHLCV: 300 frames x 5 features = 1500 features + # COB: ±20 buckets x 4 metrics = 160 features + # MA: 4 timeframes x 10 buckets = 40 features + # Technical indicators: 100 features + # Last predictions: 50 features + # Total: 7850 features + input_shape = 7850 + n_actions = 3 # BUY, SELL, HOLD + + # Create model + self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions) + self.model.to(self.device) + + # Load model if path is provided + if self.model_path: + success = self.model.load(self.model_path) + if success: + logger.info(f"Model loaded from {self.model_path}") + else: + logger.warning(f"Failed to load model from {self.model_path}, using new model") + else: + logger.info("No model path provided, using new model") + + except Exception as e: + logger.error(f"Error initializing EnhancedCNN model: {e}") + raise + + def _convert_base_data_to_features(self, base_data: BaseDataInput) -> torch.Tensor: + """ + Convert BaseDataInput to feature vector for EnhancedCNN + + Args: + base_data: Standardized input data + + Returns: + torch.Tensor: Feature vector for EnhancedCNN + """ + try: + # Use the get_feature_vector method from BaseDataInput + features = base_data.get_feature_vector() + + # Convert to torch tensor + features_tensor = torch.tensor(features, dtype=torch.float32, device=self.device) + + return features_tensor + + except Exception as e: + logger.error(f"Error converting BaseDataInput to features: {e}") + # Return empty tensor with correct shape + return torch.zeros(7850, dtype=torch.float32, device=self.device) + + def predict(self, base_data: BaseDataInput) -> ModelOutput: + """ + Make a prediction using the EnhancedCNN model + + Args: + base_data: Standardized input data + + Returns: + ModelOutput: Standardized model output + """ + try: + # Convert BaseDataInput to features + features = self._convert_base_data_to_features(base_data) + + # Ensure features has batch dimension + if features.dim() == 1: + features = features.unsqueeze(0) + + # Set model to evaluation mode + self.model.eval() + + # Make prediction + with torch.no_grad(): + q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.model(features) + + # Get action and confidence + action_probs = torch.softmax(q_values, dim=1) + action_idx = torch.argmax(action_probs, dim=1).item() + confidence = float(action_probs[0, action_idx].item()) + + # Map action index to action string + actions = ['BUY', 'SELL', 'HOLD'] + action = actions[action_idx] + + # Create predictions dictionary + predictions = { + 'action': action, + 'buy_probability': float(action_probs[0, 0].item()), + 'sell_probability': float(action_probs[0, 1].item()), + 'hold_probability': float(action_probs[0, 2].item()), + 'extrema': extrema_pred.squeeze(0).cpu().numpy().tolist(), + 'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist() + } + + # Create hidden states dictionary + hidden_states = { + 'features': features_refined.squeeze(0).cpu().numpy().tolist() + } + + # Create metadata dictionary + metadata = { + 'model_version': '1.0', + 'timestamp': datetime.now().isoformat(), + 'input_shape': features.shape + } + + # Create ModelOutput + model_output = ModelOutput( + model_type='cnn', + model_name=self.model_name, + symbol=base_data.symbol, + timestamp=datetime.now(), + confidence=confidence, + predictions=predictions, + hidden_states=hidden_states, + metadata=metadata + ) + + return model_output + + except Exception as e: + logger.error(f"Error making prediction with EnhancedCNN: {e}") + # Return default ModelOutput + return create_model_output( + model_type='cnn', + model_name=self.model_name, + symbol=base_data.symbol, + action='HOLD', + confidence=0.0 + ) + + def add_training_sample(self, base_data: BaseDataInput, actual_action: str, reward: float): + """ + Add a training sample to the training data + + Args: + base_data: Standardized input data + actual_action: Actual action taken ('BUY', 'SELL', 'HOLD') + reward: Reward received for the action + """ + try: + # Convert BaseDataInput to features + features = self._convert_base_data_to_features(base_data) + + # Convert action to index + actions = ['BUY', 'SELL', 'HOLD'] + action_idx = actions.index(actual_action) + + # Add to training data + with self.training_lock: + self.training_data.append((features, action_idx, reward)) + + # Limit training data size + if len(self.training_data) > self.max_training_samples: + # Sort by reward (highest first) and keep top samples + self.training_data.sort(key=lambda x: x[2], reverse=True) + self.training_data = self.training_data[:self.max_training_samples] + + logger.debug(f"Added training sample for {base_data.symbol}, action: {actual_action}, reward: {reward:.4f}") + + except Exception as e: + logger.error(f"Error adding training sample: {e}") + + def train(self, epochs: int = 1) -> Dict[str, float]: + """ + Train the model with collected data + + Args: + epochs: Number of epochs to train for + + Returns: + Dict[str, float]: Training metrics + """ + try: + with self.training_lock: + # Check if we have enough data + if len(self.training_data) < self.batch_size: + logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}") + return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)} + + # Set model to training mode + self.model.train() + + # Create optimizer + optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) + + # Training metrics + total_loss = 0.0 + correct_predictions = 0 + total_predictions = 0 + + # Train for specified number of epochs + for epoch in range(epochs): + # Shuffle training data + np.random.shuffle(self.training_data) + + # Process in batches + for i in range(0, len(self.training_data), self.batch_size): + batch = self.training_data[i:i+self.batch_size] + + # Skip if batch is too small + if len(batch) < 2: + continue + + # Prepare batch + features = torch.stack([sample[0] for sample in batch]) + actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device) + rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device) + + # Zero gradients + optimizer.zero_grad() + + # Forward pass + q_values, _, _, _, _ = self.model(features) + + # Calculate loss (CrossEntropyLoss with reward weighting) + # First, apply softmax to get probabilities + probs = torch.softmax(q_values, dim=1) + + # Get probability of chosen action + chosen_probs = probs[torch.arange(len(actions)), actions] + + # Calculate negative log likelihood loss + nll_loss = -torch.log(chosen_probs + 1e-10) + + # Weight by reward (higher reward = higher weight) + # Normalize rewards to [0, 1] range + min_reward = rewards.min() + max_reward = rewards.max() + if max_reward > min_reward: + normalized_rewards = (rewards - min_reward) / (max_reward - min_reward) + else: + normalized_rewards = torch.ones_like(rewards) + + # Apply reward weighting (higher reward = higher weight) + weighted_loss = nll_loss * (normalized_rewards + 0.1) # Add small constant to avoid zero weights + + # Mean loss + loss = weighted_loss.mean() + + # Backward pass + loss.backward() + + # Update weights + optimizer.step() + + # Update metrics + total_loss += loss.item() + + # Calculate accuracy + predicted_actions = torch.argmax(q_values, dim=1) + correct_predictions += (predicted_actions == actions).sum().item() + total_predictions += len(actions) + + # Calculate final metrics + avg_loss = total_loss / (len(self.training_data) / self.batch_size) + accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 + + # Save checkpoint + self._save_checkpoint(avg_loss, accuracy) + + logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}") + + return { + 'loss': avg_loss, + 'accuracy': accuracy, + 'samples': len(self.training_data) + } + + except Exception as e: + logger.error(f"Error training model: {e}") + return {'loss': 0.0, 'accuracy': 0.0, 'samples': 0, 'error': str(e)} + + def _save_checkpoint(self, loss: float, accuracy: float): + """ + Save model checkpoint + + Args: + loss: Training loss + accuracy: Training accuracy + """ + try: + # Import checkpoint manager + from utils.checkpoint_manager import CheckpointManager + + # Create checkpoint manager + checkpoint_manager = CheckpointManager( + checkpoint_dir=self.checkpoint_dir, + max_checkpoints=10, + metric_name="accuracy" + ) + + # Create temporary model file + temp_path = os.path.join(self.checkpoint_dir, f"{self.model_name}_temp") + self.model.save(temp_path) + + # Create metrics + metrics = { + 'loss': loss, + 'accuracy': accuracy, + 'samples': len(self.training_data) + } + + # Create metadata + metadata = { + 'timestamp': datetime.now().isoformat(), + 'model_name': self.model_name, + 'input_shape': self.model.input_shape, + 'n_actions': self.model.n_actions + } + + # Save checkpoint + checkpoint_path = checkpoint_manager.save_checkpoint( + model_name=self.model_name, + model_path=f"{temp_path}.pt", + metrics=metrics, + metadata=metadata + ) + + # Delete temporary model file + if os.path.exists(f"{temp_path}.pt"): + os.remove(f"{temp_path}.pt") + + logger.info(f"Model checkpoint saved to {checkpoint_path}") + + except Exception as e: + logger.error(f"Error saving checkpoint: {e}") + + def load_best_checkpoint(self): + """Load the best checkpoint based on accuracy""" + try: + # Import checkpoint manager + from utils.checkpoint_manager import CheckpointManager + + # Create checkpoint manager + checkpoint_manager = CheckpointManager( + checkpoint_dir=self.checkpoint_dir, + max_checkpoints=10, + metric_name="accuracy" + ) + + # Load best checkpoint + best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name) + + if not best_checkpoint_path: + logger.info("No checkpoints found") + return False + + # Load model + success = self.model.load(best_checkpoint_path) + + if success: + logger.info(f"Loaded best checkpoint from {best_checkpoint_path}") + + # Log metrics + metrics = best_checkpoint_metadata.get('metrics', {}) + logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}") + + return True + else: + logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}") + return False + + except Exception as e: + logger.error(f"Error loading best checkpoint: {e}") + return False \ No newline at end of file diff --git a/core/model_output_manager.py b/core/model_output_manager.py index adaf8f5..c09124d 100644 --- a/core/model_output_manager.py +++ b/core/model_output_manager.py @@ -1,34 +1,31 @@ """ Model Output Manager -This module provides extensible model output storage and management for the multi-modal trading system. -Supports CNN, RL, LSTM, Transformer, and future model types with cross-model feeding capabilities. +This module provides a centralized storage and management system for model outputs, +enabling cross-model feeding and evaluation. """ -import logging +import os import json -import pickle -from datetime import datetime, timedelta -from typing import Dict, List, Optional, Any, Union -from collections import deque, defaultdict +import logging +import time +from datetime import datetime +from typing import Dict, List, Optional, Any from threading import Lock -from pathlib import Path -from .data_models import ModelOutput, create_model_output +from .data_models import ModelOutput logger = logging.getLogger(__name__) class ModelOutputManager: """ - Extensible model output storage and management system + Centralized storage and management system for model outputs - Features: - - Standardized ModelOutput storage for all model types - - Cross-model feeding with hidden states - - Historical output tracking - - Metadata management - - Persistence and recovery - - Performance analytics + This class: + 1. Stores model outputs for all models + 2. Provides access to current and historical outputs + 3. Handles persistence of outputs to disk + 4. Supports evaluation of model performance """ def __init__(self, cache_dir: str = "cache/model_outputs", max_history: int = 1000): @@ -36,279 +33,226 @@ class ModelOutputManager: Initialize the model output manager Args: - cache_dir: Directory for persistent storage - max_history: Maximum number of outputs to keep in memory per model + cache_dir: Directory to store model outputs + max_history: Maximum number of historical outputs to keep per model """ - self.cache_dir = Path(cache_dir) - self.cache_dir.mkdir(parents=True, exist_ok=True) + self.cache_dir = cache_dir self.max_history = max_history + self.outputs_lock = Lock() - # In-memory storage - self.current_outputs: Dict[str, Dict[str, ModelOutput]] = defaultdict(dict) # {symbol: {model_name: ModelOutput}} - self.output_history: Dict[str, Dict[str, deque]] = defaultdict(lambda: defaultdict(lambda: deque(maxlen=max_history))) # {symbol: {model_name: deque}} - self.cross_model_states: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: hidden_states}} + # Current outputs for each model and symbol + # {symbol: {model_name: ModelOutput}} + self.current_outputs: Dict[str, Dict[str, ModelOutput]] = {} - # Metadata tracking - self.model_metadata: Dict[str, Dict[str, Any]] = defaultdict(dict) # {model_name: metadata} - self.performance_stats: Dict[str, Dict[str, Any]] = defaultdict(lambda: defaultdict(dict)) # {symbol: {model_name: stats}} + # Historical outputs for each model and symbol + # {symbol: {model_name: List[ModelOutput]}} + self.historical_outputs: Dict[str, Dict[str, List[ModelOutput]]] = {} - # Thread safety - self.storage_lock = Lock() + # Performance metrics for each model and symbol + # {symbol: {model_name: Dict[str, float]}} + self.performance_metrics: Dict[str, Dict[str, Dict[str, float]]] = {} - # Supported model types - self.supported_model_types = { - 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator', - 'ensemble', 'hybrid', 'custom' # Extensible for future types - } + # Create cache directory if it doesn't exist + os.makedirs(cache_dir, exist_ok=True) - logger.info(f"ModelOutputManager initialized with cache dir: {self.cache_dir}") - logger.info(f"Supported model types: {self.supported_model_types}") + logger.info(f"ModelOutputManager initialized with cache_dir: {cache_dir}") def store_output(self, model_output: ModelOutput) -> bool: """ - Store model output with full extensibility support + Store a model output Args: - model_output: ModelOutput from any model type + model_output: Model output to store Returns: - bool: True if stored successfully, False otherwise + bool: True if successful, False otherwise """ try: - with self.storage_lock: - symbol = model_output.symbol - model_name = model_output.model_name - model_type = model_output.model_type - - # Validate model type (extensible) - if model_type not in self.supported_model_types: - logger.warning(f"Unknown model type '{model_type}' - adding to supported types") - self.supported_model_types.add(model_type) + symbol = model_output.symbol + model_name = model_output.model_name + + with self.outputs_lock: + # Initialize dictionaries if they don't exist + if symbol not in self.current_outputs: + self.current_outputs[symbol] = {} + if symbol not in self.historical_outputs: + self.historical_outputs[symbol] = {} + if model_name not in self.historical_outputs[symbol]: + self.historical_outputs[symbol][model_name] = [] # Store current output self.current_outputs[symbol][model_name] = model_output - # Add to history - self.output_history[symbol][model_name].append(model_output) - - # Store cross-model states if available - if model_output.hidden_states: - self.cross_model_states[symbol][model_name] = model_output.hidden_states - - # Update model metadata - self._update_model_metadata(model_name, model_type, model_output.metadata) - - # Update performance statistics - self._update_performance_stats(symbol, model_name, model_output) - - # Persist to disk (async to avoid blocking) - self._persist_output_async(model_output) - - logger.debug(f"Stored output from {model_name} ({model_type}) for {symbol}") - return True + # Add to historical outputs + self.historical_outputs[symbol][model_name].append(model_output) + # Limit historical outputs + if len(self.historical_outputs[symbol][model_name]) > self.max_history: + self.historical_outputs[symbol][model_name] = self.historical_outputs[symbol][model_name][-self.max_history:] + + # Persist output to disk + self._persist_output(model_output) + + return True + except Exception as e: logger.error(f"Error storing model output: {e}") return False def get_current_output(self, symbol: str, model_name: str) -> Optional[ModelOutput]: """ - Get the current (latest) output from a specific model + Get the current output for a model and symbol Args: - symbol: Trading symbol - model_name: Name of the model + symbol: Symbol to get output for + model_name: Model name to get output for Returns: - ModelOutput: Latest output from the model, or None if not available + ModelOutput: Current output, or None if not available """ try: - return self.current_outputs.get(symbol, {}).get(model_name) + with self.outputs_lock: + if symbol in self.current_outputs and model_name in self.current_outputs[symbol]: + return self.current_outputs[symbol][model_name] + return None + except Exception as e: - logger.error(f"Error getting current output for {model_name}: {e}") + logger.error(f"Error getting current output: {e}") return None def get_all_current_outputs(self, symbol: str) -> Dict[str, ModelOutput]: """ - Get all current outputs for a symbol (for cross-model feeding) + Get all current outputs for a symbol Args: - symbol: Trading symbol + symbol: Symbol to get outputs for Returns: - Dict[str, ModelOutput]: Dictionary of current outputs by model name + Dict[str, ModelOutput]: Dictionary of model name to output """ try: - return dict(self.current_outputs.get(symbol, {})) + with self.outputs_lock: + if symbol in self.current_outputs: + return self.current_outputs[symbol].copy() + return {} + except Exception as e: - logger.error(f"Error getting all current outputs for {symbol}: {e}") + logger.error(f"Error getting all current outputs: {e}") return {} - def get_output_history(self, symbol: str, model_name: str, count: int = 10) -> List[ModelOutput]: + def get_historical_outputs(self, symbol: str, model_name: str, limit: int = None) -> List[ModelOutput]: """ - Get historical outputs from a model + Get historical outputs for a model and symbol Args: - symbol: Trading symbol - model_name: Name of the model - count: Number of historical outputs to retrieve + symbol: Symbol to get outputs for + model_name: Model name to get outputs for + limit: Maximum number of outputs to return, None for all Returns: - List[ModelOutput]: List of historical outputs (most recent first) + List[ModelOutput]: List of historical outputs """ try: - history = self.output_history.get(symbol, {}).get(model_name, deque()) - return list(history)[-count:][::-1] # Most recent first + with self.outputs_lock: + if symbol in self.historical_outputs and model_name in self.historical_outputs[symbol]: + outputs = self.historical_outputs[symbol][model_name] + if limit is not None: + outputs = outputs[-limit:] + return outputs.copy() + return [] + except Exception as e: - logger.error(f"Error getting output history for {model_name}: {e}") + logger.error(f"Error getting historical outputs: {e}") return [] - def get_cross_model_states(self, symbol: str, requesting_model: str) -> Dict[str, Dict[str, Any]]: + def evaluate_model_performance(self, symbol: str, model_name: str) -> Dict[str, float]: """ - Get hidden states from other models for cross-model feeding + Evaluate model performance based on historical outputs Args: - symbol: Trading symbol - requesting_model: Name of the model requesting the states + symbol: Symbol to evaluate + model_name: Model name to evaluate Returns: - Dict[str, Dict[str, Any]]: Hidden states from other models + Dict[str, float]: Performance metrics """ try: - all_states = self.cross_model_states.get(symbol, {}) - # Return states from all models except the requesting one - return {model_name: states for model_name, states in all_states.items() - if model_name != requesting_model} - except Exception as e: - logger.error(f"Error getting cross-model states for {requesting_model}: {e}") - return {} - - def get_model_types_active(self, symbol: str) -> List[str]: - """ - Get list of active model types for a symbol - - Args: - symbol: Trading symbol - - Returns: - List[str]: List of active model types - """ - try: - current_outputs = self.current_outputs.get(symbol, {}) - return [output.model_type for output in current_outputs.values()] - except Exception as e: - logger.error(f"Error getting active model types for {symbol}: {e}") - return [] - - def get_consensus_prediction(self, symbol: str, confidence_threshold: float = 0.5) -> Optional[Dict[str, Any]]: - """ - Get consensus prediction from all active models - - Args: - symbol: Trading symbol - confidence_threshold: Minimum confidence threshold for inclusion - - Returns: - Dict containing consensus prediction or None - """ - try: - current_outputs = self.current_outputs.get(symbol, {}) - if not current_outputs: - return None + # Get historical outputs + outputs = self.get_historical_outputs(symbol, model_name) - # Filter by confidence threshold - high_confidence_outputs = [ - output for output in current_outputs.values() - if output.confidence >= confidence_threshold - ] + if not outputs: + return {'accuracy': 0.0, 'confidence': 0.0, 'samples': 0} - if not high_confidence_outputs: - return None + # Calculate metrics + total_outputs = len(outputs) + total_confidence = sum(output.confidence for output in outputs) + avg_confidence = total_confidence / total_outputs if total_outputs > 0 else 0.0 - # Calculate consensus - buy_votes = sum(1 for output in high_confidence_outputs - if output.predictions.get('action') == 'BUY') - sell_votes = sum(1 for output in high_confidence_outputs - if output.predictions.get('action') == 'SELL') - hold_votes = sum(1 for output in high_confidence_outputs - if output.predictions.get('action') == 'HOLD') + # For now, we don't have ground truth to calculate accuracy + # In the future, we can add this by comparing predictions to actual market movements - total_votes = len(high_confidence_outputs) - avg_confidence = sum(output.confidence for output in high_confidence_outputs) / total_votes - - # Determine consensus action - if buy_votes > sell_votes and buy_votes > hold_votes: - consensus_action = 'BUY' - elif sell_votes > buy_votes and sell_votes > hold_votes: - consensus_action = 'SELL' - else: - consensus_action = 'HOLD' - - return { - 'action': consensus_action, + metrics = { 'confidence': avg_confidence, - 'votes': {'BUY': buy_votes, 'SELL': sell_votes, 'HOLD': hold_votes}, - 'total_models': total_votes, - 'model_types': [output.model_type for output in high_confidence_outputs] + 'samples': total_outputs, + 'last_update': datetime.now().isoformat() } - except Exception as e: - logger.error(f"Error calculating consensus prediction for {symbol}: {e}") - return None - - def _update_model_metadata(self, model_name: str, model_type: str, metadata: Dict[str, Any]): - """Update metadata for a model""" - try: - if model_name not in self.model_metadata: - self.model_metadata[model_name] = { - 'model_type': model_type, - 'first_seen': datetime.now(), - 'total_predictions': 0, - 'custom_metadata': {} - } + # Store metrics + with self.outputs_lock: + if symbol not in self.performance_metrics: + self.performance_metrics[symbol] = {} + self.performance_metrics[symbol][model_name] = metrics - self.model_metadata[model_name]['total_predictions'] += 1 - self.model_metadata[model_name]['last_seen'] = datetime.now() - - # Merge custom metadata - if metadata: - self.model_metadata[model_name]['custom_metadata'].update(metadata) - - except Exception as e: - logger.error(f"Error updating model metadata: {e}") - - def _update_performance_stats(self, symbol: str, model_name: str, model_output: ModelOutput): - """Update performance statistics for a model""" - try: - stats = self.performance_stats[symbol][model_name] - - if 'prediction_count' not in stats: - stats['prediction_count'] = 0 - stats['confidence_sum'] = 0.0 - stats['action_counts'] = {'BUY': 0, 'SELL': 0, 'HOLD': 0} - stats['first_prediction'] = model_output.timestamp - - stats['prediction_count'] += 1 - stats['confidence_sum'] += model_output.confidence - stats['avg_confidence'] = stats['confidence_sum'] / stats['prediction_count'] - stats['last_prediction'] = model_output.timestamp - - action = model_output.predictions.get('action', 'HOLD') - if action in stats['action_counts']: - stats['action_counts'][action] += 1 + return metrics except Exception as e: - logger.error(f"Error updating performance stats: {e}") + logger.error(f"Error evaluating model performance: {e}") + return {'error': str(e)} - def _persist_output_async(self, model_output: ModelOutput): - """Persist model output to disk (simplified version)""" + def get_performance_metrics(self, symbol: str, model_name: str) -> Dict[str, float]: + """ + Get performance metrics for a model and symbol + + Args: + symbol: Symbol to get metrics for + model_name: Model name to get metrics for + + Returns: + Dict[str, float]: Performance metrics + """ try: - # Create filename based on model and timestamp - timestamp_str = model_output.timestamp.strftime("%Y%m%d_%H%M%S") - filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp_str}.json" - filepath = self.cache_dir / filename + with self.outputs_lock: + if symbol in self.performance_metrics and model_name in self.performance_metrics[symbol]: + return self.performance_metrics[symbol][model_name].copy() - # Convert to JSON-serializable format + # If no metrics are available, calculate them + return self.evaluate_model_performance(symbol, model_name) + + except Exception as e: + logger.error(f"Error getting performance metrics: {e}") + return {'error': str(e)} + + def _persist_output(self, model_output: ModelOutput) -> bool: + """ + Persist a model output to disk + + Args: + model_output: Model output to persist + + Returns: + bool: True if successful, False otherwise + """ + try: + # Create directory if it doesn't exist + symbol_dir = os.path.join(self.cache_dir, model_output.symbol.replace('/', '_')) + os.makedirs(symbol_dir, exist_ok=True) + + # Create filename with timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"{model_output.model_name}_{model_output.symbol.replace('/', '_')}_{timestamp}.json" + filepath = os.path.join(self.cache_dir, filename) + + # Convert ModelOutput to dictionary output_dict = { 'model_type': model_output.model_type, 'model_name': model_output.model_name, @@ -319,77 +263,120 @@ class ModelOutputManager: 'metadata': model_output.metadata } - # Save to file (in a real implementation, this would be async) + # Don't store hidden states in file (too large) + + # Write to file with open(filepath, 'w') as f: json.dump(output_dict, f, indent=2) - + + return True + except Exception as e: logger.error(f"Error persisting model output: {e}") + return False - def get_performance_summary(self, symbol: str) -> Dict[str, Any]: + def load_outputs_from_disk(self, symbol: str = None, model_name: str = None) -> int: """ - Get performance summary for all models for a symbol + Load model outputs from disk Args: - symbol: Trading symbol + symbol: Symbol to load outputs for, None for all + model_name: Model name to load outputs for, None for all Returns: - Dict containing performance summary + int: Number of outputs loaded """ try: - summary = { - 'symbol': symbol, - 'active_models': len(self.current_outputs.get(symbol, {})), - 'model_stats': {} - } + # Find all output files + import glob - for model_name, stats in self.performance_stats.get(symbol, {}).items(): - summary['model_stats'][model_name] = { - 'predictions': stats.get('prediction_count', 0), - 'avg_confidence': round(stats.get('avg_confidence', 0.0), 3), - 'action_distribution': stats.get('action_counts', {}), - 'model_type': self.model_metadata.get(model_name, {}).get('model_type', 'unknown') - } + if symbol and model_name: + pattern = os.path.join(self.cache_dir, f"{model_name}_{symbol.replace('/', '_')}*.json") + elif symbol: + pattern = os.path.join(self.cache_dir, f"*_{symbol.replace('/', '_')}*.json") + elif model_name: + pattern = os.path.join(self.cache_dir, f"{model_name}_*.json") + else: + pattern = os.path.join(self.cache_dir, "*.json") - return summary + output_files = glob.glob(pattern) + + if not output_files: + logger.info(f"No output files found for pattern: {pattern}") + return 0 + + # Load each file + loaded_count = 0 + for filepath in output_files: + try: + with open(filepath, 'r') as f: + output_dict = json.load(f) + + # Create ModelOutput + model_output = ModelOutput( + model_type=output_dict['model_type'], + model_name=output_dict['model_name'], + symbol=output_dict['symbol'], + timestamp=datetime.fromisoformat(output_dict['timestamp']), + confidence=output_dict['confidence'], + predictions=output_dict['predictions'], + hidden_states={}, # Don't load hidden states from disk + metadata=output_dict.get('metadata', {}) + ) + + # Store output + self.store_output(model_output) + loaded_count += 1 + + except Exception as e: + logger.error(f"Error loading output file {filepath}: {e}") + + logger.info(f"Loaded {loaded_count} model outputs from disk") + return loaded_count except Exception as e: - logger.error(f"Error getting performance summary: {e}") - return {'symbol': symbol, 'error': str(e)} + logger.error(f"Error loading outputs from disk: {e}") + return 0 - def cleanup_old_outputs(self, max_age_hours: int = 24): + def cleanup_old_outputs(self, max_age_days: int = 30) -> int: """ - Clean up old outputs to manage memory usage + Clean up old output files Args: - max_age_hours: Maximum age of outputs to keep in hours + max_age_days: Maximum age of files to keep in days + + Returns: + int: Number of files deleted """ try: - cutoff_time = datetime.now() - timedelta(hours=max_age_hours) + # Find all output files + import glob + output_files = glob.glob(os.path.join(self.cache_dir, "*.json")) - with self.storage_lock: - for symbol in self.output_history: - for model_name in self.output_history[symbol]: - history = self.output_history[symbol][model_name] - # Remove old outputs - while history and history[0].timestamp < cutoff_time: - history.popleft() + if not output_files: + return 0 - logger.info(f"Cleaned up outputs older than {max_age_hours} hours") + # Calculate cutoff time + cutoff_time = time.time() - (max_age_days * 24 * 60 * 60) + + # Delete old files + deleted_count = 0 + for filepath in output_files: + try: + # Get file modification time + mtime = os.path.getmtime(filepath) + + # Delete if older than cutoff + if mtime < cutoff_time: + os.remove(filepath) + deleted_count += 1 + + except Exception as e: + logger.error(f"Error deleting file {filepath}: {e}") + + logger.info(f"Deleted {deleted_count} old model output files") + return deleted_count except Exception as e: logger.error(f"Error cleaning up old outputs: {e}") - - def add_custom_model_type(self, model_type: str): - """ - Add support for a new custom model type - - Args: - model_type: Name of the new model type - """ - self.supported_model_types.add(model_type) - logger.info(f"Added support for custom model type: {model_type}") - - def get_supported_model_types(self) -> List[str]: - """Get list of all supported model types""" - return list(self.supported_model_types) \ No newline at end of file + return 0 \ No newline at end of file diff --git a/test_continuous_cnn_training.py b/test_continuous_cnn_training.py new file mode 100644 index 0000000..0a6778b --- /dev/null +++ b/test_continuous_cnn_training.py @@ -0,0 +1,155 @@ +""" +Test Continuous CNN Training + +This script demonstrates how the CNN model can be trained with each new inference result +using collected data, implementing a continuous learning loop. +""" + +import logging +import time +from datetime import datetime +import random +import os + +from core.standardized_data_provider import StandardizedDataProvider +from core.enhanced_cnn_adapter import EnhancedCNNAdapter +from core.data_models import create_model_output + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +def simulate_market_feedback(action, symbol): + """ + Simulate market feedback for a given action + + In a real system, this would be replaced with actual market performance data + + Args: + action: Trading action ('BUY', 'SELL', 'HOLD') + symbol: Trading symbol + + Returns: + tuple: (actual_action, reward) + """ + # Simulate market movement (random for demonstration) + market_direction = random.choice(['up', 'down', 'sideways']) + + # Determine actual best action based on market direction + if market_direction == 'up': + best_action = 'BUY' + elif market_direction == 'down': + best_action = 'SELL' + else: + best_action = 'HOLD' + + # Calculate reward based on whether the action matched the best action + if action == best_action: + reward = random.uniform(0.01, 0.1) # Positive reward for correct action + else: + reward = random.uniform(-0.1, -0.01) # Negative reward for incorrect action + + logger.info(f"Market went {market_direction}, best action was {best_action}, model chose {action}, reward: {reward:.4f}") + + return best_action, reward + +def test_continuous_training(): + """Test continuous training of the CNN model with new inference results""" + try: + # Initialize data provider + symbols = ['ETH/USDT', 'BTC/USDT'] + timeframes = ['1s', '1m', '1h', '1d'] + data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes) + + # Initialize CNN adapter + checkpoint_dir = "models/enhanced_cnn" + os.makedirs(checkpoint_dir, exist_ok=True) + cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=checkpoint_dir) + + # Load best checkpoint if available + cnn_adapter.load_best_checkpoint() + + # Continuous learning loop + num_iterations = 10 + training_frequency = 3 # Train every N iterations + samples_collected = 0 + + logger.info(f"Starting continuous learning loop with {num_iterations} iterations") + + for i in range(num_iterations): + logger.info(f"\nIteration {i+1}/{num_iterations}") + + # Get standardized input data + symbol = random.choice(symbols) + logger.info(f"Getting data for {symbol}...") + base_data = data_provider.get_base_data_input(symbol) + + if base_data is None: + logger.warning(f"Failed to get base data input for {symbol}, skipping iteration") + continue + + # Make prediction + logger.info(f"Making prediction for {symbol}...") + model_output = cnn_adapter.predict(base_data) + + # Log prediction + action = model_output.predictions['action'] + confidence = model_output.confidence + logger.info(f"Prediction: {action} with confidence {confidence:.4f}") + + # Store model output + data_provider.store_model_output(model_output) + + # Simulate market feedback + best_action, reward = simulate_market_feedback(action, symbol) + + # Add training sample + logger.info(f"Adding training sample: action={best_action}, reward={reward:.4f}") + cnn_adapter.add_training_sample(base_data, best_action, reward) + samples_collected += 1 + + # Train model periodically + if (i + 1) % training_frequency == 0 and samples_collected >= 3: + logger.info(f"Training model with {samples_collected} samples...") + metrics = cnn_adapter.train(epochs=1) + + # Log training metrics + logger.info(f"Training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}") + + # Simulate time passing + time.sleep(1) + + logger.info("\nContinuous learning loop completed") + + # Final evaluation + logger.info("Performing final evaluation...") + + # Get data for evaluation + symbol = 'ETH/USDT' + base_data = data_provider.get_base_data_input(symbol) + + if base_data is not None: + # Make prediction + model_output = cnn_adapter.predict(base_data) + + # Log prediction + action = model_output.predictions['action'] + confidence = model_output.confidence + logger.info(f"Final prediction for {symbol}: {action} with confidence {confidence:.4f}") + + # Get model output manager + output_manager = data_provider.get_model_output_manager() + + # Evaluate model performance + metrics = output_manager.evaluate_model_performance(symbol, cnn_adapter.model_name) + logger.info(f"Performance metrics: {metrics}") + else: + logger.warning(f"Failed to get base data input for final evaluation") + + logger.info("Test completed successfully") + + except Exception as e: + logger.error(f"Error in test: {e}", exc_info=True) + +if __name__ == "__main__": + test_continuous_training() \ No newline at end of file diff --git a/test_enhanced_cnn_adapter.py b/test_enhanced_cnn_adapter.py new file mode 100644 index 0000000..a03ff1b --- /dev/null +++ b/test_enhanced_cnn_adapter.py @@ -0,0 +1,87 @@ +""" +Test Enhanced CNN Adapter + +This script tests the EnhancedCNNAdapter with standardized input format. +""" + +import logging +import time +from datetime import datetime + +from core.standardized_data_provider import StandardizedDataProvider +from core.enhanced_cnn_adapter import EnhancedCNNAdapter +from core.data_models import create_model_output + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +def test_cnn_adapter(): + """Test the EnhancedCNNAdapter with standardized input format""" + try: + # Initialize data provider + symbols = ['ETH/USDT', 'BTC/USDT'] + timeframes = ['1s', '1m', '1h', '1d'] + data_provider = StandardizedDataProvider(symbols=symbols, timeframes=timeframes) + + # Initialize CNN adapter + cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn") + + # Load best checkpoint if available + cnn_adapter.load_best_checkpoint() + + # Get standardized input data + logger.info("Getting standardized input data...") + base_data = data_provider.get_base_data_input('ETH/USDT') + + if base_data is None: + logger.error("Failed to get base data input") + return + + # Make prediction + logger.info("Making prediction...") + model_output = cnn_adapter.predict(base_data) + + # Log prediction + logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}") + + # Store model output + data_provider.store_model_output(model_output) + + # Add training sample (simulated) + logger.info("Adding training sample...") + cnn_adapter.add_training_sample(base_data, 'BUY', 0.05) + + # Train model + logger.info("Training model...") + metrics = cnn_adapter.train(epochs=1) + + # Log training metrics + logger.info(f"Training metrics: {metrics}") + + # Make another prediction + logger.info("Making another prediction...") + model_output = cnn_adapter.predict(base_data) + + # Log prediction + logger.info(f"Prediction: {model_output.predictions['action']} with confidence {model_output.confidence:.4f}") + + # Test model output manager + logger.info("Testing model output manager...") + output_manager = data_provider.get_model_output_manager() + + # Get current outputs + current_outputs = output_manager.get_all_current_outputs('ETH/USDT') + logger.info(f"Current outputs: {len(current_outputs)} models") + + # Evaluate model performance + metrics = output_manager.evaluate_model_performance('ETH/USDT', 'enhanced_cnn_v1') + logger.info(f"Performance metrics: {metrics}") + + logger.info("Test completed successfully") + + except Exception as e: + logger.error(f"Error in test: {e}", exc_info=True) + +if __name__ == "__main__": + test_cnn_adapter() \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py index e69de29..71aab0c 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -0,0 +1,3 @@ +""" +Utils package for the multi-modal trading system +""" \ No newline at end of file diff --git a/utils/checkpoint_manager.py b/utils/checkpoint_manager.py index 5d2b078..99b5f19 100644 --- a/utils/checkpoint_manager.py +++ b/utils/checkpoint_manager.py @@ -1,466 +1,408 @@ -#!/usr/bin/env python3 -""" -Checkpoint Management System for W&B Training +""" +Checkpoint Manager + +This module provides functionality for managing model checkpoints, including: +- Saving checkpoints with metadata +- Loading the best checkpoint based on performance metrics +- Cleaning up old or underperforming checkpoints """ import os import json +import glob import logging -from datetime import datetime, timedelta -from pathlib import Path -from typing import Dict, List, Optional, Tuple, Any -from dataclasses import dataclass, asdict -from collections import defaultdict +import shutil import torch -import random - -try: - import wandb - WANDB_AVAILABLE = True -except ImportError: - WANDB_AVAILABLE = False +from datetime import datetime +from typing import Dict, List, Optional, Any, Tuple logger = logging.getLogger(__name__) -@dataclass -class CheckpointMetadata: - checkpoint_id: str - model_name: str - model_type: str - file_path: str - created_at: datetime - file_size_mb: float - performance_score: float - accuracy: Optional[float] = None - loss: Optional[float] = None - val_accuracy: Optional[float] = None - val_loss: Optional[float] = None - reward: Optional[float] = None - pnl: Optional[float] = None - epoch: Optional[int] = None - training_time_hours: Optional[float] = None - total_parameters: Optional[int] = None - wandb_run_id: Optional[str] = None - wandb_artifact_name: Optional[str] = None +# Global checkpoint manager instance +_checkpoint_manager_instance = None + +def get_checkpoint_manager(checkpoint_dir: str = "models/checkpoints", max_checkpoints: int = 10, metric_name: str = "accuracy") -> 'CheckpointManager': + """ + Get the global checkpoint manager instance - def to_dict(self) -> Dict[str, Any]: - data = asdict(self) - data['created_at'] = self.created_at.isoformat() - return data + Args: + checkpoint_dir: Directory to store checkpoints + max_checkpoints: Maximum number of checkpoints to keep + metric_name: Metric to use for ranking checkpoints - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'CheckpointMetadata': - data['created_at'] = datetime.fromisoformat(data['created_at']) - return cls(**data) + Returns: + CheckpointManager: Global checkpoint manager instance + """ + global _checkpoint_manager_instance + + if _checkpoint_manager_instance is None: + _checkpoint_manager_instance = CheckpointManager( + checkpoint_dir=checkpoint_dir, + max_checkpoints=max_checkpoints, + metric_name=metric_name + ) + + return _checkpoint_manager_instance + +def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Dict[str, Any] = None, checkpoint_dir: str = "models/checkpoints") -> Any: + """ + Save a checkpoint with metadata + + Args: + model: The model to save + model_name: Name of the model + model_type: Type of the model ('cnn', 'rl', etc.) + performance_metrics: Performance metrics + training_metadata: Additional training metadata + checkpoint_dir: Directory to store checkpoints + + Returns: + Any: Checkpoint metadata + """ + try: + # Create checkpoint directory + os.makedirs(checkpoint_dir, exist_ok=True) + + # Create timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # Create checkpoint path + model_dir = os.path.join(checkpoint_dir, model_name) + os.makedirs(model_dir, exist_ok=True) + checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp}") + + # Save model + if hasattr(model, 'save'): + # Use model's save method if available + model.save(checkpoint_path) + else: + # Otherwise, save state_dict + torch_path = f"{checkpoint_path}.pt" + torch.save({ + 'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else None, + 'model_name': model_name, + 'model_type': model_type, + 'timestamp': timestamp + }, torch_path) + + # Create metadata + checkpoint_metadata = { + 'model_name': model_name, + 'model_type': model_type, + 'timestamp': timestamp, + 'performance_metrics': performance_metrics, + 'training_metadata': training_metadata or {}, + 'checkpoint_id': f"{model_name}_{timestamp}" + } + + # Add performance score for sorting + primary_metric = 'accuracy' if 'accuracy' in performance_metrics else 'reward' + checkpoint_metadata['performance_score'] = performance_metrics.get(primary_metric, 0.0) + checkpoint_metadata['created_at'] = timestamp + + # Save metadata + with open(f"{checkpoint_path}_metadata.json", 'w') as f: + json.dump(checkpoint_metadata, f, indent=2) + + # Get checkpoint manager and clean up old checkpoints + checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir) + checkpoint_manager._cleanup_checkpoints(model_name) + + # Return metadata as an object + class CheckpointMetadata: + def __init__(self, metadata): + for key, value in metadata.items(): + setattr(self, key, value) + + return CheckpointMetadata(checkpoint_metadata) + + except Exception as e: + logger.error(f"Error saving checkpoint: {e}") + return None + +def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoints") -> Optional[Tuple[str, Any]]: + """ + Load the best checkpoint based on performance metrics + + Args: + model_name: Name of the model + checkpoint_dir: Directory to store checkpoints + + Returns: + Optional[Tuple[str, Any]]: Path to the best checkpoint and its metadata, or None if not found + """ + try: + checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir) + checkpoint_path, checkpoint_metadata = checkpoint_manager.load_best_checkpoint(model_name) + + if not checkpoint_path: + return None + + # Convert metadata to object + class CheckpointMetadata: + def __init__(self, metadata): + for key, value in metadata.items(): + setattr(self, key, value) + + # Add performance score if not present + if not hasattr(self, 'performance_score'): + metrics = getattr(self, 'metrics', {}) + primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward' + self.performance_score = metrics.get(primary_metric, 0.0) + + # Add created_at if not present + if not hasattr(self, 'created_at'): + self.created_at = getattr(self, 'timestamp', 'unknown') + + return f"{checkpoint_path}.pt", CheckpointMetadata(checkpoint_metadata) + + except Exception as e: + logger.error(f"Error loading best checkpoint: {e}") + return None class CheckpointManager: - def __init__(self, - base_checkpoint_dir: str = "NN/models/saved", - max_checkpoints_per_model: int = 5, - metadata_file: str = "checkpoint_metadata.json", - enable_wandb: bool = True): - self.base_dir = Path(base_checkpoint_dir) - self.base_dir.mkdir(parents=True, exist_ok=True) - - self.max_checkpoints = max_checkpoints_per_model - self.metadata_file = self.base_dir / metadata_file - self.enable_wandb = enable_wandb and WANDB_AVAILABLE - - self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list) - self._load_metadata() - - logger.info(f"Checkpoint Manager initialized - Max checkpoints per model: {self.max_checkpoints}") + """ + Manages model checkpoints with performance-based optimization - def save_checkpoint(self, model, model_name: str, model_type: str, - performance_metrics: Dict[str, float], - training_metadata: Optional[Dict[str, Any]] = None, - force_save: bool = False) -> Optional[CheckpointMetadata]: + This class: + 1. Saves checkpoints with metadata + 2. Loads the best checkpoint based on performance metrics + 3. Cleans up old or underperforming checkpoints + """ + + def __init__(self, checkpoint_dir: str, max_checkpoints: int = 10, metric_name: str = "accuracy"): + """ + Initialize the checkpoint manager + + Args: + checkpoint_dir: Directory to store checkpoints + max_checkpoints: Maximum number of checkpoints to keep + metric_name: Metric to use for ranking checkpoints + """ + self.checkpoint_dir = checkpoint_dir + self.max_checkpoints = max_checkpoints + self.metric_name = metric_name + + # Create checkpoint directory if it doesn't exist + os.makedirs(checkpoint_dir, exist_ok=True) + + logger.info(f"CheckpointManager initialized with checkpoint_dir: {checkpoint_dir}") + + def save_checkpoint(self, model_name: str, model_path: str, metrics: Dict[str, float], metadata: Dict[str, Any] = None) -> str: + """ + Save a checkpoint with metadata + + Args: + model_name: Name of the model + model_path: Path to the model file + metrics: Performance metrics + metadata: Additional metadata + + Returns: + str: Path to the saved checkpoint + """ try: - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - checkpoint_id = f"{model_name}_{timestamp}" + # Create timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - model_dir = self.base_dir / model_name - model_dir.mkdir(exist_ok=True) + # Create checkpoint directory + checkpoint_dir = os.path.join(self.checkpoint_dir, model_name) + os.makedirs(checkpoint_dir, exist_ok=True) - checkpoint_path = model_dir / f"{checkpoint_id}.pt" + # Create checkpoint path + checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_{timestamp}") - performance_score = self._calculate_performance_score(performance_metrics) + # Copy model file to checkpoint path + shutil.copy2(model_path, f"{checkpoint_path}.pt") - if not force_save and not self._should_save_checkpoint(model_name, performance_score): - logger.debug(f"Skipping checkpoint save for {model_name} - performance not improved") - return None - - success = self._save_model_file(model, checkpoint_path, model_type) - if not success: - return None - - file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024) - - metadata = CheckpointMetadata( - checkpoint_id=checkpoint_id, - model_name=model_name, - model_type=model_type, - file_path=str(checkpoint_path), - created_at=datetime.now(), - file_size_mb=file_size_mb, - performance_score=performance_score, - accuracy=performance_metrics.get('accuracy'), - loss=performance_metrics.get('loss'), - val_accuracy=performance_metrics.get('val_accuracy'), - val_loss=performance_metrics.get('val_loss'), - reward=performance_metrics.get('reward'), - pnl=performance_metrics.get('pnl'), - epoch=training_metadata.get('epoch') if training_metadata else None, - training_time_hours=training_metadata.get('training_time_hours') if training_metadata else None, - total_parameters=training_metadata.get('total_parameters') if training_metadata else None - ) - - if self.enable_wandb and wandb.run is not None: - artifact_name = self._upload_to_wandb(checkpoint_path, metadata) - metadata.wandb_run_id = wandb.run.id - metadata.wandb_artifact_name = artifact_name - - self.checkpoints[model_name].append(metadata) - self._rotate_checkpoints(model_name) - self._save_metadata() - - logger.debug(f"Saved checkpoint: {checkpoint_id} (score: {performance_score:.4f})") - return metadata - - except Exception as e: - logger.error(f"Error saving checkpoint for {model_name}: {e}") - return None - - def load_best_checkpoint(self, model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: - try: - # First, try the standard checkpoint system - if model_name in self.checkpoints and self.checkpoints[model_name]: - # Filter out checkpoints with non-existent files - valid_checkpoints = [ - cp for cp in self.checkpoints[model_name] - if Path(cp.file_path).exists() - ] - - if valid_checkpoints: - best_checkpoint = max(valid_checkpoints, key=lambda x: x.performance_score) - logger.debug(f"Loading best checkpoint for {model_name}: {best_checkpoint.checkpoint_id}") - return best_checkpoint.file_path, best_checkpoint - else: - # Clean up invalid metadata entries - invalid_count = len(self.checkpoints[model_name]) - logger.warning(f"Found {invalid_count} invalid checkpoint entries for {model_name}, cleaning up metadata") - self.checkpoints[model_name] = [] - self._save_metadata() - - # Fallback: Look for existing saved models in the legacy format - logger.debug(f"No valid checkpoints found for model: {model_name}, attempting to find legacy saved models") - legacy_model_path = self._find_legacy_model(model_name) - - if legacy_model_path: - # Create checkpoint metadata for the legacy model using actual file data - legacy_metadata = self._create_legacy_metadata(model_name, legacy_model_path) - logger.debug(f"Found legacy model for {model_name}: {legacy_model_path}") - return str(legacy_model_path), legacy_metadata - - logger.warning(f"No checkpoints or legacy models found for: {model_name}") - return None - - except Exception as e: - logger.error(f"Error loading best checkpoint for {model_name}: {e}") - return None - - def _calculate_performance_score(self, metrics: Dict[str, float]) -> float: - """Calculate performance score with improved sensitivity for training models""" - score = 0.0 - - # Prioritize loss reduction for active training models - if 'loss' in metrics: - # Invert loss so lower loss = higher score, with better scaling - loss_value = metrics['loss'] - if loss_value > 0: - score += max(0, 100 / (1 + loss_value)) # More sensitive to loss changes - else: - score += 100 # Perfect loss - - # Add other metrics with appropriate weights - if 'accuracy' in metrics: - score += metrics['accuracy'] * 50 # Reduced weight to balance with loss - if 'val_accuracy' in metrics: - score += metrics['val_accuracy'] * 50 - if 'val_loss' in metrics: - val_loss = metrics['val_loss'] - if val_loss > 0: - score += max(0, 50 / (1 + val_loss)) - if 'reward' in metrics: - score += metrics['reward'] * 10 - if 'pnl' in metrics: - score += metrics['pnl'] * 5 - if 'training_samples' in metrics: - # Bonus for processing more training samples - score += min(10, metrics['training_samples'] / 10) - - # Return actual calculated score - NO SYNTHETIC MINIMUM - return score - - def _should_save_checkpoint(self, model_name: str, performance_score: float) -> bool: - """Improved checkpoint saving logic with more frequent saves during training""" - if model_name not in self.checkpoints or not self.checkpoints[model_name]: - return True # Always save first checkpoint - - # Allow more checkpoints during active training - if len(self.checkpoints[model_name]) < self.max_checkpoints: - return True - - # Get current best and worst scores - scores = [cp.performance_score for cp in self.checkpoints[model_name]] - best_score = max(scores) - worst_score = min(scores) - - # Save if better than worst (more frequent saves) - if performance_score > worst_score: - return True - - # For high-performing models (score > 100), be more sensitive to small improvements - if best_score > 100: - # Save if within 0.1% of best score (very sensitive for converged models) - if performance_score >= best_score * 0.999: - return True - else: - # Also save if we're within 10% of best score (capture near-optimal models) - if performance_score >= best_score * 0.9: - return True - - # Save more frequently during active training (every 5th attempt instead of 10th) - if random.random() < 0.2: # 20% chance to save anyway - logger.debug(f"Saving checkpoint for {model_name} - periodic save during active training") - return True - - return False - - def _save_model_file(self, model, file_path: Path, model_type: str) -> bool: - try: - if hasattr(model, 'state_dict'): - torch.save({ - 'model_state_dict': model.state_dict(), - 'model_type': model_type, - 'saved_at': datetime.now().isoformat() - }, file_path) - else: - torch.save(model, file_path) - return True - except Exception as e: - logger.error(f"Error saving model file {file_path}: {e}") - return False - - def _rotate_checkpoints(self, model_name: str): - checkpoint_list = self.checkpoints[model_name] - - if len(checkpoint_list) <= self.max_checkpoints: - return - - checkpoint_list.sort(key=lambda x: x.performance_score, reverse=True) - - to_remove = checkpoint_list[self.max_checkpoints:] - self.checkpoints[model_name] = checkpoint_list[:self.max_checkpoints] - - for checkpoint in to_remove: - try: - file_path = Path(checkpoint.file_path) - if file_path.exists(): - file_path.unlink() - logger.debug(f"Rotated out checkpoint: {checkpoint.checkpoint_id}") - except Exception as e: - logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}") - - def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]: - try: - if not self.enable_wandb or wandb.run is None: - return None - - artifact_name = f"{metadata.model_name}_checkpoint" - artifact = wandb.Artifact(artifact_name, type="model") - artifact.add_file(str(file_path)) - wandb.log_artifact(artifact) - - return artifact_name - except Exception as e: - logger.error(f"Error uploading to W&B: {e}") - return None - - def _load_metadata(self): - try: - if self.metadata_file.exists(): - with open(self.metadata_file, 'r') as f: - data = json.load(f) - - for model_name, checkpoint_list in data.items(): - self.checkpoints[model_name] = [ - CheckpointMetadata.from_dict(cp_data) - for cp_data in checkpoint_list - ] - - logger.info(f"Loaded metadata for {len(self.checkpoints)} models") - except Exception as e: - logger.error(f"Error loading checkpoint metadata: {e}") - - def _save_metadata(self): - try: - data = {} - for model_name, checkpoint_list in self.checkpoints.items(): - data[model_name] = [cp.to_dict() for cp in checkpoint_list] - - with open(self.metadata_file, 'w') as f: - json.dump(data, f, indent=2) - except Exception as e: - logger.error(f"Error saving checkpoint metadata: {e}") - - def get_checkpoint_stats(self): - """Get statistics about managed checkpoints""" - stats = { - 'total_models': len(self.checkpoints), - 'total_checkpoints': sum(len(checkpoints) for checkpoints in self.checkpoints.values()), - 'total_size_mb': 0.0, - 'models': {} - } - - for model_name, checkpoint_list in self.checkpoints.items(): - if not checkpoint_list: - continue - - model_size = sum(cp.file_size_mb for cp in checkpoint_list) - best_checkpoint = max(checkpoint_list, key=lambda x: x.performance_score) - - stats['models'][model_name] = { - 'checkpoint_count': len(checkpoint_list), - 'total_size_mb': model_size, - 'best_performance': best_checkpoint.performance_score, - 'best_checkpoint_id': best_checkpoint.checkpoint_id, - 'latest_checkpoint': max(checkpoint_list, key=lambda x: x.created_at).checkpoint_id + # Create metadata + checkpoint_metadata = { + 'model_name': model_name, + 'timestamp': timestamp, + 'metrics': metrics, + 'metadata': metadata or {} } - stats['total_size_mb'] += model_size - - return stats - - def _find_legacy_model(self, model_name: str) -> Optional[Path]: - """Find legacy saved models based on model name patterns""" - base_dir = Path(self.base_dir) - - # Define model name mappings and patterns for legacy files - legacy_patterns = { - 'dqn_agent': [ - 'dqn_agent_best_policy.pt', - 'enhanced_dqn_best_policy.pt', - 'improved_dqn_agent_best_policy.pt', - 'dqn_agent_final_policy.pt' - ], - 'enhanced_cnn': [ - 'cnn_model_best.pt', - 'optimized_short_term_model_best.pt', - 'optimized_short_term_model_realtime_best.pt', - 'optimized_short_term_model_ticks_best.pt' - ], - 'extrema_trainer': [ - 'supervised_model_best.pt' - ], - 'cob_rl': [ - 'best_rl_model.pth_policy.pt', - 'rl_agent_best_policy.pt' - ], - 'decision': [ - # Decision models might be in subdirectories, but let's check main dir too - 'decision_best.pt', - 'decision_model_best.pt', - # Check for transformer models which might be used as decision models - 'enhanced_dqn_best_policy.pt', - 'improved_dqn_agent_best_policy.pt' - ] - } - - # Get patterns for this model name - patterns = legacy_patterns.get(model_name, []) - - # Also try generic patterns based on model name - patterns.extend([ - f'{model_name}_best.pt', - f'{model_name}_best_policy.pt', - f'{model_name}_final.pt', - f'{model_name}_final_policy.pt' - ]) - - # Search for the model files - for pattern in patterns: - candidate_path = base_dir / pattern - if candidate_path.exists(): - logger.debug(f"Found legacy model file: {candidate_path}") - return candidate_path - - # Also check subdirectories - for subdir in base_dir.iterdir(): - if subdir.is_dir() and subdir.name == model_name: - for pattern in patterns: - candidate_path = subdir / pattern - if candidate_path.exists(): - logger.debug(f"Found legacy model file in subdirectory: {candidate_path}") - return candidate_path - - return None - - def _create_legacy_metadata(self, model_name: str, file_path: Path) -> CheckpointMetadata: - """Create metadata for legacy model files using only actual file information""" - try: - file_size_mb = file_path.stat().st_size / (1024 * 1024) - created_time = datetime.fromtimestamp(file_path.stat().st_mtime) + # Save metadata + with open(f"{checkpoint_path}_metadata.json", 'w') as f: + json.dump(checkpoint_metadata, f, indent=2) + + logger.info(f"Saved checkpoint to {checkpoint_path}") + + # Clean up old checkpoints + self._cleanup_checkpoints(model_name) + + return checkpoint_path - # NO SYNTHETIC DATA - use only actual file information - return CheckpointMetadata( - checkpoint_id=f"legacy_{model_name}_{int(created_time.timestamp())}", - model_name=model_name, - model_type=model_name, - file_path=str(file_path), - created_at=created_time, - file_size_mb=file_size_mb, - performance_score=0.0, # Unknown performance - use 0, not synthetic values - accuracy=None, - loss=None, - val_accuracy=None, - val_loss=None, - reward=None, - pnl=None, - epoch=None, - training_time_hours=None, - total_parameters=None, - wandb_run_id=None, - wandb_artifact_name=None - ) except Exception as e: - logger.error(f"Error creating legacy metadata for {model_name}: {e}") - # Return a basic metadata with minimal info - NO SYNTHETIC VALUES - return CheckpointMetadata( - checkpoint_id=f"legacy_{model_name}", - model_name=model_name, - model_type=model_name, - file_path=str(file_path), - created_at=datetime.now(), - file_size_mb=0.0, - performance_score=0.0 # Unknown - use 0, not synthetic - ) - -_checkpoint_manager = None - -def get_checkpoint_manager() -> CheckpointManager: - global _checkpoint_manager - if _checkpoint_manager is None: - _checkpoint_manager = CheckpointManager() - return _checkpoint_manager - -def save_checkpoint(model, model_name: str, model_type: str, - performance_metrics: Dict[str, float], - training_metadata: Optional[Dict[str, Any]] = None, - force_save: bool = False) -> Optional[CheckpointMetadata]: - return get_checkpoint_manager().save_checkpoint( - model, model_name, model_type, performance_metrics, training_metadata, force_save - ) - -def load_best_checkpoint(model_name: str) -> Optional[Tuple[str, CheckpointMetadata]]: - return get_checkpoint_manager().load_best_checkpoint(model_name) + logger.error(f"Error saving checkpoint: {e}") + return "" + + def load_best_checkpoint(self, model_name: str) -> Tuple[str, Dict[str, Any]]: + """ + Load the best checkpoint based on performance metrics + + Args: + model_name: Name of the model + + Returns: + Tuple[str, Dict[str, Any]]: Path to the best checkpoint and its metadata + """ + try: + # Find all checkpoint metadata files + checkpoint_dir = os.path.join(self.checkpoint_dir, model_name) + metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json")) + + if not metadata_files: + logger.info(f"No checkpoints found for {model_name}") + return "", {} + + # Load metadata for each checkpoint + checkpoints = [] + for metadata_file in metadata_files: + try: + with open(metadata_file, 'r') as f: + metadata = json.load(f) + + # Get checkpoint path (remove _metadata.json) + checkpoint_path = metadata_file[:-14] + + # Check if model file exists + if not os.path.exists(f"{checkpoint_path}.pt"): + logger.warning(f"Model file not found for checkpoint {checkpoint_path}") + continue + + checkpoints.append((checkpoint_path, metadata)) + + except Exception as e: + logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}") + + if not checkpoints: + logger.info(f"No valid checkpoints found for {model_name}") + return "", {} + + # Sort by metric (highest first) + checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True) + + # Return best checkpoint + best_checkpoint_path = checkpoints[0][0] + best_checkpoint_metadata = checkpoints[0][1] + + logger.info(f"Best checkpoint for {model_name}: {best_checkpoint_path}") + + return best_checkpoint_path, best_checkpoint_metadata + + except Exception as e: + logger.error(f"Error loading best checkpoint: {e}") + return "", {} + + def _cleanup_checkpoints(self, model_name: str) -> int: + """ + Clean up old or underperforming checkpoints + + Args: + model_name: Name of the model + + Returns: + int: Number of checkpoints deleted + """ + try: + # Find all checkpoint metadata files + checkpoint_dir = os.path.join(self.checkpoint_dir, model_name) + metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json")) + + if not metadata_files or len(metadata_files) <= self.max_checkpoints: + return 0 + + # Load metadata for each checkpoint + checkpoints = [] + for metadata_file in metadata_files: + try: + with open(metadata_file, 'r') as f: + metadata = json.load(f) + + # Get checkpoint path (remove _metadata.json) + checkpoint_path = metadata_file[:-14] + + checkpoints.append((checkpoint_path, metadata)) + + except Exception as e: + logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}") + + # Sort by metric (highest first) + checkpoints.sort(key=lambda x: x[1].get('metrics', {}).get(self.metric_name, 0.0), reverse=True) + + # Keep only the best checkpoints + checkpoints_to_delete = checkpoints[self.max_checkpoints:] + + # Delete checkpoints + deleted_count = 0 + for checkpoint_path, _ in checkpoints_to_delete: + try: + # Delete model file + if os.path.exists(f"{checkpoint_path}.pt"): + os.remove(f"{checkpoint_path}.pt") + + # Delete metadata file + if os.path.exists(f"{checkpoint_path}_metadata.json"): + os.remove(f"{checkpoint_path}_metadata.json") + + deleted_count += 1 + + except Exception as e: + logger.error(f"Error deleting checkpoint {checkpoint_path}: {e}") + + logger.info(f"Deleted {deleted_count} old checkpoints for {model_name}") + + return deleted_count + + except Exception as e: + logger.error(f"Error cleaning up checkpoints: {e}") + return 0 + + def get_all_checkpoints(self, model_name: str) -> List[Tuple[str, Dict[str, Any]]]: + """ + Get all checkpoints for a model + + Args: + model_name: Name of the model + + Returns: + List[Tuple[str, Dict[str, Any]]]: List of checkpoint paths and metadata + """ + try: + # Find all checkpoint metadata files + checkpoint_dir = os.path.join(self.checkpoint_dir, model_name) + metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json")) + + if not metadata_files: + return [] + + # Load metadata for each checkpoint + checkpoints = [] + for metadata_file in metadata_files: + try: + with open(metadata_file, 'r') as f: + metadata = json.load(f) + + # Get checkpoint path (remove _metadata.json) + checkpoint_path = metadata_file[:-14] + + # Check if model file exists + if not os.path.exists(f"{checkpoint_path}.pt"): + logger.warning(f"Model file not found for checkpoint {checkpoint_path}") + continue + + checkpoints.append((checkpoint_path, metadata)) + + except Exception as e: + logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}") + + # Sort by timestamp (newest first) + checkpoints.sort(key=lambda x: x[1].get('timestamp', ''), reverse=True) + + return checkpoints + + except Exception as e: + logger.error(f"Error getting all checkpoints: {e}") + return [] \ No newline at end of file diff --git a/utils/training_integration.py b/utils/training_integration.py index 0353a84..0acf9d3 100644 --- a/utils/training_integration.py +++ b/utils/training_integration.py @@ -9,7 +9,7 @@ from datetime import datetime from typing import Dict, Any, Optional from pathlib import Path -from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_best_checkpoint +from .checkpoint_manager import get_checkpoint_manager, load_best_checkpoint logger = logging.getLogger(__name__) @@ -78,7 +78,7 @@ class TrainingIntegration: except Exception as e: logger.warning(f"Error logging to W&B: {e}") - metadata = save_checkpoint( + metadata = self.checkpoint_manager.save_checkpoint( model=cnn_model, model_name=model_name, model_type='cnn', @@ -137,7 +137,7 @@ class TrainingIntegration: except Exception as e: logger.warning(f"Error logging to W&B: {e}") - metadata = save_checkpoint( + metadata = self.checkpoint_manager.save_checkpoint( model=rl_agent, model_name=model_name, model_type='rl', @@ -158,7 +158,7 @@ class TrainingIntegration: def load_best_model(self, model_name: str, model_class=None): try: - result = load_best_checkpoint(model_name) + result = self.checkpoint_manager.load_best_checkpoint(model_name) if not result: logger.warning(f"No checkpoint found for model: {model_name}") return None