219 lines
8.9 KiB
Python
219 lines
8.9 KiB
Python
"""
|
|
CNN-RL Bridge Module
|
|
|
|
This module provides the interface between CNN models and RL training,
|
|
extracting hidden features and predictions from CNN models for use in RL state building.
|
|
"""
|
|
|
|
import logging
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
from datetime import datetime, timedelta
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class CNNRLBridge:
|
|
"""Bridge between CNN models and RL training for feature extraction"""
|
|
|
|
def __init__(self, config: Dict):
|
|
"""Initialize CNN-RL bridge"""
|
|
self.config = config
|
|
self.cnn_models = {}
|
|
self.feature_cache = {}
|
|
self.cache_timeout = 60 # Cache features for 60 seconds
|
|
|
|
# Initialize CNN model registry if available
|
|
self._initialize_cnn_models()
|
|
|
|
logger.info("CNN-RL Bridge initialized")
|
|
|
|
def _initialize_cnn_models(self):
|
|
"""Initialize CNN models from config or model registry"""
|
|
try:
|
|
# Try to load CNN models from config
|
|
if hasattr(self.config, 'cnn_models') and self.config.cnn_models:
|
|
for model_name, model_config in self.config.cnn_models.items():
|
|
try:
|
|
# Load CNN model (implementation would depend on your CNN architecture)
|
|
model = self._load_cnn_model(model_name, model_config)
|
|
if model:
|
|
self.cnn_models[model_name] = model
|
|
logger.info(f"Loaded CNN model: {model_name}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load CNN model {model_name}: {e}")
|
|
|
|
if not self.cnn_models:
|
|
logger.info("No CNN models available - RL will train without CNN features")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error initializing CNN models: {e}")
|
|
|
|
def _load_cnn_model(self, model_name: str, model_config: Dict) -> Optional[nn.Module]:
|
|
"""Load a CNN model from configuration"""
|
|
try:
|
|
# This would implement actual CNN model loading
|
|
# For now, return None to indicate no models available
|
|
# In your implementation, this would load your specific CNN architecture
|
|
|
|
logger.info(f"CNN model loading framework ready for {model_name}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading CNN model {model_name}: {e}")
|
|
return None
|
|
|
|
def get_latest_features_for_symbol(self, symbol: str) -> Optional[Dict[str, Any]]:
|
|
"""Get latest CNN features and predictions for a symbol"""
|
|
try:
|
|
# Check cache first
|
|
cache_key = f"{symbol}_{datetime.now().strftime('%Y%m%d_%H%M')}"
|
|
if cache_key in self.feature_cache:
|
|
cached_data = self.feature_cache[cache_key]
|
|
if (datetime.now() - cached_data['timestamp']).seconds < self.cache_timeout:
|
|
return cached_data['features']
|
|
|
|
# Generate new features if models available
|
|
if self.cnn_models:
|
|
features = self._extract_cnn_features_for_symbol(symbol)
|
|
|
|
# Cache the features
|
|
self.feature_cache[cache_key] = {
|
|
'timestamp': datetime.now(),
|
|
'features': features
|
|
}
|
|
|
|
# Clean old cache entries
|
|
self._cleanup_cache()
|
|
|
|
return features
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error getting CNN features for {symbol}: {e}")
|
|
return None
|
|
|
|
def _extract_cnn_features_for_symbol(self, symbol: str) -> Dict[str, Any]:
|
|
"""Extract CNN hidden features and predictions for a symbol"""
|
|
try:
|
|
extracted_features = {
|
|
'hidden_features': {},
|
|
'predictions': {}
|
|
}
|
|
|
|
for model_name, model in self.cnn_models.items():
|
|
try:
|
|
# Extract features from each CNN model
|
|
hidden_features, predictions = self._extract_model_features(model, symbol)
|
|
|
|
if hidden_features is not None:
|
|
extracted_features['hidden_features'][model_name] = hidden_features
|
|
|
|
if predictions is not None:
|
|
extracted_features['predictions'][model_name] = predictions
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error extracting features from {model_name}: {e}")
|
|
|
|
return extracted_features
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error extracting CNN features for {symbol}: {e}")
|
|
return {'hidden_features': {}, 'predictions': {}}
|
|
|
|
def _extract_model_features(self, model: nn.Module, symbol: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
|
"""Extract hidden features and predictions from a specific CNN model"""
|
|
try:
|
|
# This would implement the actual feature extraction from your CNN models
|
|
# The implementation depends on your specific CNN architecture
|
|
|
|
# For now, return mock data to show the structure
|
|
# In real implementation, this would:
|
|
# 1. Get market data for the model
|
|
# 2. Run forward pass through CNN
|
|
# 3. Extract hidden layer activations
|
|
# 4. Get model predictions
|
|
|
|
# Mock hidden features (last hidden layer of CNN)
|
|
hidden_features = np.random.random(512).astype(np.float32)
|
|
|
|
# Mock predictions for different timeframes
|
|
# [1s_pred, 1m_pred, 1h_pred, 1d_pred] for each timeframe
|
|
predictions = np.array([
|
|
0.45, # 1s prediction (probability of up move)
|
|
0.52, # 1m prediction
|
|
0.38, # 1h prediction
|
|
0.61 # 1d prediction
|
|
]).astype(np.float32)
|
|
|
|
logger.debug(f"Extracted CNN features for {symbol}: {len(hidden_features)} hidden, {len(predictions)} predictions")
|
|
|
|
return hidden_features, predictions
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error extracting features from model: {e}")
|
|
return None, None
|
|
|
|
def _cleanup_cache(self):
|
|
"""Clean up old cache entries"""
|
|
try:
|
|
current_time = datetime.now()
|
|
expired_keys = []
|
|
|
|
for key, data in self.feature_cache.items():
|
|
if (current_time - data['timestamp']).seconds > self.cache_timeout * 2:
|
|
expired_keys.append(key)
|
|
|
|
for key in expired_keys:
|
|
del self.feature_cache[key]
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error cleaning up feature cache: {e}")
|
|
|
|
def register_cnn_model(self, model_name: str, model: nn.Module):
|
|
"""Register a CNN model for feature extraction"""
|
|
try:
|
|
self.cnn_models[model_name] = model
|
|
logger.info(f"Registered CNN model: {model_name}")
|
|
except Exception as e:
|
|
logger.error(f"Error registering CNN model {model_name}: {e}")
|
|
|
|
def unregister_cnn_model(self, model_name: str):
|
|
"""Unregister a CNN model"""
|
|
try:
|
|
if model_name in self.cnn_models:
|
|
del self.cnn_models[model_name]
|
|
logger.info(f"Unregistered CNN model: {model_name}")
|
|
except Exception as e:
|
|
logger.error(f"Error unregistering CNN model {model_name}: {e}")
|
|
|
|
def get_available_models(self) -> List[str]:
|
|
"""Get list of available CNN models"""
|
|
return list(self.cnn_models.keys())
|
|
|
|
def is_model_available(self, model_name: str) -> bool:
|
|
"""Check if a specific CNN model is available"""
|
|
return model_name in self.cnn_models
|
|
|
|
def get_feature_dimensions(self) -> Dict[str, int]:
|
|
"""Get the dimensions of features extracted from CNN models"""
|
|
return {
|
|
'hidden_features_per_model': 512,
|
|
'predictions_per_model': 4, # 1s, 1m, 1h, 1d
|
|
'total_models': len(self.cnn_models)
|
|
}
|
|
|
|
def validate_cnn_integration(self) -> Dict[str, Any]:
|
|
"""Validate CNN integration status"""
|
|
status = {
|
|
'models_available': len(self.cnn_models),
|
|
'models_list': list(self.cnn_models.keys()),
|
|
'cache_entries': len(self.feature_cache),
|
|
'integration_ready': len(self.cnn_models) > 0,
|
|
'expected_feature_size': len(self.cnn_models) * 512, # hidden features
|
|
'expected_prediction_size': len(self.cnn_models) * 4 # predictions
|
|
}
|
|
|
|
return status |