""" Manual Trade Annotation UI - Main Application A web-based interface for manually marking profitable buy/sell signals on historical market data to generate training test cases for machine learning models. """ import os import sys from pathlib import Path # Add parent directory to path for imports parent_dir = Path(__file__).parent.parent.parent sys.path.insert(0, str(parent_dir)) from flask import Flask, render_template, request, jsonify, send_file from dash import Dash, html import logging from datetime import datetime, timezone, timedelta from typing import Optional, Dict, List, Any, Tuple import json import pandas as pd import numpy as np import threading import uuid import time import torch from utils.logging_config import get_channel_logger, LogChannel # Import core components from main system try: from core.data_provider import DataProvider from core.orchestrator import TradingOrchestrator from core.config import get_config from core.williams_market_structure import WilliamsMarketStructure except ImportError as e: print(f"Warning: Could not import main system components: {e}") print("Running in standalone mode with limited functionality") DataProvider = None WilliamsMarketStructure = None TradingOrchestrator = None get_config = lambda: {} # Import ANNOTATE modules annotate_dir = Path(__file__).parent.parent sys.path.insert(0, str(annotate_dir)) try: from core.annotation_manager import AnnotationManager from core.real_training_adapter import RealTrainingAdapter from core.data_loader import HistoricalDataLoader, TimeRangeManager except ImportError: # Try alternative import path import importlib.util # Load annotation_manager ann_spec = importlib.util.spec_from_file_location( "annotation_manager", annotate_dir / "core" / "annotation_manager.py" ) ann_module = importlib.util.module_from_spec(ann_spec) ann_spec.loader.exec_module(ann_module) AnnotationManager = ann_module.AnnotationManager # Load real_training_adapter (NO SIMULATION!) train_spec = importlib.util.spec_from_file_location( "real_training_adapter", annotate_dir / "core" / "real_training_adapter.py" ) train_module = importlib.util.module_from_spec(train_spec) train_spec.loader.exec_module(train_module) RealTrainingAdapter = train_module.RealTrainingAdapter # Load data_loader data_spec = importlib.util.spec_from_file_location( "data_loader", annotate_dir / "core" / "data_loader.py" ) data_module = importlib.util.module_from_spec(data_spec) data_spec.loader.exec_module(data_module) HistoricalDataLoader = data_module.HistoricalDataLoader TimeRangeManager = data_module.TimeRangeManager # Setup logging - configure before any logging occurs log_dir = Path(__file__).parent.parent / 'logs' log_dir.mkdir(exist_ok=True) log_file = log_dir / 'annotate_app.log' # Configure logging to both file and console # File mode 'w' truncates the file on each run logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_file, mode='w'), # Truncate on each run logging.StreamHandler(sys.stdout) # Also print to console ] ) logger = logging.getLogger(__name__) logger.info(f"Logging to: {log_file}") # Create channel-specific loggers pivot_logger = get_channel_logger(__name__, LogChannel.PIVOTS) api_logger = get_channel_logger(__name__, LogChannel.API) webui_logger = get_channel_logger(__name__, LogChannel.WEBUI) class BacktestRunner: """Runs backtest candle-by-candle with model predictions and tracks PnL""" def __init__(self): self.active_backtests = {} # backtest_id -> state self.lock = threading.Lock() def start_backtest(self, backtest_id: str, model, data_provider, symbol: str, timeframe: str, orchestrator=None, start_time: Optional[str] = None, end_time: Optional[str] = None): """Start backtest in background thread""" # Initialize backtest state state = { 'status': 'running', 'candles_processed': 0, 'total_candles': 0, 'pnl': 0.0, 'total_trades': 0, 'wins': 0, 'losses': 0, 'new_predictions': [], 'position': None, # {'type': 'long/short', 'entry_price': float, 'entry_time': str} 'error': None, 'stop_requested': False, 'orchestrator': orchestrator, 'symbol': symbol } # Clear previous predictions from orchestrator if orchestrator and hasattr(orchestrator, 'recent_transformer_predictions'): if symbol in orchestrator.recent_transformer_predictions: orchestrator.recent_transformer_predictions[symbol].clear() if symbol in orchestrator.recent_cnn_predictions: orchestrator.recent_cnn_predictions[symbol].clear() if symbol in orchestrator.recent_dqn_predictions: orchestrator.recent_dqn_predictions[symbol].clear() logger.info(f"Cleared previous predictions for backtest on {symbol}") with self.lock: self.active_backtests[backtest_id] = state # Run backtest in background thread thread = threading.Thread( target=self._run_backtest, args=(backtest_id, model, data_provider, symbol, timeframe, orchestrator, start_time, end_time) ) thread.daemon = True thread.start() def _run_backtest(self, backtest_id: str, model, data_provider, symbol: str, timeframe: str, orchestrator=None, start_time: Optional[str] = None, end_time: Optional[str] = None): """Execute backtest candle-by-candle""" try: state = self.active_backtests[backtest_id] # Get historical data logger.info(f"Backtest {backtest_id}: Fetching data for {symbol} {timeframe}") # Get candles for the time range if start_time and end_time: # Parse time range and fetch data df = data_provider.get_historical_data( symbol=symbol, timeframe=timeframe, limit=1000 # Max candles ) else: # Use last 500 candles df = data_provider.get_historical_data( symbol=symbol, timeframe=timeframe, limit=500 ) if df is None or df.empty: state['status'] = 'error' state['error'] = 'No data available' return logger.info(f"Backtest {backtest_id}: Processing {len(df)} candles") state['total_candles'] = len(df) # Prepare for inference model.eval() # IMPORTANT: Use CPU for backtest to avoid ROCm/HIP compatibility issues # GPU inference has kernel compatibility problems with some model architectures device = torch.device('cpu') model.to(device) logger.info(f"Backtest {backtest_id}: Using CPU for stable inference (avoiding ROCm/HIP issues)") # Need at least 200 candles for context min_context = 200 # Process candles one by one for i in range(min_context, len(df)): if state['stop_requested']: state['status'] = 'stopped' break # Get context (last 200 candles) context = df.iloc[i-200:i] current_candle = df.iloc[i] current_time = current_candle.name current_price = float(current_candle['close']) # Make prediction prediction = self._make_prediction(model, device, context, symbol, timeframe) if prediction: # Store prediction for display pred_data = { 'timestamp': str(current_time), 'price': current_price, 'action': prediction['action'], 'confidence': prediction['confidence'], 'timeframe': timeframe, 'current_price': current_price } state['new_predictions'].append(pred_data) # Store in orchestrator for visualization if orchestrator and hasattr(orchestrator, 'store_transformer_prediction'): # Determine model type from model class name model_type = model.__class__.__name__.lower() logger.debug(f"Backtest: Storing prediction for model type: {model_type}") # Store in appropriate prediction collection if 'transformer' in model_type: orchestrator.store_transformer_prediction(symbol, { 'timestamp': current_time, 'current_price': current_price, 'predicted_price': current_price * (1.01 if prediction['action'] == 'BUY' else 0.99), 'price_change': 1.0 if prediction['action'] == 'BUY' else -1.0, 'confidence': prediction['confidence'], 'action': prediction['action'], 'horizon_minutes': 10 }) logger.debug(f"Backtest: Stored transformer prediction: {prediction['action']} @ {current_price}") elif 'cnn' in model_type: if hasattr(orchestrator, 'recent_cnn_predictions'): if symbol not in orchestrator.recent_cnn_predictions: from collections import deque orchestrator.recent_cnn_predictions[symbol] = deque(maxlen=50) orchestrator.recent_cnn_predictions[symbol].append({ 'timestamp': current_time, 'current_price': current_price, 'predicted_price': current_price * (1.01 if prediction['action'] == 'BUY' else 0.99), 'confidence': prediction['confidence'], 'direction': 2 if prediction['action'] == 'BUY' else 0 }) elif 'dqn' in model_type or 'rl' in model_type: if hasattr(orchestrator, 'recent_dqn_predictions'): if symbol not in orchestrator.recent_dqn_predictions: from collections import deque orchestrator.recent_dqn_predictions[symbol] = deque(maxlen=100) orchestrator.recent_dqn_predictions[symbol].append({ 'timestamp': current_time, 'current_price': current_price, 'action': prediction['action'], 'confidence': prediction['confidence'] }) # Execute trade logic self._execute_trade_logic(state, prediction, current_price, current_time) # Update progress state['candles_processed'] = i - min_context + 1 # Simulate real-time (optional, remove for faster backtest) # time.sleep(0.01) # 10ms per candle # Close any open position at end if state['position']: self._close_position(state, current_price, 'backtest_end') # Calculate final stats total_trades = state['total_trades'] wins = state['wins'] state['win_rate'] = wins / total_trades if total_trades > 0 else 0 state['status'] = 'complete' logger.info(f"Backtest {backtest_id}: Complete. PnL=${state['pnl']:.2f}, Trades={total_trades}, Win Rate={state['win_rate']:.1%}") except Exception as e: logger.error(f"Backtest {backtest_id} error: {e}", exc_info=True) state['status'] = 'error' state['error'] = str(e) def _make_prediction(self, model, device, context_df, symbol, timeframe): """Make model prediction on context data""" try: # Convert context to model input format # Extract OHLCV data candles = context_df[['open', 'high', 'low', 'close', 'volume']].values # Normalize candles_normalized = candles.copy() price_data = candles[:, :4] volume_data = candles[:, 4:5] price_min = price_data.min() price_max = price_data.max() if price_max > price_min: candles_normalized[:, :4] = (price_data - price_min) / (price_max - price_min) volume_min = volume_data.min() volume_max = volume_data.max() if volume_max > volume_min: candles_normalized[:, 4:5] = (volume_data - volume_min) / (volume_max - volume_min) # Convert to tensor [1, 200, 5] # Try GPU first, fallback to CPU if GPU fails try: price_tensor = torch.tensor(candles_normalized, dtype=torch.float32).unsqueeze(0).to(device) tech_data = torch.zeros(1, 40, dtype=torch.float32).to(device) market_data = torch.zeros(1, 30, dtype=torch.float32).to(device) use_cpu = False except Exception as gpu_error: logger.warning(f"GPU tensor creation failed, using CPU: {gpu_error}") device = torch.device('cpu') model.to(device) price_tensor = torch.tensor(candles_normalized, dtype=torch.float32).unsqueeze(0) tech_data = torch.zeros(1, 40, dtype=torch.float32) market_data = torch.zeros(1, 30, dtype=torch.float32) use_cpu = True # Make prediction with torch.no_grad(): try: outputs = model( price_data_1m=price_tensor if timeframe == '1m' else None, price_data_1s=price_tensor if timeframe == '1s' else None, price_data_1h=price_tensor if timeframe == '1h' else None, price_data_1d=price_tensor if timeframe == '1d' else None, tech_data=tech_data, market_data=market_data ) except RuntimeError as model_error: # GPU inference failed, retry on CPU if not use_cpu and 'HIP' in str(model_error): logger.warning(f"GPU inference failed, retrying on CPU: {model_error}") device = torch.device('cpu') model.to(device) price_tensor = price_tensor.cpu() tech_data = tech_data.cpu() market_data = market_data.cpu() outputs = model( price_data_1m=price_tensor if timeframe == '1m' else None, price_data_1s=price_tensor if timeframe == '1s' else None, price_data_1h=price_tensor if timeframe == '1h' else None, price_data_1d=price_tensor if timeframe == '1d' else None, tech_data=tech_data, market_data=market_data ) else: raise # Get action prediction action_probs = outputs.get('action_probs', outputs.get('trend_probs')) if action_probs is not None: action_idx = torch.argmax(action_probs, dim=-1).item() confidence = action_probs[0, action_idx].item() # Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL) actions = ['HOLD', 'BUY', 'SELL'] if action_idx < len(actions): action = actions[action_idx] else: # If 4 actions (model has 4 trend directions), map to 3 actions action = 'HOLD' if action_idx == 1 else ('BUY' if action_idx in [2, 3] else 'SELL') return { 'action': action, 'confidence': confidence } return None except Exception as e: logger.error(f"Prediction error: {e}", exc_info=True) return None def _execute_trade_logic(self, state, prediction, current_price, current_time): """Execute trading logic based on prediction""" action = prediction['action'] confidence = prediction['confidence'] # Only trade on high confidence if confidence < 0.6: return position = state['position'] if action == 'BUY' and position is None: # Enter long position state['position'] = { 'type': 'long', 'entry_price': current_price, 'entry_time': current_time } logger.debug(f"Backtest: ENTER LONG @ ${current_price}") elif action == 'SELL' and position is None: # Enter short position state['position'] = { 'type': 'short', 'entry_price': current_price, 'entry_time': current_time } logger.debug(f"Backtest: ENTER SHORT @ ${current_price}") elif position is not None: # Check if should exit should_exit = False if position['type'] == 'long' and action == 'SELL': should_exit = True elif position['type'] == 'short' and action == 'BUY': should_exit = True if should_exit: self._close_position(state, current_price, 'signal') def _close_position(self, state, exit_price, reason): """Close current position and update PnL""" position = state['position'] if not position: return entry_price = position['entry_price'] # Calculate PnL if position['type'] == 'long': pnl = exit_price - entry_price else: # short pnl = entry_price - exit_price # Update state state['pnl'] += pnl state['total_trades'] += 1 if pnl > 0: state['wins'] += 1 elif pnl < 0: state['losses'] += 1 logger.debug(f"Backtest: CLOSE {position['type'].upper()} @ ${exit_price:.2f}, PnL=${pnl:.2f} ({reason})") state['position'] = None def get_progress(self, backtest_id: str) -> Dict: """Get backtest progress""" with self.lock: state = self.active_backtests.get(backtest_id) if not state: return {'success': False, 'error': 'Backtest not found'} # Get and clear new predictions (they'll be sent to frontend) new_predictions = state['new_predictions'] state['new_predictions'] = [] return { 'success': True, 'status': state['status'], 'candles_processed': state['candles_processed'], 'total_candles': state['total_candles'], 'pnl': state['pnl'], 'total_trades': state['total_trades'], 'wins': state['wins'], 'losses': state['losses'], 'win_rate': state['wins'] / state['total_trades'] if state['total_trades'] > 0 else 0, 'new_predictions': new_predictions, 'error': state['error'] } def stop_backtest(self, backtest_id: str): """Request backtest to stop""" with self.lock: state = self.active_backtests.get(backtest_id) if state: state['stop_requested'] = True class TrainingStrategyManager: """ Manages training strategies and decisions - Separates business logic from model interface Training Modes: - 'none': No training (inference only) - 'every_candle': Train on every completed candle - 'pivots_only': Train only on pivot points (BUY at L pivots, SELL at H pivots) - 'manual': Training triggered manually by user button """ def __init__(self, data_provider, training_adapter): self.data_provider = data_provider self.training_adapter = training_adapter self.mode = 'none' # Default: no training self.dashboard = None # Set by dashboard after initialization # Statistics tracking self.stats = { 'total_trained': 0, 'by_action': {'BUY': 0, 'SELL': 0, 'HOLD': 0}, 'profitable': 0 } def should_train_on_candle(self, symbol: str, timeframe: str, candle_timestamp, pivot_markers: Dict = None) -> Tuple[bool, Optional[Dict]]: """ Decide if we should train on this candle based on current mode Args: symbol: Trading symbol timeframe: Candle timeframe candle_timestamp: Timestamp of the candle pivot_markers: Dict of pivot markers (timestamp -> pivot data) Returns: Tuple of (should_train: bool, action_data: Optional[Dict]) action_data contains: {'action': 'BUY'/'SELL'/'HOLD', 'pivot_level': int, 'pivot_strength': float} """ if self.mode == 'none': return False, None elif self.mode == 'every_candle': # Train on every candle - determine action from price movement or pivots action_data = self._get_action_for_candle(symbol, timeframe, candle_timestamp, pivot_markers) return True, action_data elif self.mode == 'pivots_only': # Train only on pivot candles return self._is_pivot_candle(candle_timestamp, pivot_markers) elif self.mode == 'manual': # Manual training - don't auto-train return False, None return False, None def _get_action_for_candle(self, symbol: str, timeframe: str, candle_timestamp, pivot_markers: Dict = None) -> Dict: """ Determine action for any candle (pivot or non-pivot) For pivot candles: BUY at L, SELL at H For non-pivot candles: Use price movement thresholds """ # First check if it's a pivot candle is_pivot, pivot_action = self._is_pivot_candle(candle_timestamp, pivot_markers) if is_pivot and pivot_action: return pivot_action # Not a pivot - use price movement based logic # Get recent candles to determine trend df = self.data_provider.get_historical_data(symbol, timeframe, limit=5) if df is None or len(df) < 3: return {'action': 'HOLD', 'reason': 'insufficient_data'} # Simple momentum: if price going up, BUY, if going down, SELL recent_change = (df.iloc[-1]['close'] - df.iloc[-3]['close']) / df.iloc[-3]['close'] if recent_change > 0.0005: # 0.05% up action = 'BUY' elif recent_change < -0.0005: # 0.05% down action = 'SELL' else: action = 'HOLD' return { 'action': action, 'reason': 'price_movement', 'change_pct': recent_change * 100 } def _is_pivot_candle(self, timestamp, pivot_markers: Dict = None) -> Tuple[bool, Optional[Dict]]: """ Check if candle is a pivot point and return action Returns: Tuple of (is_pivot: bool, action_data: Optional[Dict]) """ if not pivot_markers: return False, None candle_timestamp = str(timestamp) candle_pivots = pivot_markers.get(candle_timestamp, {}) if not candle_pivots: return False, None # BUY at L pivots (lows - support levels) if 'lows' in candle_pivots and len(candle_pivots['lows']) > 0: best_low = max(candle_pivots['lows'], key=lambda p: p.get('level', 0)) pivot_level = best_low.get('level', 1) pivot_strength = best_low.get('strength', 0.5) logger.info(f"L{pivot_level}L pivot detected @ {timestamp}, strength={pivot_strength:.2f} → BUY signal") return True, { 'action': 'BUY', 'pivot_level': pivot_level, 'pivot_strength': pivot_strength, 'reason': 'low_pivot' } # SELL at H pivots (highs - resistance levels) elif 'highs' in candle_pivots and len(candle_pivots['highs']) > 0: best_high = max(candle_pivots['highs'], key=lambda p: p.get('level', 0)) pivot_level = best_high.get('level', 1) pivot_strength = best_high.get('strength', 0.5) logger.info(f"L{pivot_level}H pivot detected @ {timestamp}, strength={pivot_strength:.2f} → SELL signal") return True, { 'action': 'SELL', 'pivot_level': pivot_level, 'pivot_strength': pivot_strength, 'reason': 'high_pivot' } return False, None def train_manually(self, symbol: str, timeframe: str, action: str) -> Dict: """ Manually trigger training with specified action Args: symbol: Trading symbol timeframe: Timeframe action: Action to train ('BUY', 'SELL', or 'HOLD') Returns: Training result dict with metrics """ logger.info(f"Manual training triggered: {action} on {symbol} {timeframe}") # Create action data action_data = { 'action': action, 'reason': 'manual_trigger' } # Update stats self.stats['total_trained'] += 1 self.stats['by_action'][action] = self.stats['by_action'].get(action, 0) + 1 return { 'success': True, 'action': action, 'triggered_by': 'manual' } def get_stats(self) -> Dict: """Get training statistics""" total = self.stats['total_trained'] if total == 0: return { 'total_trained': 0, 'by_action': {'BUY': '0%', 'SELL': '0%', 'HOLD': '0%'}, 'mode': self.mode } return { 'total_trained': total, 'by_action': { 'BUY': f"{(self.stats['by_action'].get('BUY', 0) / total * 100):.1f}%", 'SELL': f"{(self.stats['by_action'].get('SELL', 0) / total * 100):.1f}%", 'HOLD': f"{(self.stats['by_action'].get('HOLD', 0) / total * 100):.1f}%" }, 'mode': self.mode } class AnnotationDashboard: """Main annotation dashboard application""" def __init__(self): """Initialize the dashboard""" # Load configuration try: # Always try YAML loading first since get_config might not work in standalone mode import yaml with open('config.yaml', 'r') as f: self.config = yaml.safe_load(f) logger.info(f"Loaded config via YAML: {len(self.config)} keys") except Exception as e: logger.warning(f"Could not load config via YAML: {e}") try: # Fallback to get_config if available if get_config: self.config = get_config() logger.info(f"Loaded config via get_config: {len(self.config)} keys") else: raise Exception("get_config not available") except Exception as e2: logger.warning(f"Could not load config via get_config: {e2}") # Final fallback config with SOL/USDT self.config = { 'symbols': ['ETH/USDT', 'BTC/USDT', 'SOL/USDT'], 'timeframes': ['1s', '1m', '1h', '1d'] } logger.info("Using fallback config") # Initialize Flask app self.server = Flask( __name__, template_folder='templates', static_folder='static' ) # Initialize SocketIO for WebSocket support try: from flask_socketio import SocketIO, emit self.socketio = SocketIO( self.server, cors_allowed_origins="*", async_mode='threading', logger=False, engineio_logger=False ) self.has_socketio = True logger.info("SocketIO initialized for real-time updates") except ImportError: self.socketio = None self.has_socketio = False logger.warning("flask-socketio not installed - live updates will use polling") # Suppress werkzeug request logs (reduce noise from polling endpoints) werkzeug_logger = logging.getLogger('werkzeug') werkzeug_logger.setLevel(logging.WARNING) # Only show warnings and errors, not INFO # Initialize Dash app (optional component) self.app = Dash( __name__, server=self.server, url_base_pathname='/dash/', external_stylesheets=[ 'https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css', 'https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css' ] ) # Set a simple Dash layout to avoid NoLayoutException self.app.layout = html.Div([ html.H1("ANNOTATE Dashboard", className="text-center mb-4"), html.Div([ html.P("This is the Dash component of the ANNOTATE system."), html.P("The main interface is available at the Flask routes."), html.A("Go to Main Interface", href="/", className="btn btn-primary") ], className="container") ]) # Initialize core components (skip initial load for fast startup) self.data_provider = DataProvider(skip_initial_load=True) if DataProvider else None # Enable unified storage for real-time data access if self.data_provider: self._enable_unified_storage_async() # ANNOTATE doesn't need orchestrator immediately - lazy load on demand self.orchestrator = None self.models_loading = False self.available_models = ['DQN', 'CNN', 'Transformer'] # Models that CAN be loaded self.loaded_models = {} # Models that ARE loaded: {name: model_instance} # Initialize ANNOTATE components self.annotation_manager = AnnotationManager() # Use REAL training adapter - NO SIMULATION! self.training_adapter = RealTrainingAdapter(None, self.data_provider) # Initialize training strategy manager (controls training decisions) self.training_strategy = TrainingStrategyManager(self.data_provider, self.training_adapter) self.training_strategy.dashboard = self # Pass socketio to training adapter for live trade updates if self.has_socketio and self.socketio: self.training_adapter.socketio = self.socketio # Backtest runner for replaying visible chart with predictions self.backtest_runner = BacktestRunner() # Prediction cache for training: stores inference inputs/outputs to compare with actual candles # Format: {symbol: {timeframe: [{'timestamp': ts, 'inputs': {...}, 'outputs': {...}, 'norm_params': {...}}, ...]}} self.prediction_cache = {} # Check if we should auto-load a model at startup auto_load_model = os.getenv('AUTO_LOAD_MODEL', 'Transformer') # Default: Transformer if auto_load_model and auto_load_model.lower() != 'none': logger.info(f"Auto-loading model: {auto_load_model}") self._auto_load_model(auto_load_model) else: logger.info("Auto-load disabled. Models available for lazy loading: " + ", ".join(self.available_models)) # Initialize data loader with existing DataProvider self.data_loader = HistoricalDataLoader(self.data_provider) if self.data_provider else None self.time_range_manager = TimeRangeManager(self.data_loader) if self.data_loader else None # Setup routes self._setup_routes() # Start background data refresh after startup if self.data_loader: self._start_background_data_refresh() logger.info("Annotation Dashboard initialized") def _auto_load_model(self, model_name: str): """ Auto-load a model at startup in background thread Args: model_name: Name of model to load (DQN, CNN, or Transformer) """ def load_in_background(): try: logger.info(f"Starting auto-load for {model_name}...") # Initialize orchestrator if not already done if not self.orchestrator: logger.info("Initializing TradingOrchestrator...") self.orchestrator = TradingOrchestrator( data_provider=self.data_provider ) self.training_adapter.orchestrator = self.orchestrator logger.info("TradingOrchestrator initialized") # Check if the specific model is already initialized if model_name == 'Transformer': logger.info("Checking Transformer model...") if self.orchestrator.primary_transformer: self.loaded_models['Transformer'] = self.orchestrator.primary_transformer logger.info("Transformer model loaded successfully") else: logger.warning("Transformer model not initialized in orchestrator") return elif model_name == 'CNN': logger.info("Checking CNN model...") if self.orchestrator.cnn_model: self.loaded_models['CNN'] = self.orchestrator.cnn_model logger.info("CNN model loaded successfully") else: logger.warning("CNN model not initialized in orchestrator") return elif model_name == 'DQN': logger.info("Checking DQN model...") if self.orchestrator.rl_agent: self.loaded_models['DQN'] = self.orchestrator.rl_agent logger.info("DQN model loaded successfully") else: logger.warning("DQN model not initialized in orchestrator") return else: logger.warning(f"Unknown model name: {model_name}") return self.models_loading = False logger.info(f"{model_name} model ready for inference and training") except Exception as e: logger.error(f"Error auto-loading {model_name} model: {e}") import traceback logger.error(traceback.format_exc()) self.models_loading = False # Start loading in background thread self.models_loading = True thread = threading.Thread(target=load_in_background, daemon=True) thread.start() def _get_best_checkpoint_info(self, model_name: str) -> Optional[Dict]: """ Get best checkpoint info for a model without loading it First tries database, then falls back to filename parsing Args: model_name: Name of the model Returns: Dict with checkpoint info or None if no checkpoint found """ try: # Try to get from database first (has full metadata) try: from utils.database_manager import get_database_manager db_manager = get_database_manager() # Get active checkpoint for this model with db_manager._get_connection() as conn: cursor = conn.execute(""" SELECT checkpoint_id, performance_metrics, timestamp, file_path FROM checkpoint_metadata WHERE model_name = ? AND is_active = TRUE ORDER BY timestamp DESC LIMIT 1 """, (model_name.lower(),)) row = cursor.fetchone() if row: import json checkpoint_id, metrics_json, timestamp, file_path = row metrics = json.loads(metrics_json) if metrics_json else {} checkpoint_info = { 'filename': os.path.basename(file_path) if file_path else checkpoint_id, 'epoch': metrics.get('epoch', 0), 'loss': metrics.get('loss'), 'accuracy': metrics.get('accuracy'), 'source': 'database' } logger.info(f"Loaded checkpoint info from database for {model_name}: E{checkpoint_info['epoch']}, Loss={checkpoint_info['loss']}, Acc={checkpoint_info['accuracy']}") return checkpoint_info except Exception as db_error: logger.debug(f"Could not load from database: {db_error}") # Fallback to filename parsing import glob import re # Map model names to checkpoint directories checkpoint_dirs = { 'Transformer': 'models/checkpoints/transformer', 'CNN': 'models/checkpoints/enhanced_cnn', 'DQN': 'models/checkpoints/dqn_agent' } checkpoint_dir = checkpoint_dirs.get(model_name) if not checkpoint_dir: return None if not os.path.exists(checkpoint_dir): logger.debug(f"Checkpoint directory not found: {checkpoint_dir}") return None # Find all checkpoint files checkpoint_files = glob.glob(os.path.join(checkpoint_dir, '*.pt')) if not checkpoint_files: logger.debug(f"No checkpoint files found in {checkpoint_dir}") return None logger.debug(f"Found {len(checkpoint_files)} checkpoints for {model_name}") # Parse filenames to extract epoch info # Format: transformer_epoch5_20251110_123620.pt best_checkpoint = None best_epoch = -1 for cp_file in checkpoint_files: try: filename = os.path.basename(cp_file) # Extract epoch number from filename match = re.search(r'epoch(\d+)', filename, re.IGNORECASE) if match: epoch = int(match.group(1)) if epoch > best_epoch: best_epoch = epoch best_checkpoint = { 'filename': filename, 'epoch': epoch, 'loss': None, # Can't get without loading 'accuracy': None, # Can't get without loading 'source': 'filename' } logger.debug(f"Found checkpoint: {filename}, epoch {epoch}") except Exception as e: logger.debug(f"Could not parse checkpoint {cp_file}: {e}") continue if best_checkpoint: logger.info(f"Best checkpoint for {model_name}: {best_checkpoint['filename']} (E{best_checkpoint['epoch']})") return best_checkpoint except Exception as e: logger.error(f"Error getting checkpoint info for {model_name}: {e}") import traceback logger.error(traceback.format_exc()) return None def _load_model_lazy(self, model_name: str) -> dict: """ Lazy load a specific model on demand Args: model_name: Name of model to load ('DQN', 'CNN', 'Transformer') Returns: dict: Result with success status and message """ try: # Check if already loaded if model_name in self.loaded_models: return { 'success': True, 'message': f'{model_name} already loaded', 'already_loaded': True } # Check if model is available if model_name not in self.available_models: return { 'success': False, 'error': f'{model_name} is not in available models list' } logger.info(f"Loading {model_name} model...") # Initialize orchestrator if not already done if not self.orchestrator: if not TradingOrchestrator: return { 'success': False, 'error': 'TradingOrchestrator class not available' } logger.info("Creating TradingOrchestrator instance...") self.orchestrator = TradingOrchestrator( data_provider=self.data_provider, enhanced_rl_training=True ) logger.info("Orchestrator created") # Update training adapter self.training_adapter.orchestrator = self.orchestrator # Load specific model if model_name == 'DQN': if not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent: # Initialize RL agent self.orchestrator._initialize_rl_agent() self.loaded_models['DQN'] = self.orchestrator.rl_agent elif model_name == 'CNN': if not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model: # Initialize CNN model self.orchestrator._initialize_cnn_model() self.loaded_models['CNN'] = self.orchestrator.cnn_model elif model_name == 'Transformer': if not hasattr(self.orchestrator, 'primary_transformer') or not self.orchestrator.primary_transformer: # Initialize Transformer model self.orchestrator._initialize_transformer_model() self.loaded_models['Transformer'] = self.orchestrator.primary_transformer else: return { 'success': False, 'error': f'Unknown model: {model_name}' } logger.info(f"{model_name} model loaded successfully") return { 'success': True, 'message': f'{model_name} loaded successfully', 'loaded_models': list(self.loaded_models.keys()) } except Exception as e: logger.error(f"Error loading {model_name}: {e}") import traceback logger.error(f"Traceback:\n{traceback.format_exc()}") return { 'success': False, 'error': str(e) } def _enable_unified_storage_async(self): """Enable unified storage system in background thread""" def enable_storage(): try: import asyncio import threading loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) # Enable unified storage success = loop.run_until_complete( self.data_provider.enable_unified_storage() ) if success: logger.info(" ANNOTATE: Unified storage enabled for real-time data") # Get statistics stats = self.data_provider.get_unified_storage_stats() if stats.get('initialized'): logger.info(" Real-time data access: <10ms") logger.info(" Historical data access: <100ms") logger.info(" Annotation data: Available at any timestamp") else: logger.warning(" ANNOTATE: Unified storage not available, using cached data only") except Exception as e: logger.warning(f"ANNOTATE: Could not enable unified storage: {e}") logger.info("ANNOTATE: Continuing with cached data access") # Start in background thread import threading storage_thread = threading.Thread(target=enable_storage, daemon=True) storage_thread.start() def _start_background_data_refresh(self): """Start background task to refresh recent data after startup - ONCE ONLY""" def refresh_recent_data(): try: import time # Wait for app to fully start time.sleep(5) logger.info(" Starting one-time background data refresh (fetching only recent missing data)") # Disable startup mode to fetch fresh data self.data_loader.disable_startup_mode() # Use the new on-demand refresh method logger.info("Using on-demand refresh for recent data") self.data_provider.refresh_data_on_demand() logger.info(" One-time background data refresh completed") except Exception as e: logger.error(f"Error in background data refresh: {e}") # Start refresh in background thread import threading refresh_thread = threading.Thread(target=refresh_recent_data, daemon=True) refresh_thread.start() logger.info("One-time background data refresh scheduled") def _get_pivot_markers_for_timeframe(self, symbol: str, timeframe: str, df: pd.DataFrame) -> dict: """ Get pivot markers for a specific timeframe using WilliamsMarketStructure directly Returns dict with all pivot points and identifies which are the last high/low per level """ try: if WilliamsMarketStructure is None: logger.warning("WilliamsMarketStructure not available") return {} if df is None or len(df) < 10: logger.warning(f"Insufficient data for pivot calculation: {len(df) if df is not None else 0} bars") return {} # Convert DataFrame to numpy array format expected by Williams Market Structure ohlcv_array = df[['open', 'high', 'low', 'close', 'volume']].copy() # Add timestamp as first column (convert to milliseconds) timestamps = df.index.astype(np.int64) // 10**6 # pandas index is ns -> convert to ms ohlcv_array.insert(0, 'timestamp', timestamps) ohlcv_array = ohlcv_array.to_numpy() # Initialize Williams Market Structure with default distance # We'll override it in the calculation call williams = WilliamsMarketStructure(min_pivot_distance=1) # Calculate recursive pivot points with min_pivot_distance=2 # This ensures 5 candles per pivot (tip + 2 prev + 2 next) pivot_levels = williams.calculate_recursive_pivot_points( ohlcv_array, min_pivot_distance=2 ) if not pivot_levels: logger.debug(f"No pivot levels found for {symbol} {timeframe}") return {} # Build a map of timestamp -> pivot info # Also track last high/low per level for drawing horizontal lines pivot_map = {} last_pivots = {} # {level: {'high': (ts_str, idx), 'low': (ts_str, idx)}} # For each level (1-5), collect ALL pivot points for level_num, trend_level in pivot_levels.items(): if not hasattr(trend_level, 'pivot_points') or not trend_level.pivot_points: continue last_pivots[level_num] = {'high': None, 'low': None} # Add ALL pivot points to the map for pivot in trend_level.pivot_points: ts_str = pivot.timestamp.strftime('%Y-%m-%d %H:%M:%S') if ts_str not in pivot_map: pivot_map[ts_str] = {'highs': [], 'lows': []} pivot_info = { 'level': level_num, 'price': pivot.price, 'strength': pivot.strength, 'is_last': False # Will be updated below } if pivot.pivot_type == 'high': pivot_map[ts_str]['highs'].append(pivot_info) last_pivots[level_num]['high'] = (ts_str, len(pivot_map[ts_str]['highs']) - 1) elif pivot.pivot_type == 'low': pivot_map[ts_str]['lows'].append(pivot_info) last_pivots[level_num]['low'] = (ts_str, len(pivot_map[ts_str]['lows']) - 1) # Mark the last high and last low for each level for level_num, last_info in last_pivots.items(): if last_info['high']: ts_str, idx = last_info['high'] pivot_map[ts_str]['highs'][idx]['is_last'] = True if last_info['low']: ts_str, idx = last_info['low'] pivot_map[ts_str]['lows'][idx]['is_last'] = True pivot_logger.info(f"Found {len(pivot_map)} pivot candles for {symbol} {timeframe} (from {len(df)} candles)") return pivot_map except Exception as e: logger.error(f"Error getting pivot markers for {timeframe}: {e}") import traceback logger.error(traceback.format_exc()) return {} def _setup_routes(self): """Setup Flask routes""" @self.server.route('/favicon.ico') def favicon(): """Serve favicon to prevent 404 errors""" from flask import Response # Return a simple 1x1 transparent pixel as favicon favicon_data = b'\x00\x00\x01\x00\x01\x00\x10\x10\x00\x00\x01\x00\x20\x00\x68\x04\x00\x00\x16\x00\x00\x00\x28\x00\x00\x00\x10\x00\x00\x00\x20\x00\x00\x00\x01\x00\x20\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' return Response(favicon_data, mimetype='image/x-icon') @self.server.route('/') def index(): """Main dashboard page - loads existing annotations""" try: # Get symbols and timeframes from config symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) timeframes = self.config.get('timeframes', ['1s', '1m', '1h', '1d']) current_symbol = symbols[0] if symbols else 'ETH/USDT' # Get annotations filtered by current symbol annotations = self.annotation_manager.get_annotations(symbol=current_symbol) # Convert to serializable format annotations_data = [] for ann in annotations: if hasattr(ann, '__dict__'): ann_dict = ann.__dict__ else: ann_dict = ann # Ensure all fields are JSON serializable annotations_data.append({ 'annotation_id': ann_dict.get('annotation_id'), 'symbol': ann_dict.get('symbol'), 'timeframe': ann_dict.get('timeframe'), 'entry': ann_dict.get('entry'), 'exit': ann_dict.get('exit'), 'direction': ann_dict.get('direction'), 'profit_loss_pct': ann_dict.get('profit_loss_pct'), 'notes': ann_dict.get('notes', ''), 'created_at': ann_dict.get('created_at') }) logger.info(f"Loading dashboard with {len(annotations_data)} annotations for {current_symbol}") # Prepare template data template_data = { 'current_symbol': current_symbol, 'symbols': symbols, 'timeframes': timeframes, 'annotations': annotations_data } return render_template('annotation_dashboard.html', **template_data) except Exception as e: logger.error(f"Error rendering main page: {e}") # Fallback simple HTML page return f""" ANNOTATE - Manual Trade Annotation UI

📝 ANNOTATE - Manual Trade Annotation UI

System Status

Annotation Manager: Active

Data Provider: {'Available' if self.data_provider else 'Not Available (Standalone Mode)'}

Trading Orchestrator: {'Available' if self.orchestrator else 'Not Available (Standalone Mode)'}

Available Features

  • Manual trade annotation
  • Test case generation
  • Annotation export
  • Real model training

API Endpoints

  • POST /api/chart-data - Get chart data
  • POST /api/save-annotation - Save annotation
  • POST /api/delete-annotation - Delete annotation
  • POST /api/generate-test-case - Generate test case
  • POST /api/export-annotations - Export annotations
Go to Dash Interface
""" @self.server.route('/api/recalculate-pivots', methods=['POST']) def recalculate_pivots(): """Recalculate pivot points for merged data using cached data from data_loader""" try: data = request.get_json() symbol = data.get('symbol', 'ETH/USDT') timeframe = data.get('timeframe') # We don't use timestamps/ohlcv from frontend anymore, we use our own consistent data source if not timeframe: return jsonify({ 'success': False, 'error': {'code': 'INVALID_REQUEST', 'message': 'Missing timeframe'} }) pivot_logger.info(f"Recalculating pivots for {symbol} {timeframe} using backend data") if not self.data_loader: return jsonify({ 'success': False, 'error': {'code': 'DATA_LOADER_UNAVAILABLE', 'message': 'Data loader not available'} }) # Fetch latest data from data_loader (which should have the updated cache/DB from previous calls) # We get enough history for proper pivot calculation df = self.data_loader.get_data( symbol=symbol, timeframe=timeframe, limit=2500, # Enough for context direction='latest' ) if df is None or df.empty: logger.warning(f"No data found for {symbol} {timeframe} to recalculate pivots") return jsonify({ 'success': True, 'pivot_markers': {} }) # Recalculate pivot markers pivot_markers = self._get_pivot_markers_for_timeframe(symbol, timeframe, df) pivot_logger.info(f"Recalculated {len(pivot_markers)} pivot candles") return jsonify({ 'success': True, 'pivot_markers': pivot_markers }) except Exception as e: logger.error(f"Error recalculating pivots: {e}") return jsonify({ 'success': False, 'error': {'code': 'RECALC_ERROR', 'message': str(e)} }) @self.server.route('/api/chart-data', methods=['POST']) def get_chart_data(): """Get chart data for specified symbol and timeframes with infinite scroll support""" try: data = request.get_json() symbol = data.get('symbol', 'ETH/USDT') timeframes = data.get('timeframes', ['1s', '1m', '1h', '1d']) start_time_str = data.get('start_time') end_time_str = data.get('end_time') limit = data.get('limit', 2500) # Default 2500 candles for training direction = data.get('direction', 'latest') # 'latest', 'before', or 'after' webui_logger.info(f"Chart data request: {symbol} {timeframes} direction={direction} limit={limit}") if start_time_str: webui_logger.info(f" start_time: {start_time_str}") if end_time_str: webui_logger.info(f" end_time: {end_time_str}") if not self.data_loader: return jsonify({ 'success': False, 'error': { 'code': 'DATA_LOADER_UNAVAILABLE', 'message': 'Data loader not available' } }) # Parse time strings if provided start_time = datetime.fromisoformat(start_time_str.replace('Z', '+00:00')) if start_time_str else None end_time = datetime.fromisoformat(end_time_str.replace('Z', '+00:00')) if end_time_str else None # Fetch data for each timeframe using data loader # This will automatically: # 1. Check DuckDB first # 2. Fetch from API if not in cache # 3. Store in DuckDB for future use chart_data = {} for timeframe in timeframes: df = self.data_loader.get_data( symbol=symbol, timeframe=timeframe, start_time=start_time, end_time=end_time, limit=limit, direction=direction ) if df is not None and not df.empty: webui_logger.info(f" {timeframe}: {len(df)} candles ({df.index[0]} to {df.index[-1]})") # Get pivot points for this timeframe (only if we have enough context) pivot_markers = {} if len(df) >= 50: pivot_markers = self._get_pivot_markers_for_timeframe(symbol, timeframe, df) # Convert to format suitable for Plotly chart_data[timeframe] = { 'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(), 'open': df['open'].tolist(), 'high': df['high'].tolist(), 'low': df['low'].tolist(), 'close': df['close'].tolist(), 'volume': df['volume'].tolist(), 'pivot_markers': pivot_markers # Optional: only present if pivots exist } else: logger.warning(f" {timeframe}: No data returned") # Get pivot bounds for the symbol pivot_bounds = None if self.data_provider: try: pivot_bounds = self.data_provider.get_pivot_bounds(symbol) if pivot_bounds: logger.info(f"Found pivot bounds for {symbol}: {len(pivot_bounds.pivot_support_levels)} support, {len(pivot_bounds.pivot_resistance_levels)} resistance") except Exception as e: logger.error(f"Error getting pivot bounds: {e}") return jsonify({ 'success': True, 'chart_data': chart_data, 'pivot_bounds': { 'support_levels': pivot_bounds.pivot_support_levels if pivot_bounds else [], 'resistance_levels': pivot_bounds.pivot_resistance_levels if pivot_bounds else [], 'price_range': { 'min': pivot_bounds.price_min if pivot_bounds else None, 'max': pivot_bounds.price_max if pivot_bounds else None }, 'volume_range': { 'min': pivot_bounds.volume_min if pivot_bounds else None, 'max': pivot_bounds.volume_max if pivot_bounds else None }, 'timeframe': '1m', # Pivot bounds are calculated from 1m data 'period': '30 days', # Monthly data 'total_levels': len(pivot_bounds.pivot_support_levels) + len(pivot_bounds.pivot_resistance_levels) if pivot_bounds else 0 } if pivot_bounds else None }) except Exception as e: logger.error(f"Error fetching chart data: {e}") return jsonify({ 'success': False, 'error': { 'code': 'CHART_DATA_ERROR', 'message': str(e) } }) @self.server.route('/api/save-annotation', methods=['POST']) def save_annotation(): """Save a new annotation with full market context""" try: data = request.get_json() # Capture market state at entry and exit times using data provider entry_market_state = {} exit_market_state = {} if self.data_provider: try: # Parse timestamps entry_time = datetime.fromisoformat(data['entry']['timestamp'].replace('Z', '+00:00')) exit_time = datetime.fromisoformat(data['exit']['timestamp'].replace('Z', '+00:00')) # Use the new data provider method to get market state at entry time entry_market_state = self.data_provider.get_market_state_at_time( symbol=data['symbol'], timestamp=entry_time, context_window_minutes=5 ) # Use the new data provider method to get market state at exit time exit_market_state = self.data_provider.get_market_state_at_time( symbol=data['symbol'], timestamp=exit_time, context_window_minutes=5 ) logger.info(f"Captured market state: {len(entry_market_state)} timeframes at entry, {len(exit_market_state)} at exit") except Exception as e: logger.error(f"Error capturing market state: {e}") import traceback traceback.print_exc() # Create annotation with market context annotation = self.annotation_manager.create_annotation( entry_point=data['entry'], exit_point=data['exit'], symbol=data['symbol'], timeframe=data['timeframe'], entry_market_state=entry_market_state, exit_market_state=exit_market_state ) # Collect market snapshots for SQLite storage market_snapshots = {} if self.data_loader: try: # Get OHLCV data for all timeframes around the annotation time entry_time = datetime.fromisoformat(data['entry']['timestamp'].replace('Z', '+00:00')) exit_time = datetime.fromisoformat(data['exit']['timestamp'].replace('Z', '+00:00')) # Get data from 5 minutes before entry to 5 minutes after exit start_time = entry_time - timedelta(minutes=5) end_time = exit_time + timedelta(minutes=5) for timeframe in ['1s', '1m', '1h', '1d']: df = self.data_loader.get_data( symbol=data['symbol'], timeframe=timeframe, start_time=start_time, end_time=end_time, limit=1500 ) if df is not None and not df.empty: market_snapshots[timeframe] = df logger.info(f"Collected {len(market_snapshots)} timeframes for annotation storage") except Exception as e: logger.error(f"Error collecting market snapshots: {e}") # Save annotation with market snapshots self.annotation_manager.save_annotation( annotation=annotation, market_snapshots=market_snapshots ) # Automatically generate test case with ±5min data try: test_case = self.annotation_manager.generate_test_case( annotation, data_provider=self.data_provider, auto_save=True ) # Log test case details market_state = test_case.get('market_state', {}) timeframes_with_data = [k for k in market_state.keys() if k.startswith('ohlcv_')] logger.info(f"Auto-generated test case: {test_case['test_case_id']}") logger.info(f" Timeframes: {timeframes_with_data}") for tf_key in timeframes_with_data: candle_count = len(market_state[tf_key].get('timestamps', [])) logger.info(f" {tf_key}: {candle_count} candles") if 'training_labels' in market_state: logger.info(f" Training labels: {len(market_state['training_labels'].get('labels_1m', []))} labels") except Exception as e: logger.error(f"Failed to auto-generate test case: {e}") import traceback traceback.print_exc() return jsonify({ 'success': True, 'annotation': annotation.__dict__ if hasattr(annotation, '__dict__') else annotation }) except Exception as e: logger.error(f"Error saving annotation: {e}") return jsonify({ 'success': False, 'error': { 'code': 'SAVE_ANNOTATION_ERROR', 'message': str(e) } }) @self.server.route('/api/delete-annotation', methods=['POST']) def delete_annotation(): """Delete an annotation""" try: data = request.get_json() annotation_id = data['annotation_id'] # Delete annotation and check if it was found deleted = self.annotation_manager.delete_annotation(annotation_id) if deleted: return jsonify({ 'success': True, 'message': 'Annotation deleted successfully' }) else: return jsonify({ 'success': False, 'error': { 'code': 'ANNOTATION_NOT_FOUND', 'message': f'Annotation {annotation_id} not found' } }) except Exception as e: logger.error(f"Error deleting annotation: {e}", exc_info=True) return jsonify({ 'success': False, 'error': { 'code': 'DELETE_ANNOTATION_ERROR', 'message': str(e) } }) @self.server.route('/api/clear-all-annotations', methods=['POST']) def clear_all_annotations(): """Clear all annotations""" try: data = request.get_json() or {} symbol = data.get('symbol', None) # Use the efficient clear_all_annotations method deleted_count = self.annotation_manager.clear_all_annotations(symbol=symbol) if deleted_count == 0: return jsonify({ 'success': True, 'deleted_count': 0, 'message': 'No annotations to clear' }) logger.info(f"Cleared {deleted_count} annotations" + (f" for symbol {symbol}" if symbol else "")) return jsonify({ 'success': True, 'deleted_count': deleted_count, 'message': f'Cleared {deleted_count} annotations' }) except Exception as e: logger.error(f"Error clearing all annotations: {e}") import traceback logger.error(traceback.format_exc()) return jsonify({ 'success': False, 'error': { 'code': 'CLEAR_ALL_ANNOTATIONS_ERROR', 'message': str(e) } }) @self.server.route('/api/refresh-data', methods=['POST']) def refresh_data(): """Refresh chart data from data provider""" try: data = request.get_json() symbol = data.get('symbol', 'ETH/USDT') timeframes = data.get('timeframes', ['1s', '1m', '1h', '1d']) logger.info(f"Refreshing data for {symbol} with timeframes: {timeframes}") # Force refresh data from data provider chart_data = {} if self.data_provider: for timeframe in timeframes: try: # Force refresh by setting refresh=True df = self.data_provider.get_historical_data( symbol=symbol, timeframe=timeframe, limit=1000, refresh=True ) if df is not None and not df.empty: # Get pivot markers for this timeframe pivot_markers = self._get_pivot_markers_for_timeframe(symbol, timeframe, df) chart_data[timeframe] = { 'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(), 'open': df['open'].tolist(), 'high': df['high'].tolist(), 'low': df['low'].tolist(), 'close': df['close'].tolist(), 'volume': df['volume'].tolist(), 'pivot_markers': pivot_markers # Optional: only present if pivots exist } logger.info(f"Refreshed {timeframe}: {len(df)} candles") else: logger.warning(f"No data available for {timeframe}") except Exception as e: logger.error(f"Error refreshing {timeframe} data: {e}") # Get pivot bounds for the symbol pivot_bounds = None if self.data_provider: try: pivot_bounds = self.data_provider.get_pivot_bounds(symbol) if pivot_bounds: logger.info(f"Found pivot bounds for {symbol}: {len(pivot_bounds.pivot_support_levels)} support, {len(pivot_bounds.pivot_resistance_levels)} resistance") except Exception as e: logger.error(f"Error getting pivot bounds: {e}") return jsonify({ 'success': True, 'chart_data': chart_data, 'pivot_bounds': { 'support_levels': pivot_bounds.pivot_support_levels if pivot_bounds else [], 'resistance_levels': pivot_bounds.pivot_resistance_levels if pivot_bounds else [], 'price_range': { 'min': pivot_bounds.price_min if pivot_bounds else None, 'max': pivot_bounds.price_max if pivot_bounds else None }, 'volume_range': { 'min': pivot_bounds.volume_min if pivot_bounds else None, 'max': pivot_bounds.volume_max if pivot_bounds else None }, 'timeframe': '1m', # Pivot bounds are calculated from 1m data 'period': '30 days', # Monthly data 'total_levels': len(pivot_bounds.pivot_support_levels) + len(pivot_bounds.pivot_resistance_levels) if pivot_bounds else 0 } if pivot_bounds else None, 'message': f'Refreshed data for {symbol}' }) except Exception as e: logger.error(f"Error refreshing data: {e}") return jsonify({ 'success': False, 'error': { 'code': 'REFRESH_DATA_ERROR', 'message': str(e) } }) @self.server.route('/api/generate-test-case', methods=['POST']) def generate_test_case(): """Generate test case from annotation""" try: data = request.get_json() annotation_id = data['annotation_id'] # Get annotation annotations = self.annotation_manager.get_annotations() annotation = next((a for a in annotations if (a.annotation_id if hasattr(a, 'annotation_id') else a.get('annotation_id')) == annotation_id), None) if not annotation: return jsonify({ 'success': False, 'error': { 'code': 'ANNOTATION_NOT_FOUND', 'message': 'Annotation not found' } }) # Generate test case with market context test_case = self.annotation_manager.generate_test_case( annotation, data_provider=self.data_provider ) return jsonify({ 'success': True, 'test_case': test_case }) except Exception as e: logger.error(f"Error generating test case: {e}") return jsonify({ 'success': False, 'error': { 'code': 'GENERATE_TESTCASE_ERROR', 'message': str(e) } }) @self.server.route('/api/get-annotations', methods=['POST']) def get_annotations_api(): """Get annotations filtered by symbol""" try: data = request.get_json() symbol = data.get('symbol', 'ETH/USDT') # Get annotations for this symbol annotations = self.annotation_manager.get_annotations(symbol=symbol) # Convert to serializable format annotations_data = [] for ann in annotations: if hasattr(ann, '__dict__'): ann_dict = ann.__dict__ else: ann_dict = ann annotations_data.append({ 'annotation_id': ann_dict.get('annotation_id'), 'symbol': ann_dict.get('symbol'), 'timeframe': ann_dict.get('timeframe'), 'entry': ann_dict.get('entry'), 'exit': ann_dict.get('exit'), 'direction': ann_dict.get('direction'), 'profit_loss_pct': ann_dict.get('profit_loss_pct'), 'notes': ann_dict.get('notes', ''), 'created_at': ann_dict.get('created_at') }) logger.info(f"Returning {len(annotations_data)} annotations for {symbol}") return jsonify({ 'success': True, 'annotations': annotations_data, 'symbol': symbol, 'count': len(annotations_data) }) except Exception as e: logger.error(f"Error getting annotations: {e}") return jsonify({ 'success': False, 'error': str(e) }) @self.server.route('/api/export-annotations', methods=['POST']) def export_annotations(): """Export annotations to file""" try: data = request.get_json() symbol = data.get('symbol') format_type = data.get('format', 'json') # Get annotations annotations = self.annotation_manager.get_annotations(symbol=symbol) # Export to file output_path = self.annotation_manager.export_annotations( annotations=annotations, format_type=format_type ) return send_file(output_path, as_attachment=True) except Exception as e: logger.error(f"Error exporting annotations: {e}") return jsonify({ 'success': False, 'error': { 'code': 'EXPORT_ERROR', 'message': str(e) } }) @self.server.route('/api/train-model', methods=['POST']) def train_model(): """Start model training with annotations""" try: if not self.training_adapter: return jsonify({ 'success': False, 'error': { 'code': 'TRAINING_UNAVAILABLE', 'message': 'Real training adapter not available' } }) data = request.get_json() model_name = data['model_name'] annotation_ids = data.get('annotation_ids', []) # CRITICAL: Get current symbol to filter annotations current_symbol = data.get('symbol', 'ETH/USDT') # Get primary timeframe for display (optional) timeframe = data.get('timeframe', '1m') # If no specific annotations provided, use all for current symbol if not annotation_ids: annotations = self.annotation_manager.get_annotations(symbol=current_symbol) annotation_ids = [ a.annotation_id if hasattr(a, 'annotation_id') else a.get('annotation_id') for a in annotations ] logger.info(f"Using all {len(annotation_ids)} annotations for {current_symbol}") # Load test cases from disk (they were auto-generated when annotations were saved) # Filter by current symbol to avoid cross-symbol training all_test_cases = self.annotation_manager.get_all_test_cases(symbol=current_symbol) # Filter to selected annotations test_cases = [ tc for tc in all_test_cases if tc['test_case_id'].replace('annotation_', '') in annotation_ids ] if not test_cases: return jsonify({ 'success': False, 'error': { 'code': 'NO_TEST_CASES', 'message': f'No test cases found for {len(annotation_ids)} annotations' } }) logger.info(f"Starting REAL training with {len(test_cases)} test cases ({len(annotation_ids)} annotations) for model {model_name} on {timeframe}") # Start REAL training (NO SIMULATION!) training_id = self.training_adapter.start_training( model_name=model_name, test_cases=test_cases, annotation_count=len(annotation_ids), timeframe=timeframe ) return jsonify({ 'success': True, 'training_id': training_id, 'test_cases_count': len(test_cases) }) except Exception as e: logger.error(f"Error starting training: {e}") return jsonify({ 'success': False, 'error': { 'code': 'TRAINING_ERROR', 'message': str(e) } }) @self.server.route('/api/training-progress', methods=['POST']) def get_training_progress(): """Get training progress""" try: if not self.training_adapter: return jsonify({ 'success': False, 'error': { 'code': 'TRAINING_UNAVAILABLE', 'message': 'Real training adapter not available' } }) data = request.get_json() training_id = data['training_id'] progress = self.training_adapter.get_training_progress(training_id) return jsonify({ 'success': True, 'progress': progress }) except Exception as e: logger.error(f"Error getting training progress: {e}") return jsonify({ 'success': False, 'error': { 'code': 'PROGRESS_ERROR', 'message': str(e) } }) # Backtest API Endpoints @self.server.route('/api/backtest', methods=['POST']) def start_backtest(): """Start backtest on visible chart data""" try: data = request.get_json() model_name = data['model_name'] symbol = data['symbol'] timeframe = data['timeframe'] start_time = data.get('start_time') end_time = data.get('end_time') # Get the loaded model if model_name not in self.loaded_models: return jsonify({ 'success': False, 'error': f'Model {model_name} not loaded. Please load it first.' }) model = self.loaded_models[model_name] # Generate backtest ID backtest_id = str(uuid.uuid4()) # Start backtest in background self.backtest_runner.start_backtest( backtest_id=backtest_id, model=model, data_provider=self.data_provider, symbol=symbol, timeframe=timeframe, orchestrator=self.orchestrator, start_time=start_time, end_time=end_time ) # Get initial state progress = self.backtest_runner.get_progress(backtest_id) return jsonify({ 'success': True, 'backtest_id': backtest_id, 'total_candles': progress.get('total_candles', 0) }) except Exception as e: logger.error(f"Error starting backtest: {e}", exc_info=True) return jsonify({ 'success': False, 'error': str(e) }) @self.server.route('/api/backtest/progress/', methods=['GET']) def get_backtest_progress(backtest_id): """Get backtest progress""" try: progress = self.backtest_runner.get_progress(backtest_id) return jsonify(progress) except Exception as e: logger.error(f"Error getting backtest progress: {e}") return jsonify({ 'success': False, 'error': str(e) }) @self.server.route('/api/backtest/stop', methods=['POST']) def stop_backtest(): """Stop running backtest""" try: data = request.get_json() backtest_id = data['backtest_id'] self.backtest_runner.stop_backtest(backtest_id) return jsonify({ 'success': True }) except Exception as e: logger.error(f"Error stopping backtest: {e}") return jsonify({ 'success': False, 'error': str(e) }) @self.server.route('/api/active-training', methods=['GET']) def get_active_training(): """ Get currently active training session (if any) Allows UI to resume tracking after page reload or across multiple clients """ try: if not self.training_adapter: return jsonify({ 'success': False, 'active': False, 'error': { 'code': 'TRAINING_UNAVAILABLE', 'message': 'Real training adapter not available' } }) active_session = self.training_adapter.get_active_training_session() if active_session: return jsonify({ 'success': True, 'active': True, 'session': active_session }) else: return jsonify({ 'success': True, 'active': False }) except Exception as e: logger.error(f"Error getting active training: {e}") return jsonify({ 'success': False, 'active': False, 'error': { 'code': 'ACTIVE_TRAINING_ERROR', 'message': str(e) } }) # Live Training API Endpoints @self.server.route('/api/live-training/start', methods=['POST']) def start_live_training(): """Start live inference and training mode""" try: if not self.orchestrator: return jsonify({ 'success': False, 'error': 'Orchestrator not available' }), 500 if self.orchestrator.start_live_training(): return jsonify({ 'success': True, 'status': 'started', 'message': 'Live training mode started' }) else: return jsonify({ 'success': False, 'error': 'Failed to start live training' }), 500 except Exception as e: logger.error(f"Error starting live training: {e}") return jsonify({ 'success': False, 'error': str(e) }), 500 @self.server.route('/api/live-training/stop', methods=['POST']) def stop_live_training(): """Stop live inference and training mode""" try: if not self.orchestrator: return jsonify({ 'success': False, 'error': 'Orchestrator not available' }), 500 if self.orchestrator.stop_live_training(): return jsonify({ 'success': True, 'status': 'stopped', 'message': 'Live training mode stopped' }) else: return jsonify({ 'success': False, 'error': 'Failed to stop live training' }), 500 except Exception as e: logger.error(f"Error stopping live training: {e}") return jsonify({ 'success': False, 'error': str(e) }), 500 @self.server.route('/api/live-training/status', methods=['GET']) def get_live_training_status(): """Get live training status and statistics""" try: if not self.orchestrator: return jsonify({ 'success': False, 'active': False, 'error': 'Orchestrator not available' }) is_active = self.orchestrator.is_live_training_active() stats = self.orchestrator.get_live_training_stats() if is_active else {} return jsonify({ 'success': True, 'active': is_active, 'stats': stats }) except Exception as e: logger.error(f"Error getting live training status: {e}") return jsonify({ 'success': False, 'active': False, 'error': str(e) }) @self.server.route('/api/available-models', methods=['GET']) def get_available_models(): """Get list of available models with their load status""" try: # Ensure self.available_models is a list if not isinstance(self.available_models, list): logger.warning(f"self.available_models is not a list: {type(self.available_models)}. Resetting to default.") self.available_models = ['Transformer', 'COB_RL', 'CNN', 'DQN'] # Ensure self.loaded_models exists (it's a dict) if not hasattr(self, 'loaded_models'): self.loaded_models = {} # Build model state dict with checkpoint info logger.info(f"Building model states for {len(self.available_models)} models: {self.available_models}") logger.info(f"Currently loaded models: {list(self.loaded_models.keys())}") model_states = [] for model_name in self.available_models: # Check if model is in loaded_models dict is_loaded = model_name in self.loaded_models and self.loaded_models[model_name] is not None # Get checkpoint info (even for unloaded models) checkpoint_info = None # If loaded, get from orchestrator if is_loaded and self.orchestrator: checkpoint_attr = f"{model_name.lower()}_checkpoint_info" if hasattr(self.orchestrator, checkpoint_attr): cp_info = getattr(self.orchestrator, checkpoint_attr) if cp_info and cp_info.get('status') == 'loaded': checkpoint_info = { 'filename': cp_info.get('filename', 'unknown'), 'epoch': cp_info.get('epoch', 0), 'loss': cp_info.get('loss', 0.0), 'accuracy': cp_info.get('accuracy', 0.0), 'loaded_at': cp_info.get('loaded_at', ''), 'source': 'loaded' } # If not loaded, try to read best checkpoint from disk (filename parsing only) if not checkpoint_info: try: cp_info = self._get_best_checkpoint_info(model_name) if cp_info: checkpoint_info = cp_info checkpoint_info['source'] = 'disk' except Exception as e: logger.warning(f"Could not read checkpoint for {model_name}: {e}") # Continue without checkpoint info - not critical model_states.append({ 'name': model_name, 'loaded': is_loaded, 'can_train': is_loaded, 'can_infer': is_loaded, 'checkpoint': checkpoint_info # Checkpoint metadata (loaded or from disk) }) logger.info(f"Returning {len(model_states)} model states") return jsonify({ 'success': True, 'models': model_states, 'loaded_count': len(self.loaded_models), 'available_count': len(self.available_models) }) except Exception as e: logger.error(f"Error getting available models: {e}") import traceback logger.error(f"Traceback: {traceback.format_exc()}") # Return a fallback list so the UI doesn't hang return jsonify({ 'success': True, 'models': [ {'name': 'Transformer', 'loaded': False, 'can_train': False, 'can_infer': False}, {'name': 'COB_RL', 'loaded': False, 'can_train': False, 'can_infer': False} ], 'loaded_count': 0, 'available_count': 2, 'error': str(e) }) @self.server.route('/api/load-model', methods=['POST']) def load_model(): """Load a specific model on demand""" try: data = request.get_json() model_name = data.get('model_name') if not model_name: return jsonify({ 'success': False, 'error': 'model_name is required' }) # Load the model result = self._load_model_lazy(model_name) return jsonify(result) except Exception as e: logger.error(f"Error in load_model endpoint: {e}") return jsonify({ 'success': False, 'error': str(e) }) @self.server.route('/api/realtime-inference/start', methods=['POST']) def start_realtime_inference(): """Start real-time inference mode with configurable training strategy""" try: data = request.get_json() model_name = data.get('model_name') symbol = data.get('symbol', 'ETH/USDT') timeframe = data.get('timeframe', '1m') # New unified training_mode parameter training_mode = data.get('training_mode', 'none') # 'none', 'every_candle', 'pivots_only', 'manual' # Backward compatibility with old parameters if 'enable_live_training' in data or 'train_every_candle' in data: enable_live_training = data.get('enable_live_training', False) train_every_candle = data.get('train_every_candle', False) training_mode = 'every_candle' if train_every_candle else ('pivots_only' if enable_live_training else 'none') if not self.training_adapter: return jsonify({ 'success': False, 'error': { 'code': 'TRAINING_UNAVAILABLE', 'message': 'Real training adapter not available' } }) # Set training mode on strategy manager self.training_strategy.mode = training_mode logger.info(f"Training strategy mode set to: {training_mode}") # Start real-time inference - pass strategy manager for training decisions inference_id = self.training_adapter.start_realtime_inference( model_name=model_name, symbol=symbol, data_provider=self.data_provider, enable_live_training=(training_mode != 'none'), train_every_candle=(training_mode == 'every_candle'), timeframe=timeframe, training_strategy=self.training_strategy # Pass strategy manager ) return jsonify({ 'success': True, 'inference_id': inference_id, 'training_mode': training_mode, 'timeframe': timeframe }) except Exception as e: logger.error(f"Error starting real-time inference: {e}") return jsonify({ 'success': False, 'error': { 'code': 'INFERENCE_START_ERROR', 'message': str(e) } }) @self.server.route('/api/realtime-inference/stop', methods=['POST']) def stop_realtime_inference(): """Stop real-time inference mode""" try: data = request.get_json() inference_id = data.get('inference_id') if not self.training_adapter: return jsonify({ 'success': False, 'error': { 'code': 'TRAINING_UNAVAILABLE', 'message': 'Real training adapter not available' } }) self.training_adapter.stop_realtime_inference(inference_id) return jsonify({ 'success': True }) except Exception as e: logger.error(f"Error stopping real-time inference: {e}") return jsonify({ 'success': False, 'error': { 'code': 'INFERENCE_STOP_ERROR', 'message': str(e) } }) @self.server.route('/api/live-updates', methods=['POST']) def get_live_updates(): """Get live chart and prediction updates (polling endpoint)""" try: data = request.get_json() symbol = data.get('symbol', 'ETH/USDT') timeframe = data.get('timeframe', '1m') response = { 'success': True, 'chart_update': None, 'prediction': None } # Get latest candle for the requested timeframe if self.orchestrator and self.orchestrator.data_provider: try: # Get latest candle ohlcv_data = self.orchestrator.data_provider.get_ohlcv_data(symbol, timeframe, limit=1) if ohlcv_data and len(ohlcv_data) > 0: latest_candle = ohlcv_data[-1] response['chart_update'] = { 'symbol': symbol, 'timeframe': timeframe, 'candle': { 'timestamp': latest_candle[0], 'open': float(latest_candle[1]), 'high': float(latest_candle[2]), 'low': float(latest_candle[3]), 'close': float(latest_candle[4]), 'volume': float(latest_candle[5]) } } except Exception as e: logger.debug(f"Error getting latest candle: {e}") # Get latest model predictions if self.orchestrator: try: # Get latest predictions from orchestrator predictions = {} # DQN predictions if hasattr(self.orchestrator, 'recent_dqn_predictions') and symbol in self.orchestrator.recent_dqn_predictions: dqn_preds = list(self.orchestrator.recent_dqn_predictions[symbol]) if dqn_preds: predictions['dqn'] = dqn_preds[-1] # CNN predictions if hasattr(self.orchestrator, 'recent_cnn_predictions') and symbol in self.orchestrator.recent_cnn_predictions: cnn_preds = list(self.orchestrator.recent_cnn_predictions[symbol]) if cnn_preds: predictions['cnn'] = cnn_preds[-1] # Transformer predictions with next_candles for ghost candles # First check if there are stored predictions from the inference loop if hasattr(self.orchestrator, 'recent_transformer_predictions') and symbol in self.orchestrator.recent_transformer_predictions: transformer_preds = list(self.orchestrator.recent_transformer_predictions[symbol]) if transformer_preds: # Convert any remaining tensors to Python types before JSON serialization transformer_pred = transformer_preds[-1].copy() predictions['transformer'] = self._serialize_prediction(transformer_pred) if predictions: response['prediction'] = predictions except Exception as e: logger.debug(f"Error getting predictions: {e}") import traceback logger.debug(traceback.format_exc()) return jsonify(response) except Exception as e: logger.error(f"Error in live updates: {e}") return jsonify({ 'success': False, 'error': str(e) }) @self.server.route('/api/realtime-inference/signals', methods=['GET']) def get_realtime_signals(): """Get latest real-time inference signals""" try: if not self.training_adapter: return jsonify({ 'success': False, 'error': { 'code': 'TRAINING_UNAVAILABLE', 'message': 'Real training adapter not available' } }) signals = self.training_adapter.get_latest_signals() # Get metrics from active inference sessions metrics = {'accuracy': 0.0, 'loss': 0.0} if hasattr(self.training_adapter, 'inference_sessions'): for session in self.training_adapter.inference_sessions.values(): if 'metrics' in session: metrics = session['metrics'] break return jsonify({ 'success': True, 'signals': signals, 'metrics': metrics }) except Exception as e: logger.error(f"Error getting signals: {e}") return jsonify({ 'success': False, 'error': { 'code': 'SIGNALS_ERROR', 'message': str(e) } }) @self.server.route('/api/realtime-inference/train-manual', methods=['POST']) def train_manual(): """Manually trigger training on current candle with specified action""" try: data = request.get_json() inference_id = data.get('inference_id') action = data.get('action', 'HOLD') if not self.training_adapter: return jsonify({ 'success': False, 'error': 'Training adapter not available' }) # Get active inference session if not hasattr(self.training_adapter, 'inference_sessions'): return jsonify({ 'success': False, 'error': 'No active inference sessions' }) session = self.training_adapter.inference_sessions.get(inference_id) if not session: return jsonify({ 'success': False, 'error': 'Inference session not found' }) # Set pending action for training session['pending_action'] = action # Get session parameters symbol = session.get('symbol', 'ETH/USDT') timeframe = session.get('timeframe', '1m') data_provider = session.get('data_provider') # Call training method train_result = self.training_adapter._train_on_new_candle( session, symbol, timeframe, data_provider ) if train_result.get('success'): return jsonify({ 'success': True, 'action': action, 'metrics': { 'loss': train_result.get('loss', 0.0), 'accuracy': train_result.get('accuracy', 0.0), 'training_steps': train_result.get('training_steps', 0) } }) else: return jsonify({ 'success': False, 'error': train_result.get('error', 'Training failed') }) except Exception as e: logger.error(f"Error in manual training: {e}") return jsonify({ 'success': False, 'error': str(e) }) # WebSocket event handlers (if SocketIO is available) if self.has_socketio: self._setup_websocket_handlers() def _serialize_prediction(self, prediction: Dict) -> Dict: """Convert PyTorch tensors in prediction dict to JSON-serializable Python types""" try: import torch serialized = {} for key, value in prediction.items(): if isinstance(value, torch.Tensor): if value.numel() == 1: # Scalar tensor serialized[key] = value.item() else: # Multi-element tensor serialized[key] = value.detach().cpu().tolist() elif isinstance(value, dict): serialized[key] = self._serialize_prediction(value) # Recursive elif isinstance(value, (list, tuple)): serialized[key] = [ v.item() if isinstance(v, torch.Tensor) and v.numel() == 1 else (v.detach().cpu().tolist() if isinstance(v, torch.Tensor) else v) for v in value ] else: serialized[key] = value return serialized except Exception as e: logger.warning(f"Error serializing prediction: {e}") # Fallback: return as-is (might fail JSON serialization but won't crash) return prediction def _setup_websocket_handlers(self): """Setup WebSocket event handlers for real-time updates""" if not self.has_socketio: return @self.socketio.on('connect') def handle_connect(): """Handle client connection""" logger.info(f"WebSocket client connected") from flask_socketio import emit emit('connection_response', {'status': 'connected', 'message': 'Connected to ANNOTATE live updates'}) @self.socketio.on('disconnect') def handle_disconnect(): """Handle client disconnection""" logger.info(f"WebSocket client disconnected") @self.socketio.on('subscribe_live_updates') def handle_subscribe(data): """Subscribe to live chart and prediction updates""" from flask_socketio import emit, join_room symbol = data.get('symbol', 'ETH/USDT') timeframe = data.get('timeframe', '1s') room = f"{symbol}_{timeframe}" join_room(room) logger.info(f"Client subscribed to live updates: {room}") emit('subscription_confirmed', {'room': room, 'symbol': symbol, 'timeframe': timeframe}) # Start live update thread if not already running if not hasattr(self, '_live_update_thread') or not self._live_update_thread.is_alive(): self._start_live_update_thread() @self.socketio.on('request_prediction') def handle_prediction_request(data): """Handle manual prediction request""" from flask_socketio import emit try: symbol = data.get('symbol', 'ETH/USDT') timeframe = data.get('timeframe', '1s') prediction_steps = data.get('prediction_steps', 1) # Get prediction from model prediction = self._get_live_prediction(symbol, timeframe, prediction_steps) emit('prediction_update', prediction) except Exception as e: logger.error(f"Error handling prediction request: {e}") emit('prediction_error', {'error': str(e)}) @self.socketio.on('prediction_accuracy') def handle_prediction_accuracy(data): """ Handle validated prediction accuracy - trigger incremental training This is called when frontend validates a prediction against actual candle. We use this data to incrementally train the model for continuous improvement. """ from flask_socketio import emit try: timeframe = data.get('timeframe') timestamp = data.get('timestamp') predicted = data.get('predicted') # [O, H, L, C, V] actual = data.get('actual') # [O, H, L, C] errors = data.get('errors') # {open, high, low, close} pct_errors = data.get('pctErrors') direction_correct = data.get('directionCorrect') accuracy = data.get('accuracy') if not all([timeframe, timestamp, predicted, actual]): logger.warning("Incomplete prediction accuracy data received") return logger.info(f"[{timeframe}] Prediction validated: {accuracy:.1f}% accuracy, direction: {direction_correct}") logger.debug(f" Errors: O={pct_errors['open']:.2f}% H={pct_errors['high']:.2f}% L={pct_errors['low']:.2f}% C={pct_errors['close']:.2f}%") # Trigger incremental training on this validated prediction self._train_on_validated_prediction( timeframe=timeframe, timestamp=timestamp, predicted=predicted, actual=actual, errors=errors, direction_correct=direction_correct, accuracy=accuracy ) # Send confirmation back to frontend emit('training_update', { 'status': 'training_triggered', 'timestamp': timestamp, 'accuracy': accuracy, 'message': f'Incremental training triggered on validated prediction' }) except Exception as e: logger.error(f"Error handling prediction accuracy: {e}", exc_info=True) emit('training_error', {'error': str(e)}) def _start_live_update_thread(self): """Start background thread for live updates""" import threading def live_update_worker(): """Background worker for live updates""" import time from flask_socketio import emit logger.info("Live update thread started") while True: try: # Get active rooms (symbol_timeframe combinations) # For now, update all subscribed clients every second # Get latest chart data if self.data_provider: for symbol in ['ETH/USDT', 'BTC/USDT']: # TODO: Get from active subscriptions for timeframe in ['1s', '1m']: room = f"{symbol}_{timeframe}" # Get latest candles (need last 2 to determine confirmation status) try: candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=2) if candles and len(candles) > 0: latest_candle = candles[-1] # Determine if candle is confirmed (closed) # For 1s: candle is confirmed when next candle starts (2s delay) # For others: candle is confirmed when next candle starts is_confirmed = len(candles) >= 2 # If we have 2 candles, the first is confirmed # Format timestamp consistently timestamp = latest_candle.get('timestamp') if isinstance(timestamp, str): # Already formatted formatted_timestamp = timestamp else: # Convert to ISO string then format from datetime import datetime if isinstance(timestamp, datetime): formatted_timestamp = timestamp.strftime('%Y-%m-%d %H:%M:%S') else: formatted_timestamp = str(timestamp) # Emit chart update with full candle data self.socketio.emit('chart_update', { 'symbol': symbol, 'timeframe': timeframe, 'candle': { 'timestamp': formatted_timestamp, 'open': float(latest_candle.get('open', 0)), 'high': float(latest_candle.get('high', 0)), 'low': float(latest_candle.get('low', 0)), 'close': float(latest_candle.get('close', 0)), 'volume': float(latest_candle.get('volume', 0)) }, 'is_confirmed': is_confirmed, # True if this candle is closed/confirmed 'has_previous': len(candles) >= 2 # True if we have previous candle for validation }, room=room) # Get prediction if model is loaded if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer'): prediction = self._get_live_prediction(symbol, timeframe, 1) if prediction: self.socketio.emit('prediction_update', prediction, room=room) except Exception as e: logger.debug(f"Error getting data for {symbol} {timeframe}: {e}") time.sleep(1) # Update every second except Exception as e: logger.error(f"Error in live update thread: {e}") time.sleep(5) # Wait longer on error self._live_update_thread = threading.Thread(target=live_update_worker, daemon=True) self._live_update_thread.start() def _get_live_transformer_prediction(self, symbol: str = 'ETH/USDT'): """ Generate live transformer prediction with next_candles for ghost candle display """ try: if not self.orchestrator: logger.debug("No orchestrator - cannot generate predictions") return None if not hasattr(self.orchestrator, 'primary_transformer'): logger.debug("Orchestrator has no primary_transformer - enable training first") return None transformer = self.orchestrator.primary_transformer if not transformer: logger.debug("primary_transformer is None - model not loaded yet") return None transformer.eval() # Get recent market data price_data_1s = self.data_provider.get_ohlcv(symbol, '1s', limit=200) if self.data_provider else None price_data_1m = self.data_provider.get_ohlcv(symbol, '1m', limit=150) if self.data_provider else None price_data_1h = self.data_provider.get_ohlcv(symbol, '1h', limit=24) if self.data_provider else None price_data_1d = self.data_provider.get_ohlcv(symbol, '1d', limit=14) if self.data_provider else None btc_data_1m = self.data_provider.get_ohlcv('BTC/USDT', '1m', limit=150) if self.data_provider else None if not price_data_1m or len(price_data_1m) < 10: return None import torch import numpy as np device = next(transformer.parameters()).device def ohlcv_to_tensor(data, limit=None): if not data: return None data = data[-limit:] if limit and len(data) > limit else data arr = np.array([[d['open'], d['high'], d['low'], d['close'], d['volume']] for d in data], dtype=np.float32) return torch.from_numpy(arr).unsqueeze(0).to(device) inputs = { 'price_data_1s': ohlcv_to_tensor(price_data_1s, 200), 'price_data_1m': ohlcv_to_tensor(price_data_1m, 150), 'price_data_1h': ohlcv_to_tensor(price_data_1h, 24), 'price_data_1d': ohlcv_to_tensor(price_data_1d, 14), 'btc_data_1m': ohlcv_to_tensor(btc_data_1m, 150) } # Forward pass with torch.no_grad(): outputs = transformer(**inputs) # Extract next_candles next_candles = outputs.get('next_candles', {}) if not next_candles: return None # Convert to JSON-serializable format predicted_candle = {} for tf, candle_tensor in next_candles.items(): if candle_tensor is not None: candle_values = candle_tensor.squeeze(0).cpu().numpy().tolist() predicted_candle[tf] = candle_values current_price = price_data_1m[-1]['close'] predicted_1m_close = predicted_candle.get('1m', [0,0,0,current_price,0])[3] price_change = (predicted_1m_close - current_price) / current_price if price_change > 0.001: action = 'BUY' elif price_change < -0.001: action = 'SELL' else: action = 'HOLD' confidence = 0.7 if 'confidence' in outputs: conf_tensor = outputs['confidence'] confidence = float(conf_tensor.squeeze(0).cpu().numpy()[0]) prediction = { 'timestamp': datetime.now().isoformat(), 'symbol': symbol, 'action': action, 'confidence': confidence, 'predicted_price': predicted_1m_close, 'current_price': current_price, 'price_change': price_change, 'predicted_candle': predicted_candle, # This is what frontend needs! 'type': 'transformer_prediction' } # Store for tracking self.orchestrator.store_transformer_prediction(symbol, prediction) logger.debug(f"Generated transformer prediction with {len(predicted_candle)} timeframes for ghost candles") return prediction except Exception as e: logger.error(f"Error generating live transformer prediction: {e}", exc_info=True) return None def _train_on_validated_prediction(self, timeframe: str, timestamp: str, predicted: list, actual: list, errors: dict, direction_correct: bool, accuracy: float): """ Incrementally train model on validated prediction This implements online learning where each validated prediction becomes a training sample, with loss weighting based on prediction accuracy. """ try: if not self.training_adapter: logger.warning("Training adapter not available for incremental training") return if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'): logger.warning("Transformer model not available for incremental training") return # Get the transformer trainer trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None) if not trainer: logger.warning("Transformer trainer not available") return # Calculate sample weight based on accuracy # Low accuracy predictions get higher weight (we need to learn from mistakes) # High accuracy predictions get lower weight (model already knows this) if accuracy < 50: sample_weight = 3.0 # Learn hard from bad predictions elif accuracy < 70: sample_weight = 2.0 # Moderate learning elif accuracy < 85: sample_weight = 1.0 # Normal learning else: sample_weight = 0.5 # Light touch-up for good predictions # Also weight by direction correctness if not direction_correct: sample_weight *= 1.5 # Wrong direction is critical - learn more logger.info(f"[{timeframe}] Incremental training: accuracy={accuracy:.1f}%, weight={sample_weight:.1f}x") # Create training sample from validated prediction # We need to fetch the market state at that timestamp symbol = 'ETH/USDT' # TODO: Get from active trading pair training_sample = { 'symbol': symbol, 'timestamp': timestamp, 'predicted_candle': predicted, # [O, H, L, C, V] 'actual_candle': actual, # [O, H, L, C] 'errors': errors, 'accuracy': accuracy, 'direction_correct': direction_correct, 'sample_weight': sample_weight } # Get market state at that timestamp try: market_state = self._fetch_market_state_at_timestamp(symbol, timestamp, timeframe) training_sample['market_state'] = market_state except Exception as e: logger.warning(f"Could not fetch market state: {e}") return # Convert to transformer batch format batch = self.training_adapter._convert_prediction_to_batch(training_sample, timeframe) if not batch: logger.warning("Could not convert validated prediction to training batch") return # Train on this batch with sample weighting with torch.enable_grad(): trainer.model.train() result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight) if result: loss = result.get('total_loss', 0) candle_accuracy = result.get('candle_accuracy', 0) logger.info(f"[{timeframe}] Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}") # Save checkpoint periodically (every 10 incremental steps) if not hasattr(self, '_incremental_training_steps'): self._incremental_training_steps = 0 self._incremental_training_steps += 1 if self._incremental_training_steps % 10 == 0: logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps") trainer.save_checkpoint( filepath=None, # Auto-generate path metadata={ 'training_type': 'incremental_online', 'steps': self._incremental_training_steps, 'last_accuracy': accuracy } ) except Exception as e: logger.error(f"Error in incremental training: {e}", exc_info=True) def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict: """Fetch market state at a specific timestamp for training""" try: from datetime import datetime import pandas as pd # Parse timestamp ts = pd.Timestamp(timestamp) # Get historical data for multiple timeframes market_state = {'timeframes': {}, 'secondary_timeframes': {}} for tf in ['1s', '1m', '1h']: try: df = self.data_provider.get_historical_data(symbol, tf, limit=200) if df is not None and not df.empty: # Find data up to (but not including) the target timestamp df_before = df[df.index < ts] if not df_before.empty: recent = df_before.tail(200) market_state['timeframes'][tf] = { 'timestamps': recent.index.strftime('%Y-%m-%d %H:%M:%S').tolist(), 'open': recent['open'].tolist(), 'high': recent['high'].tolist(), 'low': recent['low'].tolist(), 'close': recent['close'].tolist(), 'volume': recent['volume'].tolist() } except Exception as e: logger.warning(f"Could not fetch {tf} data: {e}") return market_state except Exception as e: logger.error(f"Error fetching market state: {e}") return {} def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1): """ Get live prediction from model using trainer inference Caches inference data (inputs/outputs) for later training when actual candle arrives. This allows us to: 1. Compare predicted vs actual candle values 2. Calculate loss 3. Do backpropagation with correct outputs Returns: Dict with prediction results including predicted_candle for ghost candle display """ try: if not self.orchestrator: return None # Get trainer from orchestrator trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None) if not trainer or not trainer.model: logger.debug("No transformer trainer available for live prediction") return None # Get market data using training adapter's method (reuses existing logic) if not hasattr(self.training_adapter, '_get_realtime_market_data'): logger.warning("Training adapter missing _get_realtime_market_data method") return None market_data, norm_params = self.training_adapter._get_realtime_market_data(symbol, self.data_provider) if not market_data: logger.debug(f"No market data available for {symbol} {timeframe}") return None # Make prediction with model import torch timestamp = datetime.now(timezone.utc) with torch.no_grad(): trainer.model.eval() outputs = trainer.model(**market_data) # Extract action prediction action_probs = outputs.get('action_probs') if action_probs is None: logger.debug("No action_probs in model output") return None action_idx = torch.argmax(action_probs, dim=-1).item() confidence = action_probs[0, action_idx].item() # Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL) actions = ['HOLD', 'BUY', 'SELL'] action = actions[action_idx] if action_idx < len(actions) else 'HOLD' # Extract predicted candles and denormalize predicted_candles_raw = {} if 'next_candles' in outputs: for tf, tensor in outputs['next_candles'].items(): predicted_candles_raw[tf] = tensor.detach().cpu().numpy().tolist() # Denormalize predicted candles predicted_candles_denorm = {} if predicted_candles_raw and norm_params: for tf, raw_candle in predicted_candles_raw.items(): if tf in norm_params: params = norm_params[tf] price_min = params['price_min'] price_max = params['price_max'] vol_min = params['volume_min'] vol_max = params['volume_max'] # raw_candle is [1, 5] list candle_values = raw_candle[0] denorm_candle = [ candle_values[0] * (price_max - price_min) + price_min, # Open candle_values[1] * (price_max - price_min) + price_min, # High candle_values[2] * (price_max - price_min) + price_min, # Low candle_values[3] * (price_max - price_min) + price_min, # Close candle_values[4] * (vol_max - vol_min) + vol_min # Volume ] predicted_candles_denorm[tf] = denorm_candle # Get predicted price from candle close predicted_price = None if timeframe in predicted_candles_denorm: predicted_price = predicted_candles_denorm[timeframe][3] # Close elif '1m' in predicted_candles_denorm: predicted_price = predicted_candles_denorm['1m'][3] elif '1s' in predicted_candles_denorm: predicted_price = predicted_candles_denorm['1s'][3] # CACHE inference data for later training # Store inputs, outputs, and normalization params so we can train when actual candle arrives if symbol not in self.prediction_cache: self.prediction_cache[symbol] = {} if timeframe not in self.prediction_cache[symbol]: self.prediction_cache[symbol][timeframe] = [] # Store cached inference data (convert tensors to CPU for storage) cached_data = { 'timestamp': timestamp, 'symbol': symbol, 'timeframe': timeframe, 'model_inputs': {k: v.cpu().clone() if isinstance(v, torch.Tensor) else v for k, v in market_data.items()}, 'model_outputs': {k: v.cpu().clone() if isinstance(v, torch.Tensor) else v for k, v in outputs.items()}, 'normalization_params': norm_params, 'predicted_candle': predicted_candles_denorm.get(timeframe), 'prediction_steps': prediction_steps } self.prediction_cache[symbol][timeframe].append(cached_data) # Keep only last 100 predictions per symbol/timeframe to prevent memory bloat if len(self.prediction_cache[symbol][timeframe]) > 100: self.prediction_cache[symbol][timeframe] = self.prediction_cache[symbol][timeframe][-100:] logger.debug(f"Cached prediction for {symbol} {timeframe} @ {timestamp.isoformat()}") # Return prediction result (same format as before for compatibility) return { 'symbol': symbol, 'timeframe': timeframe, 'timestamp': timestamp.isoformat(), 'action': action, 'confidence': confidence, 'predicted_price': predicted_price, 'predicted_candle': predicted_candles_denorm, 'prediction_steps': prediction_steps } except Exception as e: logger.error(f"Error getting live prediction: {e}") import traceback logger.debug(traceback.format_exc()) return None def get_cached_predictions_for_training(self, symbol: str, timeframe: str, actual_candle_timestamp) -> List[Dict]: """ Retrieve cached predictions that match a specific candle timestamp for training When an actual candle arrives, we can: 1. Find cached predictions made before this candle 2. Compare predicted vs actual candle values 3. Calculate loss and do backpropagation Args: symbol: Trading symbol timeframe: Timeframe actual_candle_timestamp: Timestamp of the actual candle that just arrived Returns: List of cached prediction dicts that should be trained on """ try: if symbol not in self.prediction_cache: return [] if timeframe not in self.prediction_cache[symbol]: return [] # Find predictions made before this candle timestamp # Predictions should be for candles that have now completed matching_predictions = [] actual_time = actual_candle_timestamp if isinstance(actual_candle_timestamp, datetime) else datetime.fromisoformat(str(actual_candle_timestamp).replace('Z', '+00:00')) for cached_pred in self.prediction_cache[symbol][timeframe]: pred_time = cached_pred['timestamp'] if isinstance(pred_time, str): pred_time = datetime.fromisoformat(pred_time.replace('Z', '+00:00')) # Prediction should be for a candle that comes after the prediction time # We match predictions that were made before the actual candle closed if pred_time < actual_time: matching_predictions.append(cached_pred) return matching_predictions except Exception as e: logger.error(f"Error getting cached predictions for training: {e}") return [] def clear_old_cached_predictions(self, symbol: str, timeframe: str, before_timestamp: datetime): """ Clear cached predictions older than a certain timestamp Useful for cleaning up old predictions that are no longer needed """ try: if symbol not in self.prediction_cache: return if timeframe not in self.prediction_cache[symbol]: return self.prediction_cache[symbol][timeframe] = [ pred for pred in self.prediction_cache[symbol][timeframe] if pred['timestamp'] >= before_timestamp ] except Exception as e: logger.debug(f"Error clearing old cached predictions: {e}") def run(self, host='127.0.0.1', port=8051, debug=False): """Run the application""" logger.info(f"Starting Annotation Dashboard on http://{host}:{port}") if self.has_socketio: logger.info("Running with WebSocket support (SocketIO)") self.socketio.run(self.server, host=host, port=port, debug=debug, allow_unsafe_werkzeug=True) else: logger.warning("Running without WebSocket support - install flask-socketio for live updates") self.server.run(host=host, port=port, debug=debug) def main(): """Main entry point""" logger.info("=" * 80) logger.info("ANNOTATE Application Starting") logger.info(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") logger.info("=" * 80) # Print logging channel configuration from utils.logging_config import print_channel_status print_channel_status() dashboard = AnnotationDashboard() dashboard.run(debug=True) if __name__ == '__main__': main()