new backtesting feature

This commit is contained in:
Dobromir Popov
2025-11-17 19:13:30 +02:00
parent 37e90a1c3c
commit ebb062bdae
5 changed files with 1106 additions and 36 deletions

View File

@@ -21,6 +21,10 @@ from typing import Optional, Dict, List, Any
import json
import pandas as pd
import numpy as np
import threading
import uuid
import time
import torch
# Import core components from main system
try:
@@ -94,6 +98,337 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
logger.info(f"Logging to: {log_file}")
class BacktestRunner:
"""Runs backtest candle-by-candle with model predictions and tracks PnL"""
def __init__(self):
self.active_backtests = {} # backtest_id -> state
self.lock = threading.Lock()
def start_backtest(self, backtest_id: str, model, data_provider, symbol: str, timeframe: str,
start_time: Optional[str] = None, end_time: Optional[str] = None):
"""Start backtest in background thread"""
# Initialize backtest state
state = {
'status': 'running',
'candles_processed': 0,
'total_candles': 0,
'pnl': 0.0,
'total_trades': 0,
'wins': 0,
'losses': 0,
'new_predictions': [],
'position': None, # {'type': 'long/short', 'entry_price': float, 'entry_time': str}
'error': None,
'stop_requested': False
}
with self.lock:
self.active_backtests[backtest_id] = state
# Run backtest in background thread
thread = threading.Thread(
target=self._run_backtest,
args=(backtest_id, model, data_provider, symbol, timeframe, start_time, end_time)
)
thread.daemon = True
thread.start()
def _run_backtest(self, backtest_id: str, model, data_provider, symbol: str, timeframe: str,
start_time: Optional[str] = None, end_time: Optional[str] = None):
"""Execute backtest candle-by-candle"""
try:
state = self.active_backtests[backtest_id]
# Get historical data
logger.info(f"Backtest {backtest_id}: Fetching data for {symbol} {timeframe}")
# Get candles for the time range
if start_time and end_time:
# Parse time range and fetch data
df = data_provider.get_historical_data(
symbol=symbol,
timeframe=timeframe,
limit=1000 # Max candles
)
else:
# Use last 500 candles
df = data_provider.get_historical_data(
symbol=symbol,
timeframe=timeframe,
limit=500
)
if df is None or df.empty:
state['status'] = 'error'
state['error'] = 'No data available'
return
logger.info(f"Backtest {backtest_id}: Processing {len(df)} candles")
state['total_candles'] = len(df)
# Prepare for inference
model.eval()
# IMPORTANT: Use CPU for backtest to avoid ROCm/HIP compatibility issues
# GPU inference has kernel compatibility problems with some model architectures
device = torch.device('cpu')
model.to(device)
logger.info(f"Backtest {backtest_id}: Using CPU for stable inference (avoiding ROCm/HIP issues)")
# Need at least 200 candles for context
min_context = 200
# Process candles one by one
for i in range(min_context, len(df)):
if state['stop_requested']:
state['status'] = 'stopped'
break
# Get context (last 200 candles)
context = df.iloc[i-200:i]
current_candle = df.iloc[i]
current_time = current_candle.name
current_price = float(current_candle['close'])
# Make prediction
prediction = self._make_prediction(model, device, context, symbol, timeframe)
if prediction:
# Store prediction for display
pred_data = {
'timestamp': str(current_time),
'price': current_price,
'action': prediction['action'],
'confidence': prediction['confidence'],
'timeframe': timeframe
}
state['new_predictions'].append(pred_data)
# Execute trade logic
self._execute_trade_logic(state, prediction, current_price, current_time)
# Update progress
state['candles_processed'] = i - min_context + 1
# Simulate real-time (optional, remove for faster backtest)
# time.sleep(0.01) # 10ms per candle
# Close any open position at end
if state['position']:
self._close_position(state, current_price, 'backtest_end')
# Calculate final stats
total_trades = state['total_trades']
wins = state['wins']
state['win_rate'] = wins / total_trades if total_trades > 0 else 0
state['status'] = 'complete'
logger.info(f"Backtest {backtest_id}: Complete. PnL=${state['pnl']:.2f}, Trades={total_trades}, Win Rate={state['win_rate']:.1%}")
except Exception as e:
logger.error(f"Backtest {backtest_id} error: {e}", exc_info=True)
state['status'] = 'error'
state['error'] = str(e)
def _make_prediction(self, model, device, context_df, symbol, timeframe):
"""Make model prediction on context data"""
try:
# Convert context to model input format
# Extract OHLCV data
candles = context_df[['open', 'high', 'low', 'close', 'volume']].values
# Normalize
candles_normalized = candles.copy()
price_data = candles[:, :4]
volume_data = candles[:, 4:5]
price_min = price_data.min()
price_max = price_data.max()
if price_max > price_min:
candles_normalized[:, :4] = (price_data - price_min) / (price_max - price_min)
volume_min = volume_data.min()
volume_max = volume_data.max()
if volume_max > volume_min:
candles_normalized[:, 4:5] = (volume_data - volume_min) / (volume_max - volume_min)
# Convert to tensor [1, 200, 5]
# Try GPU first, fallback to CPU if GPU fails
try:
price_tensor = torch.tensor(candles_normalized, dtype=torch.float32).unsqueeze(0).to(device)
tech_data = torch.zeros(1, 40, dtype=torch.float32).to(device)
market_data = torch.zeros(1, 30, dtype=torch.float32).to(device)
use_cpu = False
except Exception as gpu_error:
logger.warning(f"GPU tensor creation failed, using CPU: {gpu_error}")
device = torch.device('cpu')
model.to(device)
price_tensor = torch.tensor(candles_normalized, dtype=torch.float32).unsqueeze(0)
tech_data = torch.zeros(1, 40, dtype=torch.float32)
market_data = torch.zeros(1, 30, dtype=torch.float32)
use_cpu = True
# Make prediction
with torch.no_grad():
try:
outputs = model(
price_data_1m=price_tensor if timeframe == '1m' else None,
price_data_1s=price_tensor if timeframe == '1s' else None,
price_data_1h=price_tensor if timeframe == '1h' else None,
price_data_1d=price_tensor if timeframe == '1d' else None,
tech_data=tech_data,
market_data=market_data
)
except RuntimeError as model_error:
# GPU inference failed, retry on CPU
if not use_cpu and 'HIP' in str(model_error):
logger.warning(f"GPU inference failed, retrying on CPU: {model_error}")
device = torch.device('cpu')
model.to(device)
price_tensor = price_tensor.cpu()
tech_data = tech_data.cpu()
market_data = market_data.cpu()
outputs = model(
price_data_1m=price_tensor if timeframe == '1m' else None,
price_data_1s=price_tensor if timeframe == '1s' else None,
price_data_1h=price_tensor if timeframe == '1h' else None,
price_data_1d=price_tensor if timeframe == '1d' else None,
tech_data=tech_data,
market_data=market_data
)
else:
raise
# Get action prediction
action_probs = outputs.get('action_probs', outputs.get('trend_probs'))
if action_probs is not None:
action_idx = torch.argmax(action_probs, dim=-1).item()
confidence = action_probs[0, action_idx].item()
# Map to BUY/SELL/HOLD
actions = ['BUY', 'SELL', 'HOLD']
if action_idx < len(actions):
action = actions[action_idx]
else:
# If 4 actions (model has 4 trend directions), map to 3 actions
action = 'HOLD' if action_idx == 1 else ('BUY' if action_idx in [2, 3] else 'SELL')
return {
'action': action,
'confidence': confidence
}
return None
except Exception as e:
logger.error(f"Prediction error: {e}", exc_info=True)
return None
def _execute_trade_logic(self, state, prediction, current_price, current_time):
"""Execute trading logic based on prediction"""
action = prediction['action']
confidence = prediction['confidence']
# Only trade on high confidence
if confidence < 0.6:
return
position = state['position']
if action == 'BUY' and position is None:
# Enter long position
state['position'] = {
'type': 'long',
'entry_price': current_price,
'entry_time': current_time
}
logger.debug(f"Backtest: ENTER LONG @ ${current_price}")
elif action == 'SELL' and position is None:
# Enter short position
state['position'] = {
'type': 'short',
'entry_price': current_price,
'entry_time': current_time
}
logger.debug(f"Backtest: ENTER SHORT @ ${current_price}")
elif position is not None:
# Check if should exit
should_exit = False
if position['type'] == 'long' and action == 'SELL':
should_exit = True
elif position['type'] == 'short' and action == 'BUY':
should_exit = True
if should_exit:
self._close_position(state, current_price, 'signal')
def _close_position(self, state, exit_price, reason):
"""Close current position and update PnL"""
position = state['position']
if not position:
return
entry_price = position['entry_price']
# Calculate PnL
if position['type'] == 'long':
pnl = exit_price - entry_price
else: # short
pnl = entry_price - exit_price
# Update state
state['pnl'] += pnl
state['total_trades'] += 1
if pnl > 0:
state['wins'] += 1
elif pnl < 0:
state['losses'] += 1
logger.debug(f"Backtest: CLOSE {position['type'].upper()} @ ${exit_price:.2f}, PnL=${pnl:.2f} ({reason})")
state['position'] = None
def get_progress(self, backtest_id: str) -> Dict:
"""Get backtest progress"""
with self.lock:
state = self.active_backtests.get(backtest_id)
if not state:
return {'success': False, 'error': 'Backtest not found'}
# Get and clear new predictions (they'll be sent to frontend)
new_predictions = state['new_predictions']
state['new_predictions'] = []
return {
'success': True,
'status': state['status'],
'candles_processed': state['candles_processed'],
'total_candles': state['total_candles'],
'pnl': state['pnl'],
'total_trades': state['total_trades'],
'wins': state['wins'],
'losses': state['losses'],
'win_rate': state['wins'] / state['total_trades'] if state['total_trades'] > 0 else 0,
'new_predictions': new_predictions,
'error': state['error']
}
def stop_backtest(self, backtest_id: str):
"""Request backtest to stop"""
with self.lock:
state = self.active_backtests.get(backtest_id)
if state:
state['stop_requested'] = True
class AnnotationDashboard:
"""Main annotation dashboard application"""
@@ -190,6 +525,8 @@ class AnnotationDashboard:
self.annotation_manager = AnnotationManager()
# Use REAL training adapter - NO SIMULATION!
self.training_adapter = RealTrainingAdapter(None, self.data_provider)
# Backtest runner for replaying visible chart with predictions
self.backtest_runner = BacktestRunner()
# Don't auto-load models - wait for user to click LOAD button
logger.info("Models available for lazy loading: " + ", ".join(self.available_models))
@@ -1310,6 +1647,89 @@ class AnnotationDashboard:
}
})
# Backtest API Endpoints
@self.server.route('/api/backtest', methods=['POST'])
def start_backtest():
"""Start backtest on visible chart data"""
try:
data = request.get_json()
model_name = data['model_name']
symbol = data['symbol']
timeframe = data['timeframe']
start_time = data.get('start_time')
end_time = data.get('end_time')
# Get the loaded model
if model_name not in self.loaded_models:
return jsonify({
'success': False,
'error': f'Model {model_name} not loaded. Please load it first.'
})
model = self.loaded_models[model_name]
# Generate backtest ID
backtest_id = str(uuid.uuid4())
# Start backtest in background
self.backtest_runner.start_backtest(
backtest_id=backtest_id,
model=model,
data_provider=self.data_provider,
symbol=symbol,
timeframe=timeframe,
start_time=start_time,
end_time=end_time
)
# Get initial state
progress = self.backtest_runner.get_progress(backtest_id)
return jsonify({
'success': True,
'backtest_id': backtest_id,
'total_candles': progress.get('total_candles', 0)
})
except Exception as e:
logger.error(f"Error starting backtest: {e}", exc_info=True)
return jsonify({
'success': False,
'error': str(e)
})
@self.server.route('/api/backtest/progress/<backtest_id>', methods=['GET'])
def get_backtest_progress(backtest_id):
"""Get backtest progress"""
try:
progress = self.backtest_runner.get_progress(backtest_id)
return jsonify(progress)
except Exception as e:
logger.error(f"Error getting backtest progress: {e}")
return jsonify({
'success': False,
'error': str(e)
})
@self.server.route('/api/backtest/stop', methods=['POST'])
def stop_backtest():
"""Stop running backtest"""
try:
data = request.get_json()
backtest_id = data['backtest_id']
self.backtest_runner.stop_backtest(backtest_id)
return jsonify({
'success': True
})
except Exception as e:
logger.error(f"Error stopping backtest: {e}")
return jsonify({
'success': False,
'error': str(e)
})
@self.server.route('/api/active-training', methods=['GET'])
def get_active_training():
"""

View File

@@ -87,6 +87,32 @@
</button>
</div>
<!-- Backtest on Visible Chart -->
<div class="mb-3">
<label class="form-label small">Backtest on Visible Data</label>
<button class="btn btn-warning btn-sm w-100" id="start-backtest-btn">
<i class="fas fa-history"></i>
Backtest Visible Chart
</button>
<button class="btn btn-danger btn-sm w-100 mt-1" id="stop-backtest-btn" style="display: none;">
<i class="fas fa-stop"></i>
Stop Backtest
</button>
<!-- Backtest Results -->
<div id="backtest-results" style="display: none;" class="mt-2">
<div class="alert alert-success py-2 px-2 mb-0">
<strong class="small">Backtest Results</strong>
<div class="small mt-1">
<div>PnL: <span id="backtest-pnl" class="fw-bold">--</span></div>
<div>Trades: <span id="backtest-trades">--</span></div>
<div>Win Rate: <span id="backtest-winrate">--</span></div>
<div>Progress: <span id="backtest-progress">0</span>/<span id="backtest-total">0</span></div>
</div>
</div>
</div>
</div>
<!-- Multi-Step Inference Control -->
<div class="mb-3" id="inference-controls" style="display: none;">
<label for="prediction-steps-slider" class="form-label small text-muted">
@@ -569,6 +595,198 @@
});
});
// Backtest controls
let currentBacktestId = null;
let backtestPollInterval = null;
let backtestMarkers = []; // Store markers to clear later
document.getElementById('start-backtest-btn').addEventListener('click', function () {
const modelName = document.getElementById('model-select').value;
if (!modelName) {
showError('Please select a model first');
return;
}
// Get current chart state
const primaryTimeframe = document.getElementById('primary-timeframe-select').value;
const symbol = appState.currentSymbol;
// Get visible chart range from the chart (if available)
const chart = document.getElementById('main-chart');
let startTime = null;
let endTime = null;
// Try to get visible range from chart's x-axis
if (chart && chart.layout && chart.layout.xaxis) {
const xaxis = chart.layout.xaxis;
if (xaxis.range) {
startTime = xaxis.range[0];
endTime = xaxis.range[1];
}
}
// Clear previous backtest markers
if (backtestMarkers.length > 0) {
clearBacktestMarkers();
}
// Start backtest
fetch('/api/backtest', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model_name: modelName,
symbol: symbol,
timeframe: primaryTimeframe,
start_time: startTime,
end_time: endTime
})
})
.then(response => response.json())
.then(data => {
if (data.success) {
currentBacktestId = data.backtest_id;
// Update UI
document.getElementById('start-backtest-btn').style.display = 'none';
document.getElementById('stop-backtest-btn').style.display = 'block';
document.getElementById('backtest-results').style.display = 'block';
// Reset results
document.getElementById('backtest-pnl').textContent = '$0.00';
document.getElementById('backtest-trades').textContent = '0';
document.getElementById('backtest-winrate').textContent = '0%';
document.getElementById('backtest-progress').textContent = '0';
document.getElementById('backtest-total').textContent = data.total_candles || '?';
// Start polling for backtest progress
startBacktestPolling();
showSuccess('Backtest started');
} else {
showError('Failed to start backtest: ' + (data.error || 'Unknown error'));
}
})
.catch(error => {
showError('Network error: ' + error.message);
});
});
document.getElementById('stop-backtest-btn').addEventListener('click', function () {
if (!currentBacktestId) return;
// Stop backtest
fetch('/api/backtest/stop', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ backtest_id: currentBacktestId })
})
.then(response => response.json())
.then(data => {
// Update UI
document.getElementById('start-backtest-btn').style.display = 'block';
document.getElementById('stop-backtest-btn').style.display = 'none';
// Stop polling
stopBacktestPolling();
currentBacktestId = null;
showSuccess('Backtest stopped');
})
.catch(error => {
showError('Network error: ' + error.message);
});
});
function startBacktestPolling() {
if (backtestPollInterval) {
clearInterval(backtestPollInterval);
}
backtestPollInterval = setInterval(() => {
if (!currentBacktestId) {
stopBacktestPolling();
return;
}
fetch(`/api/backtest/progress/${currentBacktestId}`)
.then(response => response.json())
.then(data => {
if (data.success) {
updateBacktestUI(data);
// If complete, stop polling
if (data.status === 'complete' || data.status === 'error') {
stopBacktestPolling();
document.getElementById('start-backtest-btn').style.display = 'block';
document.getElementById('stop-backtest-btn').style.display = 'none';
currentBacktestId = null;
if (data.status === 'complete') {
showSuccess('Backtest complete');
} else {
showError('Backtest error: ' + (data.error || 'Unknown'));
}
}
}
})
.catch(error => {
console.error('Backtest polling error:', error);
});
}, 500); // Poll every 500ms for backtest progress
}
function stopBacktestPolling() {
if (backtestPollInterval) {
clearInterval(backtestPollInterval);
backtestPollInterval = null;
}
}
function updateBacktestUI(data) {
// Update progress
document.getElementById('backtest-progress').textContent = data.candles_processed || 0;
document.getElementById('backtest-total').textContent = data.total_candles || 0;
// Update PnL
const pnl = data.pnl || 0;
const pnlElement = document.getElementById('backtest-pnl');
pnlElement.textContent = `$${pnl.toFixed(2)}`;
pnlElement.className = pnl >= 0 ? 'fw-bold text-success' : 'fw-bold text-danger';
// Update trades
document.getElementById('backtest-trades').textContent = data.total_trades || 0;
// Update win rate
const winRate = data.win_rate || 0;
document.getElementById('backtest-winrate').textContent = `${(winRate * 100).toFixed(1)}%`;
// Add new predictions to chart
if (data.new_predictions && data.new_predictions.length > 0) {
addBacktestMarkersToChart(data.new_predictions);
}
}
function addBacktestMarkersToChart(predictions) {
// Store markers for later clearing
predictions.forEach(pred => {
backtestMarkers.push(pred);
});
// Trigger chart update with new markers
if (window.updateBacktestMarkers) {
window.updateBacktestMarkers(backtestMarkers);
}
}
function clearBacktestMarkers() {
backtestMarkers = [];
if (window.clearBacktestMarkers) {
window.clearBacktestMarkers();
}
}
function updatePredictionHistory() {
const historyDiv = document.getElementById('prediction-history');
if (predictionHistory.length === 0) {