This commit is contained in:
Dobromir Popov
2025-11-23 02:16:34 +02:00
parent 24aeefda9d
commit 53ce4a355a
8 changed files with 1088 additions and 155 deletions

View File

@@ -17,7 +17,7 @@ from flask import Flask, render_template, request, jsonify, send_file
from dash import Dash, html
import logging
from datetime import datetime, timezone, timedelta
from typing import Optional, Dict, List, Any
from typing import Optional, Dict, List, Any, Tuple
import json
import pandas as pd
import numpy as np
@@ -370,8 +370,8 @@ class BacktestRunner:
action_idx = torch.argmax(action_probs, dim=-1).item()
confidence = action_probs[0, action_idx].item()
# Map to BUY/SELL/HOLD
actions = ['BUY', 'SELL', 'HOLD']
# Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL)
actions = ['HOLD', 'BUY', 'SELL']
if action_idx < len(actions):
action = actions[action_idx]
else:
@@ -490,6 +490,194 @@ class BacktestRunner:
state['stop_requested'] = True
class TrainingStrategyManager:
"""
Manages training strategies and decisions - Separates business logic from model interface
Training Modes:
- 'none': No training (inference only)
- 'every_candle': Train on every completed candle
- 'pivots_only': Train only on pivot points (BUY at L pivots, SELL at H pivots)
- 'manual': Training triggered manually by user button
"""
def __init__(self, data_provider, training_adapter):
self.data_provider = data_provider
self.training_adapter = training_adapter
self.mode = 'none' # Default: no training
self.dashboard = None # Set by dashboard after initialization
# Statistics tracking
self.stats = {
'total_trained': 0,
'by_action': {'BUY': 0, 'SELL': 0, 'HOLD': 0},
'profitable': 0
}
def should_train_on_candle(self, symbol: str, timeframe: str, candle_timestamp, pivot_markers: Dict = None) -> Tuple[bool, Optional[Dict]]:
"""
Decide if we should train on this candle based on current mode
Args:
symbol: Trading symbol
timeframe: Candle timeframe
candle_timestamp: Timestamp of the candle
pivot_markers: Dict of pivot markers (timestamp -> pivot data)
Returns:
Tuple of (should_train: bool, action_data: Optional[Dict])
action_data contains: {'action': 'BUY'/'SELL'/'HOLD', 'pivot_level': int, 'pivot_strength': float}
"""
if self.mode == 'none':
return False, None
elif self.mode == 'every_candle':
# Train on every candle - determine action from price movement or pivots
action_data = self._get_action_for_candle(symbol, timeframe, candle_timestamp, pivot_markers)
return True, action_data
elif self.mode == 'pivots_only':
# Train only on pivot candles
return self._is_pivot_candle(candle_timestamp, pivot_markers)
elif self.mode == 'manual':
# Manual training - don't auto-train
return False, None
return False, None
def _get_action_for_candle(self, symbol: str, timeframe: str, candle_timestamp, pivot_markers: Dict = None) -> Dict:
"""
Determine action for any candle (pivot or non-pivot)
For pivot candles: BUY at L, SELL at H
For non-pivot candles: Use price movement thresholds
"""
# First check if it's a pivot candle
is_pivot, pivot_action = self._is_pivot_candle(candle_timestamp, pivot_markers)
if is_pivot and pivot_action:
return pivot_action
# Not a pivot - use price movement based logic
# Get recent candles to determine trend
df = self.data_provider.get_historical_data(symbol, timeframe, limit=5)
if df is None or len(df) < 3:
return {'action': 'HOLD', 'reason': 'insufficient_data'}
# Simple momentum: if price going up, BUY, if going down, SELL
recent_change = (df.iloc[-1]['close'] - df.iloc[-3]['close']) / df.iloc[-3]['close']
if recent_change > 0.0005: # 0.05% up
action = 'BUY'
elif recent_change < -0.0005: # 0.05% down
action = 'SELL'
else:
action = 'HOLD'
return {
'action': action,
'reason': 'price_movement',
'change_pct': recent_change * 100
}
def _is_pivot_candle(self, timestamp, pivot_markers: Dict = None) -> Tuple[bool, Optional[Dict]]:
"""
Check if candle is a pivot point and return action
Returns:
Tuple of (is_pivot: bool, action_data: Optional[Dict])
"""
if not pivot_markers:
return False, None
candle_timestamp = str(timestamp)
candle_pivots = pivot_markers.get(candle_timestamp, {})
if not candle_pivots:
return False, None
# BUY at L pivots (lows - support levels)
if 'lows' in candle_pivots and len(candle_pivots['lows']) > 0:
best_low = max(candle_pivots['lows'], key=lambda p: p.get('level', 0))
pivot_level = best_low.get('level', 1)
pivot_strength = best_low.get('strength', 0.5)
logger.info(f"L{pivot_level}L pivot detected @ {timestamp}, strength={pivot_strength:.2f} → BUY signal")
return True, {
'action': 'BUY',
'pivot_level': pivot_level,
'pivot_strength': pivot_strength,
'reason': 'low_pivot'
}
# SELL at H pivots (highs - resistance levels)
elif 'highs' in candle_pivots and len(candle_pivots['highs']) > 0:
best_high = max(candle_pivots['highs'], key=lambda p: p.get('level', 0))
pivot_level = best_high.get('level', 1)
pivot_strength = best_high.get('strength', 0.5)
logger.info(f"L{pivot_level}H pivot detected @ {timestamp}, strength={pivot_strength:.2f} → SELL signal")
return True, {
'action': 'SELL',
'pivot_level': pivot_level,
'pivot_strength': pivot_strength,
'reason': 'high_pivot'
}
return False, None
def train_manually(self, symbol: str, timeframe: str, action: str) -> Dict:
"""
Manually trigger training with specified action
Args:
symbol: Trading symbol
timeframe: Timeframe
action: Action to train ('BUY', 'SELL', or 'HOLD')
Returns:
Training result dict with metrics
"""
logger.info(f"Manual training triggered: {action} on {symbol} {timeframe}")
# Create action data
action_data = {
'action': action,
'reason': 'manual_trigger'
}
# Update stats
self.stats['total_trained'] += 1
self.stats['by_action'][action] = self.stats['by_action'].get(action, 0) + 1
return {
'success': True,
'action': action,
'triggered_by': 'manual'
}
def get_stats(self) -> Dict:
"""Get training statistics"""
total = self.stats['total_trained']
if total == 0:
return {
'total_trained': 0,
'by_action': {'BUY': '0%', 'SELL': '0%', 'HOLD': '0%'},
'mode': self.mode
}
return {
'total_trained': total,
'by_action': {
'BUY': f"{(self.stats['by_action'].get('BUY', 0) / total * 100):.1f}%",
'SELL': f"{(self.stats['by_action'].get('SELL', 0) / total * 100):.1f}%",
'HOLD': f"{(self.stats['by_action'].get('HOLD', 0) / total * 100):.1f}%"
},
'mode': self.mode
}
class AnnotationDashboard:
"""Main annotation dashboard application"""
@@ -586,12 +774,19 @@ class AnnotationDashboard:
self.annotation_manager = AnnotationManager()
# Use REAL training adapter - NO SIMULATION!
self.training_adapter = RealTrainingAdapter(None, self.data_provider)
# Initialize training strategy manager (controls training decisions)
self.training_strategy = TrainingStrategyManager(self.data_provider, self.training_adapter)
self.training_strategy.dashboard = self
# Pass socketio to training adapter for live trade updates
if self.has_socketio and self.socketio:
self.training_adapter.socketio = self.socketio
# Backtest runner for replaying visible chart with predictions
self.backtest_runner = BacktestRunner()
# Prediction cache for training: stores inference inputs/outputs to compare with actual candles
# Format: {symbol: {timeframe: [{'timestamp': ts, 'inputs': {...}, 'outputs': {...}, 'norm_params': {...}}, ...]}}
self.prediction_cache = {}
# Check if we should auto-load a model at startup
auto_load_model = os.getenv('AUTO_LOAD_MODEL', 'Transformer') # Default: Transformer
@@ -2121,14 +2316,21 @@ class AnnotationDashboard:
@self.server.route('/api/realtime-inference/start', methods=['POST'])
def start_realtime_inference():
"""Start real-time inference mode with optional training modes"""
"""Start real-time inference mode with configurable training strategy"""
try:
data = request.get_json()
model_name = data.get('model_name')
symbol = data.get('symbol', 'ETH/USDT')
timeframe = data.get('timeframe', '1m')
enable_live_training = data.get('enable_live_training', False) # Pivot-based training
train_every_candle = data.get('train_every_candle', False) # Per-candle training
# New unified training_mode parameter
training_mode = data.get('training_mode', 'none') # 'none', 'every_candle', 'pivots_only', 'manual'
# Backward compatibility with old parameters
if 'enable_live_training' in data or 'train_every_candle' in data:
enable_live_training = data.get('enable_live_training', False)
train_every_candle = data.get('train_every_candle', False)
training_mode = 'every_candle' if train_every_candle else ('pivots_only' if enable_live_training else 'none')
if not self.training_adapter:
return jsonify({
@@ -2139,18 +2341,21 @@ class AnnotationDashboard:
}
})
# Start real-time inference with optional training modes
# Set training mode on strategy manager
self.training_strategy.mode = training_mode
logger.info(f"Training strategy mode set to: {training_mode}")
# Start real-time inference - pass strategy manager for training decisions
inference_id = self.training_adapter.start_realtime_inference(
model_name=model_name,
symbol=symbol,
data_provider=self.data_provider,
enable_live_training=enable_live_training,
train_every_candle=train_every_candle,
timeframe=timeframe
enable_live_training=(training_mode != 'none'),
train_every_candle=(training_mode == 'every_candle'),
timeframe=timeframe,
training_strategy=self.training_strategy # Pass strategy manager
)
training_mode = "per-candle" if train_every_candle else ("pivot-based" if enable_live_training else "inference-only")
return jsonify({
'success': True,
'inference_id': inference_id,
@@ -2259,20 +2464,17 @@ class AnnotationDashboard:
if hasattr(self.orchestrator, 'recent_transformer_predictions') and symbol in self.orchestrator.recent_transformer_predictions:
transformer_preds = list(self.orchestrator.recent_transformer_predictions[symbol])
if transformer_preds:
# Use the most recent stored prediction (from inference loop)
predictions['transformer'] = transformer_preds[-1]
logger.debug(f"Using stored prediction: {list(transformer_preds[-1].keys())}")
else:
# Fallback: generate new prediction if no stored predictions
transformer_pred = self._get_live_transformer_prediction(symbol)
if transformer_pred:
predictions['transformer'] = transformer_pred
# Convert any remaining tensors to Python types before JSON serialization
transformer_pred = transformer_preds[-1].copy()
predictions['transformer'] = self._serialize_prediction(transformer_pred)
if predictions:
response['prediction'] = predictions
except Exception as e:
logger.debug(f"Error getting predictions: {e}")
import traceback
logger.debug(traceback.format_exc())
return jsonify(response)
@@ -2322,10 +2524,101 @@ class AnnotationDashboard:
}
})
@self.server.route('/api/realtime-inference/train-manual', methods=['POST'])
def train_manual():
"""Manually trigger training on current candle with specified action"""
try:
data = request.get_json()
inference_id = data.get('inference_id')
action = data.get('action', 'HOLD')
if not self.training_adapter:
return jsonify({
'success': False,
'error': 'Training adapter not available'
})
# Get active inference session
if not hasattr(self.training_adapter, 'inference_sessions'):
return jsonify({
'success': False,
'error': 'No active inference sessions'
})
session = self.training_adapter.inference_sessions.get(inference_id)
if not session:
return jsonify({
'success': False,
'error': 'Inference session not found'
})
# Set pending action for training
session['pending_action'] = action
# Get session parameters
symbol = session.get('symbol', 'ETH/USDT')
timeframe = session.get('timeframe', '1m')
data_provider = session.get('data_provider')
# Call training method
train_result = self.training_adapter._train_on_new_candle(
session, symbol, timeframe, data_provider
)
if train_result.get('success'):
return jsonify({
'success': True,
'action': action,
'metrics': {
'loss': train_result.get('loss', 0.0),
'accuracy': train_result.get('accuracy', 0.0),
'training_steps': train_result.get('training_steps', 0)
}
})
else:
return jsonify({
'success': False,
'error': train_result.get('error', 'Training failed')
})
except Exception as e:
logger.error(f"Error in manual training: {e}")
return jsonify({
'success': False,
'error': str(e)
})
# WebSocket event handlers (if SocketIO is available)
if self.has_socketio:
self._setup_websocket_handlers()
def _serialize_prediction(self, prediction: Dict) -> Dict:
"""Convert PyTorch tensors in prediction dict to JSON-serializable Python types"""
try:
import torch
serialized = {}
for key, value in prediction.items():
if isinstance(value, torch.Tensor):
if value.numel() == 1: # Scalar tensor
serialized[key] = value.item()
else: # Multi-element tensor
serialized[key] = value.detach().cpu().tolist()
elif isinstance(value, dict):
serialized[key] = self._serialize_prediction(value) # Recursive
elif isinstance(value, (list, tuple)):
serialized[key] = [
v.item() if isinstance(v, torch.Tensor) and v.numel() == 1 else
(v.detach().cpu().tolist() if isinstance(v, torch.Tensor) else v)
for v in value
]
else:
serialized[key] = value
return serialized
except Exception as e:
logger.warning(f"Error serializing prediction: {e}")
# Fallback: return as-is (might fail JSON serialization but won't crash)
return prediction
def _setup_websocket_handlers(self):
"""Setup WebSocket event handlers for real-time updates"""
if not self.has_socketio:
@@ -2748,35 +3041,209 @@ class AnnotationDashboard:
return {}
def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1):
"""Get live prediction from model"""
"""
Get live prediction from model using trainer inference
Caches inference data (inputs/outputs) for later training when actual candle arrives.
This allows us to:
1. Compare predicted vs actual candle values
2. Calculate loss
3. Do backpropagation with correct outputs
Returns:
Dict with prediction results including predicted_candle for ghost candle display
"""
try:
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
if not self.orchestrator:
return None
# Get recent candles for prediction
candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=200)
if not candles or len(candles) < 200:
# Get trainer from orchestrator
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if not trainer or not trainer.model:
logger.debug("No transformer trainer available for live prediction")
return None
# TODO: Implement actual prediction logic
# For now, return placeholder
import random
# Get market data using training adapter's method (reuses existing logic)
if not hasattr(self.training_adapter, '_get_realtime_market_data'):
logger.warning("Training adapter missing _get_realtime_market_data method")
return None
market_data, norm_params = self.training_adapter._get_realtime_market_data(symbol, self.data_provider)
if not market_data:
logger.debug(f"No market data available for {symbol} {timeframe}")
return None
# Make prediction with model
import torch
timestamp = datetime.now(timezone.utc)
with torch.no_grad():
trainer.model.eval()
outputs = trainer.model(**market_data)
# Extract action prediction
action_probs = outputs.get('action_probs')
if action_probs is None:
logger.debug("No action_probs in model output")
return None
action_idx = torch.argmax(action_probs, dim=-1).item()
confidence = action_probs[0, action_idx].item()
# Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL)
actions = ['HOLD', 'BUY', 'SELL']
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
# Extract predicted candles and denormalize
predicted_candles_raw = {}
if 'next_candles' in outputs:
for tf, tensor in outputs['next_candles'].items():
predicted_candles_raw[tf] = tensor.detach().cpu().numpy().tolist()
# Denormalize predicted candles
predicted_candles_denorm = {}
if predicted_candles_raw and norm_params:
for tf, raw_candle in predicted_candles_raw.items():
if tf in norm_params:
params = norm_params[tf]
price_min = params['price_min']
price_max = params['price_max']
vol_min = params['volume_min']
vol_max = params['volume_max']
# raw_candle is [1, 5] list
candle_values = raw_candle[0]
denorm_candle = [
candle_values[0] * (price_max - price_min) + price_min, # Open
candle_values[1] * (price_max - price_min) + price_min, # High
candle_values[2] * (price_max - price_min) + price_min, # Low
candle_values[3] * (price_max - price_min) + price_min, # Close
candle_values[4] * (vol_max - vol_min) + vol_min # Volume
]
predicted_candles_denorm[tf] = denorm_candle
# Get predicted price from candle close
predicted_price = None
if timeframe in predicted_candles_denorm:
predicted_price = predicted_candles_denorm[timeframe][3] # Close
elif '1m' in predicted_candles_denorm:
predicted_price = predicted_candles_denorm['1m'][3]
elif '1s' in predicted_candles_denorm:
predicted_price = predicted_candles_denorm['1s'][3]
# CACHE inference data for later training
# Store inputs, outputs, and normalization params so we can train when actual candle arrives
if symbol not in self.prediction_cache:
self.prediction_cache[symbol] = {}
if timeframe not in self.prediction_cache[symbol]:
self.prediction_cache[symbol][timeframe] = []
# Store cached inference data (convert tensors to CPU for storage)
cached_data = {
'timestamp': timestamp,
'symbol': symbol,
'timeframe': timeframe,
'model_inputs': {k: v.cpu().clone() if isinstance(v, torch.Tensor) else v
for k, v in market_data.items()},
'model_outputs': {k: v.cpu().clone() if isinstance(v, torch.Tensor) else v
for k, v in outputs.items()},
'normalization_params': norm_params,
'predicted_candle': predicted_candles_denorm.get(timeframe),
'prediction_steps': prediction_steps
}
self.prediction_cache[symbol][timeframe].append(cached_data)
# Keep only last 100 predictions per symbol/timeframe to prevent memory bloat
if len(self.prediction_cache[symbol][timeframe]) > 100:
self.prediction_cache[symbol][timeframe] = self.prediction_cache[symbol][timeframe][-100:]
logger.debug(f"Cached prediction for {symbol} {timeframe} @ {timestamp.isoformat()}")
# Return prediction result (same format as before for compatibility)
return {
'symbol': symbol,
'timeframe': timeframe,
'timestamp': datetime.now(timezone.utc).isoformat(),
'action': random.choice(['BUY', 'SELL', 'HOLD']),
'confidence': random.uniform(0.6, 0.95),
'predicted_price': candles[-1].get('close', 0) * (1 + random.uniform(-0.01, 0.01)),
'timestamp': timestamp.isoformat(),
'action': action,
'confidence': confidence,
'predicted_price': predicted_price,
'predicted_candle': predicted_candles_denorm,
'prediction_steps': prediction_steps
}
except Exception as e:
logger.error(f"Error getting live prediction: {e}")
import traceback
logger.debug(traceback.format_exc())
return None
def run(self, host='127.0.0.1', port=8052, debug=False):
def get_cached_predictions_for_training(self, symbol: str, timeframe: str, actual_candle_timestamp) -> List[Dict]:
"""
Retrieve cached predictions that match a specific candle timestamp for training
When an actual candle arrives, we can:
1. Find cached predictions made before this candle
2. Compare predicted vs actual candle values
3. Calculate loss and do backpropagation
Args:
symbol: Trading symbol
timeframe: Timeframe
actual_candle_timestamp: Timestamp of the actual candle that just arrived
Returns:
List of cached prediction dicts that should be trained on
"""
try:
if symbol not in self.prediction_cache:
return []
if timeframe not in self.prediction_cache[symbol]:
return []
# Find predictions made before this candle timestamp
# Predictions should be for candles that have now completed
matching_predictions = []
actual_time = actual_candle_timestamp if isinstance(actual_candle_timestamp, datetime) else datetime.fromisoformat(str(actual_candle_timestamp).replace('Z', '+00:00'))
for cached_pred in self.prediction_cache[symbol][timeframe]:
pred_time = cached_pred['timestamp']
if isinstance(pred_time, str):
pred_time = datetime.fromisoformat(pred_time.replace('Z', '+00:00'))
# Prediction should be for a candle that comes after the prediction time
# We match predictions that were made before the actual candle closed
if pred_time < actual_time:
matching_predictions.append(cached_pred)
return matching_predictions
except Exception as e:
logger.error(f"Error getting cached predictions for training: {e}")
return []
def clear_old_cached_predictions(self, symbol: str, timeframe: str, before_timestamp: datetime):
"""
Clear cached predictions older than a certain timestamp
Useful for cleaning up old predictions that are no longer needed
"""
try:
if symbol not in self.prediction_cache:
return
if timeframe not in self.prediction_cache[symbol]:
return
self.prediction_cache[symbol][timeframe] = [
pred for pred in self.prediction_cache[symbol][timeframe]
if pred['timestamp'] >= before_timestamp
]
except Exception as e:
logger.debug(f"Error clearing old cached predictions: {e}")
def run(self, host='127.0.0.1', port=8051, debug=False):
"""Run the application"""
logger.info(f"Starting Annotation Dashboard on http://{host}:{port}")