716 lines
30 KiB
Python
716 lines
30 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Multi-Horizon Prediction Manager
|
|
|
|
This module generates predictions for multiple time horizons (1m, 5m, 15m, 60m)
|
|
every minute, focusing on predicting min/max prices in the next 60 minutes.
|
|
It stores model input snapshots for future training when outcomes are known.
|
|
"""
|
|
|
|
import logging
|
|
import threading
|
|
import time
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Any, Optional, Tuple
|
|
from dataclasses import dataclass, field
|
|
import numpy as np
|
|
import pandas as pd
|
|
from collections import deque
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class PredictionSnapshot:
|
|
"""Stores a prediction with model inputs for future training"""
|
|
prediction_id: str
|
|
symbol: str
|
|
prediction_time: datetime
|
|
target_horizon_minutes: int
|
|
target_time: datetime
|
|
current_price: float
|
|
predicted_min_price: float
|
|
predicted_max_price: float
|
|
confidence: float
|
|
model_inputs: Dict[str, Any]
|
|
market_state: Dict[str, Any]
|
|
technical_indicators: Dict[str, Any]
|
|
pivot_analysis: Dict[str, Any]
|
|
prediction_metadata: Dict[str, Any] = field(default_factory=dict)
|
|
actual_min_price: Optional[float] = None
|
|
actual_max_price: Optional[float] = None
|
|
outcome_known: bool = False
|
|
outcome_timestamp: Optional[datetime] = None
|
|
|
|
@dataclass
|
|
class HorizonPrediction:
|
|
"""Represents a prediction for a specific time horizon"""
|
|
horizon_minutes: int
|
|
predicted_min: float
|
|
predicted_max: float
|
|
confidence: float
|
|
prediction_basis: str # 'cnn', 'rl', 'technical', 'ensemble'
|
|
|
|
class MultiHorizonPredictionManager:
|
|
"""Manages multi-timeframe predictions for trading system"""
|
|
|
|
def __init__(self, orchestrator=None, data_provider=None, config: Optional[Dict[str, Any]] = None):
|
|
"""Initialize the multi-horizon prediction manager"""
|
|
self.orchestrator = orchestrator
|
|
self.data_provider = data_provider
|
|
self.config = config or {}
|
|
|
|
# Prediction horizons in minutes
|
|
self.horizons = [1, 5, 15, 60]
|
|
|
|
# Prediction frequency (every minute)
|
|
self.prediction_interval_seconds = 60
|
|
|
|
# Storage for prediction snapshots
|
|
self.max_snapshots_per_horizon = 1000
|
|
self.prediction_snapshots: Dict[int, deque] = {} # {horizon: deque of PredictionSnapshot}
|
|
|
|
# Initialize snapshot storage for each horizon
|
|
for horizon in self.horizons:
|
|
self.prediction_snapshots[horizon] = deque(maxlen=self.max_snapshots_per_horizon)
|
|
|
|
# Threading
|
|
self.prediction_thread = None
|
|
self.is_running = False
|
|
self.last_prediction_time = 0.0
|
|
|
|
# Performance tracking
|
|
self.prediction_stats = {
|
|
'total_predictions': 0,
|
|
'predictions_by_horizon': {h: 0 for h in self.horizons},
|
|
'validated_predictions': 0,
|
|
'accurate_predictions': 0,
|
|
'avg_confidence': 0.0,
|
|
'last_prediction_time': None
|
|
}
|
|
|
|
# Minimum confidence threshold for storing predictions
|
|
self.min_confidence_threshold = 0.3
|
|
|
|
logger.info("MultiHorizonPredictionManager initialized")
|
|
logger.info(f"Prediction horizons: {self.horizons} minutes")
|
|
logger.info(f"Prediction interval: {self.prediction_interval_seconds} seconds")
|
|
|
|
def start(self):
|
|
"""Start the prediction manager"""
|
|
if self.is_running:
|
|
logger.warning("Prediction manager already running")
|
|
return
|
|
|
|
self.is_running = True
|
|
self.prediction_thread = threading.Thread(
|
|
target=self._prediction_loop,
|
|
daemon=True,
|
|
name="MultiHorizonPredictor"
|
|
)
|
|
self.prediction_thread.start()
|
|
logger.info("MultiHorizonPredictionManager started")
|
|
|
|
def stop(self):
|
|
"""Stop the prediction manager"""
|
|
self.is_running = False
|
|
if self.prediction_thread and self.prediction_thread.is_alive():
|
|
self.prediction_thread.join(timeout=10)
|
|
logger.info("MultiHorizonPredictionManager stopped")
|
|
|
|
def _prediction_loop(self):
|
|
"""Main prediction loop - runs every minute"""
|
|
while self.is_running:
|
|
try:
|
|
current_time = time.time()
|
|
|
|
# Check if it's time for new predictions
|
|
if current_time - self.last_prediction_time >= self.prediction_interval_seconds:
|
|
self._generate_all_horizon_predictions()
|
|
self.last_prediction_time = current_time
|
|
|
|
# Validate pending predictions
|
|
self._validate_pending_predictions()
|
|
|
|
# Sleep for 10 seconds before next check
|
|
time.sleep(10)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in prediction loop: {e}")
|
|
time.sleep(30) # Longer sleep on error
|
|
|
|
def _generate_all_horizon_predictions(self):
|
|
"""Generate predictions for all horizons"""
|
|
try:
|
|
symbols = ['ETH/USDT', 'BTC/USDT'] # Focus on main symbols
|
|
prediction_time = datetime.now()
|
|
|
|
for symbol in symbols:
|
|
# Get current market state
|
|
market_state = self._get_current_market_state(symbol)
|
|
if not market_state:
|
|
continue
|
|
|
|
current_price = market_state['current_price']
|
|
|
|
# Generate predictions for each horizon
|
|
for horizon_minutes in self.horizons:
|
|
try:
|
|
prediction = self._generate_horizon_prediction(
|
|
symbol, horizon_minutes, prediction_time, market_state
|
|
)
|
|
|
|
if prediction and prediction.confidence >= self.min_confidence_threshold:
|
|
# Create prediction snapshot
|
|
snapshot = self._create_prediction_snapshot(
|
|
symbol, horizon_minutes, prediction_time, current_price,
|
|
prediction, market_state
|
|
)
|
|
|
|
# Store snapshot
|
|
self.prediction_snapshots[horizon_minutes].append(snapshot)
|
|
|
|
# Update stats
|
|
self.prediction_stats['total_predictions'] += 1
|
|
self.prediction_stats['predictions_by_horizon'][horizon_minutes] += 1
|
|
|
|
logger.info(f"Generated {horizon_minutes}m prediction for {symbol}: "
|
|
f"min={prediction.predicted_min:.4f}, max={prediction.predicted_max:.4f}, "
|
|
f"confidence={prediction.confidence:.2f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating {horizon_minutes}m prediction for {symbol}: {e}")
|
|
|
|
self.prediction_stats['last_prediction_time'] = prediction_time
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating all horizon predictions: {e}")
|
|
|
|
def _generate_horizon_prediction(self, symbol: str, horizon_minutes: int,
|
|
prediction_time: datetime, market_state: Dict[str, Any]) -> Optional[HorizonPrediction]:
|
|
"""Generate prediction for a specific horizon"""
|
|
try:
|
|
current_price = market_state['current_price']
|
|
|
|
# Use ensemble approach: combine CNN, RL, and technical analysis
|
|
predictions = []
|
|
|
|
# CNN-based prediction
|
|
cnn_prediction = self._get_cnn_prediction(symbol, horizon_minutes, market_state)
|
|
if cnn_prediction:
|
|
predictions.append(cnn_prediction)
|
|
|
|
# RL-based prediction
|
|
rl_prediction = self._get_rl_prediction(symbol, horizon_minutes, market_state)
|
|
if rl_prediction:
|
|
predictions.append(rl_prediction)
|
|
|
|
# Technical analysis prediction
|
|
technical_prediction = self._get_technical_prediction(symbol, horizon_minutes, market_state)
|
|
if technical_prediction:
|
|
predictions.append(technical_prediction)
|
|
|
|
if not predictions:
|
|
# Fallback to technical analysis only
|
|
return self._get_technical_prediction(symbol, horizon_minutes, market_state, fallback=True)
|
|
|
|
# Ensemble prediction
|
|
return self._ensemble_predictions(predictions, current_price)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating horizon prediction: {e}")
|
|
return None
|
|
|
|
def _get_cnn_prediction(self, symbol: str, horizon_minutes: int,
|
|
market_state: Dict[str, Any]) -> Optional[HorizonPrediction]:
|
|
"""Get CNN-based prediction"""
|
|
try:
|
|
if not self.orchestrator or not hasattr(self.orchestrator, 'cnn_model'):
|
|
return None
|
|
|
|
# Prepare CNN features based on horizon
|
|
features = self._prepare_cnn_features_for_horizon(market_state, horizon_minutes)
|
|
|
|
# Get CNN prediction
|
|
cnn_model = self.orchestrator.cnn_model
|
|
prediction_output = cnn_model.predict(features)
|
|
|
|
# Interpret CNN output for min/max prediction
|
|
predicted_min, predicted_max, confidence = self._interpret_cnn_output(
|
|
prediction_output, market_state['current_price'], horizon_minutes
|
|
)
|
|
|
|
return HorizonPrediction(
|
|
horizon_minutes=horizon_minutes,
|
|
predicted_min=predicted_min,
|
|
predicted_max=predicted_max,
|
|
confidence=confidence,
|
|
prediction_basis='cnn'
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"CNN prediction failed: {e}")
|
|
return None
|
|
|
|
def _get_rl_prediction(self, symbol: str, horizon_minutes: int,
|
|
market_state: Dict[str, Any]) -> Optional[HorizonPrediction]:
|
|
"""Get RL-based prediction"""
|
|
try:
|
|
if not self.orchestrator or not hasattr(self.orchestrator, 'rl_agent'):
|
|
return None
|
|
|
|
# Prepare RL state
|
|
rl_state = self._prepare_rl_state_for_horizon(market_state, horizon_minutes)
|
|
|
|
# Get RL prediction
|
|
rl_agent = self.orchestrator.rl_agent
|
|
action = rl_agent.act(rl_state, explore=False)
|
|
|
|
# Convert action to min/max prediction
|
|
current_price = market_state['current_price']
|
|
predicted_min, predicted_max, confidence = self._convert_rl_action_to_price_prediction(
|
|
action, current_price, horizon_minutes, rl_agent
|
|
)
|
|
|
|
return HorizonPrediction(
|
|
horizon_minutes=horizon_minutes,
|
|
predicted_min=predicted_min,
|
|
predicted_max=predicted_max,
|
|
confidence=confidence,
|
|
prediction_basis='rl'
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"RL prediction failed: {e}")
|
|
return None
|
|
|
|
def _get_technical_prediction(self, symbol: str, horizon_minutes: int,
|
|
market_state: Dict[str, Any], fallback: bool = False) -> Optional[HorizonPrediction]:
|
|
"""Get technical analysis based prediction"""
|
|
try:
|
|
current_price = market_state['current_price']
|
|
|
|
# Use pivot points and technical indicators to predict range
|
|
pivot_analysis = market_state.get('pivot_analysis', {})
|
|
technical_indicators = market_state.get('technical_indicators', {})
|
|
|
|
# Base prediction on trend strength and pivot levels
|
|
trend_direction = pivot_analysis.get('trend_direction', 'SIDEWAYS')
|
|
trend_strength = pivot_analysis.get('trend_strength', 0.0)
|
|
|
|
# Calculate expected range based on volatility and trend
|
|
volatility = technical_indicators.get('volatility', 0.02) # Default 2%
|
|
expected_range_percent = volatility * np.sqrt(horizon_minutes / 60.0) # Scale by sqrt(time)
|
|
|
|
if trend_direction == 'UPTREND':
|
|
# Bias toward higher prices
|
|
predicted_min = current_price * (1 - expected_range_percent * 0.3)
|
|
predicted_max = current_price * (1 + expected_range_percent * 1.2)
|
|
elif trend_direction == 'DOWNTREND':
|
|
# Bias toward lower prices
|
|
predicted_min = current_price * (1 - expected_range_percent * 1.2)
|
|
predicted_max = current_price * (1 + expected_range_percent * 0.3)
|
|
else:
|
|
# Symmetric range for sideways
|
|
range_half = expected_range_percent * current_price
|
|
predicted_min = current_price - range_half
|
|
predicted_max = current_price + range_half
|
|
|
|
# Adjust confidence based on trend strength and market conditions
|
|
base_confidence = 0.4 + (trend_strength * 0.4) # 0.4 to 0.8
|
|
|
|
# Reduce confidence for longer horizons
|
|
horizon_factor = max(0.3, 1.0 - (horizon_minutes - 1) / 120.0) # Decrease with horizon
|
|
confidence = base_confidence * horizon_factor
|
|
|
|
if fallback:
|
|
confidence = max(confidence, 0.2) # Minimum confidence for fallback
|
|
|
|
return HorizonPrediction(
|
|
horizon_minutes=horizon_minutes,
|
|
predicted_min=predicted_min,
|
|
predicted_max=predicted_max,
|
|
confidence=confidence,
|
|
prediction_basis='technical'
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Technical prediction failed: {e}")
|
|
return None
|
|
|
|
def _ensemble_predictions(self, predictions: List[HorizonPrediction], current_price: float) -> HorizonPrediction:
|
|
"""Combine multiple predictions into ensemble prediction"""
|
|
try:
|
|
if not predictions:
|
|
return None
|
|
|
|
# Weight predictions by confidence
|
|
total_weight = sum(p.confidence for p in predictions)
|
|
if total_weight == 0:
|
|
total_weight = len(predictions)
|
|
|
|
# Weighted average of min/max predictions
|
|
weighted_min = sum(p.predicted_min * p.confidence for p in predictions) / total_weight
|
|
weighted_max = sum(p.predicted_max * p.confidence for p in predictions) / total_weight
|
|
|
|
# Average confidence
|
|
avg_confidence = sum(p.confidence for p in predictions) / len(predictions)
|
|
|
|
# Ensure min < max and reasonable bounds
|
|
if weighted_min >= weighted_max:
|
|
# Fallback to symmetric range
|
|
range_half = abs(current_price * 0.02) # 2% range
|
|
weighted_min = current_price - range_half
|
|
weighted_max = current_price + range_half
|
|
|
|
return HorizonPrediction(
|
|
horizon_minutes=predictions[0].horizon_minutes,
|
|
predicted_min=weighted_min,
|
|
predicted_max=weighted_max,
|
|
confidence=min(avg_confidence, 0.95), # Cap at 95%
|
|
prediction_basis='ensemble'
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Ensemble prediction failed: {e}")
|
|
return None
|
|
|
|
def _get_current_market_state(self, symbol: str) -> Optional[Dict[str, Any]]:
|
|
"""Get comprehensive market state for prediction"""
|
|
try:
|
|
if not self.data_provider:
|
|
return None
|
|
|
|
# Get current price
|
|
current_price = None
|
|
if hasattr(self.data_provider, 'current_prices'):
|
|
current_price = self.data_provider.current_prices.get(symbol.replace('/', '').upper())
|
|
|
|
if current_price is None:
|
|
logger.debug(f"No current price available for {symbol}")
|
|
return None
|
|
|
|
# Get recent OHLCV data (last 100 candles for analysis)
|
|
ohlcv_data = self.data_provider.get_historical_data(symbol, '1m', limit=100)
|
|
if ohlcv_data is None or len(ohlcv_data) < 20:
|
|
logger.debug(f"Insufficient OHLCV data for {symbol}")
|
|
return None
|
|
|
|
# Calculate technical indicators
|
|
technical_indicators = self._calculate_technical_indicators(ohlcv_data)
|
|
|
|
# Get pivot analysis
|
|
pivot_analysis = self._get_pivot_analysis(symbol, ohlcv_data)
|
|
|
|
return {
|
|
'current_price': current_price,
|
|
'ohlcv_data': ohlcv_data,
|
|
'technical_indicators': technical_indicators,
|
|
'pivot_analysis': pivot_analysis,
|
|
'timestamp': datetime.now()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting market state for {symbol}: {e}")
|
|
return None
|
|
|
|
def _calculate_technical_indicators(self, ohlcv_data: np.ndarray) -> Dict[str, Any]:
|
|
"""Calculate technical indicators from OHLCV data"""
|
|
try:
|
|
if len(ohlcv_data) < 20:
|
|
return {}
|
|
|
|
closes = ohlcv_data[:, 4].astype(float)
|
|
highs = ohlcv_data[:, 2].astype(float)
|
|
lows = ohlcv_data[:, 3].astype(float)
|
|
volumes = ohlcv_data[:, 5].astype(float)
|
|
|
|
# Basic indicators
|
|
sma_5 = np.mean(closes[-5:])
|
|
sma_20 = np.mean(closes[-20:])
|
|
|
|
# RSI
|
|
def calculate_rsi(prices, period=14):
|
|
if len(prices) < period + 1:
|
|
return 50.0
|
|
gains = []
|
|
losses = []
|
|
for i in range(1, min(len(prices), period + 1)):
|
|
change = prices[-i] - prices[-i-1]
|
|
if change > 0:
|
|
gains.append(change)
|
|
losses.append(0)
|
|
else:
|
|
gains.append(0)
|
|
losses.append(abs(change))
|
|
avg_gain = np.mean(gains) if gains else 0
|
|
avg_loss = np.mean(losses) if losses else 0
|
|
if avg_loss == 0:
|
|
return 100.0
|
|
rs = avg_gain / avg_loss
|
|
return 100 - (100 / (1 + rs))
|
|
|
|
rsi = calculate_rsi(closes)
|
|
|
|
# Volatility (standard deviation of returns)
|
|
returns = np.diff(closes) / closes[:-1]
|
|
volatility = np.std(returns) if len(returns) > 0 else 0.02
|
|
|
|
# Volume analysis
|
|
avg_volume = np.mean(volumes[-20:]) if len(volumes) >= 20 else np.mean(volumes)
|
|
volume_ratio = volumes[-1] / avg_volume if avg_volume > 0 else 1.0
|
|
|
|
return {
|
|
'sma_5': float(sma_5),
|
|
'sma_20': float(sma_20),
|
|
'rsi': float(rsi),
|
|
'volatility': float(volatility),
|
|
'volume_ratio': float(volume_ratio),
|
|
'price_change_5m': float((closes[-1] - closes[-5]) / closes[-5]) if len(closes) >= 5 else 0.0,
|
|
'price_change_15m': float((closes[-1] - closes[-15]) / closes[-15]) if len(closes) >= 15 else 0.0
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating technical indicators: {e}")
|
|
return {}
|
|
|
|
def _get_pivot_analysis(self, symbol: str, ohlcv_data: np.ndarray) -> Dict[str, Any]:
|
|
"""Get pivot point analysis"""
|
|
try:
|
|
# Use Williams Market Structure if available
|
|
if hasattr(self.orchestrator, 'williams_structure'):
|
|
pivot_levels = self.orchestrator.williams_structure.calculate_recursive_pivot_points(ohlcv_data)
|
|
if pivot_levels:
|
|
# Get the most recent level
|
|
latest_level = max(pivot_levels.keys(), key=lambda x: int(x.split('_')[1]))
|
|
level_data = pivot_levels[latest_level]
|
|
|
|
return {
|
|
'trend_direction': level_data.trend_direction,
|
|
'trend_strength': level_data.trend_strength,
|
|
'support_levels': level_data.support_levels,
|
|
'resistance_levels': level_data.resistance_levels
|
|
}
|
|
|
|
# Fallback to basic pivot analysis
|
|
if len(ohlcv_data) >= 20:
|
|
recent_highs = ohlcv_data[-20:, 2].astype(float)
|
|
recent_lows = ohlcv_data[-20:, 3].astype(float)
|
|
|
|
pivot_high = np.max(recent_highs)
|
|
pivot_low = np.min(recent_lows)
|
|
|
|
return {
|
|
'trend_direction': 'SIDEWAYS',
|
|
'trend_strength': 0.5,
|
|
'support_levels': [pivot_low],
|
|
'resistance_levels': [pivot_high]
|
|
}
|
|
|
|
return {
|
|
'trend_direction': 'SIDEWAYS',
|
|
'trend_strength': 0.0,
|
|
'support_levels': [],
|
|
'resistance_levels': []
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting pivot analysis: {e}")
|
|
return {}
|
|
|
|
def _create_prediction_snapshot(self, symbol: str, horizon_minutes: int,
|
|
prediction_time: datetime, current_price: float,
|
|
prediction: HorizonPrediction, market_state: Dict[str, Any]) -> PredictionSnapshot:
|
|
"""Create a prediction snapshot for future training"""
|
|
prediction_id = f"{symbol.replace('/', '')}_{horizon_minutes}m_{int(prediction_time.timestamp())}"
|
|
|
|
target_time = prediction_time + timedelta(minutes=horizon_minutes)
|
|
|
|
return PredictionSnapshot(
|
|
prediction_id=prediction_id,
|
|
symbol=symbol,
|
|
prediction_time=prediction_time,
|
|
target_horizon_minutes=horizon_minutes,
|
|
target_time=target_time,
|
|
current_price=current_price,
|
|
predicted_min_price=prediction.predicted_min,
|
|
predicted_max_price=prediction.predicted_max,
|
|
confidence=prediction.confidence,
|
|
model_inputs=self._extract_model_inputs(market_state),
|
|
market_state=market_state,
|
|
technical_indicators=market_state.get('technical_indicators', {}),
|
|
pivot_analysis=market_state.get('pivot_analysis', {}),
|
|
prediction_metadata={
|
|
'prediction_basis': prediction.prediction_basis,
|
|
'ensemble_components': 1 if prediction.prediction_basis != 'ensemble' else 3
|
|
}
|
|
)
|
|
|
|
def _extract_model_inputs(self, market_state: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Extract model inputs for future training"""
|
|
try:
|
|
model_inputs = {}
|
|
|
|
# CNN features
|
|
if hasattr(self, '_prepare_cnn_features_for_horizon'):
|
|
model_inputs['cnn_features'] = self._prepare_cnn_features_for_horizon(
|
|
market_state, 60 # Use 60m horizon for consistency
|
|
)
|
|
|
|
# RL state
|
|
if hasattr(self, '_prepare_rl_state_for_horizon'):
|
|
model_inputs['rl_state'] = self._prepare_rl_state_for_horizon(
|
|
market_state, 60
|
|
)
|
|
|
|
# Raw market data
|
|
model_inputs['current_price'] = market_state['current_price']
|
|
model_inputs['ohlcv_sequence'] = market_state['ohlcv_data'][-50:].tolist() # Last 50 candles
|
|
|
|
return model_inputs
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error extracting model inputs: {e}")
|
|
return {}
|
|
|
|
def _validate_pending_predictions(self):
|
|
"""Validate predictions that have reached their target time"""
|
|
try:
|
|
current_time = datetime.now()
|
|
symbols = ['ETH/USDT', 'BTC/USDT']
|
|
|
|
for symbol in symbols:
|
|
# Get current price for validation
|
|
current_price = None
|
|
if self.data_provider and hasattr(self.data_provider, 'current_prices'):
|
|
current_price = self.data_provider.current_prices.get(symbol.replace('/', '').upper())
|
|
|
|
if current_price is None:
|
|
continue
|
|
|
|
# Check each horizon for predictions to validate
|
|
for horizon_minutes in self.horizons:
|
|
snapshots_to_validate = []
|
|
|
|
for snapshot in list(self.prediction_snapshots[horizon_minutes]):
|
|
if (not snapshot.outcome_known and
|
|
current_time >= snapshot.target_time):
|
|
|
|
# Prediction has reached target time - validate it
|
|
snapshot.actual_min_price = current_price # Simplified: current price as proxy for min
|
|
snapshot.actual_max_price = current_price # In reality, we'd need price range over the period
|
|
snapshot.outcome_known = True
|
|
snapshot.outcome_timestamp = current_time
|
|
|
|
snapshots_to_validate.append(snapshot)
|
|
|
|
# Process validated snapshots
|
|
for snapshot in snapshots_to_validate:
|
|
self._process_validated_prediction(snapshot)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating pending predictions: {e}")
|
|
|
|
def _process_validated_prediction(self, snapshot: PredictionSnapshot):
|
|
"""Process a validated prediction for training"""
|
|
try:
|
|
self.prediction_stats['validated_predictions'] += 1
|
|
|
|
# Calculate prediction accuracy
|
|
if snapshot.actual_min_price is not None and snapshot.actual_max_price is not None:
|
|
# Simple accuracy check: was the actual price within predicted range?
|
|
actual_price_range = abs(snapshot.actual_max_price - snapshot.actual_min_price)
|
|
predicted_range = abs(snapshot.predicted_max_price - snapshot.predicted_min_price)
|
|
|
|
# Check if ranges overlap significantly
|
|
range_overlap = self._calculate_range_overlap(
|
|
(snapshot.predicted_min_price, snapshot.predicted_max_price),
|
|
(snapshot.actual_min_price, snapshot.actual_max_price)
|
|
)
|
|
|
|
if range_overlap > 0.5: # 50% overlap threshold
|
|
self.prediction_stats['accurate_predictions'] += 1
|
|
|
|
# Here we would trigger training with the snapshot data
|
|
# For now, just log the result
|
|
accuracy_rate = (self.prediction_stats['accurate_predictions'] /
|
|
max(1, self.prediction_stats['validated_predictions']))
|
|
|
|
logger.info(f"Validated {snapshot.target_horizon_minutes}m prediction for {snapshot.symbol}: "
|
|
f"confidence={snapshot.confidence:.2f}, accuracy_rate={accuracy_rate:.2f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing validated prediction: {e}")
|
|
|
|
def _calculate_range_overlap(self, range1: Tuple[float, float], range2: Tuple[float, float]) -> float:
|
|
"""Calculate overlap between two price ranges (0.0 to 1.0)"""
|
|
try:
|
|
min1, max1 = range1
|
|
min2, max2 = range2
|
|
|
|
# Find overlap
|
|
overlap_min = max(min1, min2)
|
|
overlap_max = min(max1, max2)
|
|
|
|
if overlap_max <= overlap_min:
|
|
return 0.0
|
|
|
|
overlap_size = overlap_max - overlap_min
|
|
union_size = max(max1, max2) - min(min1, min2)
|
|
|
|
return overlap_size / union_size if union_size > 0 else 0.0
|
|
|
|
except Exception:
|
|
return 0.0
|
|
|
|
def get_prediction_stats(self) -> Dict[str, Any]:
|
|
"""Get prediction statistics"""
|
|
stats = self.prediction_stats.copy()
|
|
|
|
# Calculate accuracy rate
|
|
if stats['validated_predictions'] > 0:
|
|
stats['accuracy_rate'] = stats['accurate_predictions'] / stats['validated_predictions']
|
|
else:
|
|
stats['accuracy_rate'] = 0.0
|
|
|
|
# Calculate average confidence
|
|
if stats['total_predictions'] > 0:
|
|
# This is approximate since we don't store all confidences
|
|
stats['avg_confidence'] = 0.5 # Placeholder
|
|
|
|
return stats
|
|
|
|
def get_recent_predictions(self, horizon_minutes: int, limit: int = 10) -> List[PredictionSnapshot]:
|
|
"""Get recent predictions for a specific horizon"""
|
|
if horizon_minutes not in self.prediction_snapshots:
|
|
return []
|
|
|
|
return list(self.prediction_snapshots[horizon_minutes])[-limit:]
|
|
|
|
# Placeholder methods for CNN and RL feature preparation - to be implemented
|
|
def _prepare_cnn_features_for_horizon(self, market_state: Dict[str, Any], horizon: int) -> np.ndarray:
|
|
"""Prepare CNN features for specific horizon - placeholder"""
|
|
# This would extract relevant features based on horizon
|
|
return np.random.rand(50) # Placeholder
|
|
|
|
def _prepare_rl_state_for_horizon(self, market_state: Dict[str, Any], horizon: int) -> np.ndarray:
|
|
"""Prepare RL state for specific horizon - placeholder"""
|
|
# This would create state representation for the horizon
|
|
return np.random.rand(100) # Placeholder
|
|
|
|
def _interpret_cnn_output(self, cnn_output, current_price: float, horizon: int) -> Tuple[float, float, float]:
|
|
"""Interpret CNN output for min/max prediction - placeholder"""
|
|
# This would convert CNN output to price predictions
|
|
range_percent = 0.05 # 5% range
|
|
return (current_price * 0.95, current_price * 1.05, 0.6) # Placeholder
|
|
|
|
def _convert_rl_action_to_price_prediction(self, action: int, current_price: float,
|
|
horizon: int, rl_agent) -> Tuple[float, float, float]:
|
|
"""Convert RL action to price prediction - placeholder"""
|
|
# This would interpret RL action as price movement expectation
|
|
if action == 0: # BUY
|
|
return (current_price * 0.98, current_price * 1.03, 0.7)
|
|
elif action == 1: # SELL
|
|
return (current_price * 0.97, current_price * 1.02, 0.7)
|
|
else: # HOLD
|
|
return (current_price * 0.99, current_price * 1.01, 0.5)
|