new backtesting feature
This commit is contained in:
@@ -21,6 +21,10 @@ from typing import Optional, Dict, List, Any
|
||||
import json
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import threading
|
||||
import uuid
|
||||
import time
|
||||
import torch
|
||||
|
||||
# Import core components from main system
|
||||
try:
|
||||
@@ -94,6 +98,337 @@ logging.basicConfig(
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Logging to: {log_file}")
|
||||
|
||||
|
||||
class BacktestRunner:
|
||||
"""Runs backtest candle-by-candle with model predictions and tracks PnL"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_backtests = {} # backtest_id -> state
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def start_backtest(self, backtest_id: str, model, data_provider, symbol: str, timeframe: str,
|
||||
start_time: Optional[str] = None, end_time: Optional[str] = None):
|
||||
"""Start backtest in background thread"""
|
||||
|
||||
# Initialize backtest state
|
||||
state = {
|
||||
'status': 'running',
|
||||
'candles_processed': 0,
|
||||
'total_candles': 0,
|
||||
'pnl': 0.0,
|
||||
'total_trades': 0,
|
||||
'wins': 0,
|
||||
'losses': 0,
|
||||
'new_predictions': [],
|
||||
'position': None, # {'type': 'long/short', 'entry_price': float, 'entry_time': str}
|
||||
'error': None,
|
||||
'stop_requested': False
|
||||
}
|
||||
|
||||
with self.lock:
|
||||
self.active_backtests[backtest_id] = state
|
||||
|
||||
# Run backtest in background thread
|
||||
thread = threading.Thread(
|
||||
target=self._run_backtest,
|
||||
args=(backtest_id, model, data_provider, symbol, timeframe, start_time, end_time)
|
||||
)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
def _run_backtest(self, backtest_id: str, model, data_provider, symbol: str, timeframe: str,
|
||||
start_time: Optional[str] = None, end_time: Optional[str] = None):
|
||||
"""Execute backtest candle-by-candle"""
|
||||
try:
|
||||
state = self.active_backtests[backtest_id]
|
||||
|
||||
# Get historical data
|
||||
logger.info(f"Backtest {backtest_id}: Fetching data for {symbol} {timeframe}")
|
||||
|
||||
# Get candles for the time range
|
||||
if start_time and end_time:
|
||||
# Parse time range and fetch data
|
||||
df = data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=1000 # Max candles
|
||||
)
|
||||
else:
|
||||
# Use last 500 candles
|
||||
df = data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=500
|
||||
)
|
||||
|
||||
if df is None or df.empty:
|
||||
state['status'] = 'error'
|
||||
state['error'] = 'No data available'
|
||||
return
|
||||
|
||||
logger.info(f"Backtest {backtest_id}: Processing {len(df)} candles")
|
||||
state['total_candles'] = len(df)
|
||||
|
||||
# Prepare for inference
|
||||
model.eval()
|
||||
|
||||
# IMPORTANT: Use CPU for backtest to avoid ROCm/HIP compatibility issues
|
||||
# GPU inference has kernel compatibility problems with some model architectures
|
||||
device = torch.device('cpu')
|
||||
model.to(device)
|
||||
logger.info(f"Backtest {backtest_id}: Using CPU for stable inference (avoiding ROCm/HIP issues)")
|
||||
|
||||
# Need at least 200 candles for context
|
||||
min_context = 200
|
||||
|
||||
# Process candles one by one
|
||||
for i in range(min_context, len(df)):
|
||||
if state['stop_requested']:
|
||||
state['status'] = 'stopped'
|
||||
break
|
||||
|
||||
# Get context (last 200 candles)
|
||||
context = df.iloc[i-200:i]
|
||||
current_candle = df.iloc[i]
|
||||
current_time = current_candle.name
|
||||
current_price = float(current_candle['close'])
|
||||
|
||||
# Make prediction
|
||||
prediction = self._make_prediction(model, device, context, symbol, timeframe)
|
||||
|
||||
if prediction:
|
||||
# Store prediction for display
|
||||
pred_data = {
|
||||
'timestamp': str(current_time),
|
||||
'price': current_price,
|
||||
'action': prediction['action'],
|
||||
'confidence': prediction['confidence'],
|
||||
'timeframe': timeframe
|
||||
}
|
||||
state['new_predictions'].append(pred_data)
|
||||
|
||||
# Execute trade logic
|
||||
self._execute_trade_logic(state, prediction, current_price, current_time)
|
||||
|
||||
# Update progress
|
||||
state['candles_processed'] = i - min_context + 1
|
||||
|
||||
# Simulate real-time (optional, remove for faster backtest)
|
||||
# time.sleep(0.01) # 10ms per candle
|
||||
|
||||
# Close any open position at end
|
||||
if state['position']:
|
||||
self._close_position(state, current_price, 'backtest_end')
|
||||
|
||||
# Calculate final stats
|
||||
total_trades = state['total_trades']
|
||||
wins = state['wins']
|
||||
state['win_rate'] = wins / total_trades if total_trades > 0 else 0
|
||||
|
||||
state['status'] = 'complete'
|
||||
logger.info(f"Backtest {backtest_id}: Complete. PnL=${state['pnl']:.2f}, Trades={total_trades}, Win Rate={state['win_rate']:.1%}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Backtest {backtest_id} error: {e}", exc_info=True)
|
||||
state['status'] = 'error'
|
||||
state['error'] = str(e)
|
||||
|
||||
def _make_prediction(self, model, device, context_df, symbol, timeframe):
|
||||
"""Make model prediction on context data"""
|
||||
try:
|
||||
# Convert context to model input format
|
||||
# Extract OHLCV data
|
||||
candles = context_df[['open', 'high', 'low', 'close', 'volume']].values
|
||||
|
||||
# Normalize
|
||||
candles_normalized = candles.copy()
|
||||
price_data = candles[:, :4]
|
||||
volume_data = candles[:, 4:5]
|
||||
|
||||
price_min = price_data.min()
|
||||
price_max = price_data.max()
|
||||
if price_max > price_min:
|
||||
candles_normalized[:, :4] = (price_data - price_min) / (price_max - price_min)
|
||||
|
||||
volume_min = volume_data.min()
|
||||
volume_max = volume_data.max()
|
||||
if volume_max > volume_min:
|
||||
candles_normalized[:, 4:5] = (volume_data - volume_min) / (volume_max - volume_min)
|
||||
|
||||
# Convert to tensor [1, 200, 5]
|
||||
# Try GPU first, fallback to CPU if GPU fails
|
||||
try:
|
||||
price_tensor = torch.tensor(candles_normalized, dtype=torch.float32).unsqueeze(0).to(device)
|
||||
tech_data = torch.zeros(1, 40, dtype=torch.float32).to(device)
|
||||
market_data = torch.zeros(1, 30, dtype=torch.float32).to(device)
|
||||
use_cpu = False
|
||||
except Exception as gpu_error:
|
||||
logger.warning(f"GPU tensor creation failed, using CPU: {gpu_error}")
|
||||
device = torch.device('cpu')
|
||||
model.to(device)
|
||||
price_tensor = torch.tensor(candles_normalized, dtype=torch.float32).unsqueeze(0)
|
||||
tech_data = torch.zeros(1, 40, dtype=torch.float32)
|
||||
market_data = torch.zeros(1, 30, dtype=torch.float32)
|
||||
use_cpu = True
|
||||
|
||||
# Make prediction
|
||||
with torch.no_grad():
|
||||
try:
|
||||
outputs = model(
|
||||
price_data_1m=price_tensor if timeframe == '1m' else None,
|
||||
price_data_1s=price_tensor if timeframe == '1s' else None,
|
||||
price_data_1h=price_tensor if timeframe == '1h' else None,
|
||||
price_data_1d=price_tensor if timeframe == '1d' else None,
|
||||
tech_data=tech_data,
|
||||
market_data=market_data
|
||||
)
|
||||
except RuntimeError as model_error:
|
||||
# GPU inference failed, retry on CPU
|
||||
if not use_cpu and 'HIP' in str(model_error):
|
||||
logger.warning(f"GPU inference failed, retrying on CPU: {model_error}")
|
||||
device = torch.device('cpu')
|
||||
model.to(device)
|
||||
price_tensor = price_tensor.cpu()
|
||||
tech_data = tech_data.cpu()
|
||||
market_data = market_data.cpu()
|
||||
|
||||
outputs = model(
|
||||
price_data_1m=price_tensor if timeframe == '1m' else None,
|
||||
price_data_1s=price_tensor if timeframe == '1s' else None,
|
||||
price_data_1h=price_tensor if timeframe == '1h' else None,
|
||||
price_data_1d=price_tensor if timeframe == '1d' else None,
|
||||
tech_data=tech_data,
|
||||
market_data=market_data
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
# Get action prediction
|
||||
action_probs = outputs.get('action_probs', outputs.get('trend_probs'))
|
||||
if action_probs is not None:
|
||||
action_idx = torch.argmax(action_probs, dim=-1).item()
|
||||
confidence = action_probs[0, action_idx].item()
|
||||
|
||||
# Map to BUY/SELL/HOLD
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
if action_idx < len(actions):
|
||||
action = actions[action_idx]
|
||||
else:
|
||||
# If 4 actions (model has 4 trend directions), map to 3 actions
|
||||
action = 'HOLD' if action_idx == 1 else ('BUY' if action_idx in [2, 3] else 'SELL')
|
||||
|
||||
return {
|
||||
'action': action,
|
||||
'confidence': confidence
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Prediction error: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _execute_trade_logic(self, state, prediction, current_price, current_time):
|
||||
"""Execute trading logic based on prediction"""
|
||||
action = prediction['action']
|
||||
confidence = prediction['confidence']
|
||||
|
||||
# Only trade on high confidence
|
||||
if confidence < 0.6:
|
||||
return
|
||||
|
||||
position = state['position']
|
||||
|
||||
if action == 'BUY' and position is None:
|
||||
# Enter long position
|
||||
state['position'] = {
|
||||
'type': 'long',
|
||||
'entry_price': current_price,
|
||||
'entry_time': current_time
|
||||
}
|
||||
logger.debug(f"Backtest: ENTER LONG @ ${current_price}")
|
||||
|
||||
elif action == 'SELL' and position is None:
|
||||
# Enter short position
|
||||
state['position'] = {
|
||||
'type': 'short',
|
||||
'entry_price': current_price,
|
||||
'entry_time': current_time
|
||||
}
|
||||
logger.debug(f"Backtest: ENTER SHORT @ ${current_price}")
|
||||
|
||||
elif position is not None:
|
||||
# Check if should exit
|
||||
should_exit = False
|
||||
|
||||
if position['type'] == 'long' and action == 'SELL':
|
||||
should_exit = True
|
||||
elif position['type'] == 'short' and action == 'BUY':
|
||||
should_exit = True
|
||||
|
||||
if should_exit:
|
||||
self._close_position(state, current_price, 'signal')
|
||||
|
||||
def _close_position(self, state, exit_price, reason):
|
||||
"""Close current position and update PnL"""
|
||||
position = state['position']
|
||||
if not position:
|
||||
return
|
||||
|
||||
entry_price = position['entry_price']
|
||||
|
||||
# Calculate PnL
|
||||
if position['type'] == 'long':
|
||||
pnl = exit_price - entry_price
|
||||
else: # short
|
||||
pnl = entry_price - exit_price
|
||||
|
||||
# Update state
|
||||
state['pnl'] += pnl
|
||||
state['total_trades'] += 1
|
||||
|
||||
if pnl > 0:
|
||||
state['wins'] += 1
|
||||
elif pnl < 0:
|
||||
state['losses'] += 1
|
||||
|
||||
logger.debug(f"Backtest: CLOSE {position['type'].upper()} @ ${exit_price:.2f}, PnL=${pnl:.2f} ({reason})")
|
||||
|
||||
state['position'] = None
|
||||
|
||||
def get_progress(self, backtest_id: str) -> Dict:
|
||||
"""Get backtest progress"""
|
||||
with self.lock:
|
||||
state = self.active_backtests.get(backtest_id)
|
||||
if not state:
|
||||
return {'success': False, 'error': 'Backtest not found'}
|
||||
|
||||
# Get and clear new predictions (they'll be sent to frontend)
|
||||
new_predictions = state['new_predictions']
|
||||
state['new_predictions'] = []
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'status': state['status'],
|
||||
'candles_processed': state['candles_processed'],
|
||||
'total_candles': state['total_candles'],
|
||||
'pnl': state['pnl'],
|
||||
'total_trades': state['total_trades'],
|
||||
'wins': state['wins'],
|
||||
'losses': state['losses'],
|
||||
'win_rate': state['wins'] / state['total_trades'] if state['total_trades'] > 0 else 0,
|
||||
'new_predictions': new_predictions,
|
||||
'error': state['error']
|
||||
}
|
||||
|
||||
def stop_backtest(self, backtest_id: str):
|
||||
"""Request backtest to stop"""
|
||||
with self.lock:
|
||||
state = self.active_backtests.get(backtest_id)
|
||||
if state:
|
||||
state['stop_requested'] = True
|
||||
|
||||
|
||||
class AnnotationDashboard:
|
||||
"""Main annotation dashboard application"""
|
||||
|
||||
@@ -190,6 +525,8 @@ class AnnotationDashboard:
|
||||
self.annotation_manager = AnnotationManager()
|
||||
# Use REAL training adapter - NO SIMULATION!
|
||||
self.training_adapter = RealTrainingAdapter(None, self.data_provider)
|
||||
# Backtest runner for replaying visible chart with predictions
|
||||
self.backtest_runner = BacktestRunner()
|
||||
|
||||
# Don't auto-load models - wait for user to click LOAD button
|
||||
logger.info("Models available for lazy loading: " + ", ".join(self.available_models))
|
||||
@@ -1310,6 +1647,89 @@ class AnnotationDashboard:
|
||||
}
|
||||
})
|
||||
|
||||
# Backtest API Endpoints
|
||||
@self.server.route('/api/backtest', methods=['POST'])
|
||||
def start_backtest():
|
||||
"""Start backtest on visible chart data"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
model_name = data['model_name']
|
||||
symbol = data['symbol']
|
||||
timeframe = data['timeframe']
|
||||
start_time = data.get('start_time')
|
||||
end_time = data.get('end_time')
|
||||
|
||||
# Get the loaded model
|
||||
if model_name not in self.loaded_models:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': f'Model {model_name} not loaded. Please load it first.'
|
||||
})
|
||||
|
||||
model = self.loaded_models[model_name]
|
||||
|
||||
# Generate backtest ID
|
||||
backtest_id = str(uuid.uuid4())
|
||||
|
||||
# Start backtest in background
|
||||
self.backtest_runner.start_backtest(
|
||||
backtest_id=backtest_id,
|
||||
model=model,
|
||||
data_provider=self.data_provider,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
)
|
||||
|
||||
# Get initial state
|
||||
progress = self.backtest_runner.get_progress(backtest_id)
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'backtest_id': backtest_id,
|
||||
'total_candles': progress.get('total_candles', 0)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting backtest: {e}", exc_info=True)
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@self.server.route('/api/backtest/progress/<backtest_id>', methods=['GET'])
|
||||
def get_backtest_progress(backtest_id):
|
||||
"""Get backtest progress"""
|
||||
try:
|
||||
progress = self.backtest_runner.get_progress(backtest_id)
|
||||
return jsonify(progress)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting backtest progress: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@self.server.route('/api/backtest/stop', methods=['POST'])
|
||||
def stop_backtest():
|
||||
"""Stop running backtest"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
backtest_id = data['backtest_id']
|
||||
|
||||
self.backtest_runner.stop_backtest(backtest_id)
|
||||
|
||||
return jsonify({
|
||||
'success': True
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping backtest: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@self.server.route('/api/active-training', methods=['GET'])
|
||||
def get_active_training():
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user