wip
This commit is contained in:
@@ -17,7 +17,7 @@ from flask import Flask, render_template, request, jsonify, send_file
|
||||
from dash import Dash, html
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional, Dict, List, Any
|
||||
from typing import Optional, Dict, List, Any, Tuple
|
||||
import json
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -370,8 +370,8 @@ class BacktestRunner:
|
||||
action_idx = torch.argmax(action_probs, dim=-1).item()
|
||||
confidence = action_probs[0, action_idx].item()
|
||||
|
||||
# Map to BUY/SELL/HOLD
|
||||
actions = ['BUY', 'SELL', 'HOLD']
|
||||
# Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL)
|
||||
actions = ['HOLD', 'BUY', 'SELL']
|
||||
if action_idx < len(actions):
|
||||
action = actions[action_idx]
|
||||
else:
|
||||
@@ -490,6 +490,194 @@ class BacktestRunner:
|
||||
state['stop_requested'] = True
|
||||
|
||||
|
||||
class TrainingStrategyManager:
|
||||
"""
|
||||
Manages training strategies and decisions - Separates business logic from model interface
|
||||
|
||||
Training Modes:
|
||||
- 'none': No training (inference only)
|
||||
- 'every_candle': Train on every completed candle
|
||||
- 'pivots_only': Train only on pivot points (BUY at L pivots, SELL at H pivots)
|
||||
- 'manual': Training triggered manually by user button
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider, training_adapter):
|
||||
self.data_provider = data_provider
|
||||
self.training_adapter = training_adapter
|
||||
self.mode = 'none' # Default: no training
|
||||
self.dashboard = None # Set by dashboard after initialization
|
||||
|
||||
# Statistics tracking
|
||||
self.stats = {
|
||||
'total_trained': 0,
|
||||
'by_action': {'BUY': 0, 'SELL': 0, 'HOLD': 0},
|
||||
'profitable': 0
|
||||
}
|
||||
|
||||
def should_train_on_candle(self, symbol: str, timeframe: str, candle_timestamp, pivot_markers: Dict = None) -> Tuple[bool, Optional[Dict]]:
|
||||
"""
|
||||
Decide if we should train on this candle based on current mode
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Candle timeframe
|
||||
candle_timestamp: Timestamp of the candle
|
||||
pivot_markers: Dict of pivot markers (timestamp -> pivot data)
|
||||
|
||||
Returns:
|
||||
Tuple of (should_train: bool, action_data: Optional[Dict])
|
||||
action_data contains: {'action': 'BUY'/'SELL'/'HOLD', 'pivot_level': int, 'pivot_strength': float}
|
||||
"""
|
||||
if self.mode == 'none':
|
||||
return False, None
|
||||
|
||||
elif self.mode == 'every_candle':
|
||||
# Train on every candle - determine action from price movement or pivots
|
||||
action_data = self._get_action_for_candle(symbol, timeframe, candle_timestamp, pivot_markers)
|
||||
return True, action_data
|
||||
|
||||
elif self.mode == 'pivots_only':
|
||||
# Train only on pivot candles
|
||||
return self._is_pivot_candle(candle_timestamp, pivot_markers)
|
||||
|
||||
elif self.mode == 'manual':
|
||||
# Manual training - don't auto-train
|
||||
return False, None
|
||||
|
||||
return False, None
|
||||
|
||||
def _get_action_for_candle(self, symbol: str, timeframe: str, candle_timestamp, pivot_markers: Dict = None) -> Dict:
|
||||
"""
|
||||
Determine action for any candle (pivot or non-pivot)
|
||||
For pivot candles: BUY at L, SELL at H
|
||||
For non-pivot candles: Use price movement thresholds
|
||||
"""
|
||||
# First check if it's a pivot candle
|
||||
is_pivot, pivot_action = self._is_pivot_candle(candle_timestamp, pivot_markers)
|
||||
if is_pivot and pivot_action:
|
||||
return pivot_action
|
||||
|
||||
# Not a pivot - use price movement based logic
|
||||
# Get recent candles to determine trend
|
||||
df = self.data_provider.get_historical_data(symbol, timeframe, limit=5)
|
||||
if df is None or len(df) < 3:
|
||||
return {'action': 'HOLD', 'reason': 'insufficient_data'}
|
||||
|
||||
# Simple momentum: if price going up, BUY, if going down, SELL
|
||||
recent_change = (df.iloc[-1]['close'] - df.iloc[-3]['close']) / df.iloc[-3]['close']
|
||||
|
||||
if recent_change > 0.0005: # 0.05% up
|
||||
action = 'BUY'
|
||||
elif recent_change < -0.0005: # 0.05% down
|
||||
action = 'SELL'
|
||||
else:
|
||||
action = 'HOLD'
|
||||
|
||||
return {
|
||||
'action': action,
|
||||
'reason': 'price_movement',
|
||||
'change_pct': recent_change * 100
|
||||
}
|
||||
|
||||
def _is_pivot_candle(self, timestamp, pivot_markers: Dict = None) -> Tuple[bool, Optional[Dict]]:
|
||||
"""
|
||||
Check if candle is a pivot point and return action
|
||||
|
||||
Returns:
|
||||
Tuple of (is_pivot: bool, action_data: Optional[Dict])
|
||||
"""
|
||||
if not pivot_markers:
|
||||
return False, None
|
||||
|
||||
candle_timestamp = str(timestamp)
|
||||
candle_pivots = pivot_markers.get(candle_timestamp, {})
|
||||
|
||||
if not candle_pivots:
|
||||
return False, None
|
||||
|
||||
# BUY at L pivots (lows - support levels)
|
||||
if 'lows' in candle_pivots and len(candle_pivots['lows']) > 0:
|
||||
best_low = max(candle_pivots['lows'], key=lambda p: p.get('level', 0))
|
||||
pivot_level = best_low.get('level', 1)
|
||||
pivot_strength = best_low.get('strength', 0.5)
|
||||
|
||||
logger.info(f"L{pivot_level}L pivot detected @ {timestamp}, strength={pivot_strength:.2f} → BUY signal")
|
||||
|
||||
return True, {
|
||||
'action': 'BUY',
|
||||
'pivot_level': pivot_level,
|
||||
'pivot_strength': pivot_strength,
|
||||
'reason': 'low_pivot'
|
||||
}
|
||||
|
||||
# SELL at H pivots (highs - resistance levels)
|
||||
elif 'highs' in candle_pivots and len(candle_pivots['highs']) > 0:
|
||||
best_high = max(candle_pivots['highs'], key=lambda p: p.get('level', 0))
|
||||
pivot_level = best_high.get('level', 1)
|
||||
pivot_strength = best_high.get('strength', 0.5)
|
||||
|
||||
logger.info(f"L{pivot_level}H pivot detected @ {timestamp}, strength={pivot_strength:.2f} → SELL signal")
|
||||
|
||||
return True, {
|
||||
'action': 'SELL',
|
||||
'pivot_level': pivot_level,
|
||||
'pivot_strength': pivot_strength,
|
||||
'reason': 'high_pivot'
|
||||
}
|
||||
|
||||
return False, None
|
||||
|
||||
def train_manually(self, symbol: str, timeframe: str, action: str) -> Dict:
|
||||
"""
|
||||
Manually trigger training with specified action
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
action: Action to train ('BUY', 'SELL', or 'HOLD')
|
||||
|
||||
Returns:
|
||||
Training result dict with metrics
|
||||
"""
|
||||
logger.info(f"Manual training triggered: {action} on {symbol} {timeframe}")
|
||||
|
||||
# Create action data
|
||||
action_data = {
|
||||
'action': action,
|
||||
'reason': 'manual_trigger'
|
||||
}
|
||||
|
||||
# Update stats
|
||||
self.stats['total_trained'] += 1
|
||||
self.stats['by_action'][action] = self.stats['by_action'].get(action, 0) + 1
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'action': action,
|
||||
'triggered_by': 'manual'
|
||||
}
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""Get training statistics"""
|
||||
total = self.stats['total_trained']
|
||||
if total == 0:
|
||||
return {
|
||||
'total_trained': 0,
|
||||
'by_action': {'BUY': '0%', 'SELL': '0%', 'HOLD': '0%'},
|
||||
'mode': self.mode
|
||||
}
|
||||
|
||||
return {
|
||||
'total_trained': total,
|
||||
'by_action': {
|
||||
'BUY': f"{(self.stats['by_action'].get('BUY', 0) / total * 100):.1f}%",
|
||||
'SELL': f"{(self.stats['by_action'].get('SELL', 0) / total * 100):.1f}%",
|
||||
'HOLD': f"{(self.stats['by_action'].get('HOLD', 0) / total * 100):.1f}%"
|
||||
},
|
||||
'mode': self.mode
|
||||
}
|
||||
|
||||
|
||||
class AnnotationDashboard:
|
||||
"""Main annotation dashboard application"""
|
||||
|
||||
@@ -586,12 +774,19 @@ class AnnotationDashboard:
|
||||
self.annotation_manager = AnnotationManager()
|
||||
# Use REAL training adapter - NO SIMULATION!
|
||||
self.training_adapter = RealTrainingAdapter(None, self.data_provider)
|
||||
# Initialize training strategy manager (controls training decisions)
|
||||
self.training_strategy = TrainingStrategyManager(self.data_provider, self.training_adapter)
|
||||
self.training_strategy.dashboard = self
|
||||
# Pass socketio to training adapter for live trade updates
|
||||
if self.has_socketio and self.socketio:
|
||||
self.training_adapter.socketio = self.socketio
|
||||
# Backtest runner for replaying visible chart with predictions
|
||||
self.backtest_runner = BacktestRunner()
|
||||
|
||||
# Prediction cache for training: stores inference inputs/outputs to compare with actual candles
|
||||
# Format: {symbol: {timeframe: [{'timestamp': ts, 'inputs': {...}, 'outputs': {...}, 'norm_params': {...}}, ...]}}
|
||||
self.prediction_cache = {}
|
||||
|
||||
# Check if we should auto-load a model at startup
|
||||
auto_load_model = os.getenv('AUTO_LOAD_MODEL', 'Transformer') # Default: Transformer
|
||||
|
||||
@@ -2121,14 +2316,21 @@ class AnnotationDashboard:
|
||||
|
||||
@self.server.route('/api/realtime-inference/start', methods=['POST'])
|
||||
def start_realtime_inference():
|
||||
"""Start real-time inference mode with optional training modes"""
|
||||
"""Start real-time inference mode with configurable training strategy"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
model_name = data.get('model_name')
|
||||
symbol = data.get('symbol', 'ETH/USDT')
|
||||
timeframe = data.get('timeframe', '1m')
|
||||
enable_live_training = data.get('enable_live_training', False) # Pivot-based training
|
||||
train_every_candle = data.get('train_every_candle', False) # Per-candle training
|
||||
|
||||
# New unified training_mode parameter
|
||||
training_mode = data.get('training_mode', 'none') # 'none', 'every_candle', 'pivots_only', 'manual'
|
||||
|
||||
# Backward compatibility with old parameters
|
||||
if 'enable_live_training' in data or 'train_every_candle' in data:
|
||||
enable_live_training = data.get('enable_live_training', False)
|
||||
train_every_candle = data.get('train_every_candle', False)
|
||||
training_mode = 'every_candle' if train_every_candle else ('pivots_only' if enable_live_training else 'none')
|
||||
|
||||
if not self.training_adapter:
|
||||
return jsonify({
|
||||
@@ -2139,18 +2341,21 @@ class AnnotationDashboard:
|
||||
}
|
||||
})
|
||||
|
||||
# Start real-time inference with optional training modes
|
||||
# Set training mode on strategy manager
|
||||
self.training_strategy.mode = training_mode
|
||||
logger.info(f"Training strategy mode set to: {training_mode}")
|
||||
|
||||
# Start real-time inference - pass strategy manager for training decisions
|
||||
inference_id = self.training_adapter.start_realtime_inference(
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
data_provider=self.data_provider,
|
||||
enable_live_training=enable_live_training,
|
||||
train_every_candle=train_every_candle,
|
||||
timeframe=timeframe
|
||||
enable_live_training=(training_mode != 'none'),
|
||||
train_every_candle=(training_mode == 'every_candle'),
|
||||
timeframe=timeframe,
|
||||
training_strategy=self.training_strategy # Pass strategy manager
|
||||
)
|
||||
|
||||
training_mode = "per-candle" if train_every_candle else ("pivot-based" if enable_live_training else "inference-only")
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'inference_id': inference_id,
|
||||
@@ -2259,20 +2464,17 @@ class AnnotationDashboard:
|
||||
if hasattr(self.orchestrator, 'recent_transformer_predictions') and symbol in self.orchestrator.recent_transformer_predictions:
|
||||
transformer_preds = list(self.orchestrator.recent_transformer_predictions[symbol])
|
||||
if transformer_preds:
|
||||
# Use the most recent stored prediction (from inference loop)
|
||||
predictions['transformer'] = transformer_preds[-1]
|
||||
logger.debug(f"Using stored prediction: {list(transformer_preds[-1].keys())}")
|
||||
else:
|
||||
# Fallback: generate new prediction if no stored predictions
|
||||
transformer_pred = self._get_live_transformer_prediction(symbol)
|
||||
if transformer_pred:
|
||||
predictions['transformer'] = transformer_pred
|
||||
# Convert any remaining tensors to Python types before JSON serialization
|
||||
transformer_pred = transformer_preds[-1].copy()
|
||||
predictions['transformer'] = self._serialize_prediction(transformer_pred)
|
||||
|
||||
if predictions:
|
||||
response['prediction'] = predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting predictions: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
|
||||
return jsonify(response)
|
||||
|
||||
@@ -2322,10 +2524,101 @@ class AnnotationDashboard:
|
||||
}
|
||||
})
|
||||
|
||||
@self.server.route('/api/realtime-inference/train-manual', methods=['POST'])
|
||||
def train_manual():
|
||||
"""Manually trigger training on current candle with specified action"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
inference_id = data.get('inference_id')
|
||||
action = data.get('action', 'HOLD')
|
||||
|
||||
if not self.training_adapter:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': 'Training adapter not available'
|
||||
})
|
||||
|
||||
# Get active inference session
|
||||
if not hasattr(self.training_adapter, 'inference_sessions'):
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': 'No active inference sessions'
|
||||
})
|
||||
|
||||
session = self.training_adapter.inference_sessions.get(inference_id)
|
||||
if not session:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': 'Inference session not found'
|
||||
})
|
||||
|
||||
# Set pending action for training
|
||||
session['pending_action'] = action
|
||||
|
||||
# Get session parameters
|
||||
symbol = session.get('symbol', 'ETH/USDT')
|
||||
timeframe = session.get('timeframe', '1m')
|
||||
data_provider = session.get('data_provider')
|
||||
|
||||
# Call training method
|
||||
train_result = self.training_adapter._train_on_new_candle(
|
||||
session, symbol, timeframe, data_provider
|
||||
)
|
||||
|
||||
if train_result.get('success'):
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'action': action,
|
||||
'metrics': {
|
||||
'loss': train_result.get('loss', 0.0),
|
||||
'accuracy': train_result.get('accuracy', 0.0),
|
||||
'training_steps': train_result.get('training_steps', 0)
|
||||
}
|
||||
})
|
||||
else:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': train_result.get('error', 'Training failed')
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in manual training: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
# WebSocket event handlers (if SocketIO is available)
|
||||
if self.has_socketio:
|
||||
self._setup_websocket_handlers()
|
||||
|
||||
def _serialize_prediction(self, prediction: Dict) -> Dict:
|
||||
"""Convert PyTorch tensors in prediction dict to JSON-serializable Python types"""
|
||||
try:
|
||||
import torch
|
||||
serialized = {}
|
||||
for key, value in prediction.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
if value.numel() == 1: # Scalar tensor
|
||||
serialized[key] = value.item()
|
||||
else: # Multi-element tensor
|
||||
serialized[key] = value.detach().cpu().tolist()
|
||||
elif isinstance(value, dict):
|
||||
serialized[key] = self._serialize_prediction(value) # Recursive
|
||||
elif isinstance(value, (list, tuple)):
|
||||
serialized[key] = [
|
||||
v.item() if isinstance(v, torch.Tensor) and v.numel() == 1 else
|
||||
(v.detach().cpu().tolist() if isinstance(v, torch.Tensor) else v)
|
||||
for v in value
|
||||
]
|
||||
else:
|
||||
serialized[key] = value
|
||||
return serialized
|
||||
except Exception as e:
|
||||
logger.warning(f"Error serializing prediction: {e}")
|
||||
# Fallback: return as-is (might fail JSON serialization but won't crash)
|
||||
return prediction
|
||||
|
||||
def _setup_websocket_handlers(self):
|
||||
"""Setup WebSocket event handlers for real-time updates"""
|
||||
if not self.has_socketio:
|
||||
@@ -2748,35 +3041,209 @@ class AnnotationDashboard:
|
||||
return {}
|
||||
|
||||
def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1):
|
||||
"""Get live prediction from model"""
|
||||
"""
|
||||
Get live prediction from model using trainer inference
|
||||
|
||||
Caches inference data (inputs/outputs) for later training when actual candle arrives.
|
||||
This allows us to:
|
||||
1. Compare predicted vs actual candle values
|
||||
2. Calculate loss
|
||||
3. Do backpropagation with correct outputs
|
||||
|
||||
Returns:
|
||||
Dict with prediction results including predicted_candle for ghost candle display
|
||||
"""
|
||||
try:
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
|
||||
if not self.orchestrator:
|
||||
return None
|
||||
|
||||
# Get recent candles for prediction
|
||||
candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=200)
|
||||
if not candles or len(candles) < 200:
|
||||
# Get trainer from orchestrator
|
||||
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
||||
if not trainer or not trainer.model:
|
||||
logger.debug("No transformer trainer available for live prediction")
|
||||
return None
|
||||
|
||||
# TODO: Implement actual prediction logic
|
||||
# For now, return placeholder
|
||||
import random
|
||||
# Get market data using training adapter's method (reuses existing logic)
|
||||
if not hasattr(self.training_adapter, '_get_realtime_market_data'):
|
||||
logger.warning("Training adapter missing _get_realtime_market_data method")
|
||||
return None
|
||||
|
||||
market_data, norm_params = self.training_adapter._get_realtime_market_data(symbol, self.data_provider)
|
||||
if not market_data:
|
||||
logger.debug(f"No market data available for {symbol} {timeframe}")
|
||||
return None
|
||||
|
||||
# Make prediction with model
|
||||
import torch
|
||||
timestamp = datetime.now(timezone.utc)
|
||||
|
||||
with torch.no_grad():
|
||||
trainer.model.eval()
|
||||
outputs = trainer.model(**market_data)
|
||||
|
||||
# Extract action prediction
|
||||
action_probs = outputs.get('action_probs')
|
||||
if action_probs is None:
|
||||
logger.debug("No action_probs in model output")
|
||||
return None
|
||||
|
||||
action_idx = torch.argmax(action_probs, dim=-1).item()
|
||||
confidence = action_probs[0, action_idx].item()
|
||||
|
||||
# Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL)
|
||||
actions = ['HOLD', 'BUY', 'SELL']
|
||||
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
|
||||
|
||||
# Extract predicted candles and denormalize
|
||||
predicted_candles_raw = {}
|
||||
if 'next_candles' in outputs:
|
||||
for tf, tensor in outputs['next_candles'].items():
|
||||
predicted_candles_raw[tf] = tensor.detach().cpu().numpy().tolist()
|
||||
|
||||
# Denormalize predicted candles
|
||||
predicted_candles_denorm = {}
|
||||
if predicted_candles_raw and norm_params:
|
||||
for tf, raw_candle in predicted_candles_raw.items():
|
||||
if tf in norm_params:
|
||||
params = norm_params[tf]
|
||||
price_min = params['price_min']
|
||||
price_max = params['price_max']
|
||||
vol_min = params['volume_min']
|
||||
vol_max = params['volume_max']
|
||||
|
||||
# raw_candle is [1, 5] list
|
||||
candle_values = raw_candle[0]
|
||||
|
||||
denorm_candle = [
|
||||
candle_values[0] * (price_max - price_min) + price_min, # Open
|
||||
candle_values[1] * (price_max - price_min) + price_min, # High
|
||||
candle_values[2] * (price_max - price_min) + price_min, # Low
|
||||
candle_values[3] * (price_max - price_min) + price_min, # Close
|
||||
candle_values[4] * (vol_max - vol_min) + vol_min # Volume
|
||||
]
|
||||
predicted_candles_denorm[tf] = denorm_candle
|
||||
|
||||
# Get predicted price from candle close
|
||||
predicted_price = None
|
||||
if timeframe in predicted_candles_denorm:
|
||||
predicted_price = predicted_candles_denorm[timeframe][3] # Close
|
||||
elif '1m' in predicted_candles_denorm:
|
||||
predicted_price = predicted_candles_denorm['1m'][3]
|
||||
elif '1s' in predicted_candles_denorm:
|
||||
predicted_price = predicted_candles_denorm['1s'][3]
|
||||
|
||||
# CACHE inference data for later training
|
||||
# Store inputs, outputs, and normalization params so we can train when actual candle arrives
|
||||
if symbol not in self.prediction_cache:
|
||||
self.prediction_cache[symbol] = {}
|
||||
if timeframe not in self.prediction_cache[symbol]:
|
||||
self.prediction_cache[symbol][timeframe] = []
|
||||
|
||||
# Store cached inference data (convert tensors to CPU for storage)
|
||||
cached_data = {
|
||||
'timestamp': timestamp,
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe,
|
||||
'model_inputs': {k: v.cpu().clone() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in market_data.items()},
|
||||
'model_outputs': {k: v.cpu().clone() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in outputs.items()},
|
||||
'normalization_params': norm_params,
|
||||
'predicted_candle': predicted_candles_denorm.get(timeframe),
|
||||
'prediction_steps': prediction_steps
|
||||
}
|
||||
|
||||
self.prediction_cache[symbol][timeframe].append(cached_data)
|
||||
|
||||
# Keep only last 100 predictions per symbol/timeframe to prevent memory bloat
|
||||
if len(self.prediction_cache[symbol][timeframe]) > 100:
|
||||
self.prediction_cache[symbol][timeframe] = self.prediction_cache[symbol][timeframe][-100:]
|
||||
|
||||
logger.debug(f"Cached prediction for {symbol} {timeframe} @ {timestamp.isoformat()}")
|
||||
|
||||
# Return prediction result (same format as before for compatibility)
|
||||
return {
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe,
|
||||
'timestamp': datetime.now(timezone.utc).isoformat(),
|
||||
'action': random.choice(['BUY', 'SELL', 'HOLD']),
|
||||
'confidence': random.uniform(0.6, 0.95),
|
||||
'predicted_price': candles[-1].get('close', 0) * (1 + random.uniform(-0.01, 0.01)),
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'action': action,
|
||||
'confidence': confidence,
|
||||
'predicted_price': predicted_price,
|
||||
'predicted_candle': predicted_candles_denorm,
|
||||
'prediction_steps': prediction_steps
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting live prediction: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def run(self, host='127.0.0.1', port=8052, debug=False):
|
||||
def get_cached_predictions_for_training(self, symbol: str, timeframe: str, actual_candle_timestamp) -> List[Dict]:
|
||||
"""
|
||||
Retrieve cached predictions that match a specific candle timestamp for training
|
||||
|
||||
When an actual candle arrives, we can:
|
||||
1. Find cached predictions made before this candle
|
||||
2. Compare predicted vs actual candle values
|
||||
3. Calculate loss and do backpropagation
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
actual_candle_timestamp: Timestamp of the actual candle that just arrived
|
||||
|
||||
Returns:
|
||||
List of cached prediction dicts that should be trained on
|
||||
"""
|
||||
try:
|
||||
if symbol not in self.prediction_cache:
|
||||
return []
|
||||
if timeframe not in self.prediction_cache[symbol]:
|
||||
return []
|
||||
|
||||
# Find predictions made before this candle timestamp
|
||||
# Predictions should be for candles that have now completed
|
||||
matching_predictions = []
|
||||
actual_time = actual_candle_timestamp if isinstance(actual_candle_timestamp, datetime) else datetime.fromisoformat(str(actual_candle_timestamp).replace('Z', '+00:00'))
|
||||
|
||||
for cached_pred in self.prediction_cache[symbol][timeframe]:
|
||||
pred_time = cached_pred['timestamp']
|
||||
if isinstance(pred_time, str):
|
||||
pred_time = datetime.fromisoformat(pred_time.replace('Z', '+00:00'))
|
||||
|
||||
# Prediction should be for a candle that comes after the prediction time
|
||||
# We match predictions that were made before the actual candle closed
|
||||
if pred_time < actual_time:
|
||||
matching_predictions.append(cached_pred)
|
||||
|
||||
return matching_predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting cached predictions for training: {e}")
|
||||
return []
|
||||
|
||||
def clear_old_cached_predictions(self, symbol: str, timeframe: str, before_timestamp: datetime):
|
||||
"""
|
||||
Clear cached predictions older than a certain timestamp
|
||||
|
||||
Useful for cleaning up old predictions that are no longer needed
|
||||
"""
|
||||
try:
|
||||
if symbol not in self.prediction_cache:
|
||||
return
|
||||
if timeframe not in self.prediction_cache[symbol]:
|
||||
return
|
||||
|
||||
self.prediction_cache[symbol][timeframe] = [
|
||||
pred for pred in self.prediction_cache[symbol][timeframe]
|
||||
if pred['timestamp'] >= before_timestamp
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error clearing old cached predictions: {e}")
|
||||
|
||||
def run(self, host='127.0.0.1', port=8051, debug=False):
|
||||
"""Run the application"""
|
||||
logger.info(f"Starting Annotation Dashboard on http://{host}:{port}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user