Files
gogo2/NN/models/multi_timeframe_predictor.py
2025-09-09 09:50:14 +03:00

447 lines
18 KiB
Python

"""
Multi-Timeframe Prediction System for Enhanced Trading
This module implements a sophisticated multi-timeframe prediction system that allows
models to make predictions for different time horizons (1, 5, 10 minutes) with
appropriate confidence thresholds and position holding strategies.
Key Features:
- Dynamic sequence length adaptation for different timeframes
- Confidence calibration based on prediction horizon
- Position holding logic for longer-term trades
- Risk-adjusted trading strategies
"""
import logging
import torch
import torch.nn as nn
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime, timedelta
from dataclasses import dataclass
from enum import Enum
logger = logging.getLogger(__name__)
class PredictionHorizon(Enum):
"""Prediction time horizons"""
ONE_MINUTE = 1
FIVE_MINUTES = 5
TEN_MINUTES = 10
class ConfidenceThreshold(Enum):
"""Confidence thresholds for different horizons"""
ONE_MINUTE = 0.35 # Lower threshold for quick trades
FIVE_MINUTES = 0.65 # Higher threshold for 5-minute holds
TEN_MINUTES = 0.80 # Very high threshold for 10-minute holds
@dataclass
class MultiTimeframePrediction:
"""Container for multi-timeframe predictions"""
symbol: str
current_price: float
predictions: Dict[PredictionHorizon, Dict[str, Any]]
timestamp: datetime
market_conditions: Dict[str, Any]
class MultiTimeframePredictor:
"""
Advanced multi-timeframe prediction system that adapts model behavior
based on desired prediction horizon and market conditions.
"""
def __init__(self, orchestrator):
self.orchestrator = orchestrator
self.horizons = {
PredictionHorizon.ONE_MINUTE: {
'sequence_length': 60, # 60 minutes for 1-minute predictions
'confidence_threshold': ConfidenceThreshold.ONE_MINUTE.value,
'max_hold_time': 60, # 1 minute max hold
'risk_multiplier': 1.0
},
PredictionHorizon.FIVE_MINUTES: {
'sequence_length': 300, # 300 minutes for 5-minute predictions
'confidence_threshold': ConfidenceThreshold.FIVE_MINUTES.value,
'max_hold_time': 300, # 5 minutes max hold
'risk_multiplier': 1.5 # Higher risk for longer holds
},
PredictionHorizon.TEN_MINUTES: {
'sequence_length': 600, # 600 minutes for 10-minute predictions
'confidence_threshold': ConfidenceThreshold.TEN_MINUTES.value,
'max_hold_time': 600, # 10 minutes max hold
'risk_multiplier': 2.0 # Highest risk for longest holds
}
}
# Initialize models for different horizons
self.models = {}
self._initialize_multi_horizon_models()
def _initialize_multi_horizon_models(self):
"""Initialize separate model instances for different horizons"""
try:
for horizon, config in self.horizons.items():
# CNN Model for this horizon
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
# Create horizon-specific model configuration
horizon_model = self._create_horizon_specific_model(
self.orchestrator.cnn_model,
config['sequence_length'],
horizon
)
self.models[f'cnn_{horizon.value}min'] = horizon_model
# COB RL Model for this horizon
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
self.models[f'cob_rl_{horizon.value}min'] = self.orchestrator.cob_rl_agent
logger.info(f"Initialized {horizon.value}-minute prediction model")
except Exception as e:
logger.error(f"Error initializing multi-horizon models: {e}")
def _create_horizon_specific_model(self, base_model, sequence_length: int, horizon: PredictionHorizon):
"""Create a model instance optimized for specific prediction horizon"""
try:
# For CNN models, we need to adjust input size and potentially architecture
if hasattr(base_model, '__class__'):
model_class = base_model.__class__
# Calculate appropriate input size for horizon
# More data for longer predictions
adjusted_input_size = min(sequence_length, 300) # Cap at 300 to avoid memory issues
# Create new model instance with horizon-specific parameters
horizon_model = model_class(
input_size=adjusted_input_size,
feature_dim=getattr(base_model, 'feature_dim', 50),
output_size=getattr(base_model, 'output_size', 2),
base_channels=getattr(base_model, 'base_channels', 256),
num_blocks=getattr(base_model, 'num_blocks', 12),
num_attention_heads=getattr(base_model, 'num_attention_heads', 16),
dropout_rate=getattr(base_model, 'dropout_rate', 0.2),
prediction_horizon=horizon.value
)
# Try to load pre-trained weights if available
try:
if hasattr(base_model, 'state_dict'):
# Load base model weights and adapt if necessary
base_state = base_model.state_dict()
horizon_model.load_state_dict(base_state, strict=False)
logger.info(f"Loaded base model weights for {horizon.value}-minute horizon")
except Exception as e:
logger.warning(f"Could not load base weights for {horizon.value}-minute model: {e}")
return horizon_model
except Exception as e:
logger.error(f"Error creating horizon-specific model: {e}")
return base_model # Fallback to base model
def generate_multi_timeframe_prediction(self, symbol: str) -> Optional[MultiTimeframePrediction]:
"""
Generate predictions for all timeframes with appropriate confidence thresholds
"""
try:
# Get current market data
current_price = self._get_current_price(symbol)
if not current_price:
return None
# Get market conditions for confidence adjustment
market_conditions = self._assess_market_conditions(symbol)
predictions = {}
# Generate predictions for each horizon
for horizon, config in self.horizons.items():
prediction = self._generate_single_horizon_prediction(
symbol, current_price, horizon, config, market_conditions
)
if prediction:
predictions[horizon] = prediction
if not predictions:
return None
return MultiTimeframePrediction(
symbol=symbol,
current_price=current_price,
predictions=predictions,
timestamp=datetime.now(),
market_conditions=market_conditions
)
except Exception as e:
logger.error(f"Error generating multi-timeframe prediction: {e}")
return None
def _generate_single_horizon_prediction(self, symbol: str, current_price: float,
horizon: PredictionHorizon, config: Dict,
market_conditions: Dict) -> Optional[Dict[str, Any]]:
"""Generate prediction for single timeframe"""
try:
# Get appropriate data for this horizon
sequence_data = self._get_sequence_data_for_horizon(symbol, config['sequence_length'])
if not sequence_data:
return None
# Generate predictions from available models
model_predictions = []
# CNN prediction
cnn_key = f'cnn_{horizon.value}min'
if cnn_key in self.models:
cnn_pred = self._get_cnn_prediction(
self.models[cnn_key], sequence_data, config
)
if cnn_pred:
model_predictions.append(cnn_pred)
# COB RL prediction
cob_key = f'cob_rl_{horizon.value}min'
if cob_key in self.models:
cob_pred = self._get_cob_rl_prediction(
self.models[cob_key], sequence_data, config
)
if cob_pred:
model_predictions.append(cob_pred)
if not model_predictions:
return None
# Ensemble predictions
ensemble_prediction = self._ensemble_predictions(
model_predictions, config, market_conditions
)
# Apply confidence threshold
if ensemble_prediction['confidence'] < config['confidence_threshold']:
return None # Not confident enough for this horizon
return ensemble_prediction
except Exception as e:
logger.error(f"Error generating {horizon.value}-minute prediction: {e}")
return None
def _get_sequence_data_for_horizon(self, symbol: str, sequence_length: int) -> Optional[torch.Tensor]:
"""Get appropriate sequence data for prediction horizon"""
try:
# This would need to be implemented based on your data provider
# For now, return a placeholder
if hasattr(self.orchestrator, 'data_provider'):
# Get historical data for the required sequence length
data = self.orchestrator.data_provider.get_historical_data(
symbol, '1m', limit=sequence_length
)
if data is not None and len(data) >= sequence_length // 10: # At least 10% of required data
# Convert to tensor format expected by models
return self._convert_data_to_tensor(data)
else:
logger.warning(f"Insufficient data for {sequence_length}-point prediction")
return None
return None
except Exception as e:
logger.error(f"Error getting sequence data: {e}")
return None
def _convert_data_to_tensor(self, data) -> torch.Tensor:
"""Convert market data to tensor format"""
try:
# This is a placeholder - implement based on your data format
if hasattr(data, 'values'):
# Assume pandas DataFrame
features = ['open', 'high', 'low', 'close', 'volume']
feature_data = []
for feature in features:
if feature in data.columns:
values = data[feature].fillna(method='ffill').fillna(0).values
feature_data.append(values)
if feature_data:
# Stack features
tensor_data = torch.tensor(feature_data, dtype=torch.float32).transpose(0, 1)
return tensor_data.unsqueeze(0) # Add batch dimension
return None
except Exception as e:
logger.error(f"Error converting data to tensor: {e}")
return None
def _get_cnn_prediction(self, model, sequence_data: torch.Tensor, config: Dict) -> Optional[Dict]:
"""Get CNN model prediction"""
try:
with torch.no_grad():
outputs = model(sequence_data)
if isinstance(outputs, tuple):
predictions, confidence = outputs
else:
predictions = outputs
confidence = torch.softmax(predictions, dim=-1).max().item()
action_idx = predictions.argmax().item()
actions = ['SELL', 'BUY'] # Adjust based on your model's output format
return {
'action': actions[action_idx] if action_idx < len(actions) else 'HOLD',
'confidence': confidence,
'model': 'cnn',
'horizon': config.get('max_hold_time', 60)
}
except Exception as e:
logger.error(f"Error getting CNN prediction: {e}")
return None
def _get_cob_rl_prediction(self, model, sequence_data: torch.Tensor, config: Dict) -> Optional[Dict]:
"""Get COB RL model prediction"""
try:
# This would need to be implemented based on your COB RL model interface
if hasattr(model, 'predict'):
result = model.predict(sequence_data)
return {
'action': result.get('action', 'HOLD'),
'confidence': result.get('confidence', 0.5),
'model': 'cob_rl',
'horizon': config.get('max_hold_time', 60)
}
return None
except Exception as e:
logger.error(f"Error getting COB RL prediction: {e}")
return None
def _ensemble_predictions(self, predictions: List[Dict], config: Dict,
market_conditions: Dict) -> Dict[str, Any]:
"""Ensemble multiple model predictions"""
try:
if not predictions:
return None
# Simple voting ensemble
action_votes = {}
confidence_sum = 0
for pred in predictions:
action = pred['action']
confidence = pred['confidence']
if action not in action_votes:
action_votes[action] = 0
action_votes[action] += confidence
confidence_sum += confidence
# Get winning action
best_action = max(action_votes, key=action_votes.get)
ensemble_confidence = action_votes[best_action] / len(predictions)
# Adjust confidence based on market conditions
market_confidence_multiplier = market_conditions.get('confidence_multiplier', 1.0)
final_confidence = min(ensemble_confidence * market_confidence_multiplier, 1.0)
return {
'action': best_action,
'confidence': final_confidence,
'horizon_minutes': config['max_hold_time'] // 60,
'risk_multiplier': config['risk_multiplier'],
'models_used': len(predictions),
'market_conditions': market_conditions
}
except Exception as e:
logger.error(f"Error in prediction ensemble: {e}")
return None
def _assess_market_conditions(self, symbol: str) -> Dict[str, Any]:
"""Assess current market conditions for confidence adjustment"""
try:
conditions = {
'volatility': 'medium',
'trend': 'sideways',
'confidence_multiplier': 1.0,
'risk_level': 'normal'
}
# This could be enhanced with actual market analysis
# For now, return default conditions
return conditions
except Exception as e:
logger.error(f"Error assessing market conditions: {e}")
return {'confidence_multiplier': 1.0}
def _get_current_price(self, symbol: str) -> Optional[float]:
"""Get current price for symbol"""
try:
if hasattr(self.orchestrator, 'data_provider'):
ticker = self.orchestrator.data_provider.get_current_price(symbol)
return ticker
return None
except Exception as e:
logger.error(f"Error getting current price for {symbol}: {e}")
return None
def should_execute_trade(self, prediction: MultiTimeframePrediction) -> Tuple[bool, str]:
"""
Determine if a trade should be executed based on multi-timeframe analysis
"""
try:
if not prediction or not prediction.predictions:
return False, "No predictions available"
# Find the best prediction across all horizons
best_prediction = None
best_confidence = 0
for horizon, pred in prediction.predictions.items():
if pred['confidence'] > best_confidence:
best_confidence = pred['confidence']
best_prediction = (horizon, pred)
if not best_prediction:
return False, "No valid predictions"
horizon, pred = best_prediction
config = self.horizons[horizon]
# Check if confidence meets threshold
if pred['confidence'] < config['confidence_threshold']:
return False, ".2f"
# Check market conditions
market_risk = prediction.market_conditions.get('risk_level', 'normal')
if market_risk == 'high' and horizon.value >= 5:
return False, "High market risk - avoiding longer-term predictions"
return True, f"Valid {horizon.value}-minute prediction with {pred['confidence']:.2f} confidence"
except Exception as e:
logger.error(f"Error in trade execution decision: {e}")
return False, f"Decision error: {e}"
def get_position_hold_time(self, prediction: MultiTimeframePrediction) -> int:
"""Determine how long to hold a position based on prediction horizon"""
try:
if not prediction or not prediction.predictions:
return 60 # Default 1 minute
# Use the longest horizon prediction that's available and confident
max_horizon = 1
for horizon, pred in prediction.predictions.items():
config = self.horizons[horizon]
if pred['confidence'] >= config['confidence_threshold']:
max_horizon = max(max_horizon, horizon.value)
return max_horizon * 60 # Convert minutes to seconds
except Exception as e:
logger.error(f"Error determining hold time: {e}")
return 60