gogo2/training/cnn_rl_bridge.py
2025-05-28 23:42:06 +03:00

219 lines
8.9 KiB
Python

"""
CNN-RL Bridge Module
This module provides the interface between CNN models and RL training,
extracting hidden features and predictions from CNN models for use in RL state building.
"""
import logging
import numpy as np
import torch
import torch.nn as nn
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
class CNNRLBridge:
"""Bridge between CNN models and RL training for feature extraction"""
def __init__(self, config: Dict):
"""Initialize CNN-RL bridge"""
self.config = config
self.cnn_models = {}
self.feature_cache = {}
self.cache_timeout = 60 # Cache features for 60 seconds
# Initialize CNN model registry if available
self._initialize_cnn_models()
logger.info("CNN-RL Bridge initialized")
def _initialize_cnn_models(self):
"""Initialize CNN models from config or model registry"""
try:
# Try to load CNN models from config
if hasattr(self.config, 'cnn_models') and self.config.cnn_models:
for model_name, model_config in self.config.cnn_models.items():
try:
# Load CNN model (implementation would depend on your CNN architecture)
model = self._load_cnn_model(model_name, model_config)
if model:
self.cnn_models[model_name] = model
logger.info(f"Loaded CNN model: {model_name}")
except Exception as e:
logger.warning(f"Failed to load CNN model {model_name}: {e}")
if not self.cnn_models:
logger.info("No CNN models available - RL will train without CNN features")
except Exception as e:
logger.warning(f"Error initializing CNN models: {e}")
def _load_cnn_model(self, model_name: str, model_config: Dict) -> Optional[nn.Module]:
"""Load a CNN model from configuration"""
try:
# This would implement actual CNN model loading
# For now, return None to indicate no models available
# In your implementation, this would load your specific CNN architecture
logger.info(f"CNN model loading framework ready for {model_name}")
return None
except Exception as e:
logger.error(f"Error loading CNN model {model_name}: {e}")
return None
def get_latest_features_for_symbol(self, symbol: str) -> Optional[Dict[str, Any]]:
"""Get latest CNN features and predictions for a symbol"""
try:
# Check cache first
cache_key = f"{symbol}_{datetime.now().strftime('%Y%m%d_%H%M')}"
if cache_key in self.feature_cache:
cached_data = self.feature_cache[cache_key]
if (datetime.now() - cached_data['timestamp']).seconds < self.cache_timeout:
return cached_data['features']
# Generate new features if models available
if self.cnn_models:
features = self._extract_cnn_features_for_symbol(symbol)
# Cache the features
self.feature_cache[cache_key] = {
'timestamp': datetime.now(),
'features': features
}
# Clean old cache entries
self._cleanup_cache()
return features
return None
except Exception as e:
logger.warning(f"Error getting CNN features for {symbol}: {e}")
return None
def _extract_cnn_features_for_symbol(self, symbol: str) -> Dict[str, Any]:
"""Extract CNN hidden features and predictions for a symbol"""
try:
extracted_features = {
'hidden_features': {},
'predictions': {}
}
for model_name, model in self.cnn_models.items():
try:
# Extract features from each CNN model
hidden_features, predictions = self._extract_model_features(model, symbol)
if hidden_features is not None:
extracted_features['hidden_features'][model_name] = hidden_features
if predictions is not None:
extracted_features['predictions'][model_name] = predictions
except Exception as e:
logger.warning(f"Error extracting features from {model_name}: {e}")
return extracted_features
except Exception as e:
logger.error(f"Error extracting CNN features for {symbol}: {e}")
return {'hidden_features': {}, 'predictions': {}}
def _extract_model_features(self, model: nn.Module, symbol: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
"""Extract hidden features and predictions from a specific CNN model"""
try:
# This would implement the actual feature extraction from your CNN models
# The implementation depends on your specific CNN architecture
# For now, return mock data to show the structure
# In real implementation, this would:
# 1. Get market data for the model
# 2. Run forward pass through CNN
# 3. Extract hidden layer activations
# 4. Get model predictions
# Mock hidden features (last hidden layer of CNN)
hidden_features = np.random.random(512).astype(np.float32)
# Mock predictions for different timeframes
# [1s_pred, 1m_pred, 1h_pred, 1d_pred] for each timeframe
predictions = np.array([
0.45, # 1s prediction (probability of up move)
0.52, # 1m prediction
0.38, # 1h prediction
0.61 # 1d prediction
]).astype(np.float32)
logger.debug(f"Extracted CNN features for {symbol}: {len(hidden_features)} hidden, {len(predictions)} predictions")
return hidden_features, predictions
except Exception as e:
logger.warning(f"Error extracting features from model: {e}")
return None, None
def _cleanup_cache(self):
"""Clean up old cache entries"""
try:
current_time = datetime.now()
expired_keys = []
for key, data in self.feature_cache.items():
if (current_time - data['timestamp']).seconds > self.cache_timeout * 2:
expired_keys.append(key)
for key in expired_keys:
del self.feature_cache[key]
except Exception as e:
logger.warning(f"Error cleaning up feature cache: {e}")
def register_cnn_model(self, model_name: str, model: nn.Module):
"""Register a CNN model for feature extraction"""
try:
self.cnn_models[model_name] = model
logger.info(f"Registered CNN model: {model_name}")
except Exception as e:
logger.error(f"Error registering CNN model {model_name}: {e}")
def unregister_cnn_model(self, model_name: str):
"""Unregister a CNN model"""
try:
if model_name in self.cnn_models:
del self.cnn_models[model_name]
logger.info(f"Unregistered CNN model: {model_name}")
except Exception as e:
logger.error(f"Error unregistering CNN model {model_name}: {e}")
def get_available_models(self) -> List[str]:
"""Get list of available CNN models"""
return list(self.cnn_models.keys())
def is_model_available(self, model_name: str) -> bool:
"""Check if a specific CNN model is available"""
return model_name in self.cnn_models
def get_feature_dimensions(self) -> Dict[str, int]:
"""Get the dimensions of features extracted from CNN models"""
return {
'hidden_features_per_model': 512,
'predictions_per_model': 4, # 1s, 1m, 1h, 1d
'total_models': len(self.cnn_models)
}
def validate_cnn_integration(self) -> Dict[str, Any]:
"""Validate CNN integration status"""
status = {
'models_available': len(self.cnn_models),
'models_list': list(self.cnn_models.keys()),
'cache_entries': len(self.feature_cache),
'integration_ready': len(self.cnn_models) > 0,
'expected_feature_size': len(self.cnn_models) * 512, # hidden features
'expected_prediction_size': len(self.cnn_models) * 4 # predictions
}
return status