865 lines
40 KiB
Python
865 lines
40 KiB
Python
# """
|
|
# Enhanced CNN Adapter for Standardized Input Format
|
|
|
|
# This module provides an adapter for the EnhancedCNN model to work with the standardized
|
|
# BaseDataInput format, enabling seamless integration with the multi-modal trading system.
|
|
# """
|
|
|
|
# import torch
|
|
# import numpy as np
|
|
# import logging
|
|
# import os
|
|
# import random
|
|
# from datetime import datetime, timedelta
|
|
# from typing import Dict, List, Optional, Tuple, Any, Union
|
|
# from threading import Lock
|
|
|
|
# from .data_models import BaseDataInput, ModelOutput, create_model_output
|
|
# from NN.models.enhanced_cnn import EnhancedCNN
|
|
# from utils.inference_logger import log_model_inference
|
|
|
|
# logger = logging.getLogger(__name__)
|
|
|
|
# class EnhancedCNNAdapter:
|
|
# """
|
|
# Adapter for EnhancedCNN model to work with standardized BaseDataInput format
|
|
|
|
# This adapter:
|
|
# 1. Converts BaseDataInput to the format expected by EnhancedCNN
|
|
# 2. Processes model outputs to create standardized ModelOutput
|
|
# 3. Manages model training with collected data
|
|
# 4. Handles checkpoint management
|
|
# """
|
|
|
|
# def __init__(self, model_path: str = None, checkpoint_dir: str = "models/enhanced_cnn"):
|
|
# """
|
|
# Initialize the EnhancedCNN adapter
|
|
|
|
# Args:
|
|
# model_path: Path to load model from, if None a new model is created
|
|
# checkpoint_dir: Directory to save checkpoints to
|
|
# """
|
|
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
# self.model = None
|
|
# self.model_path = model_path
|
|
# self.checkpoint_dir = checkpoint_dir
|
|
# self.training_lock = Lock()
|
|
# self.training_data = []
|
|
# self.max_training_samples = 10000
|
|
# self.batch_size = 32
|
|
# self.learning_rate = 0.0001
|
|
# self.model_name = "enhanced_cnn"
|
|
|
|
# # Enhanced metrics tracking
|
|
# self.last_inference_time = None
|
|
# self.last_inference_duration = 0.0
|
|
# self.last_prediction_output = None
|
|
# self.last_training_time = None
|
|
# self.last_training_duration = 0.0
|
|
# self.last_training_loss = 0.0
|
|
# self.inference_count = 0
|
|
# self.training_count = 0
|
|
|
|
# # Create checkpoint directory if it doesn't exist
|
|
# os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
# # Initialize the model
|
|
# self._initialize_model()
|
|
|
|
# # Load checkpoint if available
|
|
# if model_path and os.path.exists(model_path):
|
|
# self._load_checkpoint(model_path)
|
|
# else:
|
|
# self._load_best_checkpoint()
|
|
|
|
# # Final device check and move
|
|
# self._ensure_model_on_device()
|
|
|
|
# logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
|
|
|
|
# def _create_realistic_synthetic_features(self, symbol: str) -> torch.Tensor:
|
|
# """Create realistic synthetic features instead of random data"""
|
|
# try:
|
|
# # Create realistic market-like features
|
|
# features = torch.zeros(7850, dtype=torch.float32, device=self.device)
|
|
|
|
# # OHLCV features (6000 features: 300 frames x 4 timeframes x 5 features)
|
|
# ohlcv_start = 0
|
|
# for timeframe_idx in range(4): # 1s, 1m, 1h, 1d
|
|
# base_price = 3500.0 + timeframe_idx * 10 # Slight variation per timeframe
|
|
# for frame_idx in range(300):
|
|
# # Create realistic price movement
|
|
# price_change = torch.sin(torch.tensor(frame_idx * 0.1)) * 0.01 # Cyclical movement
|
|
# current_price = base_price * (1 + price_change)
|
|
|
|
# # Realistic OHLCV values
|
|
# open_price = current_price
|
|
# high_price = current_price * torch.uniform(1.0, 1.005)
|
|
# low_price = current_price * torch.uniform(0.995, 1.0)
|
|
# close_price = current_price * torch.uniform(0.998, 1.002)
|
|
# volume = torch.uniform(500.0, 2000.0)
|
|
|
|
# # Set features
|
|
# feature_idx = ohlcv_start + frame_idx * 5 + timeframe_idx * 1500
|
|
# features[feature_idx:feature_idx+5] = torch.tensor([open_price, high_price, low_price, close_price, volume])
|
|
|
|
# # BTC OHLCV features (1500 features: 300 frames x 5 features)
|
|
# btc_start = 6000
|
|
# btc_base_price = 50000.0
|
|
# for frame_idx in range(300):
|
|
# price_change = torch.sin(torch.tensor(frame_idx * 0.05)) * 0.02
|
|
# current_price = btc_base_price * (1 + price_change)
|
|
|
|
# open_price = current_price
|
|
# high_price = current_price * torch.uniform(1.0, 1.01)
|
|
# low_price = current_price * torch.uniform(0.99, 1.0)
|
|
# close_price = current_price * torch.uniform(0.995, 1.005)
|
|
# volume = torch.uniform(100.0, 500.0)
|
|
|
|
# feature_idx = btc_start + frame_idx * 5
|
|
# features[feature_idx:feature_idx+5] = torch.tensor([open_price, high_price, low_price, close_price, volume])
|
|
|
|
# # COB features (200 features) - realistic order book data
|
|
# cob_start = 7500
|
|
# for i in range(200):
|
|
# features[cob_start + i] = torch.uniform(0.0, 1000.0) # Realistic COB values
|
|
|
|
# # Technical indicators (100 features)
|
|
# indicator_start = 7700
|
|
# for i in range(100):
|
|
# features[indicator_start + i] = torch.uniform(-1.0, 1.0) # Normalized indicators
|
|
|
|
# # Last predictions (50 features)
|
|
# prediction_start = 7800
|
|
# for i in range(50):
|
|
# features[prediction_start + i] = torch.uniform(0.0, 1.0) # Probability values
|
|
|
|
# return features
|
|
|
|
# except Exception as e:
|
|
# logger.error(f"Error creating realistic synthetic features: {e}")
|
|
# # Fallback to small random variation
|
|
# base_features = torch.ones(7850, dtype=torch.float32, device=self.device) * 0.5
|
|
# noise = torch.randn(7850, dtype=torch.float32, device=self.device) * 0.1
|
|
# return base_features + noise
|
|
|
|
# def _create_realistic_features(self, symbol: str) -> torch.Tensor:
|
|
# """Create features from real market data if available"""
|
|
# try:
|
|
# # This would need to be implemented to use actual market data
|
|
# # For now, fall back to synthetic features
|
|
# return self._create_realistic_synthetic_features(symbol)
|
|
# except Exception as e:
|
|
# logger.error(f"Error creating realistic features: {e}")
|
|
# return self._create_realistic_synthetic_features(symbol)
|
|
|
|
# def _initialize_model(self):
|
|
# """Initialize the EnhancedCNN model"""
|
|
# try:
|
|
# # Calculate input shape based on BaseDataInput structure
|
|
# # OHLCV: 300 frames x 4 timeframes x 5 features = 6000 features
|
|
# # BTC OHLCV: 300 frames x 5 features = 1500 features
|
|
# # COB: ±20 buckets x 4 metrics = 160 features
|
|
# # MA: 4 timeframes x 10 buckets = 40 features
|
|
# # Technical indicators: 100 features
|
|
# # Last predictions: 50 features
|
|
# # Total: 7850 features
|
|
# input_shape = 7850
|
|
# n_actions = 3 # BUY, SELL, HOLD
|
|
|
|
# # Create model
|
|
# self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
|
|
# # Ensure model is moved to the correct device
|
|
# self.model.to(self.device)
|
|
|
|
# logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions} on device {self.device}")
|
|
|
|
# except Exception as e:
|
|
# logger.error(f"Error initializing EnhancedCNN model: {e}")
|
|
# raise
|
|
|
|
# def _load_checkpoint(self, checkpoint_path: str) -> bool:
|
|
# """Load model from checkpoint path"""
|
|
# try:
|
|
# if self.model and os.path.exists(checkpoint_path):
|
|
# success = self.model.load(checkpoint_path)
|
|
# if success:
|
|
# # Ensure model is moved to the correct device after loading
|
|
# self.model.to(self.device)
|
|
# logger.info(f"Loaded model from {checkpoint_path} and moved to {self.device}")
|
|
# return True
|
|
# else:
|
|
# logger.warning(f"Failed to load model from {checkpoint_path}")
|
|
# return False
|
|
# else:
|
|
# logger.warning(f"Checkpoint path does not exist: {checkpoint_path}")
|
|
# return False
|
|
# except Exception as e:
|
|
# logger.error(f"Error loading checkpoint: {e}")
|
|
# return False
|
|
|
|
# def _load_best_checkpoint(self) -> bool:
|
|
# """Load the best available checkpoint"""
|
|
# try:
|
|
# return self.load_best_checkpoint()
|
|
# except Exception as e:
|
|
# logger.error(f"Error loading best checkpoint: {e}")
|
|
# return False
|
|
|
|
# def load_best_checkpoint(self) -> bool:
|
|
# """Load the best checkpoint based on accuracy"""
|
|
# try:
|
|
# # Import checkpoint manager
|
|
# from utils.checkpoint_manager import CheckpointManager
|
|
|
|
# # Create checkpoint manager
|
|
# checkpoint_manager = CheckpointManager(
|
|
# checkpoint_dir=self.checkpoint_dir,
|
|
# max_checkpoints=10,
|
|
# metric_name="accuracy"
|
|
# )
|
|
|
|
# # Load best checkpoint
|
|
# best_checkpoint_path, best_checkpoint_metadata = checkpoint_manager.load_best_checkpoint(self.model_name)
|
|
|
|
# if not best_checkpoint_path:
|
|
# logger.info(f"No checkpoints found for {self.model_name} - starting in COLD START mode")
|
|
# return False
|
|
|
|
# # Load model
|
|
# success = self.model.load(best_checkpoint_path)
|
|
|
|
# if success:
|
|
# # Ensure model is moved to the correct device after loading
|
|
# self.model.to(self.device)
|
|
# logger.info(f"Loaded best checkpoint from {best_checkpoint_path} and moved to {self.device}")
|
|
|
|
# # Log metrics
|
|
# metrics = best_checkpoint_metadata.get('metrics', {})
|
|
# logger.info(f"Checkpoint metrics: accuracy={metrics.get('accuracy', 0.0):.4f}, loss={metrics.get('loss', 0.0):.4f}")
|
|
|
|
# return True
|
|
# else:
|
|
# logger.warning(f"Failed to load best checkpoint from {best_checkpoint_path}")
|
|
# return False
|
|
|
|
# except Exception as e:
|
|
# logger.error(f"Error loading best checkpoint: {e}")
|
|
# return False
|
|
|
|
# def _ensure_model_on_device(self):
|
|
# """Ensure model and all its components are on the correct device"""
|
|
# try:
|
|
# if self.model:
|
|
# self.model.to(self.device)
|
|
# # Also ensure the model's internal device is set correctly
|
|
# if hasattr(self.model, 'device'):
|
|
# self.model.device = self.device
|
|
# logger.debug(f"Model ensured on device {self.device}")
|
|
# except Exception as e:
|
|
# logger.error(f"Error ensuring model on device: {e}")
|
|
|
|
# def _create_default_output(self, symbol: str) -> ModelOutput:
|
|
# """Create default output when prediction fails"""
|
|
# return create_model_output(
|
|
# model_type='cnn',
|
|
# model_name=self.model_name,
|
|
# symbol=symbol,
|
|
# action='HOLD',
|
|
# confidence=0.0,
|
|
# metadata={'error': 'Prediction failed, using default output'}
|
|
# )
|
|
|
|
# def _process_hidden_states(self, hidden_states: Dict[str, Any]) -> Dict[str, Any]:
|
|
# """Process hidden states for cross-model feeding"""
|
|
# processed_states = {}
|
|
|
|
# for key, value in hidden_states.items():
|
|
# if isinstance(value, torch.Tensor):
|
|
# # Convert tensor to numpy array
|
|
# processed_states[key] = value.cpu().numpy().tolist()
|
|
# else:
|
|
# processed_states[key] = value
|
|
|
|
# return processed_states
|
|
|
|
|
|
|
|
# def _convert_base_data_to_features(self, base_data: BaseDataInput) -> torch.Tensor:
|
|
# """
|
|
# Convert BaseDataInput to feature vector for EnhancedCNN
|
|
|
|
# Args:
|
|
# base_data: Standardized input data
|
|
|
|
# Returns:
|
|
# torch.Tensor: Feature vector for EnhancedCNN
|
|
# """
|
|
# try:
|
|
# # Use the get_feature_vector method from BaseDataInput
|
|
# features = base_data.get_feature_vector()
|
|
|
|
# # Validate feature quality before using
|
|
# self._validate_feature_quality(features)
|
|
|
|
# # Convert to torch tensor
|
|
# features_tensor = torch.tensor(features, dtype=torch.float32, device=self.device)
|
|
|
|
# return features_tensor
|
|
|
|
# except Exception as e:
|
|
# logger.error(f"Error converting BaseDataInput to features: {e}")
|
|
# # Return empty tensor with correct shape
|
|
# return torch.zeros(7850, dtype=torch.float32, device=self.device)
|
|
|
|
# def _validate_feature_quality(self, features: np.ndarray):
|
|
# """Validate that features are realistic and not synthetic/placeholder data"""
|
|
# try:
|
|
# if len(features) != 7850:
|
|
# logger.warning(f"Feature vector has wrong size: {len(features)} != 7850")
|
|
# return
|
|
|
|
# # Check for all-zero or all-identical features (indicates placeholder data)
|
|
# if np.all(features == 0):
|
|
# logger.warning("Feature vector contains all zeros - likely placeholder data")
|
|
# return
|
|
|
|
# # Check for repetitive patterns in OHLCV data (first 6000 features)
|
|
# ohlcv_features = features[:6000]
|
|
# if len(ohlcv_features) >= 20:
|
|
# # Check if first 20 values are identical (indicates padding with same bar)
|
|
# if np.allclose(ohlcv_features[:20], ohlcv_features[0], atol=1e-6):
|
|
# logger.warning("OHLCV features show repetitive pattern - possible synthetic data")
|
|
|
|
# # Check for unrealistic values
|
|
# if np.any(features > 1e6) or np.any(features < -1e6):
|
|
# logger.warning("Feature vector contains unrealistic values")
|
|
|
|
# # Check for NaN or infinite values
|
|
# if np.any(np.isnan(features)) or np.any(np.isinf(features)):
|
|
# logger.warning("Feature vector contains NaN or infinite values")
|
|
|
|
# except Exception as e:
|
|
# logger.error(f"Error validating feature quality: {e}")
|
|
|
|
# def predict(self, base_data: BaseDataInput) -> ModelOutput:
|
|
# """
|
|
# Make a prediction using the EnhancedCNN model
|
|
|
|
# Args:
|
|
# base_data: Standardized input data
|
|
|
|
# Returns:
|
|
# ModelOutput: Standardized model output
|
|
# """
|
|
# try:
|
|
# # Track inference timing
|
|
# start_time = datetime.now()
|
|
# inference_start = start_time.timestamp()
|
|
|
|
# # Convert BaseDataInput to features
|
|
# features = self._convert_base_data_to_features(base_data)
|
|
|
|
# # Ensure features has batch dimension
|
|
# if features.dim() == 1:
|
|
# features = features.unsqueeze(0)
|
|
|
|
# # Ensure model is on correct device before prediction
|
|
# self._ensure_model_on_device()
|
|
|
|
# # Set model to evaluation mode
|
|
# self.model.eval()
|
|
|
|
# # Make prediction
|
|
# with torch.no_grad():
|
|
# q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.model(features)
|
|
|
|
# # Get action and confidence
|
|
# action_probs = torch.softmax(q_values, dim=1)
|
|
# action_idx = torch.argmax(action_probs, dim=1).item()
|
|
# raw_confidence = float(action_probs[0, action_idx].item())
|
|
|
|
# # Validate confidence - prevent 100% confidence which indicates overfitting
|
|
# if raw_confidence >= 0.99:
|
|
# logger.warning(f"CNN produced suspiciously high confidence: {raw_confidence:.4f} - possible overfitting")
|
|
# # Cap confidence at 0.95 to prevent unrealistic predictions
|
|
# confidence = min(raw_confidence, 0.95)
|
|
# logger.info(f"Capped confidence from {raw_confidence:.4f} to {confidence:.4f}")
|
|
# else:
|
|
# confidence = raw_confidence
|
|
|
|
# # Map action index to action string
|
|
# actions = ['BUY', 'SELL', 'HOLD']
|
|
# action = actions[action_idx]
|
|
|
|
# # Extract pivot price prediction (simplified - take first value from price_pred)
|
|
# pivot_price = None
|
|
# if price_pred is not None and len(price_pred.squeeze()) > 0:
|
|
# # Get current price from base_data for context
|
|
# current_price = 0.0
|
|
# if base_data.ohlcv_1s and len(base_data.ohlcv_1s) > 0:
|
|
# current_price = base_data.ohlcv_1s[-1].close
|
|
|
|
# # Calculate pivot price as current price + predicted change
|
|
# price_change_pct = float(price_pred.squeeze()[0].item()) # First prediction value
|
|
# pivot_price = current_price * (1 + price_change_pct * 0.01) # Convert percentage to price
|
|
|
|
# # Create predictions dictionary
|
|
# predictions = {
|
|
# 'action': action,
|
|
# 'buy_probability': float(action_probs[0, 0].item()),
|
|
# 'sell_probability': float(action_probs[0, 1].item()),
|
|
# 'hold_probability': float(action_probs[0, 2].item()),
|
|
# 'extrema': extrema_pred.squeeze(0).cpu().numpy().tolist(),
|
|
# 'price_prediction': price_pred.squeeze(0).cpu().numpy().tolist(),
|
|
# 'pivot_price': pivot_price
|
|
# }
|
|
|
|
# # Create hidden states dictionary
|
|
# hidden_states = {
|
|
# 'features': features_refined.squeeze(0).cpu().numpy().tolist()
|
|
# }
|
|
|
|
# # Calculate inference duration
|
|
# end_time = datetime.now()
|
|
# inference_duration = (end_time.timestamp() - inference_start) * 1000 # Convert to milliseconds
|
|
|
|
# # Update metrics
|
|
# self.last_inference_time = start_time
|
|
# self.last_inference_duration = inference_duration
|
|
# self.inference_count += 1
|
|
|
|
# # Store last prediction output for dashboard
|
|
# self.last_prediction_output = {
|
|
# 'action': action,
|
|
# 'confidence': confidence,
|
|
# 'pivot_price': pivot_price,
|
|
# 'timestamp': start_time,
|
|
# 'symbol': base_data.symbol
|
|
# }
|
|
|
|
# # Create metadata dictionary
|
|
# metadata = {
|
|
# 'model_version': '1.0',
|
|
# 'timestamp': start_time.isoformat(),
|
|
# 'input_shape': features.shape,
|
|
# 'inference_duration_ms': inference_duration,
|
|
# 'inference_count': self.inference_count
|
|
# }
|
|
|
|
# # Create ModelOutput
|
|
# model_output = ModelOutput(
|
|
# model_type='cnn',
|
|
# model_name=self.model_name,
|
|
# symbol=base_data.symbol,
|
|
# timestamp=start_time,
|
|
# confidence=confidence,
|
|
# predictions=predictions,
|
|
# hidden_states=hidden_states,
|
|
# metadata=metadata
|
|
# )
|
|
|
|
# # Log inference with full input data for training feedback
|
|
# log_model_inference(
|
|
# model_name=self.model_name,
|
|
# symbol=base_data.symbol,
|
|
# action=action,
|
|
# confidence=confidence,
|
|
# probabilities={
|
|
# 'BUY': predictions['buy_probability'],
|
|
# 'SELL': predictions['sell_probability'],
|
|
# 'HOLD': predictions['hold_probability']
|
|
# },
|
|
# input_features=features.cpu().numpy(), # Store full feature vector
|
|
# processing_time_ms=inference_duration,
|
|
# checkpoint_id=None, # Could be enhanced to track checkpoint
|
|
# metadata={
|
|
# 'base_data_input': {
|
|
# 'symbol': base_data.symbol,
|
|
# 'timestamp': base_data.timestamp.isoformat(),
|
|
# 'ohlcv_1s_count': len(base_data.ohlcv_1s),
|
|
# 'ohlcv_1m_count': len(base_data.ohlcv_1m),
|
|
# 'ohlcv_1h_count': len(base_data.ohlcv_1h),
|
|
# 'ohlcv_1d_count': len(base_data.ohlcv_1d),
|
|
# 'btc_ohlcv_1s_count': len(base_data.btc_ohlcv_1s),
|
|
# 'has_cob_data': base_data.cob_data is not None,
|
|
# 'technical_indicators_count': len(base_data.technical_indicators),
|
|
# 'pivot_points_count': len(base_data.pivot_points),
|
|
# 'last_predictions_count': len(base_data.last_predictions)
|
|
# },
|
|
# 'model_predictions': {
|
|
# 'pivot_price': pivot_price,
|
|
# 'extrema_prediction': predictions['extrema'],
|
|
# 'price_prediction': predictions['price_prediction']
|
|
# }
|
|
# }
|
|
# )
|
|
|
|
# return model_output
|
|
|
|
# except Exception as e:
|
|
# logger.error(f"Error making prediction with EnhancedCNN: {e}")
|
|
# # Return default ModelOutput
|
|
# return create_model_output(
|
|
# model_type='cnn',
|
|
# model_name=self.model_name,
|
|
# symbol=base_data.symbol,
|
|
# action='HOLD',
|
|
# confidence=0.0
|
|
# )
|
|
|
|
# def add_training_sample(self, symbol_or_base_data, actual_action: str, reward: float):
|
|
# """
|
|
# Add a training sample to the training data
|
|
|
|
# Args:
|
|
# symbol_or_base_data: Either a symbol string or BaseDataInput object
|
|
# actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
|
|
# reward: Reward received for the action
|
|
# """
|
|
# try:
|
|
# # Handle both symbol string and BaseDataInput object
|
|
# if isinstance(symbol_or_base_data, str):
|
|
# # For cold start mode - create a simple training sample with current features
|
|
# # This is a simplified approach for rapid training
|
|
# symbol = symbol_or_base_data
|
|
|
|
# # Create a realistic feature vector instead of random data
|
|
# # Use actual market data if available, otherwise create realistic synthetic data
|
|
# try:
|
|
# # Try to get real market data first
|
|
# if hasattr(self, 'data_provider') and self.data_provider:
|
|
# # This would need to be implemented in the adapter
|
|
# features = self._create_realistic_features(symbol)
|
|
# else:
|
|
# # Create realistic synthetic features (not random)
|
|
# features = self._create_realistic_synthetic_features(symbol)
|
|
# except Exception as e:
|
|
# logger.warning(f"Could not create realistic features for {symbol}: {e}")
|
|
# # Fallback to small random variation instead of pure random
|
|
# base_features = torch.ones(7850, dtype=torch.float32, device=self.device) * 0.5
|
|
# noise = torch.randn(7850, dtype=torch.float32, device=self.device) * 0.1
|
|
# features = base_features + noise
|
|
|
|
# logger.debug(f"Added realistic training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}")
|
|
|
|
# else:
|
|
# # Full BaseDataInput object
|
|
# base_data = symbol_or_base_data
|
|
# features = self._convert_base_data_to_features(base_data)
|
|
# symbol = base_data.symbol
|
|
|
|
# logger.debug(f"Added full training sample for {symbol}, action: {actual_action}, reward: {reward:.4f}")
|
|
|
|
# # Convert action to index
|
|
# actions = ['BUY', 'SELL', 'HOLD']
|
|
# action_idx = actions.index(actual_action)
|
|
|
|
# # Add to training data
|
|
# with self.training_lock:
|
|
# self.training_data.append((features, action_idx, reward))
|
|
|
|
# # Limit training data size
|
|
# if len(self.training_data) > self.max_training_samples:
|
|
# # Sort by reward (highest first) and keep top samples
|
|
# self.training_data.sort(key=lambda x: x[2], reverse=True)
|
|
# self.training_data = self.training_data[:self.max_training_samples]
|
|
|
|
# except Exception as e:
|
|
# logger.error(f"Error adding training sample: {e}")
|
|
|
|
# def train(self, epochs: int = 1) -> Dict[str, float]:
|
|
# """
|
|
# Train the model with collected data and inference history
|
|
|
|
# Args:
|
|
# epochs: Number of epochs to train for
|
|
|
|
# Returns:
|
|
# Dict[str, float]: Training metrics
|
|
# """
|
|
# try:
|
|
# # Track training timing
|
|
# training_start_time = datetime.now()
|
|
# training_start = training_start_time.timestamp()
|
|
|
|
# with self.training_lock:
|
|
# # Get additional training data from inference history
|
|
# self._load_training_data_from_inference_history()
|
|
|
|
# # Check if we have enough data
|
|
# if len(self.training_data) < self.batch_size:
|
|
# logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
|
|
# return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
|
|
|
|
# # Ensure model is on correct device before training
|
|
# self._ensure_model_on_device()
|
|
|
|
# # Set model to training mode
|
|
# self.model.train()
|
|
|
|
# # Create optimizer
|
|
# optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
|
|
|
|
# # Training metrics
|
|
# total_loss = 0.0
|
|
# correct_predictions = 0
|
|
# total_predictions = 0
|
|
|
|
# # Train for specified number of epochs
|
|
# for epoch in range(epochs):
|
|
# # Shuffle training data
|
|
# np.random.shuffle(self.training_data)
|
|
|
|
# # Process in batches
|
|
# for i in range(0, len(self.training_data), self.batch_size):
|
|
# batch = self.training_data[i:i+self.batch_size]
|
|
|
|
# # Skip if batch is too small
|
|
# if len(batch) < 2:
|
|
# continue
|
|
|
|
# # Prepare batch - ensure all tensors are on the correct device
|
|
# features = torch.stack([sample[0].to(self.device) for sample in batch])
|
|
# actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
|
|
# rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)
|
|
|
|
# # Zero gradients
|
|
# optimizer.zero_grad()
|
|
|
|
# # Forward pass
|
|
# q_values, _, _, _, _ = self.model(features)
|
|
|
|
# # Calculate loss (CrossEntropyLoss with reward weighting)
|
|
# # First, apply softmax to get probabilities
|
|
# probs = torch.softmax(q_values, dim=1)
|
|
|
|
# # Get probability of chosen action
|
|
# chosen_probs = probs[torch.arange(len(actions)), actions]
|
|
|
|
# # Calculate negative log likelihood loss
|
|
# nll_loss = -torch.log(chosen_probs + 1e-10)
|
|
|
|
# # Weight by reward (higher reward = higher weight)
|
|
# # Normalize rewards to [0, 1] range
|
|
# min_reward = rewards.min()
|
|
# max_reward = rewards.max()
|
|
# if max_reward > min_reward:
|
|
# normalized_rewards = (rewards - min_reward) / (max_reward - min_reward)
|
|
# else:
|
|
# normalized_rewards = torch.ones_like(rewards)
|
|
|
|
# # Apply reward weighting (higher reward = higher weight)
|
|
# weighted_loss = nll_loss * (normalized_rewards + 0.1) # Add small constant to avoid zero weights
|
|
|
|
# # Mean loss
|
|
# loss = weighted_loss.mean()
|
|
|
|
# # Backward pass
|
|
# loss.backward()
|
|
|
|
# # Update weights
|
|
# optimizer.step()
|
|
|
|
# # Update metrics
|
|
# total_loss += loss.item()
|
|
|
|
# # Calculate accuracy
|
|
# predicted_actions = torch.argmax(q_values, dim=1)
|
|
# correct_predictions += (predicted_actions == actions).sum().item()
|
|
# total_predictions += len(actions)
|
|
|
|
# # Validate training - detect overfitting
|
|
# if total_predictions > 0:
|
|
# current_accuracy = correct_predictions / total_predictions
|
|
# if current_accuracy >= 0.99:
|
|
# logger.warning(f"CNN training shows suspiciously high accuracy: {current_accuracy:.4f} - possible overfitting")
|
|
# # Add regularization to prevent overfitting
|
|
# l2_reg = 0.01 * sum(p.pow(2.0).sum() for p in self.model.parameters())
|
|
# loss = loss + l2_reg
|
|
# logger.info("Added L2 regularization to prevent overfitting")
|
|
|
|
# # Calculate final metrics
|
|
# avg_loss = total_loss / (len(self.training_data) / self.batch_size)
|
|
# accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
|
|
|
# # Calculate training duration
|
|
# training_end_time = datetime.now()
|
|
# training_duration = (training_end_time.timestamp() - training_start) * 1000 # Convert to milliseconds
|
|
|
|
# # Update training metrics
|
|
# self.last_training_time = training_start_time
|
|
# self.last_training_duration = training_duration
|
|
# self.last_training_loss = avg_loss
|
|
# self.training_count += 1
|
|
|
|
# # Save checkpoint
|
|
# self._save_checkpoint(avg_loss, accuracy)
|
|
|
|
# logger.info(f"Training completed: loss={avg_loss:.4f}, accuracy={accuracy:.4f}, samples={len(self.training_data)}, duration={training_duration:.1f}ms")
|
|
|
|
# return {
|
|
# 'loss': avg_loss,
|
|
# 'accuracy': accuracy,
|
|
# 'samples': len(self.training_data),
|
|
# 'duration_ms': training_duration,
|
|
# 'training_count': self.training_count
|
|
# }
|
|
|
|
# except Exception as e:
|
|
# logger.error(f"Error training model: {e}")
|
|
# return {'loss': 0.0, 'accuracy': 0.0, 'samples': 0, 'error': str(e)}
|
|
|
|
# def _save_checkpoint(self, loss: float, accuracy: float):
|
|
# """
|
|
# Save model checkpoint
|
|
|
|
# Args:
|
|
# loss: Training loss
|
|
# accuracy: Training accuracy
|
|
# """
|
|
# try:
|
|
# # Import checkpoint manager
|
|
# from utils.checkpoint_manager import CheckpointManager
|
|
|
|
# # Create checkpoint manager
|
|
# checkpoint_manager = CheckpointManager(
|
|
# checkpoint_dir=self.checkpoint_dir,
|
|
# max_checkpoints=10,
|
|
# metric_name="accuracy"
|
|
# )
|
|
|
|
# # Create temporary model file
|
|
# temp_path = os.path.join(self.checkpoint_dir, f"{self.model_name}_temp")
|
|
# self.model.save(temp_path)
|
|
|
|
# # Create metrics
|
|
# metrics = {
|
|
# 'loss': loss,
|
|
# 'accuracy': accuracy,
|
|
# 'samples': len(self.training_data)
|
|
# }
|
|
|
|
# # Create metadata
|
|
# metadata = {
|
|
# 'timestamp': datetime.now().isoformat(),
|
|
# 'model_name': self.model_name,
|
|
# 'input_shape': self.model.input_shape,
|
|
# 'n_actions': self.model.n_actions
|
|
# }
|
|
|
|
# # Save checkpoint
|
|
# checkpoint_path = checkpoint_manager.save_checkpoint(
|
|
# model_name=self.model_name,
|
|
# model_path=f"{temp_path}.pt",
|
|
# metrics=metrics,
|
|
# metadata=metadata
|
|
# )
|
|
|
|
# # Delete temporary model file
|
|
# if os.path.exists(f"{temp_path}.pt"):
|
|
# os.remove(f"{temp_path}.pt")
|
|
|
|
# logger.info(f"Model checkpoint saved to {checkpoint_path}")
|
|
|
|
# except Exception as e:
|
|
# logger.error(f"Error saving checkpoint: {e}")
|
|
|
|
# def _load_training_data_from_inference_history(self):
|
|
# """Load training data from inference history for continuous learning"""
|
|
# try:
|
|
# from utils.database_manager import get_database_manager
|
|
|
|
# db_manager = get_database_manager()
|
|
|
|
# # Get recent inference records with input features
|
|
# inference_records = db_manager.get_inference_records_for_training(
|
|
# model_name=self.model_name,
|
|
# hours_back=24, # Last 24 hours
|
|
# limit=1000
|
|
# )
|
|
|
|
# if not inference_records:
|
|
# logger.debug("No inference records found for training")
|
|
# return
|
|
|
|
# # Convert inference records to training samples
|
|
# # For now, use a simple approach: treat high-confidence predictions as ground truth
|
|
# for record in inference_records:
|
|
# if record.input_features is not None and record.confidence > 0.7:
|
|
# # Convert action to index
|
|
# actions = ['BUY', 'SELL', 'HOLD']
|
|
# if record.action in actions:
|
|
# action_idx = actions.index(record.action)
|
|
|
|
# # Use confidence as a proxy for reward (high confidence = good prediction)
|
|
# reward = record.confidence * 2 - 1 # Scale to [-1, 1]
|
|
|
|
# # Convert features to tensor
|
|
# features_tensor = torch.tensor(record.input_features, dtype=torch.float32, device=self.device)
|
|
|
|
# # Add to training data if not already present (avoid duplicates)
|
|
# sample_exists = any(
|
|
# torch.equal(features_tensor, existing[0])
|
|
# for existing in self.training_data
|
|
# )
|
|
|
|
# if not sample_exists:
|
|
# self.training_data.append((features_tensor, action_idx, reward))
|
|
|
|
# logger.info(f"Loaded {len(inference_records)} inference records for training, total training samples: {len(self.training_data)}")
|
|
|
|
# except Exception as e:
|
|
# logger.error(f"Error loading training data from inference history: {e}")
|
|
|
|
# def evaluate_predictions_against_outcomes(self, hours_back: int = 1) -> Dict[str, float]:
|
|
# """
|
|
# Evaluate past predictions against actual market outcomes
|
|
|
|
# Args:
|
|
# hours_back: How many hours back to evaluate
|
|
|
|
# Returns:
|
|
# Dict with evaluation metrics
|
|
# """
|
|
# try:
|
|
# from utils.database_manager import get_database_manager
|
|
|
|
# db_manager = get_database_manager()
|
|
|
|
# # Get inference records from the specified time period
|
|
# inference_records = db_manager.get_inference_records_for_training(
|
|
# model_name=self.model_name,
|
|
# hours_back=hours_back,
|
|
# limit=100
|
|
# )
|
|
|
|
# if not inference_records:
|
|
# return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}
|
|
|
|
# # For now, use a simple evaluation based on confidence
|
|
# # In a real implementation, this would compare against actual price movements
|
|
# correct_predictions = 0
|
|
# total_predictions = len(inference_records)
|
|
|
|
# # Simple heuristic: high confidence predictions are more likely to be correct
|
|
# for record in inference_records:
|
|
# if record.confidence > 0.8: # High confidence threshold
|
|
# correct_predictions += 1
|
|
# elif record.confidence > 0.6: # Medium confidence
|
|
# correct_predictions += 0.5
|
|
|
|
# accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
|
|
|
# logger.info(f"Prediction evaluation: {correct_predictions:.1f}/{total_predictions} = {accuracy:.3f} accuracy")
|
|
|
|
# return {
|
|
# 'accuracy': accuracy,
|
|
# 'total_predictions': total_predictions,
|
|
# 'correct_predictions': correct_predictions
|
|
# }
|
|
|
|
# except Exception as e:
|
|
# logger.error(f"Error evaluating predictions: {e}")
|
|
# return {'accuracy': 0.0, 'total_predictions': 0, 'correct_predictions': 0}
|