increase prediction horizon
This commit is contained in:
@@ -140,14 +140,15 @@ class EnhancedCNNModel(nn.Module):
|
||||
- Large capacity for complex pattern learning
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(self,
|
||||
input_size: int = 60,
|
||||
feature_dim: int = 50,
|
||||
output_size: int = 2, # BUY/SELL for 2-action system
|
||||
base_channels: int = 256, # Increased from 128 to 256
|
||||
num_blocks: int = 12, # Increased from 6 to 12
|
||||
num_attention_heads: int = 16, # Increased from 8 to 16
|
||||
dropout_rate: float = 0.2):
|
||||
dropout_rate: float = 0.2,
|
||||
prediction_horizon: int = 1): # New: Prediction horizon in minutes
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
|
446
NN/models/multi_timeframe_predictor.py
Normal file
446
NN/models/multi_timeframe_predictor.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""
|
||||
Multi-Timeframe Prediction System for Enhanced Trading
|
||||
|
||||
This module implements a sophisticated multi-timeframe prediction system that allows
|
||||
models to make predictions for different time horizons (1, 5, 10 minutes) with
|
||||
appropriate confidence thresholds and position holding strategies.
|
||||
|
||||
Key Features:
|
||||
- Dynamic sequence length adaptation for different timeframes
|
||||
- Confidence calibration based on prediction horizon
|
||||
- Position holding logic for longer-term trades
|
||||
- Risk-adjusted trading strategies
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PredictionHorizon(Enum):
|
||||
"""Prediction time horizons"""
|
||||
ONE_MINUTE = 1
|
||||
FIVE_MINUTES = 5
|
||||
TEN_MINUTES = 10
|
||||
|
||||
class ConfidenceThreshold(Enum):
|
||||
"""Confidence thresholds for different horizons"""
|
||||
ONE_MINUTE = 0.35 # Lower threshold for quick trades
|
||||
FIVE_MINUTES = 0.65 # Higher threshold for 5-minute holds
|
||||
TEN_MINUTES = 0.80 # Very high threshold for 10-minute holds
|
||||
|
||||
@dataclass
|
||||
class MultiTimeframePrediction:
|
||||
"""Container for multi-timeframe predictions"""
|
||||
symbol: str
|
||||
current_price: float
|
||||
predictions: Dict[PredictionHorizon, Dict[str, Any]]
|
||||
timestamp: datetime
|
||||
market_conditions: Dict[str, Any]
|
||||
|
||||
class MultiTimeframePredictor:
|
||||
"""
|
||||
Advanced multi-timeframe prediction system that adapts model behavior
|
||||
based on desired prediction horizon and market conditions.
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator):
|
||||
self.orchestrator = orchestrator
|
||||
self.horizons = {
|
||||
PredictionHorizon.ONE_MINUTE: {
|
||||
'sequence_length': 60, # 60 minutes for 1-minute predictions
|
||||
'confidence_threshold': ConfidenceThreshold.ONE_MINUTE.value,
|
||||
'max_hold_time': 60, # 1 minute max hold
|
||||
'risk_multiplier': 1.0
|
||||
},
|
||||
PredictionHorizon.FIVE_MINUTES: {
|
||||
'sequence_length': 300, # 300 minutes for 5-minute predictions
|
||||
'confidence_threshold': ConfidenceThreshold.FIVE_MINUTES.value,
|
||||
'max_hold_time': 300, # 5 minutes max hold
|
||||
'risk_multiplier': 1.5 # Higher risk for longer holds
|
||||
},
|
||||
PredictionHorizon.TEN_MINUTES: {
|
||||
'sequence_length': 600, # 600 minutes for 10-minute predictions
|
||||
'confidence_threshold': ConfidenceThreshold.TEN_MINUTES.value,
|
||||
'max_hold_time': 600, # 10 minutes max hold
|
||||
'risk_multiplier': 2.0 # Highest risk for longest holds
|
||||
}
|
||||
}
|
||||
|
||||
# Initialize models for different horizons
|
||||
self.models = {}
|
||||
self._initialize_multi_horizon_models()
|
||||
|
||||
def _initialize_multi_horizon_models(self):
|
||||
"""Initialize separate model instances for different horizons"""
|
||||
try:
|
||||
for horizon, config in self.horizons.items():
|
||||
# CNN Model for this horizon
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
# Create horizon-specific model configuration
|
||||
horizon_model = self._create_horizon_specific_model(
|
||||
self.orchestrator.cnn_model,
|
||||
config['sequence_length'],
|
||||
horizon
|
||||
)
|
||||
self.models[f'cnn_{horizon.value}min'] = horizon_model
|
||||
|
||||
# COB RL Model for this horizon
|
||||
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
||||
self.models[f'cob_rl_{horizon.value}min'] = self.orchestrator.cob_rl_agent
|
||||
|
||||
logger.info(f"Initialized {horizon.value}-minute prediction model")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing multi-horizon models: {e}")
|
||||
|
||||
def _create_horizon_specific_model(self, base_model, sequence_length: int, horizon: PredictionHorizon):
|
||||
"""Create a model instance optimized for specific prediction horizon"""
|
||||
try:
|
||||
# For CNN models, we need to adjust input size and potentially architecture
|
||||
if hasattr(base_model, '__class__'):
|
||||
model_class = base_model.__class__
|
||||
|
||||
# Calculate appropriate input size for horizon
|
||||
# More data for longer predictions
|
||||
adjusted_input_size = min(sequence_length, 300) # Cap at 300 to avoid memory issues
|
||||
|
||||
# Create new model instance with horizon-specific parameters
|
||||
horizon_model = model_class(
|
||||
input_size=adjusted_input_size,
|
||||
feature_dim=getattr(base_model, 'feature_dim', 50),
|
||||
output_size=getattr(base_model, 'output_size', 2),
|
||||
base_channels=getattr(base_model, 'base_channels', 256),
|
||||
num_blocks=getattr(base_model, 'num_blocks', 12),
|
||||
num_attention_heads=getattr(base_model, 'num_attention_heads', 16),
|
||||
dropout_rate=getattr(base_model, 'dropout_rate', 0.2),
|
||||
prediction_horizon=horizon.value
|
||||
)
|
||||
|
||||
# Try to load pre-trained weights if available
|
||||
try:
|
||||
if hasattr(base_model, 'state_dict'):
|
||||
# Load base model weights and adapt if necessary
|
||||
base_state = base_model.state_dict()
|
||||
horizon_model.load_state_dict(base_state, strict=False)
|
||||
logger.info(f"Loaded base model weights for {horizon.value}-minute horizon")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load base weights for {horizon.value}-minute model: {e}")
|
||||
|
||||
return horizon_model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating horizon-specific model: {e}")
|
||||
return base_model # Fallback to base model
|
||||
|
||||
def generate_multi_timeframe_prediction(self, symbol: str) -> Optional[MultiTimeframePrediction]:
|
||||
"""
|
||||
Generate predictions for all timeframes with appropriate confidence thresholds
|
||||
"""
|
||||
try:
|
||||
# Get current market data
|
||||
current_price = self._get_current_price(symbol)
|
||||
if not current_price:
|
||||
return None
|
||||
|
||||
# Get market conditions for confidence adjustment
|
||||
market_conditions = self._assess_market_conditions(symbol)
|
||||
|
||||
predictions = {}
|
||||
|
||||
# Generate predictions for each horizon
|
||||
for horizon, config in self.horizons.items():
|
||||
prediction = self._generate_single_horizon_prediction(
|
||||
symbol, current_price, horizon, config, market_conditions
|
||||
)
|
||||
if prediction:
|
||||
predictions[horizon] = prediction
|
||||
|
||||
if not predictions:
|
||||
return None
|
||||
|
||||
return MultiTimeframePrediction(
|
||||
symbol=symbol,
|
||||
current_price=current_price,
|
||||
predictions=predictions,
|
||||
timestamp=datetime.now(),
|
||||
market_conditions=market_conditions
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating multi-timeframe prediction: {e}")
|
||||
return None
|
||||
|
||||
def _generate_single_horizon_prediction(self, symbol: str, current_price: float,
|
||||
horizon: PredictionHorizon, config: Dict,
|
||||
market_conditions: Dict) -> Optional[Dict[str, Any]]:
|
||||
"""Generate prediction for single timeframe"""
|
||||
try:
|
||||
# Get appropriate data for this horizon
|
||||
sequence_data = self._get_sequence_data_for_horizon(symbol, config['sequence_length'])
|
||||
|
||||
if not sequence_data:
|
||||
return None
|
||||
|
||||
# Generate predictions from available models
|
||||
model_predictions = []
|
||||
|
||||
# CNN prediction
|
||||
cnn_key = f'cnn_{horizon.value}min'
|
||||
if cnn_key in self.models:
|
||||
cnn_pred = self._get_cnn_prediction(
|
||||
self.models[cnn_key], sequence_data, config
|
||||
)
|
||||
if cnn_pred:
|
||||
model_predictions.append(cnn_pred)
|
||||
|
||||
# COB RL prediction
|
||||
cob_key = f'cob_rl_{horizon.value}min'
|
||||
if cob_key in self.models:
|
||||
cob_pred = self._get_cob_rl_prediction(
|
||||
self.models[cob_key], sequence_data, config
|
||||
)
|
||||
if cob_pred:
|
||||
model_predictions.append(cob_pred)
|
||||
|
||||
if not model_predictions:
|
||||
return None
|
||||
|
||||
# Ensemble predictions
|
||||
ensemble_prediction = self._ensemble_predictions(
|
||||
model_predictions, config, market_conditions
|
||||
)
|
||||
|
||||
# Apply confidence threshold
|
||||
if ensemble_prediction['confidence'] < config['confidence_threshold']:
|
||||
return None # Not confident enough for this horizon
|
||||
|
||||
return ensemble_prediction
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating {horizon.value}-minute prediction: {e}")
|
||||
return None
|
||||
|
||||
def _get_sequence_data_for_horizon(self, symbol: str, sequence_length: int) -> Optional[torch.Tensor]:
|
||||
"""Get appropriate sequence data for prediction horizon"""
|
||||
try:
|
||||
# This would need to be implemented based on your data provider
|
||||
# For now, return a placeholder
|
||||
if hasattr(self.orchestrator, 'data_provider'):
|
||||
# Get historical data for the required sequence length
|
||||
data = self.orchestrator.data_provider.get_historical_data(
|
||||
symbol, '1m', limit=sequence_length
|
||||
)
|
||||
|
||||
if data is not None and len(data) >= sequence_length // 10: # At least 10% of required data
|
||||
# Convert to tensor format expected by models
|
||||
return self._convert_data_to_tensor(data)
|
||||
else:
|
||||
logger.warning(f"Insufficient data for {sequence_length}-point prediction")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting sequence data: {e}")
|
||||
return None
|
||||
|
||||
def _convert_data_to_tensor(self, data) -> torch.Tensor:
|
||||
"""Convert market data to tensor format"""
|
||||
try:
|
||||
# This is a placeholder - implement based on your data format
|
||||
if hasattr(data, 'values'):
|
||||
# Assume pandas DataFrame
|
||||
features = ['open', 'high', 'low', 'close', 'volume']
|
||||
feature_data = []
|
||||
|
||||
for feature in features:
|
||||
if feature in data.columns:
|
||||
values = data[feature].fillna(method='ffill').fillna(0).values
|
||||
feature_data.append(values)
|
||||
|
||||
if feature_data:
|
||||
# Stack features
|
||||
tensor_data = torch.tensor(feature_data, dtype=torch.float32).transpose(0, 1)
|
||||
return tensor_data.unsqueeze(0) # Add batch dimension
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting data to tensor: {e}")
|
||||
return None
|
||||
|
||||
def _get_cnn_prediction(self, model, sequence_data: torch.Tensor, config: Dict) -> Optional[Dict]:
|
||||
"""Get CNN model prediction"""
|
||||
try:
|
||||
with torch.no_grad():
|
||||
outputs = model(sequence_data)
|
||||
if isinstance(outputs, tuple):
|
||||
predictions, confidence = outputs
|
||||
else:
|
||||
predictions = outputs
|
||||
confidence = torch.softmax(predictions, dim=-1).max().item()
|
||||
|
||||
action_idx = predictions.argmax().item()
|
||||
actions = ['SELL', 'BUY'] # Adjust based on your model's output format
|
||||
|
||||
return {
|
||||
'action': actions[action_idx] if action_idx < len(actions) else 'HOLD',
|
||||
'confidence': confidence,
|
||||
'model': 'cnn',
|
||||
'horizon': config.get('max_hold_time', 60)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting CNN prediction: {e}")
|
||||
return None
|
||||
|
||||
def _get_cob_rl_prediction(self, model, sequence_data: torch.Tensor, config: Dict) -> Optional[Dict]:
|
||||
"""Get COB RL model prediction"""
|
||||
try:
|
||||
# This would need to be implemented based on your COB RL model interface
|
||||
if hasattr(model, 'predict'):
|
||||
result = model.predict(sequence_data)
|
||||
return {
|
||||
'action': result.get('action', 'HOLD'),
|
||||
'confidence': result.get('confidence', 0.5),
|
||||
'model': 'cob_rl',
|
||||
'horizon': config.get('max_hold_time', 60)
|
||||
}
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting COB RL prediction: {e}")
|
||||
return None
|
||||
|
||||
def _ensemble_predictions(self, predictions: List[Dict], config: Dict,
|
||||
market_conditions: Dict) -> Dict[str, Any]:
|
||||
"""Ensemble multiple model predictions"""
|
||||
try:
|
||||
if not predictions:
|
||||
return None
|
||||
|
||||
# Simple voting ensemble
|
||||
action_votes = {}
|
||||
confidence_sum = 0
|
||||
|
||||
for pred in predictions:
|
||||
action = pred['action']
|
||||
confidence = pred['confidence']
|
||||
|
||||
if action not in action_votes:
|
||||
action_votes[action] = 0
|
||||
action_votes[action] += confidence
|
||||
confidence_sum += confidence
|
||||
|
||||
# Get winning action
|
||||
best_action = max(action_votes, key=action_votes.get)
|
||||
ensemble_confidence = action_votes[best_action] / len(predictions)
|
||||
|
||||
# Adjust confidence based on market conditions
|
||||
market_confidence_multiplier = market_conditions.get('confidence_multiplier', 1.0)
|
||||
final_confidence = min(ensemble_confidence * market_confidence_multiplier, 1.0)
|
||||
|
||||
return {
|
||||
'action': best_action,
|
||||
'confidence': final_confidence,
|
||||
'horizon_minutes': config['max_hold_time'] // 60,
|
||||
'risk_multiplier': config['risk_multiplier'],
|
||||
'models_used': len(predictions),
|
||||
'market_conditions': market_conditions
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prediction ensemble: {e}")
|
||||
return None
|
||||
|
||||
def _assess_market_conditions(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Assess current market conditions for confidence adjustment"""
|
||||
try:
|
||||
conditions = {
|
||||
'volatility': 'medium',
|
||||
'trend': 'sideways',
|
||||
'confidence_multiplier': 1.0,
|
||||
'risk_level': 'normal'
|
||||
}
|
||||
|
||||
# This could be enhanced with actual market analysis
|
||||
# For now, return default conditions
|
||||
return conditions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error assessing market conditions: {e}")
|
||||
return {'confidence_multiplier': 1.0}
|
||||
|
||||
def _get_current_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get current price for symbol"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'data_provider'):
|
||||
ticker = self.orchestrator.data_provider.get_current_price(symbol)
|
||||
return ticker
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current price for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def should_execute_trade(self, prediction: MultiTimeframePrediction) -> Tuple[bool, str]:
|
||||
"""
|
||||
Determine if a trade should be executed based on multi-timeframe analysis
|
||||
"""
|
||||
try:
|
||||
if not prediction or not prediction.predictions:
|
||||
return False, "No predictions available"
|
||||
|
||||
# Find the best prediction across all horizons
|
||||
best_prediction = None
|
||||
best_confidence = 0
|
||||
|
||||
for horizon, pred in prediction.predictions.items():
|
||||
if pred['confidence'] > best_confidence:
|
||||
best_confidence = pred['confidence']
|
||||
best_prediction = (horizon, pred)
|
||||
|
||||
if not best_prediction:
|
||||
return False, "No valid predictions"
|
||||
|
||||
horizon, pred = best_prediction
|
||||
config = self.horizons[horizon]
|
||||
|
||||
# Check if confidence meets threshold
|
||||
if pred['confidence'] < config['confidence_threshold']:
|
||||
return False, ".2f"
|
||||
|
||||
# Check market conditions
|
||||
market_risk = prediction.market_conditions.get('risk_level', 'normal')
|
||||
if market_risk == 'high' and horizon.value >= 5:
|
||||
return False, "High market risk - avoiding longer-term predictions"
|
||||
|
||||
return True, f"Valid {horizon.value}-minute prediction with {pred['confidence']:.2f} confidence"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in trade execution decision: {e}")
|
||||
return False, f"Decision error: {e}"
|
||||
|
||||
def get_position_hold_time(self, prediction: MultiTimeframePrediction) -> int:
|
||||
"""Determine how long to hold a position based on prediction horizon"""
|
||||
try:
|
||||
if not prediction or not prediction.predictions:
|
||||
return 60 # Default 1 minute
|
||||
|
||||
# Use the longest horizon prediction that's available and confident
|
||||
max_horizon = 1
|
||||
for horizon, pred in prediction.predictions.items():
|
||||
config = self.horizons[horizon]
|
||||
if pred['confidence'] >= config['confidence_threshold']:
|
||||
max_horizon = max(max_horizon, horizon.value)
|
||||
|
||||
return max_horizon * 60 # Convert minutes to seconds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error determining hold time: {e}")
|
||||
return 60
|
Reference in New Issue
Block a user