561 lines
23 KiB
Python
561 lines
23 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Multi-Horizon Backtesting Framework
|
|
|
|
This module provides backtesting capabilities for the multi-horizon prediction system
|
|
using historical data to validate prediction accuracy.
|
|
"""
|
|
|
|
import logging
|
|
import pandas as pd
|
|
import numpy as np
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Any, Optional, Tuple
|
|
from pathlib import Path
|
|
import json
|
|
|
|
from .data_provider import DataProvider
|
|
from .multi_horizon_prediction_manager import MultiHorizonPredictionManager, PredictionSnapshot
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class MultiHorizonBacktester:
|
|
"""Backtesting framework for multi-horizon predictions"""
|
|
|
|
def __init__(self, data_provider: Optional[DataProvider] = None):
|
|
"""Initialize the backtester"""
|
|
self.data_provider = data_provider
|
|
|
|
# Backtesting configuration
|
|
self.horizons = [1, 5, 15, 60] # minutes
|
|
self.prediction_interval_minutes = 1 # Generate predictions every minute
|
|
self.min_data_points = 100 # Minimum data points needed for backtesting
|
|
|
|
# Results storage
|
|
self.backtest_results = {}
|
|
|
|
logger.info("MultiHorizonBacktester initialized")
|
|
|
|
def run_backtest(self, symbol: str, start_date: datetime, end_date: datetime,
|
|
cache_dir: str = "cache") -> Dict[str, Any]:
|
|
"""Run backtest for a symbol over a date range"""
|
|
try:
|
|
logger.info(f"Starting backtest for {symbol} from {start_date} to {end_date}")
|
|
|
|
# Get historical data
|
|
historical_data = self._load_historical_data(symbol, start_date, end_date, cache_dir)
|
|
if historical_data is None or len(historical_data) < self.min_data_points:
|
|
return {'error': 'Insufficient historical data'}
|
|
|
|
# Run backtest simulation
|
|
results = self._simulate_predictions(historical_data, symbol)
|
|
|
|
# Store results
|
|
backtest_id = f"{symbol.replace('/', '_')}_{start_date.strftime('%Y%m%d')}_{end_date.strftime('%Y%m%d')}"
|
|
self.backtest_results[backtest_id] = {
|
|
'symbol': symbol,
|
|
'start_date': start_date,
|
|
'end_date': end_date,
|
|
'total_predictions': results['total_predictions'],
|
|
'results': results
|
|
}
|
|
|
|
logger.info(f"Backtest completed: {results['total_predictions']} predictions evaluated")
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error running backtest: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def _load_historical_data(self, symbol: str, start_date: datetime,
|
|
end_date: datetime, cache_dir: str) -> Optional[pd.DataFrame]:
|
|
"""Load historical data for backtesting"""
|
|
try:
|
|
# Load from data provider (use available cached data)
|
|
if self.data_provider:
|
|
# Get 1-minute data
|
|
data = self.data_provider.get_historical_data(
|
|
symbol=symbol,
|
|
timeframe='1m',
|
|
limit=50000 # Get a large amount of recent data
|
|
)
|
|
|
|
if data is not None and len(data) >= self.min_data_points:
|
|
# Filter to date range if data has timestamps
|
|
if isinstance(data.index, pd.DatetimeIndex):
|
|
data = data[(data.index >= start_date) & (data.index <= end_date)]
|
|
|
|
# Ensure we have enough data
|
|
if len(data) >= self.min_data_points:
|
|
logger.info(f"Loaded {len(data)} historical records for backtesting")
|
|
return data
|
|
|
|
# Fallback: try to load from existing cache files
|
|
cache_path = Path(cache_dir) / f"{symbol.replace('/', '_')}_1m.parquet"
|
|
if cache_path.exists():
|
|
df = pd.read_parquet(cache_path)
|
|
if len(df) >= self.min_data_points:
|
|
logger.info(f"Loaded {len(df)} historical records from cache")
|
|
return df
|
|
|
|
logger.warning(f"No historical data available for {symbol} (need at least {self.min_data_points} points)")
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading historical data: {e}")
|
|
return None
|
|
|
|
def _simulate_predictions(self, historical_data: pd.DataFrame, symbol: str) -> Dict[str, Any]:
|
|
"""Simulate predictions over historical data"""
|
|
try:
|
|
results = {
|
|
'total_predictions': 0,
|
|
'horizon_results': {},
|
|
'overall_accuracy': 0.0,
|
|
'avg_confidence': 0.0,
|
|
'profitability_analysis': {}
|
|
}
|
|
|
|
# Sort data by timestamp
|
|
historical_data = historical_data.sort_values('timestamp').reset_index(drop=True)
|
|
|
|
# Process data in chunks for memory efficiency
|
|
chunk_size = 1000
|
|
all_predictions = []
|
|
|
|
for i in range(0, len(historical_data) - max(self.horizons) - 1, self.prediction_interval_minutes):
|
|
chunk_end = min(i + chunk_size, len(historical_data))
|
|
|
|
# Generate predictions for this time point
|
|
predictions = self._generate_historical_predictions(
|
|
historical_data.iloc[i:chunk_end], i, symbol
|
|
)
|
|
|
|
all_predictions.extend(predictions)
|
|
|
|
# Process predictions that can be validated
|
|
validated_predictions = self._validate_predictions(predictions, historical_data, i)
|
|
|
|
# Update results
|
|
for pred in validated_predictions:
|
|
horizon = pred['target_horizon_minutes']
|
|
if horizon not in results['horizon_results']:
|
|
results['horizon_results'][horizon] = {
|
|
'predictions': 0,
|
|
'accurate': 0,
|
|
'total_error': 0.0,
|
|
'avg_confidence': 0.0,
|
|
'confidence_accuracy_correlation': 0.0
|
|
}
|
|
|
|
results['horizon_results'][horizon]['predictions'] += 1
|
|
if pred['accurate']:
|
|
results['horizon_results'][horizon]['accurate'] += 1
|
|
|
|
results['horizon_results'][horizon]['total_error'] += pred['range_error']
|
|
results['horizon_results'][horizon]['avg_confidence'] += pred['confidence']
|
|
|
|
# Calculate final metrics
|
|
total_accurate = 0
|
|
total_predictions = 0
|
|
total_confidence = 0.0
|
|
|
|
for horizon, h_results in results['horizon_results'].items():
|
|
if h_results['predictions'] > 0:
|
|
h_results['accuracy'] = h_results['accurate'] / h_results['predictions']
|
|
h_results['avg_range_error'] = h_results['total_error'] / h_results['predictions']
|
|
h_results['avg_confidence'] = h_results['avg_confidence'] / h_results['predictions']
|
|
|
|
total_accurate += h_results['accurate']
|
|
total_predictions += h_results['predictions']
|
|
total_confidence += h_results['avg_confidence'] * h_results['predictions']
|
|
|
|
results['total_predictions'] = total_predictions
|
|
results['overall_accuracy'] = total_accurate / total_predictions if total_predictions > 0 else 0.0
|
|
results['avg_confidence'] = total_confidence / total_predictions if total_predictions > 0 else 0.0
|
|
|
|
# Analyze profitability
|
|
results['profitability_analysis'] = self._analyze_profitability(all_predictions)
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error simulating predictions: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def _generate_historical_predictions(self, data_chunk: pd.DataFrame,
|
|
start_idx: int, symbol: str) -> List[Dict[str, Any]]:
|
|
"""Generate predictions for a historical data chunk"""
|
|
try:
|
|
predictions = []
|
|
|
|
# Use current data point as prediction starting point
|
|
if len(data_chunk) < 10: # Need some history
|
|
return predictions
|
|
|
|
current_row = data_chunk.iloc[0]
|
|
current_price = current_row['close']
|
|
# Use DataFrame index for timestamp if available, otherwise use current time
|
|
if isinstance(data_chunk.index, pd.DatetimeIndex):
|
|
current_time = data_chunk.index[0]
|
|
else:
|
|
current_time = datetime.now()
|
|
|
|
# Calculate technical indicators
|
|
tech_indicators = self._calculate_technical_indicators(data_chunk)
|
|
|
|
# Generate predictions for each horizon
|
|
for horizon in self.horizons:
|
|
try:
|
|
# Check if we have enough future data
|
|
if start_idx + horizon >= len(data_chunk):
|
|
continue
|
|
|
|
# Get actual future price range
|
|
future_data = data_chunk.iloc[:horizon+1]
|
|
actual_min = future_data['low'].min()
|
|
actual_max = future_data['high'].max()
|
|
|
|
# Generate prediction using technical analysis (simplified model)
|
|
predicted_min, predicted_max, confidence = self._predict_price_range(
|
|
current_price, tech_indicators, horizon
|
|
)
|
|
|
|
prediction = {
|
|
'prediction_id': f"backtest_{symbol}_{start_idx}_{horizon}m",
|
|
'symbol': symbol,
|
|
'prediction_time': current_time,
|
|
'target_horizon_minutes': horizon,
|
|
'target_time': current_time + timedelta(minutes=horizon),
|
|
'current_price': current_price,
|
|
'predicted_min_price': predicted_min,
|
|
'predicted_max_price': predicted_max,
|
|
'confidence': confidence,
|
|
'actual_min_price': actual_min,
|
|
'actual_max_price': actual_max,
|
|
'accurate': False, # Will be set during validation
|
|
'range_error': 0.0 # Will be calculated during validation
|
|
}
|
|
|
|
predictions.append(prediction)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error generating prediction for horizon {horizon}: {e}")
|
|
|
|
return predictions
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating historical predictions: {e}")
|
|
return []
|
|
|
|
def _calculate_technical_indicators(self, data: pd.DataFrame) -> Dict[str, Any]:
|
|
"""Calculate technical indicators for prediction"""
|
|
try:
|
|
closes = data['close'].values
|
|
highs = data['high'].values
|
|
lows = data['low'].values
|
|
volumes = data['volume'].values
|
|
|
|
# Simple moving averages
|
|
if len(closes) >= 20:
|
|
sma_5 = np.mean(closes[-5:])
|
|
sma_20 = np.mean(closes[-20:])
|
|
else:
|
|
sma_5 = np.mean(closes)
|
|
sma_20 = np.mean(closes)
|
|
|
|
# RSI
|
|
def calculate_rsi(prices, period=14):
|
|
if len(prices) < period + 1:
|
|
return 50.0
|
|
gains = []
|
|
losses = []
|
|
for i in range(1, min(len(prices), period + 1)):
|
|
change = prices[-i] - prices[-i-1]
|
|
if change > 0:
|
|
gains.append(change)
|
|
losses.append(0)
|
|
else:
|
|
gains.append(0)
|
|
losses.append(abs(change))
|
|
avg_gain = np.mean(gains) if gains else 0
|
|
avg_loss = np.mean(losses) if losses else 0
|
|
if avg_loss == 0:
|
|
return 100.0
|
|
rs = avg_gain / avg_loss
|
|
return 100 - (100 / (1 + rs))
|
|
|
|
rsi = calculate_rsi(closes)
|
|
|
|
# Volatility
|
|
returns = np.diff(closes) / closes[:-1]
|
|
volatility = np.std(returns) if len(returns) > 0 else 0.02
|
|
|
|
# Trend
|
|
if len(closes) >= 10:
|
|
recent_trend = np.polyfit(range(10), closes[-10:], 1)[0]
|
|
trend_strength = abs(recent_trend) / np.mean(closes[-10:])
|
|
else:
|
|
trend_strength = 0.0
|
|
|
|
return {
|
|
'sma_5': float(sma_5),
|
|
'sma_20': float(sma_20),
|
|
'rsi': float(rsi),
|
|
'volatility': float(volatility),
|
|
'trend_strength': float(trend_strength),
|
|
'price_change_5m': float((closes[-1] - closes[-5]) / closes[-5]) if len(closes) >= 5 else 0.0
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating technical indicators: {e}")
|
|
return {}
|
|
|
|
def _predict_price_range(self, current_price: float, tech_indicators: Dict[str, Any],
|
|
horizon: int) -> Tuple[float, float, float]:
|
|
"""Predict price range using technical analysis"""
|
|
try:
|
|
volatility = tech_indicators.get('volatility', 0.02)
|
|
trend_strength = tech_indicators.get('trend_strength', 0.0)
|
|
rsi = tech_indicators.get('rsi', 50.0)
|
|
|
|
# Base range on volatility and horizon
|
|
expected_range_percent = volatility * np.sqrt(horizon / 60.0) # Scale by sqrt(time)
|
|
|
|
# Adjust for trend
|
|
if trend_strength > 0.001: # Uptrend
|
|
range_center = current_price * (1 + trend_strength * horizon / 60.0)
|
|
predicted_min = range_center * (1 - expected_range_percent * 0.7)
|
|
predicted_max = range_center * (1 + expected_range_percent * 1.3)
|
|
elif trend_strength < -0.001: # Downtrend
|
|
range_center = current_price * (1 + trend_strength * horizon / 60.0)
|
|
predicted_min = range_center * (1 - expected_range_percent * 1.3)
|
|
predicted_max = range_center * (1 + expected_range_percent * 0.7)
|
|
else: # Sideways
|
|
predicted_min = current_price * (1 - expected_range_percent)
|
|
predicted_max = current_price * (1 + expected_range_percent)
|
|
|
|
# Adjust confidence based on indicators
|
|
base_confidence = 0.5
|
|
|
|
# Higher confidence with clear trend
|
|
if abs(trend_strength) > 0.002:
|
|
base_confidence += 0.2
|
|
|
|
# Lower confidence for extreme RSI
|
|
if rsi > 70 or rsi < 30:
|
|
base_confidence -= 0.1
|
|
|
|
# Reduce confidence for longer horizons
|
|
horizon_factor = max(0.3, 1.0 - (horizon - 1) / 120.0)
|
|
confidence = base_confidence * horizon_factor
|
|
|
|
confidence = np.clip(confidence, 0.1, 0.9)
|
|
|
|
return predicted_min, predicted_max, confidence
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error predicting price range: {e}")
|
|
# Fallback prediction
|
|
range_percent = 0.05
|
|
return (current_price * (1 - range_percent),
|
|
current_price * (1 + range_percent),
|
|
0.3)
|
|
|
|
def _validate_predictions(self, predictions: List[Dict[str, Any]],
|
|
historical_data: pd.DataFrame, start_idx: int) -> List[Dict[str, Any]]:
|
|
"""Validate predictions against actual historical data"""
|
|
try:
|
|
validated = []
|
|
|
|
for prediction in predictions:
|
|
try:
|
|
horizon = prediction['target_horizon_minutes']
|
|
|
|
# Check if we have enough future data
|
|
if start_idx + horizon >= len(historical_data):
|
|
continue
|
|
|
|
# Get actual price range for the prediction horizon
|
|
future_data = historical_data.iloc[start_idx:start_idx + horizon + 1]
|
|
actual_min = future_data['low'].min()
|
|
actual_max = future_data['high'].max()
|
|
|
|
prediction['actual_min_price'] = actual_min
|
|
prediction['actual_max_price'] = actual_max
|
|
|
|
# Calculate accuracy metrics
|
|
range_overlap = self._calculate_range_overlap(
|
|
(prediction['predicted_min_price'], prediction['predicted_max_price']),
|
|
(actual_min, actual_max)
|
|
)
|
|
|
|
# Range error (normalized)
|
|
predicted_range = prediction['predicted_max_price'] - prediction['predicted_min_price']
|
|
actual_range = actual_max - actual_min
|
|
range_error = abs(predicted_range - actual_range) / actual_range if actual_range > 0 else 1.0
|
|
|
|
prediction['accurate'] = range_overlap > 0.5 # 50% overlap threshold
|
|
prediction['range_error'] = range_error
|
|
prediction['range_overlap'] = range_overlap
|
|
|
|
validated.append(prediction)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error validating prediction: {e}")
|
|
|
|
return validated
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating predictions: {e}")
|
|
return []
|
|
|
|
def _calculate_range_overlap(self, range1: Tuple[float, float], range2: Tuple[float, float]) -> float:
|
|
"""Calculate overlap between two price ranges"""
|
|
try:
|
|
min1, max1 = range1
|
|
min2, max2 = range2
|
|
|
|
overlap_min = max(min1, min2)
|
|
overlap_max = min(max1, max2)
|
|
|
|
if overlap_max <= overlap_min:
|
|
return 0.0
|
|
|
|
overlap_size = overlap_max - overlap_min
|
|
union_size = max(max1, max2) - min(min1, min2)
|
|
|
|
return overlap_size / union_size if union_size > 0 else 0.0
|
|
|
|
except Exception:
|
|
return 0.0
|
|
|
|
def _analyze_profitability(self, predictions: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
"""Analyze profitability of predictions"""
|
|
try:
|
|
analysis = {
|
|
'total_trades': 0,
|
|
'profitable_trades': 0,
|
|
'total_return': 0.0,
|
|
'avg_return_per_trade': 0.0,
|
|
'win_rate': 0.0,
|
|
'confidence_win_rate_correlation': 0.0
|
|
}
|
|
|
|
if not predictions:
|
|
return analysis
|
|
|
|
# Simulate trades based on predictions
|
|
trades = []
|
|
|
|
for pred in predictions:
|
|
if not pred.get('accurate', False):
|
|
continue
|
|
|
|
# Simple trading strategy: buy if predicted range center > current price, sell otherwise
|
|
predicted_center = (pred['predicted_min_price'] + pred['predicted_max_price']) / 2
|
|
actual_center = (pred['actual_min_price'] + pred['actual_max_price']) / 2
|
|
|
|
if predicted_center > pred['current_price']:
|
|
# Buy prediction
|
|
entry_price = pred['current_price']
|
|
exit_price = actual_center
|
|
trade_return = (exit_price - entry_price) / entry_price
|
|
else:
|
|
# Sell prediction
|
|
entry_price = pred['current_price']
|
|
exit_price = actual_center
|
|
trade_return = (entry_price - exit_price) / entry_price
|
|
|
|
trades.append({
|
|
'return': trade_return,
|
|
'confidence': pred['confidence'],
|
|
'profitable': trade_return > 0
|
|
})
|
|
|
|
if trades:
|
|
analysis['total_trades'] = len(trades)
|
|
analysis['profitable_trades'] = sum(1 for t in trades if t['profitable'])
|
|
analysis['total_return'] = sum(t['return'] for t in trades)
|
|
analysis['avg_return_per_trade'] = analysis['total_return'] / len(trades)
|
|
analysis['win_rate'] = analysis['profitable_trades'] / len(trades)
|
|
|
|
return analysis
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error analyzing profitability: {e}")
|
|
return {'error': str(e)}
|
|
|
|
def get_backtest_results(self, backtest_id: Optional[str] = None) -> Dict[str, Any]:
|
|
"""Get backtest results"""
|
|
if backtest_id:
|
|
return self.backtest_results.get(backtest_id, {})
|
|
|
|
return self.backtest_results
|
|
|
|
def save_results(self, output_dir: str = "reports"):
|
|
"""Save backtest results to files"""
|
|
try:
|
|
output_path = Path(output_dir)
|
|
output_path.mkdir(exist_ok=True)
|
|
|
|
for backtest_id, results in self.backtest_results.items():
|
|
file_path = output_path / f"backtest_{backtest_id}.json"
|
|
|
|
with open(file_path, 'w') as f:
|
|
json.dump(results, f, indent=2, default=str)
|
|
|
|
logger.info(f"Saved backtest results to {file_path}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving backtest results: {e}")
|
|
|
|
def generate_report(self, backtest_id: str) -> str:
|
|
"""Generate a human-readable report for a backtest"""
|
|
try:
|
|
if backtest_id not in self.backtest_results:
|
|
return f"Backtest {backtest_id} not found"
|
|
|
|
results = self.backtest_results[backtest_id]
|
|
|
|
report = f"""
|
|
Multi-Horizon Prediction Backtest Report
|
|
========================================
|
|
|
|
Symbol: {results['symbol']}
|
|
Period: {results['start_date']} to {results['end_date']}
|
|
Total Predictions: {results['total_predictions']}
|
|
|
|
Overall Performance:
|
|
- Accuracy: {results['results'].get('overall_accuracy', 0):.2%}
|
|
- Average Confidence: {results['results'].get('avg_confidence', 0):.2%}
|
|
|
|
Horizon Performance:
|
|
"""
|
|
|
|
for horizon, h_results in results['results'].get('horizon_results', {}).items():
|
|
report += f"""
|
|
{horizon}min Horizon:
|
|
- Predictions: {h_results['predictions']}
|
|
- Accuracy: {h_results.get('accuracy', 0):.2%}
|
|
- Avg Range Error: {h_results.get('avg_range_error', 0):.4f}
|
|
- Avg Confidence: {h_results.get('avg_confidence', 0):.2%}
|
|
"""
|
|
|
|
# Profitability analysis
|
|
profit_analysis = results['results'].get('profitability_analysis', {})
|
|
if profit_analysis:
|
|
report += f"""
|
|
Profitability Analysis:
|
|
- Total Simulated Trades: {profit_analysis.get('total_trades', 0)}
|
|
- Win Rate: {profit_analysis.get('win_rate', 0):.2%}
|
|
- Total Return: {profit_analysis.get('total_return', 0):.4f}
|
|
- Avg Return per Trade: {profit_analysis.get('avg_return_per_trade', 0):.4f}
|
|
"""
|
|
|
|
return report
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating report: {e}")
|
|
return f"Error generating report: {e}"
|