781 lines
34 KiB
Python
781 lines
34 KiB
Python
"""
|
|
Multi-Timeframe Prediction System for Enhanced Trading
|
|
|
|
This module implements a sophisticated multi-timeframe prediction system that allows
|
|
models to make predictions for different time horizons (1, 5, 10 minutes) with
|
|
appropriate confidence thresholds and position holding strategies.
|
|
|
|
Key Features:
|
|
- Dynamic sequence length adaptation for different timeframes
|
|
- Confidence calibration based on prediction horizon
|
|
- Position holding logic for longer-term trades
|
|
- Risk-adjusted trading strategies
|
|
"""
|
|
|
|
import logging
|
|
import torch
|
|
import torch.nn as nn
|
|
from typing import Dict, List, Optional, Tuple, Any
|
|
from datetime import datetime, timedelta
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class PredictionHorizon(Enum):
|
|
"""Prediction time horizons"""
|
|
ONE_MINUTE = 1
|
|
FIVE_MINUTES = 5
|
|
TEN_MINUTES = 10
|
|
|
|
class ConfidenceThreshold(Enum):
|
|
"""Confidence thresholds for different horizons"""
|
|
ONE_MINUTE = 0.35 # Lower threshold for quick trades
|
|
FIVE_MINUTES = 0.65 # Higher threshold for 5-minute holds
|
|
TEN_MINUTES = 0.80 # Very high threshold for 10-minute holds
|
|
|
|
@dataclass
|
|
class MultiTimeframePrediction:
|
|
"""Container for multi-timeframe predictions"""
|
|
symbol: str
|
|
current_price: float
|
|
predictions: Dict[PredictionHorizon, Dict[str, Any]]
|
|
timestamp: datetime
|
|
market_conditions: Dict[str, Any]
|
|
|
|
class MultiTimeframePredictor:
|
|
"""
|
|
Advanced multi-timeframe prediction system that adapts model behavior
|
|
based on desired prediction horizon and market conditions.
|
|
"""
|
|
|
|
def __init__(self, orchestrator):
|
|
self.orchestrator = orchestrator
|
|
self.horizons = {
|
|
PredictionHorizon.ONE_MINUTE: {
|
|
'sequence_length': 60, # 60 minutes for 1-minute predictions
|
|
'confidence_threshold': ConfidenceThreshold.ONE_MINUTE.value,
|
|
'max_hold_time': 60, # 1 minute max hold
|
|
'risk_multiplier': 1.0
|
|
},
|
|
PredictionHorizon.FIVE_MINUTES: {
|
|
'sequence_length': 300, # 300 minutes for 5-minute predictions
|
|
'confidence_threshold': ConfidenceThreshold.FIVE_MINUTES.value,
|
|
'max_hold_time': 300, # 5 minutes max hold
|
|
'risk_multiplier': 1.5 # Higher risk for longer holds
|
|
},
|
|
PredictionHorizon.TEN_MINUTES: {
|
|
'sequence_length': 600, # 600 minutes for 10-minute predictions
|
|
'confidence_threshold': ConfidenceThreshold.TEN_MINUTES.value,
|
|
'max_hold_time': 600, # 10 minutes max hold
|
|
'risk_multiplier': 2.0 # Highest risk for longest holds
|
|
}
|
|
}
|
|
|
|
# Initialize models for different horizons
|
|
self.models = {}
|
|
self._initialize_multi_horizon_models()
|
|
|
|
def _initialize_multi_horizon_models(self):
|
|
"""Initialize separate model instances for different horizons"""
|
|
try:
|
|
for horizon, config in self.horizons.items():
|
|
# CNN Model for this horizon
|
|
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
|
# Create horizon-specific model configuration
|
|
horizon_model = self._create_horizon_specific_model(
|
|
self.orchestrator.cnn_model,
|
|
config['sequence_length'],
|
|
horizon
|
|
)
|
|
self.models[f'cnn_{horizon.value}min'] = horizon_model
|
|
|
|
# COB RL Model for this horizon
|
|
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
|
|
self.models[f'cob_rl_{horizon.value}min'] = self.orchestrator.cob_rl_agent
|
|
|
|
logger.info(f"Initialized {horizon.value}-minute prediction model")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing multi-horizon models: {e}")
|
|
|
|
def _create_horizon_specific_model(self, base_model, sequence_length: int, horizon: PredictionHorizon):
|
|
"""Create a model instance optimized for specific prediction horizon"""
|
|
try:
|
|
# For CNN models, we need to adjust input size and potentially architecture
|
|
if hasattr(base_model, '__class__'):
|
|
model_class = base_model.__class__
|
|
|
|
# Calculate appropriate input size for horizon
|
|
# More data for longer predictions
|
|
adjusted_input_size = min(sequence_length, 300) # Cap at 300 to avoid memory issues
|
|
|
|
# Create new model instance with horizon-specific parameters
|
|
# Use only the parameters that the model actually accepts
|
|
try:
|
|
horizon_model = model_class(
|
|
input_size=adjusted_input_size,
|
|
feature_dim=getattr(base_model, 'feature_dim', 50),
|
|
output_size=5, # Always use 5 for OHLCV predictions
|
|
prediction_horizon=horizon.value
|
|
)
|
|
except TypeError:
|
|
# If the model doesn't accept these parameters, just create with defaults
|
|
logger.warning(f"Model {model_class.__name__} doesn't accept expected parameters, using defaults")
|
|
horizon_model = model_class()
|
|
|
|
# Try to load pre-trained weights if available
|
|
try:
|
|
if hasattr(base_model, 'state_dict'):
|
|
# Load base model weights and adapt if necessary
|
|
base_state = base_model.state_dict()
|
|
horizon_model.load_state_dict(base_state, strict=False)
|
|
logger.info(f"Loaded base model weights for {horizon.value}-minute horizon")
|
|
except Exception as e:
|
|
logger.warning(f"Could not load base weights for {horizon.value}-minute model: {e}")
|
|
|
|
return horizon_model
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating horizon-specific model: {e}")
|
|
return base_model # Fallback to base model
|
|
|
|
def generate_multi_timeframe_prediction(self, symbol: str) -> Optional[MultiTimeframePrediction]:
|
|
"""
|
|
Generate predictions for all timeframes with appropriate confidence thresholds
|
|
"""
|
|
try:
|
|
# Get current market data
|
|
current_price = self._get_current_price(symbol)
|
|
if not current_price:
|
|
return None
|
|
|
|
# Get market conditions for confidence adjustment
|
|
market_conditions = self._assess_market_conditions(symbol)
|
|
|
|
predictions = {}
|
|
|
|
# Generate predictions for each horizon
|
|
for horizon, config in self.horizons.items():
|
|
prediction = self._generate_single_horizon_prediction(
|
|
symbol, current_price, horizon, config, market_conditions
|
|
)
|
|
if prediction:
|
|
predictions[horizon] = prediction
|
|
|
|
if not predictions:
|
|
return None
|
|
|
|
return MultiTimeframePrediction(
|
|
symbol=symbol,
|
|
current_price=current_price,
|
|
predictions=predictions,
|
|
timestamp=datetime.now(),
|
|
market_conditions=market_conditions
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating multi-timeframe prediction: {e}")
|
|
return None
|
|
|
|
def _generate_single_horizon_prediction(self, symbol: str, current_price: float,
|
|
horizon: PredictionHorizon, config: Dict,
|
|
market_conditions: Dict) -> Optional[Dict[str, Any]]:
|
|
"""Generate prediction for single timeframe using iterative candle prediction"""
|
|
try:
|
|
# Get base historical data (use shorter sequence for iterative prediction)
|
|
base_sequence_length = min(60, config['sequence_length'] // 2) # Use half for base data
|
|
base_data = self._get_sequence_data_for_horizon(symbol, base_sequence_length)
|
|
|
|
if not base_data:
|
|
return None
|
|
|
|
# Generate iterative predictions for this horizon
|
|
iterative_predictions = self._generate_iterative_predictions(
|
|
symbol, base_data, horizon.value, market_conditions
|
|
)
|
|
|
|
if not iterative_predictions:
|
|
return None
|
|
|
|
# Analyze the predicted price movement over the horizon
|
|
horizon_prediction = self._analyze_horizon_prediction(
|
|
iterative_predictions, config, market_conditions
|
|
)
|
|
|
|
# Apply confidence threshold
|
|
if horizon_prediction['confidence'] < config['confidence_threshold']:
|
|
return None # Not confident enough for this horizon
|
|
|
|
return horizon_prediction
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating {horizon.value}-minute prediction: {e}")
|
|
return None
|
|
|
|
def _get_sequence_data_for_horizon(self, symbol: str, sequence_length: int) -> Optional[torch.Tensor]:
|
|
"""Get appropriate sequence data for prediction horizon"""
|
|
try:
|
|
# This would need to be implemented based on your data provider
|
|
# For now, return a placeholder
|
|
if hasattr(self.orchestrator, 'data_provider'):
|
|
# Get historical data for the required sequence length
|
|
data = self.orchestrator.data_provider.get_historical_data(
|
|
symbol, '1m', limit=sequence_length
|
|
)
|
|
|
|
if data is not None and len(data) >= sequence_length // 10: # At least 10% of required data
|
|
# Convert to tensor format expected by models
|
|
tensor_data = self._convert_data_to_tensor(data)
|
|
if tensor_data is not None:
|
|
logger.debug(f"✅ Converted {len(data)} data points to tensor shape: {tensor_data.shape}")
|
|
return tensor_data
|
|
else:
|
|
logger.warning("Failed to convert data to tensor")
|
|
return None
|
|
else:
|
|
logger.warning(f"Insufficient data for {sequence_length}-point prediction: {len(data) if data is not None else 'None'}")
|
|
return None
|
|
|
|
# Fallback: create mock data if no data provider available
|
|
logger.warning("No data provider available - creating mock sequence data")
|
|
return self._create_mock_sequence_data(sequence_length)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting sequence data: {e}")
|
|
# Fallback: create mock data on error
|
|
logger.warning("Creating mock sequence data due to error")
|
|
return self._create_mock_sequence_data(sequence_length)
|
|
|
|
def _convert_data_to_tensor(self, data) -> torch.Tensor:
|
|
"""Convert market data to tensor format"""
|
|
try:
|
|
# This is a placeholder - implement based on your data format
|
|
if hasattr(data, 'values'):
|
|
# Assume pandas DataFrame
|
|
features = ['open', 'high', 'low', 'close', 'volume']
|
|
feature_data = []
|
|
|
|
for feature in features:
|
|
if feature in data.columns:
|
|
values = data[feature].ffill().fillna(0).values
|
|
feature_data.append(values)
|
|
|
|
if feature_data:
|
|
# Ensure all feature arrays have the same length
|
|
min_length = min(len(arr) for arr in feature_data)
|
|
feature_data = [arr[:min_length] for arr in feature_data]
|
|
|
|
# Stack features
|
|
tensor_data = torch.tensor(feature_data, dtype=torch.float32).transpose(0, 1)
|
|
|
|
# Validate tensor data
|
|
if torch.any(torch.isnan(tensor_data)) or torch.any(torch.isinf(tensor_data)):
|
|
logger.warning("Found NaN or Inf values in tensor data, replacing with zeros")
|
|
tensor_data = torch.nan_to_num(tensor_data, nan=0.0, posinf=0.0, neginf=0.0)
|
|
|
|
return tensor_data.unsqueeze(0) # Add batch dimension
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error converting data to tensor: {e}")
|
|
return None
|
|
|
|
def _get_cnn_prediction(self, model, sequence_data: torch.Tensor, config: Dict) -> Optional[Dict]:
|
|
"""Get CNN model prediction using OHLCV prediction"""
|
|
try:
|
|
# Use the predict method which now handles OHLCV predictions
|
|
if hasattr(model, 'predict'):
|
|
if sequence_data.dim() == 3: # [batch, seq, features]
|
|
sequence_data_flat = sequence_data.squeeze(0) # Remove batch dim
|
|
else:
|
|
sequence_data_flat = sequence_data
|
|
|
|
prediction = model.predict(sequence_data_flat)
|
|
|
|
if prediction and 'action_name' in prediction:
|
|
return {
|
|
'action': prediction['action_name'],
|
|
'confidence': prediction.get('action_confidence', 0.5),
|
|
'model': 'cnn',
|
|
'horizon': config.get('max_hold_time', 60),
|
|
'ohlcv_prediction': prediction.get('ohlcv_prediction'),
|
|
'price_change_pct': prediction.get('price_change_pct', 0)
|
|
}
|
|
|
|
# Fallback to direct forward pass if predict method not available
|
|
with torch.no_grad():
|
|
outputs = model(sequence_data)
|
|
if isinstance(outputs, dict) and 'ohlcv' in outputs:
|
|
ohlcv = outputs['ohlcv'].cpu().numpy()[0]
|
|
confidence = outputs['confidence'].cpu().numpy()[0] if hasattr(outputs['confidence'], 'cpu') else outputs['confidence']
|
|
|
|
# Determine action from OHLCV
|
|
price_change_pct = ((ohlcv[3] - ohlcv[0]) / ohlcv[0]) * 100 if ohlcv[0] != 0 else 0
|
|
|
|
if price_change_pct > 0.1:
|
|
action = 'BUY'
|
|
elif price_change_pct < -0.1:
|
|
action = 'SELL'
|
|
else:
|
|
action = 'HOLD'
|
|
|
|
return {
|
|
'action': action,
|
|
'confidence': float(confidence),
|
|
'model': 'cnn',
|
|
'horizon': config.get('max_hold_time', 60),
|
|
'ohlcv_prediction': {
|
|
'open': float(ohlcv[0]),
|
|
'high': float(ohlcv[1]),
|
|
'low': float(ohlcv[2]),
|
|
'close': float(ohlcv[3]),
|
|
'volume': float(ohlcv[4])
|
|
},
|
|
'price_change_pct': price_change_pct
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting CNN prediction: {e}")
|
|
return None
|
|
|
|
def _get_cob_rl_prediction(self, model, sequence_data: torch.Tensor, config: Dict) -> Optional[Dict]:
|
|
"""Get COB RL model prediction"""
|
|
try:
|
|
# This would need to be implemented based on your COB RL model interface
|
|
if hasattr(model, 'predict'):
|
|
result = model.predict(sequence_data)
|
|
return {
|
|
'action': result.get('action', 'HOLD'),
|
|
'confidence': result.get('confidence', 0.5),
|
|
'model': 'cob_rl',
|
|
'horizon': config.get('max_hold_time', 60)
|
|
}
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting COB RL prediction: {e}")
|
|
return None
|
|
|
|
def _ensemble_predictions(self, predictions: List[Dict], config: Dict,
|
|
market_conditions: Dict) -> Dict[str, Any]:
|
|
"""Ensemble multiple model predictions using OHLCV data"""
|
|
try:
|
|
if not predictions:
|
|
return None
|
|
|
|
# Enhanced ensemble considering both action and price movement
|
|
action_votes = {}
|
|
confidence_sum = 0
|
|
price_change_indicators = []
|
|
|
|
for pred in predictions:
|
|
action = pred['action']
|
|
confidence = pred['confidence']
|
|
|
|
# Weight by confidence
|
|
if action not in action_votes:
|
|
action_votes[action] = 0
|
|
action_votes[action] += confidence
|
|
confidence_sum += confidence
|
|
|
|
# Collect price change indicators for ensemble analysis
|
|
if 'price_change_pct' in pred:
|
|
price_change_indicators.append(pred['price_change_pct'])
|
|
|
|
# Get winning action
|
|
if action_votes:
|
|
best_action = max(action_votes, key=action_votes.get)
|
|
ensemble_confidence = action_votes[best_action] / len(predictions)
|
|
else:
|
|
best_action = 'HOLD'
|
|
ensemble_confidence = 0.1
|
|
|
|
# Analyze price movement consensus
|
|
if price_change_indicators:
|
|
avg_price_change = sum(price_change_indicators) / len(price_change_indicators)
|
|
price_consensus = abs(avg_price_change) / 0.1 # Normalize around 0.1% threshold
|
|
|
|
# Boost confidence if price movements are consistent
|
|
if len(price_change_indicators) > 1:
|
|
price_std = torch.std(torch.tensor(price_change_indicators)).item()
|
|
if price_std < 0.05: # Low variability in predictions
|
|
ensemble_confidence *= 1.2
|
|
elif price_std > 0.15: # High variability
|
|
ensemble_confidence *= 0.8
|
|
|
|
# Override action based on strong price consensus
|
|
if abs(avg_price_change) > 0.2: # Strong price movement
|
|
if avg_price_change > 0:
|
|
best_action = 'BUY'
|
|
else:
|
|
best_action = 'SELL'
|
|
ensemble_confidence = min(ensemble_confidence * 1.3, 0.9)
|
|
|
|
# Adjust confidence based on market conditions
|
|
market_confidence_multiplier = market_conditions.get('confidence_multiplier', 1.0)
|
|
final_confidence = min(ensemble_confidence * market_confidence_multiplier, 1.0)
|
|
|
|
return {
|
|
'action': best_action,
|
|
'confidence': final_confidence,
|
|
'horizon_minutes': config['max_hold_time'] // 60,
|
|
'risk_multiplier': config['risk_multiplier'],
|
|
'models_used': len(predictions),
|
|
'market_conditions': market_conditions,
|
|
'price_change_indicators': price_change_indicators,
|
|
'avg_price_change_pct': sum(price_change_indicators) / len(price_change_indicators) if price_change_indicators else 0
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in prediction ensemble: {e}")
|
|
return None
|
|
|
|
def _assess_market_conditions(self, symbol: str) -> Dict[str, Any]:
|
|
"""Assess current market conditions for confidence adjustment"""
|
|
try:
|
|
conditions = {
|
|
'volatility': 'medium',
|
|
'trend': 'sideways',
|
|
'confidence_multiplier': 1.0,
|
|
'risk_level': 'normal'
|
|
}
|
|
|
|
# This could be enhanced with actual market analysis
|
|
# For now, return default conditions
|
|
return conditions
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error assessing market conditions: {e}")
|
|
return {'confidence_multiplier': 1.0}
|
|
|
|
def _get_current_price(self, symbol: str) -> Optional[float]:
|
|
"""Get current price for symbol"""
|
|
try:
|
|
if hasattr(self.orchestrator, 'data_provider'):
|
|
ticker = self.orchestrator.data_provider.get_current_price(symbol)
|
|
return ticker
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting current price for {symbol}: {e}")
|
|
return None
|
|
|
|
def should_execute_trade(self, prediction: MultiTimeframePrediction) -> Tuple[bool, str]:
|
|
"""
|
|
Determine if a trade should be executed based on multi-timeframe analysis
|
|
"""
|
|
try:
|
|
if not prediction or not prediction.predictions:
|
|
return False, "No predictions available"
|
|
|
|
# Find the best prediction across all horizons
|
|
best_prediction = None
|
|
best_confidence = 0
|
|
|
|
for horizon, pred in prediction.predictions.items():
|
|
if pred['confidence'] > best_confidence:
|
|
best_confidence = pred['confidence']
|
|
best_prediction = (horizon, pred)
|
|
|
|
if not best_prediction:
|
|
return False, "No valid predictions"
|
|
|
|
horizon, pred = best_prediction
|
|
config = self.horizons[horizon]
|
|
|
|
# Check if confidence meets threshold
|
|
if pred['confidence'] < config['confidence_threshold']:
|
|
return False, ".2f"
|
|
|
|
# Check market conditions
|
|
market_risk = prediction.market_conditions.get('risk_level', 'normal')
|
|
if market_risk == 'high' and horizon.value >= 5:
|
|
return False, "High market risk - avoiding longer-term predictions"
|
|
|
|
return True, f"Valid {horizon.value}-minute prediction with {pred['confidence']:.2f} confidence"
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in trade execution decision: {e}")
|
|
return False, f"Decision error: {e}"
|
|
|
|
def get_position_hold_time(self, prediction: MultiTimeframePrediction) -> int:
|
|
"""Determine how long to hold a position based on prediction horizon"""
|
|
try:
|
|
if not prediction or not prediction.predictions:
|
|
return 60 # Default 1 minute
|
|
|
|
# Use the longest horizon prediction that's available and confident
|
|
max_horizon = 1
|
|
for horizon, pred in prediction.predictions.items():
|
|
config = self.horizons[horizon]
|
|
if pred['confidence'] >= config['confidence_threshold']:
|
|
max_horizon = max(max_horizon, horizon.value)
|
|
|
|
return max_horizon * 60 # Convert minutes to seconds
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error determining hold time: {e}")
|
|
return 60
|
|
|
|
def _generate_iterative_predictions(self, symbol: str, base_data: torch.Tensor,
|
|
num_steps: int, market_conditions: Dict) -> Optional[List[Dict]]:
|
|
"""Generate iterative candle predictions for the specified number of steps"""
|
|
try:
|
|
predictions = []
|
|
current_data = base_data.clone() # Start with base historical data
|
|
|
|
# Get the CNN model for iterative prediction
|
|
cnn_model = None
|
|
for model_key, model in self.models.items():
|
|
if model_key.startswith('cnn_'):
|
|
cnn_model = model
|
|
break
|
|
|
|
if not cnn_model:
|
|
logger.warning("No CNN model available for iterative prediction")
|
|
return None
|
|
|
|
# Check if CNN model has predict method
|
|
if not hasattr(cnn_model, 'predict'):
|
|
logger.warning("CNN model does not have predict method - trying alternative approach")
|
|
# Try to use the orchestrator's CNN model directly
|
|
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
|
cnn_model = self.orchestrator.cnn_model
|
|
logger.info("Using orchestrator's CNN model for predictions")
|
|
|
|
# Check if orchestrator's CNN model also lacks predict method
|
|
if not hasattr(cnn_model, 'predict'):
|
|
logger.error("Orchestrator's CNN model also lacks predict method - creating mock predictions")
|
|
return self._create_mock_predictions(num_steps)
|
|
else:
|
|
logger.error("No CNN model with predict method available - creating mock predictions")
|
|
# Create mock predictions for testing
|
|
return self._create_mock_predictions(num_steps)
|
|
|
|
for step in range(num_steps):
|
|
# Use CNN model to predict next candle
|
|
try:
|
|
with torch.no_grad():
|
|
# Prepare data for CNN prediction
|
|
# Convert tensor to format expected by predict method
|
|
if current_data.dim() == 3: # [batch, seq, features]
|
|
current_data_flat = current_data.squeeze(0) # Remove batch dim
|
|
else:
|
|
current_data_flat = current_data
|
|
|
|
prediction = cnn_model.predict(current_data_flat)
|
|
|
|
if prediction and 'ohlcv_prediction' in prediction:
|
|
# Add timestamp to the prediction
|
|
prediction_time = datetime.now() + timedelta(minutes=step + 1)
|
|
prediction['timestamp'] = prediction_time
|
|
predictions.append(prediction)
|
|
logger.debug(f"📊 Step {step}: Added prediction for {prediction_time}, close: {prediction['ohlcv_prediction']['close']:.2f}")
|
|
|
|
# Extract predicted OHLCV values
|
|
ohlcv = prediction['ohlcv_prediction']
|
|
new_candle = torch.tensor([
|
|
ohlcv['open'],
|
|
ohlcv['high'],
|
|
ohlcv['low'],
|
|
ohlcv['close'],
|
|
ohlcv['volume']
|
|
], dtype=current_data.dtype)
|
|
|
|
# Add the predicted candle to our data sequence
|
|
# Remove oldest candle and add new prediction
|
|
if current_data.dim() == 3:
|
|
current_data = torch.cat([
|
|
current_data[:, 1:, :], # Remove oldest candle
|
|
new_candle.unsqueeze(0).unsqueeze(0) # Add new prediction
|
|
], dim=1)
|
|
else:
|
|
current_data = torch.cat([
|
|
current_data[1:, :], # Remove oldest candle
|
|
new_candle.unsqueeze(0) # Add new prediction
|
|
], dim=0)
|
|
else:
|
|
logger.warning(f"❌ Step {step}: Invalid prediction format")
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in iterative prediction step {step}: {e}")
|
|
break
|
|
|
|
return predictions if predictions else None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in iterative predictions: {e}")
|
|
return None
|
|
|
|
def _create_mock_predictions(self, num_steps: int) -> List[Dict]:
|
|
"""Create mock predictions for testing when CNN model is not available"""
|
|
try:
|
|
logger.info(f"Creating {num_steps} mock predictions for testing")
|
|
predictions = []
|
|
current_time = datetime.now()
|
|
base_price = 4300.0 # Mock base price
|
|
|
|
for step in range(num_steps):
|
|
prediction_time = current_time + timedelta(minutes=step + 1)
|
|
price_change = (step - num_steps // 2) * 2.0 # Mock price movement
|
|
predicted_price = base_price + price_change
|
|
|
|
mock_prediction = {
|
|
'timestamp': prediction_time,
|
|
'ohlcv_prediction': {
|
|
'open': predicted_price,
|
|
'high': predicted_price + 1.0,
|
|
'low': predicted_price - 1.0,
|
|
'close': predicted_price + 0.5,
|
|
'volume': 1000
|
|
},
|
|
'confidence': max(0.3, 0.8 - step * 0.05), # Decreasing confidence
|
|
'action': 0 if price_change > 0 else 1,
|
|
'action_name': 'BUY' if price_change > 0 else 'SELL'
|
|
}
|
|
predictions.append(mock_prediction)
|
|
|
|
logger.info(f"✅ Created {len(predictions)} mock predictions")
|
|
return predictions
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating mock predictions: {e}")
|
|
return []
|
|
|
|
def _create_mock_sequence_data(self, sequence_length: int) -> torch.Tensor:
|
|
"""Create mock sequence data for testing when real data is not available"""
|
|
try:
|
|
logger.info(f"Creating mock sequence data with {sequence_length} points")
|
|
|
|
# Create mock OHLCV data
|
|
base_price = 4300.0
|
|
mock_data = []
|
|
|
|
for i in range(sequence_length):
|
|
# Simulate price movement
|
|
price_change = (i - sequence_length // 2) * 0.5
|
|
price = base_price + price_change
|
|
|
|
# Create OHLCV candle
|
|
candle = [
|
|
price, # open
|
|
price + 1.0, # high
|
|
price - 1.0, # low
|
|
price + 0.5, # close
|
|
1000.0 # volume
|
|
]
|
|
mock_data.append(candle)
|
|
|
|
# Convert to tensor
|
|
tensor_data = torch.tensor(mock_data, dtype=torch.float32)
|
|
tensor_data = tensor_data.unsqueeze(0) # Add batch dimension
|
|
|
|
logger.debug(f"✅ Created mock sequence data shape: {tensor_data.shape}")
|
|
return tensor_data
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating mock sequence data: {e}")
|
|
# Return minimal valid tensor
|
|
return torch.zeros((1, 10, 5), dtype=torch.float32)
|
|
|
|
def _analyze_horizon_prediction(self, iterative_predictions: List[Dict],
|
|
config: Dict, market_conditions: Dict) -> Optional[Dict[str, Any]]:
|
|
"""Analyze the series of iterative predictions to determine overall horizon movement"""
|
|
try:
|
|
if not iterative_predictions:
|
|
return None
|
|
|
|
# Extract price data from predictions
|
|
predicted_prices = []
|
|
confidences = []
|
|
actions = []
|
|
|
|
for pred in iterative_predictions:
|
|
if 'ohlcv_prediction' in pred:
|
|
close_price = pred['ohlcv_prediction']['close']
|
|
predicted_prices.append(close_price)
|
|
|
|
confidence = pred.get('action_confidence', 0.5)
|
|
confidences.append(confidence)
|
|
|
|
action = pred.get('action', 2) # Default to HOLD
|
|
actions.append(action)
|
|
|
|
if not predicted_prices:
|
|
return None
|
|
|
|
# Calculate overall price movement
|
|
start_price = predicted_prices[0]
|
|
end_price = predicted_prices[-1]
|
|
total_change = end_price - start_price
|
|
total_change_pct = (total_change / start_price) * 100 if start_price != 0 else 0
|
|
|
|
# Calculate volatility and trend strength
|
|
price_volatility = torch.std(torch.tensor(predicted_prices)).item()
|
|
avg_confidence = sum(confidences) / len(confidences)
|
|
|
|
# Determine overall action based on price movement and confidence
|
|
if total_change_pct > 0.5: # Overall bullish movement
|
|
action = 0 # BUY
|
|
action_name = 'BUY'
|
|
confidence_multiplier = 1.2
|
|
elif total_change_pct < -0.5: # Overall bearish movement
|
|
action = 1 # SELL
|
|
action_name = 'SELL'
|
|
confidence_multiplier = 1.2
|
|
else: # Sideways movement
|
|
# Use majority vote from individual predictions
|
|
buy_count = sum(1 for a in actions if a == 0)
|
|
sell_count = sum(1 for a in actions if a == 1)
|
|
|
|
if buy_count > sell_count:
|
|
action = 0
|
|
action_name = 'BUY'
|
|
confidence_multiplier = 0.8 # Reduce confidence for mixed signals
|
|
elif sell_count > buy_count:
|
|
action = 1
|
|
action_name = 'SELL'
|
|
confidence_multiplier = 0.8
|
|
else:
|
|
action = 2 # HOLD
|
|
action_name = 'HOLD'
|
|
confidence_multiplier = 0.5
|
|
|
|
# Calculate final confidence
|
|
final_confidence = avg_confidence * confidence_multiplier
|
|
|
|
# Adjust for market conditions
|
|
market_multiplier = market_conditions.get('confidence_multiplier', 1.0)
|
|
final_confidence *= market_multiplier
|
|
|
|
# Cap confidence at reasonable levels
|
|
final_confidence = min(0.95, max(0.1, final_confidence))
|
|
|
|
# Adjust for volatility
|
|
if price_volatility > 0.02: # High volatility in predictions
|
|
final_confidence *= 0.9
|
|
|
|
return {
|
|
'action': action,
|
|
'action_name': action_name,
|
|
'confidence': final_confidence,
|
|
'horizon_minutes': config['max_hold_time'] // 60,
|
|
'total_price_change_pct': total_change_pct,
|
|
'price_volatility': price_volatility,
|
|
'avg_prediction_confidence': avg_confidence,
|
|
'num_predictions': len(iterative_predictions),
|
|
'risk_multiplier': config['risk_multiplier'],
|
|
'market_conditions': market_conditions,
|
|
'prediction_series': {
|
|
'prices': predicted_prices,
|
|
'confidences': confidences,
|
|
'actions': actions
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error analyzing horizon prediction: {e}")
|
|
return None
|