447 lines
18 KiB
Python
447 lines
18 KiB
Python
"""
|
|
Multi-Timeframe Prediction System for Enhanced Trading
|
|
|
|
This module implements a sophisticated multi-timeframe prediction system that allows
|
|
models to make predictions for different time horizons (1, 5, 10 minutes) with
|
|
appropriate confidence thresholds and position holding strategies.
|
|
|
|
Key Features:
|
|
- Dynamic sequence length adaptation for different timeframes
|
|
- Confidence calibration based on prediction horizon
|
|
- Position holding logic for longer-term trades
|
|
- Risk-adjusted trading strategies
|
|
"""
|
|
|
|
import logging
|
|
import torch
|
|
import torch.nn as nn
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
from datetime import datetime, timedelta
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class PredictionHorizon(Enum):
|
|
"""Prediction time horizons"""
|
|
ONE_MINUTE = 1
|
|
FIVE_MINUTES = 5
|
|
TEN_MINUTES = 10
|
|
|
|
class ConfidenceThreshold(Enum):
|
|
"""Confidence thresholds for different horizons"""
|
|
ONE_MINUTE = 0.35 # Lower threshold for quick trades
|
|
FIVE_MINUTES = 0.65 # Higher threshold for 5-minute holds
|
|
TEN_MINUTES = 0.80 # Very high threshold for 10-minute holds
|
|
|
|
@dataclass
|
|
class MultiTimeframePrediction:
|
|
"""Container for multi-timeframe predictions"""
|
|
symbol: str
|
|
current_price: float
|
|
predictions: Dict[PredictionHorizon, Dict[str, Any]]
|
|
timestamp: datetime
|
|
market_conditions: Dict[str, Any]
|
|
|
|
class MultiTimeframePredictor:
|
|
"""
|
|
Advanced multi-timeframe prediction system that adapts model behavior
|
|
based on desired prediction horizon and market conditions.
|
|
"""
|
|
|
|
def __init__(self, orchestrator):
|
|
self.orchestrator = orchestrator
|
|
self.horizons = {
|
|
PredictionHorizon.ONE_MINUTE: {
|
|
'sequence_length': 60, # 60 minutes for 1-minute predictions
|
|
'confidence_threshold': ConfidenceThreshold.ONE_MINUTE.value,
|
|
'max_hold_time': 60, # 1 minute max hold
|
|
'risk_multiplier': 1.0
|
|
},
|
|
PredictionHorizon.FIVE_MINUTES: {
|
|
'sequence_length': 300, # 300 minutes for 5-minute predictions
|
|
'confidence_threshold': ConfidenceThreshold.FIVE_MINUTES.value,
|
|
'max_hold_time': 300, # 5 minutes max hold
|
|
'risk_multiplier': 1.5 # Higher risk for longer holds
|
|
},
|
|
PredictionHorizon.TEN_MINUTES: {
|
|
'sequence_length': 600, # 600 minutes for 10-minute predictions
|
|
'confidence_threshold': ConfidenceThreshold.TEN_MINUTES.value,
|
|
'max_hold_time': 600, # 10 minutes max hold
|
|
'risk_multiplier': 2.0 # Highest risk for longest holds
|
|
}
|
|
}
|
|
|
|
# Initialize models for different horizons
|
|
self.models = {}
|
|
self._initialize_multi_horizon_models()
|
|
|
|
def _initialize_multi_horizon_models(self):
|
|
"""Initialize separate model instances for different horizons"""
|
|
try:
|
|
for horizon, config in self.horizons.items():
|
|
# CNN Model for this horizon
|
|
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
|
# Create horizon-specific model configuration
|
|
horizon_model = self._create_horizon_specific_model(
|
|
self.orchestrator.cnn_model,
|
|
config['sequence_length'],
|
|
horizon
|
|
)
|
|
self.models[f'cnn_{horizon.value}min'] = horizon_model
|
|
|
|
# COB RL Model for this horizon
|
|
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
|
self.models[f'cob_rl_{horizon.value}min'] = self.orchestrator.cob_rl_agent
|
|
|
|
logger.info(f"Initialized {horizon.value}-minute prediction model")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing multi-horizon models: {e}")
|
|
|
|
def _create_horizon_specific_model(self, base_model, sequence_length: int, horizon: PredictionHorizon):
|
|
"""Create a model instance optimized for specific prediction horizon"""
|
|
try:
|
|
# For CNN models, we need to adjust input size and potentially architecture
|
|
if hasattr(base_model, '__class__'):
|
|
model_class = base_model.__class__
|
|
|
|
# Calculate appropriate input size for horizon
|
|
# More data for longer predictions
|
|
adjusted_input_size = min(sequence_length, 300) # Cap at 300 to avoid memory issues
|
|
|
|
# Create new model instance with horizon-specific parameters
|
|
horizon_model = model_class(
|
|
input_size=adjusted_input_size,
|
|
feature_dim=getattr(base_model, 'feature_dim', 50),
|
|
output_size=getattr(base_model, 'output_size', 2),
|
|
base_channels=getattr(base_model, 'base_channels', 256),
|
|
num_blocks=getattr(base_model, 'num_blocks', 12),
|
|
num_attention_heads=getattr(base_model, 'num_attention_heads', 16),
|
|
dropout_rate=getattr(base_model, 'dropout_rate', 0.2),
|
|
prediction_horizon=horizon.value
|
|
)
|
|
|
|
# Try to load pre-trained weights if available
|
|
try:
|
|
if hasattr(base_model, 'state_dict'):
|
|
# Load base model weights and adapt if necessary
|
|
base_state = base_model.state_dict()
|
|
horizon_model.load_state_dict(base_state, strict=False)
|
|
logger.info(f"Loaded base model weights for {horizon.value}-minute horizon")
|
|
except Exception as e:
|
|
logger.warning(f"Could not load base weights for {horizon.value}-minute model: {e}")
|
|
|
|
return horizon_model
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating horizon-specific model: {e}")
|
|
return base_model # Fallback to base model
|
|
|
|
def generate_multi_timeframe_prediction(self, symbol: str) -> Optional[MultiTimeframePrediction]:
|
|
"""
|
|
Generate predictions for all timeframes with appropriate confidence thresholds
|
|
"""
|
|
try:
|
|
# Get current market data
|
|
current_price = self._get_current_price(symbol)
|
|
if not current_price:
|
|
return None
|
|
|
|
# Get market conditions for confidence adjustment
|
|
market_conditions = self._assess_market_conditions(symbol)
|
|
|
|
predictions = {}
|
|
|
|
# Generate predictions for each horizon
|
|
for horizon, config in self.horizons.items():
|
|
prediction = self._generate_single_horizon_prediction(
|
|
symbol, current_price, horizon, config, market_conditions
|
|
)
|
|
if prediction:
|
|
predictions[horizon] = prediction
|
|
|
|
if not predictions:
|
|
return None
|
|
|
|
return MultiTimeframePrediction(
|
|
symbol=symbol,
|
|
current_price=current_price,
|
|
predictions=predictions,
|
|
timestamp=datetime.now(),
|
|
market_conditions=market_conditions
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating multi-timeframe prediction: {e}")
|
|
return None
|
|
|
|
def _generate_single_horizon_prediction(self, symbol: str, current_price: float,
|
|
horizon: PredictionHorizon, config: Dict,
|
|
market_conditions: Dict) -> Optional[Dict[str, Any]]:
|
|
"""Generate prediction for single timeframe"""
|
|
try:
|
|
# Get appropriate data for this horizon
|
|
sequence_data = self._get_sequence_data_for_horizon(symbol, config['sequence_length'])
|
|
|
|
if not sequence_data:
|
|
return None
|
|
|
|
# Generate predictions from available models
|
|
model_predictions = []
|
|
|
|
# CNN prediction
|
|
cnn_key = f'cnn_{horizon.value}min'
|
|
if cnn_key in self.models:
|
|
cnn_pred = self._get_cnn_prediction(
|
|
self.models[cnn_key], sequence_data, config
|
|
)
|
|
if cnn_pred:
|
|
model_predictions.append(cnn_pred)
|
|
|
|
# COB RL prediction
|
|
cob_key = f'cob_rl_{horizon.value}min'
|
|
if cob_key in self.models:
|
|
cob_pred = self._get_cob_rl_prediction(
|
|
self.models[cob_key], sequence_data, config
|
|
)
|
|
if cob_pred:
|
|
model_predictions.append(cob_pred)
|
|
|
|
if not model_predictions:
|
|
return None
|
|
|
|
# Ensemble predictions
|
|
ensemble_prediction = self._ensemble_predictions(
|
|
model_predictions, config, market_conditions
|
|
)
|
|
|
|
# Apply confidence threshold
|
|
if ensemble_prediction['confidence'] < config['confidence_threshold']:
|
|
return None # Not confident enough for this horizon
|
|
|
|
return ensemble_prediction
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating {horizon.value}-minute prediction: {e}")
|
|
return None
|
|
|
|
def _get_sequence_data_for_horizon(self, symbol: str, sequence_length: int) -> Optional[torch.Tensor]:
|
|
"""Get appropriate sequence data for prediction horizon"""
|
|
try:
|
|
# This would need to be implemented based on your data provider
|
|
# For now, return a placeholder
|
|
if hasattr(self.orchestrator, 'data_provider'):
|
|
# Get historical data for the required sequence length
|
|
data = self.orchestrator.data_provider.get_historical_data(
|
|
symbol, '1m', limit=sequence_length
|
|
)
|
|
|
|
if data is not None and len(data) >= sequence_length // 10: # At least 10% of required data
|
|
# Convert to tensor format expected by models
|
|
return self._convert_data_to_tensor(data)
|
|
else:
|
|
logger.warning(f"Insufficient data for {sequence_length}-point prediction")
|
|
return None
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting sequence data: {e}")
|
|
return None
|
|
|
|
def _convert_data_to_tensor(self, data) -> torch.Tensor:
|
|
"""Convert market data to tensor format"""
|
|
try:
|
|
# This is a placeholder - implement based on your data format
|
|
if hasattr(data, 'values'):
|
|
# Assume pandas DataFrame
|
|
features = ['open', 'high', 'low', 'close', 'volume']
|
|
feature_data = []
|
|
|
|
for feature in features:
|
|
if feature in data.columns:
|
|
values = data[feature].fillna(method='ffill').fillna(0).values
|
|
feature_data.append(values)
|
|
|
|
if feature_data:
|
|
# Stack features
|
|
tensor_data = torch.tensor(feature_data, dtype=torch.float32).transpose(0, 1)
|
|
return tensor_data.unsqueeze(0) # Add batch dimension
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error converting data to tensor: {e}")
|
|
return None
|
|
|
|
def _get_cnn_prediction(self, model, sequence_data: torch.Tensor, config: Dict) -> Optional[Dict]:
|
|
"""Get CNN model prediction"""
|
|
try:
|
|
with torch.no_grad():
|
|
outputs = model(sequence_data)
|
|
if isinstance(outputs, tuple):
|
|
predictions, confidence = outputs
|
|
else:
|
|
predictions = outputs
|
|
confidence = torch.softmax(predictions, dim=-1).max().item()
|
|
|
|
action_idx = predictions.argmax().item()
|
|
actions = ['SELL', 'BUY'] # Adjust based on your model's output format
|
|
|
|
return {
|
|
'action': actions[action_idx] if action_idx < len(actions) else 'HOLD',
|
|
'confidence': confidence,
|
|
'model': 'cnn',
|
|
'horizon': config.get('max_hold_time', 60)
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting CNN prediction: {e}")
|
|
return None
|
|
|
|
def _get_cob_rl_prediction(self, model, sequence_data: torch.Tensor, config: Dict) -> Optional[Dict]:
|
|
"""Get COB RL model prediction"""
|
|
try:
|
|
# This would need to be implemented based on your COB RL model interface
|
|
if hasattr(model, 'predict'):
|
|
result = model.predict(sequence_data)
|
|
return {
|
|
'action': result.get('action', 'HOLD'),
|
|
'confidence': result.get('confidence', 0.5),
|
|
'model': 'cob_rl',
|
|
'horizon': config.get('max_hold_time', 60)
|
|
}
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting COB RL prediction: {e}")
|
|
return None
|
|
|
|
def _ensemble_predictions(self, predictions: List[Dict], config: Dict,
|
|
market_conditions: Dict) -> Dict[str, Any]:
|
|
"""Ensemble multiple model predictions"""
|
|
try:
|
|
if not predictions:
|
|
return None
|
|
|
|
# Simple voting ensemble
|
|
action_votes = {}
|
|
confidence_sum = 0
|
|
|
|
for pred in predictions:
|
|
action = pred['action']
|
|
confidence = pred['confidence']
|
|
|
|
if action not in action_votes:
|
|
action_votes[action] = 0
|
|
action_votes[action] += confidence
|
|
confidence_sum += confidence
|
|
|
|
# Get winning action
|
|
best_action = max(action_votes, key=action_votes.get)
|
|
ensemble_confidence = action_votes[best_action] / len(predictions)
|
|
|
|
# Adjust confidence based on market conditions
|
|
market_confidence_multiplier = market_conditions.get('confidence_multiplier', 1.0)
|
|
final_confidence = min(ensemble_confidence * market_confidence_multiplier, 1.0)
|
|
|
|
return {
|
|
'action': best_action,
|
|
'confidence': final_confidence,
|
|
'horizon_minutes': config['max_hold_time'] // 60,
|
|
'risk_multiplier': config['risk_multiplier'],
|
|
'models_used': len(predictions),
|
|
'market_conditions': market_conditions
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in prediction ensemble: {e}")
|
|
return None
|
|
|
|
def _assess_market_conditions(self, symbol: str) -> Dict[str, Any]:
|
|
"""Assess current market conditions for confidence adjustment"""
|
|
try:
|
|
conditions = {
|
|
'volatility': 'medium',
|
|
'trend': 'sideways',
|
|
'confidence_multiplier': 1.0,
|
|
'risk_level': 'normal'
|
|
}
|
|
|
|
# This could be enhanced with actual market analysis
|
|
# For now, return default conditions
|
|
return conditions
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error assessing market conditions: {e}")
|
|
return {'confidence_multiplier': 1.0}
|
|
|
|
def _get_current_price(self, symbol: str) -> Optional[float]:
|
|
"""Get current price for symbol"""
|
|
try:
|
|
if hasattr(self.orchestrator, 'data_provider'):
|
|
ticker = self.orchestrator.data_provider.get_current_price(symbol)
|
|
return ticker
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting current price for {symbol}: {e}")
|
|
return None
|
|
|
|
def should_execute_trade(self, prediction: MultiTimeframePrediction) -> Tuple[bool, str]:
|
|
"""
|
|
Determine if a trade should be executed based on multi-timeframe analysis
|
|
"""
|
|
try:
|
|
if not prediction or not prediction.predictions:
|
|
return False, "No predictions available"
|
|
|
|
# Find the best prediction across all horizons
|
|
best_prediction = None
|
|
best_confidence = 0
|
|
|
|
for horizon, pred in prediction.predictions.items():
|
|
if pred['confidence'] > best_confidence:
|
|
best_confidence = pred['confidence']
|
|
best_prediction = (horizon, pred)
|
|
|
|
if not best_prediction:
|
|
return False, "No valid predictions"
|
|
|
|
horizon, pred = best_prediction
|
|
config = self.horizons[horizon]
|
|
|
|
# Check if confidence meets threshold
|
|
if pred['confidence'] < config['confidence_threshold']:
|
|
return False, ".2f"
|
|
|
|
# Check market conditions
|
|
market_risk = prediction.market_conditions.get('risk_level', 'normal')
|
|
if market_risk == 'high' and horizon.value >= 5:
|
|
return False, "High market risk - avoiding longer-term predictions"
|
|
|
|
return True, f"Valid {horizon.value}-minute prediction with {pred['confidence']:.2f} confidence"
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in trade execution decision: {e}")
|
|
return False, f"Decision error: {e}"
|
|
|
|
def get_position_hold_time(self, prediction: MultiTimeframePrediction) -> int:
|
|
"""Determine how long to hold a position based on prediction horizon"""
|
|
try:
|
|
if not prediction or not prediction.predictions:
|
|
return 60 # Default 1 minute
|
|
|
|
# Use the longest horizon prediction that's available and confident
|
|
max_horizon = 1
|
|
for horizon, pred in prediction.predictions.items():
|
|
config = self.horizons[horizon]
|
|
if pred['confidence'] >= config['confidence_threshold']:
|
|
max_horizon = max(max_horizon, horizon.value)
|
|
|
|
return max_horizon * 60 # Convert minutes to seconds
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error determining hold time: {e}")
|
|
return 60
|