Files
gogo2/ANNOTATE/web/app.py
Dobromir Popov a8d59a946e training fixes
2025-11-17 21:05:06 +02:00

2225 lines
96 KiB
Python

"""
Manual Trade Annotation UI - Main Application
A web-based interface for manually marking profitable buy/sell signals on historical
market data to generate training test cases for machine learning models.
"""
import os
import sys
from pathlib import Path
# Add parent directory to path for imports
parent_dir = Path(__file__).parent.parent.parent
sys.path.insert(0, str(parent_dir))
from flask import Flask, render_template, request, jsonify, send_file
from dash import Dash, html
import logging
from datetime import datetime
from typing import Optional, Dict, List, Any
import json
import pandas as pd
import numpy as np
import threading
import uuid
import time
import torch
# Import core components from main system
try:
from core.data_provider import DataProvider
from core.orchestrator import TradingOrchestrator
from core.config import get_config
from core.williams_market_structure import WilliamsMarketStructure
except ImportError as e:
print(f"Warning: Could not import main system components: {e}")
print("Running in standalone mode with limited functionality")
DataProvider = None
WilliamsMarketStructure = None
TradingOrchestrator = None
get_config = lambda: {}
# Import ANNOTATE modules
annotate_dir = Path(__file__).parent.parent
sys.path.insert(0, str(annotate_dir))
try:
from core.annotation_manager import AnnotationManager
from core.real_training_adapter import RealTrainingAdapter
from core.data_loader import HistoricalDataLoader, TimeRangeManager
except ImportError:
# Try alternative import path
import importlib.util
# Load annotation_manager
ann_spec = importlib.util.spec_from_file_location(
"annotation_manager",
annotate_dir / "core" / "annotation_manager.py"
)
ann_module = importlib.util.module_from_spec(ann_spec)
ann_spec.loader.exec_module(ann_module)
AnnotationManager = ann_module.AnnotationManager
# Load real_training_adapter (NO SIMULATION!)
train_spec = importlib.util.spec_from_file_location(
"real_training_adapter",
annotate_dir / "core" / "real_training_adapter.py"
)
train_module = importlib.util.module_from_spec(train_spec)
train_spec.loader.exec_module(train_module)
RealTrainingAdapter = train_module.RealTrainingAdapter
# Load data_loader
data_spec = importlib.util.spec_from_file_location(
"data_loader",
annotate_dir / "core" / "data_loader.py"
)
data_module = importlib.util.module_from_spec(data_spec)
data_spec.loader.exec_module(data_module)
HistoricalDataLoader = data_module.HistoricalDataLoader
TimeRangeManager = data_module.TimeRangeManager
# Setup logging - configure before any logging occurs
log_dir = Path(__file__).parent.parent / 'logs'
log_dir.mkdir(exist_ok=True)
log_file = log_dir / 'annotate_app.log'
# Configure logging to both file and console
# File mode 'w' truncates the file on each run
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file, mode='w'), # Truncate on each run
logging.StreamHandler(sys.stdout) # Also print to console
]
)
logger = logging.getLogger(__name__)
logger.info(f"Logging to: {log_file}")
class BacktestRunner:
"""Runs backtest candle-by-candle with model predictions and tracks PnL"""
def __init__(self):
self.active_backtests = {} # backtest_id -> state
self.lock = threading.Lock()
def start_backtest(self, backtest_id: str, model, data_provider, symbol: str, timeframe: str,
start_time: Optional[str] = None, end_time: Optional[str] = None):
"""Start backtest in background thread"""
# Initialize backtest state
state = {
'status': 'running',
'candles_processed': 0,
'total_candles': 0,
'pnl': 0.0,
'total_trades': 0,
'wins': 0,
'losses': 0,
'new_predictions': [],
'position': None, # {'type': 'long/short', 'entry_price': float, 'entry_time': str}
'error': None,
'stop_requested': False
}
with self.lock:
self.active_backtests[backtest_id] = state
# Run backtest in background thread
thread = threading.Thread(
target=self._run_backtest,
args=(backtest_id, model, data_provider, symbol, timeframe, start_time, end_time)
)
thread.daemon = True
thread.start()
def _run_backtest(self, backtest_id: str, model, data_provider, symbol: str, timeframe: str,
start_time: Optional[str] = None, end_time: Optional[str] = None):
"""Execute backtest candle-by-candle"""
try:
state = self.active_backtests[backtest_id]
# Get historical data
logger.info(f"Backtest {backtest_id}: Fetching data for {symbol} {timeframe}")
# Get candles for the time range
if start_time and end_time:
# Parse time range and fetch data
df = data_provider.get_historical_data(
symbol=symbol,
timeframe=timeframe,
limit=1000 # Max candles
)
else:
# Use last 500 candles
df = data_provider.get_historical_data(
symbol=symbol,
timeframe=timeframe,
limit=500
)
if df is None or df.empty:
state['status'] = 'error'
state['error'] = 'No data available'
return
logger.info(f"Backtest {backtest_id}: Processing {len(df)} candles")
state['total_candles'] = len(df)
# Prepare for inference
model.eval()
# IMPORTANT: Use CPU for backtest to avoid ROCm/HIP compatibility issues
# GPU inference has kernel compatibility problems with some model architectures
device = torch.device('cpu')
model.to(device)
logger.info(f"Backtest {backtest_id}: Using CPU for stable inference (avoiding ROCm/HIP issues)")
# Need at least 200 candles for context
min_context = 200
# Process candles one by one
for i in range(min_context, len(df)):
if state['stop_requested']:
state['status'] = 'stopped'
break
# Get context (last 200 candles)
context = df.iloc[i-200:i]
current_candle = df.iloc[i]
current_time = current_candle.name
current_price = float(current_candle['close'])
# Make prediction
prediction = self._make_prediction(model, device, context, symbol, timeframe)
if prediction:
# Store prediction for display
pred_data = {
'timestamp': str(current_time),
'price': current_price,
'action': prediction['action'],
'confidence': prediction['confidence'],
'timeframe': timeframe
}
state['new_predictions'].append(pred_data)
# Execute trade logic
self._execute_trade_logic(state, prediction, current_price, current_time)
# Update progress
state['candles_processed'] = i - min_context + 1
# Simulate real-time (optional, remove for faster backtest)
# time.sleep(0.01) # 10ms per candle
# Close any open position at end
if state['position']:
self._close_position(state, current_price, 'backtest_end')
# Calculate final stats
total_trades = state['total_trades']
wins = state['wins']
state['win_rate'] = wins / total_trades if total_trades > 0 else 0
state['status'] = 'complete'
logger.info(f"Backtest {backtest_id}: Complete. PnL=${state['pnl']:.2f}, Trades={total_trades}, Win Rate={state['win_rate']:.1%}")
except Exception as e:
logger.error(f"Backtest {backtest_id} error: {e}", exc_info=True)
state['status'] = 'error'
state['error'] = str(e)
def _make_prediction(self, model, device, context_df, symbol, timeframe):
"""Make model prediction on context data"""
try:
# Convert context to model input format
# Extract OHLCV data
candles = context_df[['open', 'high', 'low', 'close', 'volume']].values
# Normalize
candles_normalized = candles.copy()
price_data = candles[:, :4]
volume_data = candles[:, 4:5]
price_min = price_data.min()
price_max = price_data.max()
if price_max > price_min:
candles_normalized[:, :4] = (price_data - price_min) / (price_max - price_min)
volume_min = volume_data.min()
volume_max = volume_data.max()
if volume_max > volume_min:
candles_normalized[:, 4:5] = (volume_data - volume_min) / (volume_max - volume_min)
# Convert to tensor [1, 200, 5]
# Try GPU first, fallback to CPU if GPU fails
try:
price_tensor = torch.tensor(candles_normalized, dtype=torch.float32).unsqueeze(0).to(device)
tech_data = torch.zeros(1, 40, dtype=torch.float32).to(device)
market_data = torch.zeros(1, 30, dtype=torch.float32).to(device)
use_cpu = False
except Exception as gpu_error:
logger.warning(f"GPU tensor creation failed, using CPU: {gpu_error}")
device = torch.device('cpu')
model.to(device)
price_tensor = torch.tensor(candles_normalized, dtype=torch.float32).unsqueeze(0)
tech_data = torch.zeros(1, 40, dtype=torch.float32)
market_data = torch.zeros(1, 30, dtype=torch.float32)
use_cpu = True
# Make prediction
with torch.no_grad():
try:
outputs = model(
price_data_1m=price_tensor if timeframe == '1m' else None,
price_data_1s=price_tensor if timeframe == '1s' else None,
price_data_1h=price_tensor if timeframe == '1h' else None,
price_data_1d=price_tensor if timeframe == '1d' else None,
tech_data=tech_data,
market_data=market_data
)
except RuntimeError as model_error:
# GPU inference failed, retry on CPU
if not use_cpu and 'HIP' in str(model_error):
logger.warning(f"GPU inference failed, retrying on CPU: {model_error}")
device = torch.device('cpu')
model.to(device)
price_tensor = price_tensor.cpu()
tech_data = tech_data.cpu()
market_data = market_data.cpu()
outputs = model(
price_data_1m=price_tensor if timeframe == '1m' else None,
price_data_1s=price_tensor if timeframe == '1s' else None,
price_data_1h=price_tensor if timeframe == '1h' else None,
price_data_1d=price_tensor if timeframe == '1d' else None,
tech_data=tech_data,
market_data=market_data
)
else:
raise
# Get action prediction
action_probs = outputs.get('action_probs', outputs.get('trend_probs'))
if action_probs is not None:
action_idx = torch.argmax(action_probs, dim=-1).item()
confidence = action_probs[0, action_idx].item()
# Map to BUY/SELL/HOLD
actions = ['BUY', 'SELL', 'HOLD']
if action_idx < len(actions):
action = actions[action_idx]
else:
# If 4 actions (model has 4 trend directions), map to 3 actions
action = 'HOLD' if action_idx == 1 else ('BUY' if action_idx in [2, 3] else 'SELL')
return {
'action': action,
'confidence': confidence
}
return None
except Exception as e:
logger.error(f"Prediction error: {e}", exc_info=True)
return None
def _execute_trade_logic(self, state, prediction, current_price, current_time):
"""Execute trading logic based on prediction"""
action = prediction['action']
confidence = prediction['confidence']
# Only trade on high confidence
if confidence < 0.6:
return
position = state['position']
if action == 'BUY' and position is None:
# Enter long position
state['position'] = {
'type': 'long',
'entry_price': current_price,
'entry_time': current_time
}
logger.debug(f"Backtest: ENTER LONG @ ${current_price}")
elif action == 'SELL' and position is None:
# Enter short position
state['position'] = {
'type': 'short',
'entry_price': current_price,
'entry_time': current_time
}
logger.debug(f"Backtest: ENTER SHORT @ ${current_price}")
elif position is not None:
# Check if should exit
should_exit = False
if position['type'] == 'long' and action == 'SELL':
should_exit = True
elif position['type'] == 'short' and action == 'BUY':
should_exit = True
if should_exit:
self._close_position(state, current_price, 'signal')
def _close_position(self, state, exit_price, reason):
"""Close current position and update PnL"""
position = state['position']
if not position:
return
entry_price = position['entry_price']
# Calculate PnL
if position['type'] == 'long':
pnl = exit_price - entry_price
else: # short
pnl = entry_price - exit_price
# Update state
state['pnl'] += pnl
state['total_trades'] += 1
if pnl > 0:
state['wins'] += 1
elif pnl < 0:
state['losses'] += 1
logger.debug(f"Backtest: CLOSE {position['type'].upper()} @ ${exit_price:.2f}, PnL=${pnl:.2f} ({reason})")
state['position'] = None
def get_progress(self, backtest_id: str) -> Dict:
"""Get backtest progress"""
with self.lock:
state = self.active_backtests.get(backtest_id)
if not state:
return {'success': False, 'error': 'Backtest not found'}
# Get and clear new predictions (they'll be sent to frontend)
new_predictions = state['new_predictions']
state['new_predictions'] = []
return {
'success': True,
'status': state['status'],
'candles_processed': state['candles_processed'],
'total_candles': state['total_candles'],
'pnl': state['pnl'],
'total_trades': state['total_trades'],
'wins': state['wins'],
'losses': state['losses'],
'win_rate': state['wins'] / state['total_trades'] if state['total_trades'] > 0 else 0,
'new_predictions': new_predictions,
'error': state['error']
}
def stop_backtest(self, backtest_id: str):
"""Request backtest to stop"""
with self.lock:
state = self.active_backtests.get(backtest_id)
if state:
state['stop_requested'] = True
class AnnotationDashboard:
"""Main annotation dashboard application"""
def __init__(self):
"""Initialize the dashboard"""
# Load configuration
try:
# Always try YAML loading first since get_config might not work in standalone mode
import yaml
with open('config.yaml', 'r') as f:
self.config = yaml.safe_load(f)
logger.info(f"Loaded config via YAML: {len(self.config)} keys")
except Exception as e:
logger.warning(f"Could not load config via YAML: {e}")
try:
# Fallback to get_config if available
if get_config:
self.config = get_config()
logger.info(f"Loaded config via get_config: {len(self.config)} keys")
else:
raise Exception("get_config not available")
except Exception as e2:
logger.warning(f"Could not load config via get_config: {e2}")
# Final fallback config with SOL/USDT
self.config = {
'symbols': ['ETH/USDT', 'BTC/USDT', 'SOL/USDT'],
'timeframes': ['1s', '1m', '1h', '1d']
}
logger.info("Using fallback config")
# Initialize Flask app
self.server = Flask(
__name__,
template_folder='templates',
static_folder='static'
)
# Initialize SocketIO for WebSocket support
try:
from flask_socketio import SocketIO, emit
self.socketio = SocketIO(
self.server,
cors_allowed_origins="*",
async_mode='threading',
logger=False,
engineio_logger=False
)
self.has_socketio = True
logger.info("SocketIO initialized for real-time updates")
except ImportError:
self.socketio = None
self.has_socketio = False
logger.warning("flask-socketio not installed - live updates will use polling")
# Suppress werkzeug request logs (reduce noise from polling endpoints)
werkzeug_logger = logging.getLogger('werkzeug')
werkzeug_logger.setLevel(logging.WARNING) # Only show warnings and errors, not INFO
# Initialize Dash app (optional component)
self.app = Dash(
__name__,
server=self.server,
url_base_pathname='/dash/',
external_stylesheets=[
'https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css',
'https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css'
]
)
# Set a simple Dash layout to avoid NoLayoutException
self.app.layout = html.Div([
html.H1("ANNOTATE Dashboard", className="text-center mb-4"),
html.Div([
html.P("This is the Dash component of the ANNOTATE system."),
html.P("The main interface is available at the Flask routes."),
html.A("Go to Main Interface", href="/", className="btn btn-primary")
], className="container")
])
# Initialize core components (skip initial load for fast startup)
self.data_provider = DataProvider(skip_initial_load=True) if DataProvider else None
# Enable unified storage for real-time data access
if self.data_provider:
self._enable_unified_storage_async()
# ANNOTATE doesn't need orchestrator immediately - lazy load on demand
self.orchestrator = None
self.models_loading = False
self.available_models = ['DQN', 'CNN', 'Transformer'] # Models that CAN be loaded
self.loaded_models = {} # Models that ARE loaded: {name: model_instance}
# Initialize ANNOTATE components
self.annotation_manager = AnnotationManager()
# Use REAL training adapter - NO SIMULATION!
self.training_adapter = RealTrainingAdapter(None, self.data_provider)
# Backtest runner for replaying visible chart with predictions
self.backtest_runner = BacktestRunner()
# Don't auto-load models - wait for user to click LOAD button
logger.info("Models available for lazy loading: " + ", ".join(self.available_models))
# Initialize data loader with existing DataProvider
self.data_loader = HistoricalDataLoader(self.data_provider) if self.data_provider else None
self.time_range_manager = TimeRangeManager(self.data_loader) if self.data_loader else None
# Setup routes
self._setup_routes()
# Start background data refresh after startup
if self.data_loader:
self._start_background_data_refresh()
logger.info("Annotation Dashboard initialized")
def _get_best_checkpoint_info(self, model_name: str) -> Optional[Dict]:
"""
Get best checkpoint info for a model without loading it
First tries database, then falls back to filename parsing
Args:
model_name: Name of the model
Returns:
Dict with checkpoint info or None if no checkpoint found
"""
try:
# Try to get from database first (has full metadata)
try:
from utils.database_manager import get_database_manager
db_manager = get_database_manager()
# Get active checkpoint for this model
with db_manager._get_connection() as conn:
cursor = conn.execute("""
SELECT checkpoint_id, performance_metrics, timestamp, file_path
FROM checkpoint_metadata
WHERE model_name = ? AND is_active = TRUE
ORDER BY timestamp DESC
LIMIT 1
""", (model_name.lower(),))
row = cursor.fetchone()
if row:
import json
checkpoint_id, metrics_json, timestamp, file_path = row
metrics = json.loads(metrics_json) if metrics_json else {}
checkpoint_info = {
'filename': os.path.basename(file_path) if file_path else checkpoint_id,
'epoch': metrics.get('epoch', 0),
'loss': metrics.get('loss'),
'accuracy': metrics.get('accuracy'),
'source': 'database'
}
logger.info(f"Loaded checkpoint info from database for {model_name}: E{checkpoint_info['epoch']}, Loss={checkpoint_info['loss']}, Acc={checkpoint_info['accuracy']}")
return checkpoint_info
except Exception as db_error:
logger.debug(f"Could not load from database: {db_error}")
# Fallback to filename parsing
import glob
import re
# Map model names to checkpoint directories
checkpoint_dirs = {
'Transformer': 'models/checkpoints/transformer',
'CNN': 'models/checkpoints/enhanced_cnn',
'DQN': 'models/checkpoints/dqn_agent'
}
checkpoint_dir = checkpoint_dirs.get(model_name)
if not checkpoint_dir:
return None
if not os.path.exists(checkpoint_dir):
logger.debug(f"Checkpoint directory not found: {checkpoint_dir}")
return None
# Find all checkpoint files
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, '*.pt'))
if not checkpoint_files:
logger.debug(f"No checkpoint files found in {checkpoint_dir}")
return None
logger.debug(f"Found {len(checkpoint_files)} checkpoints for {model_name}")
# Parse filenames to extract epoch info
# Format: transformer_epoch5_20251110_123620.pt
best_checkpoint = None
best_epoch = -1
for cp_file in checkpoint_files:
try:
filename = os.path.basename(cp_file)
# Extract epoch number from filename
match = re.search(r'epoch(\d+)', filename, re.IGNORECASE)
if match:
epoch = int(match.group(1))
if epoch > best_epoch:
best_epoch = epoch
best_checkpoint = {
'filename': filename,
'epoch': epoch,
'loss': None, # Can't get without loading
'accuracy': None, # Can't get without loading
'source': 'filename'
}
logger.debug(f"Found checkpoint: {filename}, epoch {epoch}")
except Exception as e:
logger.debug(f"Could not parse checkpoint {cp_file}: {e}")
continue
if best_checkpoint:
logger.info(f"Best checkpoint for {model_name}: {best_checkpoint['filename']} (E{best_checkpoint['epoch']})")
return best_checkpoint
except Exception as e:
logger.error(f"Error getting checkpoint info for {model_name}: {e}")
import traceback
logger.error(traceback.format_exc())
return None
def _load_model_lazy(self, model_name: str) -> dict:
"""
Lazy load a specific model on demand
Args:
model_name: Name of model to load ('DQN', 'CNN', 'Transformer')
Returns:
dict: Result with success status and message
"""
try:
# Check if already loaded
if model_name in self.loaded_models:
return {
'success': True,
'message': f'{model_name} already loaded',
'already_loaded': True
}
# Check if model is available
if model_name not in self.available_models:
return {
'success': False,
'error': f'{model_name} is not in available models list'
}
logger.info(f"Loading {model_name} model...")
# Initialize orchestrator if not already done
if not self.orchestrator:
if not TradingOrchestrator:
return {
'success': False,
'error': 'TradingOrchestrator class not available'
}
logger.info("Creating TradingOrchestrator instance...")
self.orchestrator = TradingOrchestrator(
data_provider=self.data_provider,
enhanced_rl_training=True
)
logger.info("Orchestrator created")
# Update training adapter
self.training_adapter.orchestrator = self.orchestrator
# Load specific model
if model_name == 'DQN':
if not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
# Initialize RL agent
self.orchestrator._initialize_rl_agent()
self.loaded_models['DQN'] = self.orchestrator.rl_agent
elif model_name == 'CNN':
if not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
# Initialize CNN model
self.orchestrator._initialize_cnn_model()
self.loaded_models['CNN'] = self.orchestrator.cnn_model
elif model_name == 'Transformer':
if not hasattr(self.orchestrator, 'primary_transformer') or not self.orchestrator.primary_transformer:
# Initialize Transformer model
self.orchestrator._initialize_transformer_model()
self.loaded_models['Transformer'] = self.orchestrator.primary_transformer
else:
return {
'success': False,
'error': f'Unknown model: {model_name}'
}
logger.info(f"{model_name} model loaded successfully")
return {
'success': True,
'message': f'{model_name} loaded successfully',
'loaded_models': list(self.loaded_models.keys())
}
except Exception as e:
logger.error(f"Error loading {model_name}: {e}")
import traceback
logger.error(f"Traceback:\n{traceback.format_exc()}")
return {
'success': False,
'error': str(e)
}
def _enable_unified_storage_async(self):
"""Enable unified storage system in background thread"""
def enable_storage():
try:
import asyncio
import threading
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Enable unified storage
success = loop.run_until_complete(
self.data_provider.enable_unified_storage()
)
if success:
logger.info(" ANNOTATE: Unified storage enabled for real-time data")
# Get statistics
stats = self.data_provider.get_unified_storage_stats()
if stats.get('initialized'):
logger.info(" Real-time data access: <10ms")
logger.info(" Historical data access: <100ms")
logger.info(" Annotation data: Available at any timestamp")
else:
logger.warning(" ANNOTATE: Unified storage not available, using cached data only")
except Exception as e:
logger.warning(f"ANNOTATE: Could not enable unified storage: {e}")
logger.info("ANNOTATE: Continuing with cached data access")
# Start in background thread
import threading
storage_thread = threading.Thread(target=enable_storage, daemon=True)
storage_thread.start()
def _start_background_data_refresh(self):
"""Start background task to refresh recent data after startup - ONCE ONLY"""
def refresh_recent_data():
try:
import time
# Wait for app to fully start
time.sleep(5)
logger.info(" Starting one-time background data refresh (fetching only recent missing data)")
# Disable startup mode to fetch fresh data
self.data_loader.disable_startup_mode()
# Use the new on-demand refresh method
logger.info("Using on-demand refresh for recent data")
self.data_provider.refresh_data_on_demand()
logger.info(" One-time background data refresh completed")
except Exception as e:
logger.error(f"Error in background data refresh: {e}")
# Start refresh in background thread
import threading
refresh_thread = threading.Thread(target=refresh_recent_data, daemon=True)
refresh_thread.start()
logger.info("One-time background data refresh scheduled")
def _get_pivot_markers_for_timeframe(self, symbol: str, timeframe: str, df: pd.DataFrame) -> dict:
"""
Get pivot markers for a specific timeframe using WilliamsMarketStructure directly
Returns dict with all pivot points and identifies which are the last high/low per level
"""
try:
if WilliamsMarketStructure is None:
logger.warning("WilliamsMarketStructure not available")
return {}
if df is None or len(df) < 10:
logger.warning(f"Insufficient data for pivot calculation: {len(df) if df is not None else 0} bars")
return {}
# Convert DataFrame to numpy array format expected by Williams Market Structure
ohlcv_array = df[['open', 'high', 'low', 'close', 'volume']].copy()
# Add timestamp as first column (convert to milliseconds)
timestamps = df.index.astype(np.int64) // 10**6 # pandas index is ns -> convert to ms
ohlcv_array.insert(0, 'timestamp', timestamps)
ohlcv_array = ohlcv_array.to_numpy()
# Initialize Williams Market Structure with default distance
# We'll override it in the calculation call
williams = WilliamsMarketStructure(min_pivot_distance=1)
# Calculate recursive pivot points with min_pivot_distance=2
# This ensures 5 candles per pivot (tip + 2 prev + 2 next)
pivot_levels = williams.calculate_recursive_pivot_points(
ohlcv_array,
min_pivot_distance=2
)
if not pivot_levels:
logger.debug(f"No pivot levels found for {symbol} {timeframe}")
return {}
# Build a map of timestamp -> pivot info
# Also track last high/low per level for drawing horizontal lines
pivot_map = {}
last_pivots = {} # {level: {'high': (ts_str, idx), 'low': (ts_str, idx)}}
# For each level (1-5), collect ALL pivot points
for level_num, trend_level in pivot_levels.items():
if not hasattr(trend_level, 'pivot_points') or not trend_level.pivot_points:
continue
last_pivots[level_num] = {'high': None, 'low': None}
# Add ALL pivot points to the map
for pivot in trend_level.pivot_points:
ts_str = pivot.timestamp.strftime('%Y-%m-%d %H:%M:%S')
if ts_str not in pivot_map:
pivot_map[ts_str] = {'highs': [], 'lows': []}
pivot_info = {
'level': level_num,
'price': pivot.price,
'strength': pivot.strength,
'is_last': False # Will be updated below
}
if pivot.pivot_type == 'high':
pivot_map[ts_str]['highs'].append(pivot_info)
last_pivots[level_num]['high'] = (ts_str, len(pivot_map[ts_str]['highs']) - 1)
elif pivot.pivot_type == 'low':
pivot_map[ts_str]['lows'].append(pivot_info)
last_pivots[level_num]['low'] = (ts_str, len(pivot_map[ts_str]['lows']) - 1)
# Mark the last high and last low for each level
for level_num, last_info in last_pivots.items():
if last_info['high']:
ts_str, idx = last_info['high']
pivot_map[ts_str]['highs'][idx]['is_last'] = True
if last_info['low']:
ts_str, idx = last_info['low']
pivot_map[ts_str]['lows'][idx]['is_last'] = True
logger.info(f"Found {len(pivot_map)} pivot candles for {symbol} {timeframe} (from {len(df)} candles)")
return pivot_map
except Exception as e:
logger.error(f"Error getting pivot markers for {timeframe}: {e}")
import traceback
logger.error(traceback.format_exc())
return {}
def _setup_routes(self):
"""Setup Flask routes"""
@self.server.route('/favicon.ico')
def favicon():
"""Serve favicon to prevent 404 errors"""
from flask import Response
# Return a simple 1x1 transparent pixel as favicon
favicon_data = b'\x00\x00\x01\x00\x01\x00\x10\x10\x00\x00\x01\x00\x20\x00\x68\x04\x00\x00\x16\x00\x00\x00\x28\x00\x00\x00\x10\x00\x00\x00\x20\x00\x00\x00\x01\x00\x20\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
return Response(favicon_data, mimetype='image/x-icon')
@self.server.route('/')
def index():
"""Main dashboard page - loads existing annotations"""
try:
# Get symbols and timeframes from config
symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT'])
timeframes = self.config.get('timeframes', ['1s', '1m', '1h', '1d'])
current_symbol = symbols[0] if symbols else 'ETH/USDT'
# Get annotations filtered by current symbol
annotations = self.annotation_manager.get_annotations(symbol=current_symbol)
# Convert to serializable format
annotations_data = []
for ann in annotations:
if hasattr(ann, '__dict__'):
ann_dict = ann.__dict__
else:
ann_dict = ann
# Ensure all fields are JSON serializable
annotations_data.append({
'annotation_id': ann_dict.get('annotation_id'),
'symbol': ann_dict.get('symbol'),
'timeframe': ann_dict.get('timeframe'),
'entry': ann_dict.get('entry'),
'exit': ann_dict.get('exit'),
'direction': ann_dict.get('direction'),
'profit_loss_pct': ann_dict.get('profit_loss_pct'),
'notes': ann_dict.get('notes', ''),
'created_at': ann_dict.get('created_at')
})
logger.info(f"Loading dashboard with {len(annotations_data)} annotations for {current_symbol}")
# Prepare template data
template_data = {
'current_symbol': current_symbol,
'symbols': symbols,
'timeframes': timeframes,
'annotations': annotations_data
}
return render_template('annotation_dashboard.html', **template_data)
except Exception as e:
logger.error(f"Error rendering main page: {e}")
# Fallback simple HTML page
return f"""
<html>
<head>
<title>ANNOTATE - Manual Trade Annotation UI</title>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
</head>
<body>
<div class="container mt-5">
<h1 class="text-center">📝 ANNOTATE - Manual Trade Annotation UI</h1>
<div class="alert alert-info">
<h4>System Status</h4>
<p> Annotation Manager: Active</p>
<p> Data Provider: {'Available' if self.data_provider else 'Not Available (Standalone Mode)'}</p>
<p> Trading Orchestrator: {'Available' if self.orchestrator else 'Not Available (Standalone Mode)'}</p>
</div>
<div class="row">
<div class="col-md-6">
<h3>Available Features</h3>
<ul>
<li>Manual trade annotation</li>
<li>Test case generation</li>
<li>Annotation export</li>
<li>Real model training</li>
</ul>
</div>
<div class="col-md-6">
<h3>API Endpoints</h3>
<ul>
<li><code>POST /api/chart-data</code> - Get chart data</li>
<li><code>POST /api/save-annotation</code> - Save annotation</li>
<li><code>POST /api/delete-annotation</code> - Delete annotation</li>
<li><code>POST /api/generate-test-case</code> - Generate test case</li>
<li><code>POST /api/export-annotations</code> - Export annotations</li>
</ul>
</div>
</div>
<div class="text-center mt-4">
<a href="/dash/" class="btn btn-primary">Go to Dash Interface</a>
</div>
</div>
</body>
</html>
"""
@self.server.route('/api/recalculate-pivots', methods=['POST'])
def recalculate_pivots():
"""Recalculate pivot points for merged data"""
try:
data = request.get_json()
symbol = data.get('symbol', 'ETH/USDT')
timeframe = data.get('timeframe')
timestamps = data.get('timestamps', [])
ohlcv_data = data.get('ohlcv', {})
if not timeframe or not timestamps:
return jsonify({
'success': False,
'error': {'code': 'INVALID_REQUEST', 'message': 'Missing timeframe or timestamps'}
})
logger.info(f" Recalculating pivots for {symbol} {timeframe} with {len(timestamps)} candles")
# Convert to DataFrame
df = pd.DataFrame({
'open': ohlcv_data.get('open', []),
'high': ohlcv_data.get('high', []),
'low': ohlcv_data.get('low', []),
'close': ohlcv_data.get('close', []),
'volume': ohlcv_data.get('volume', [])
})
df.index = pd.to_datetime(timestamps)
# Recalculate pivot markers
pivot_markers = self._get_pivot_markers_for_timeframe(symbol, timeframe, df)
logger.info(f" Recalculated {len(pivot_markers)} pivot candles")
return jsonify({
'success': True,
'pivot_markers': pivot_markers
})
except Exception as e:
logger.error(f"Error recalculating pivots: {e}")
return jsonify({
'success': False,
'error': {'code': 'RECALC_ERROR', 'message': str(e)}
})
@self.server.route('/api/chart-data', methods=['POST'])
def get_chart_data():
"""Get chart data for specified symbol and timeframes with infinite scroll support"""
try:
data = request.get_json()
symbol = data.get('symbol', 'ETH/USDT')
timeframes = data.get('timeframes', ['1s', '1m', '1h', '1d'])
start_time_str = data.get('start_time')
end_time_str = data.get('end_time')
limit = data.get('limit', 2500) # Default 2500 candles for training
direction = data.get('direction', 'latest') # 'latest', 'before', or 'after'
logger.info(f"Chart data request: {symbol} {timeframes} direction={direction} limit={limit}")
if start_time_str:
logger.info(f" start_time: {start_time_str}")
if end_time_str:
logger.info(f" end_time: {end_time_str}")
if not self.data_loader:
return jsonify({
'success': False,
'error': {
'code': 'DATA_LOADER_UNAVAILABLE',
'message': 'Data loader not available'
}
})
# Parse time strings if provided
start_time = datetime.fromisoformat(start_time_str.replace('Z', '+00:00')) if start_time_str else None
end_time = datetime.fromisoformat(end_time_str.replace('Z', '+00:00')) if end_time_str else None
# Fetch data for each timeframe using data loader
# This will automatically:
# 1. Check DuckDB first
# 2. Fetch from API if not in cache
# 3. Store in DuckDB for future use
chart_data = {}
for timeframe in timeframes:
df = self.data_loader.get_data(
symbol=symbol,
timeframe=timeframe,
start_time=start_time,
end_time=end_time,
limit=limit,
direction=direction
)
if df is not None and not df.empty:
logger.info(f" {timeframe}: {len(df)} candles ({df.index[0]} to {df.index[-1]})")
# Get pivot points for this timeframe
pivot_markers = self._get_pivot_markers_for_timeframe(symbol, timeframe, df)
# Convert to format suitable for Plotly
chart_data[timeframe] = {
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': df['open'].tolist(),
'high': df['high'].tolist(),
'low': df['low'].tolist(),
'close': df['close'].tolist(),
'volume': df['volume'].tolist(),
'pivot_markers': pivot_markers # Optional: only present if pivots exist
}
else:
logger.warning(f" {timeframe}: No data returned")
# Get pivot bounds for the symbol
pivot_bounds = None
if self.data_provider:
try:
pivot_bounds = self.data_provider.get_pivot_bounds(symbol)
if pivot_bounds:
logger.info(f"Found pivot bounds for {symbol}: {len(pivot_bounds.pivot_support_levels)} support, {len(pivot_bounds.pivot_resistance_levels)} resistance")
except Exception as e:
logger.error(f"Error getting pivot bounds: {e}")
return jsonify({
'success': True,
'chart_data': chart_data,
'pivot_bounds': {
'support_levels': pivot_bounds.pivot_support_levels if pivot_bounds else [],
'resistance_levels': pivot_bounds.pivot_resistance_levels if pivot_bounds else [],
'price_range': {
'min': pivot_bounds.price_min if pivot_bounds else None,
'max': pivot_bounds.price_max if pivot_bounds else None
},
'volume_range': {
'min': pivot_bounds.volume_min if pivot_bounds else None,
'max': pivot_bounds.volume_max if pivot_bounds else None
},
'timeframe': '1m', # Pivot bounds are calculated from 1m data
'period': '30 days', # Monthly data
'total_levels': len(pivot_bounds.pivot_support_levels) + len(pivot_bounds.pivot_resistance_levels) if pivot_bounds else 0
} if pivot_bounds else None
})
except Exception as e:
logger.error(f"Error fetching chart data: {e}")
return jsonify({
'success': False,
'error': {
'code': 'CHART_DATA_ERROR',
'message': str(e)
}
})
@self.server.route('/api/save-annotation', methods=['POST'])
def save_annotation():
"""Save a new annotation with full market context"""
try:
data = request.get_json()
# Capture market state at entry and exit times using data provider
entry_market_state = {}
exit_market_state = {}
if self.data_provider:
try:
# Parse timestamps
entry_time = datetime.fromisoformat(data['entry']['timestamp'].replace('Z', '+00:00'))
exit_time = datetime.fromisoformat(data['exit']['timestamp'].replace('Z', '+00:00'))
# Use the new data provider method to get market state at entry time
entry_market_state = self.data_provider.get_market_state_at_time(
symbol=data['symbol'],
timestamp=entry_time,
context_window_minutes=5
)
# Use the new data provider method to get market state at exit time
exit_market_state = self.data_provider.get_market_state_at_time(
symbol=data['symbol'],
timestamp=exit_time,
context_window_minutes=5
)
logger.info(f"Captured market state: {len(entry_market_state)} timeframes at entry, {len(exit_market_state)} at exit")
except Exception as e:
logger.error(f"Error capturing market state: {e}")
import traceback
traceback.print_exc()
# Create annotation with market context
annotation = self.annotation_manager.create_annotation(
entry_point=data['entry'],
exit_point=data['exit'],
symbol=data['symbol'],
timeframe=data['timeframe'],
entry_market_state=entry_market_state,
exit_market_state=exit_market_state
)
# Collect market snapshots for SQLite storage
market_snapshots = {}
if self.data_loader:
try:
# Get OHLCV data for all timeframes around the annotation time
entry_time = datetime.fromisoformat(data['entry']['timestamp'].replace('Z', '+00:00'))
exit_time = datetime.fromisoformat(data['exit']['timestamp'].replace('Z', '+00:00'))
# Get data from 5 minutes before entry to 5 minutes after exit
start_time = entry_time - timedelta(minutes=5)
end_time = exit_time + timedelta(minutes=5)
for timeframe in ['1s', '1m', '1h', '1d']:
df = self.data_loader.get_data(
symbol=data['symbol'],
timeframe=timeframe,
start_time=start_time,
end_time=end_time,
limit=1500
)
if df is not None and not df.empty:
market_snapshots[timeframe] = df
logger.info(f"Collected {len(market_snapshots)} timeframes for annotation storage")
except Exception as e:
logger.error(f"Error collecting market snapshots: {e}")
# Save annotation with market snapshots
self.annotation_manager.save_annotation(
annotation=annotation,
market_snapshots=market_snapshots
)
# Automatically generate test case with ±5min data
try:
test_case = self.annotation_manager.generate_test_case(
annotation,
data_provider=self.data_provider,
auto_save=True
)
# Log test case details
market_state = test_case.get('market_state', {})
timeframes_with_data = [k for k in market_state.keys() if k.startswith('ohlcv_')]
logger.info(f"Auto-generated test case: {test_case['test_case_id']}")
logger.info(f" Timeframes: {timeframes_with_data}")
for tf_key in timeframes_with_data:
candle_count = len(market_state[tf_key].get('timestamps', []))
logger.info(f" {tf_key}: {candle_count} candles")
if 'training_labels' in market_state:
logger.info(f" Training labels: {len(market_state['training_labels'].get('labels_1m', []))} labels")
except Exception as e:
logger.error(f"Failed to auto-generate test case: {e}")
import traceback
traceback.print_exc()
return jsonify({
'success': True,
'annotation': annotation.__dict__ if hasattr(annotation, '__dict__') else annotation
})
except Exception as e:
logger.error(f"Error saving annotation: {e}")
return jsonify({
'success': False,
'error': {
'code': 'SAVE_ANNOTATION_ERROR',
'message': str(e)
}
})
@self.server.route('/api/delete-annotation', methods=['POST'])
def delete_annotation():
"""Delete an annotation"""
try:
data = request.get_json()
annotation_id = data['annotation_id']
# Delete annotation and check if it was found
deleted = self.annotation_manager.delete_annotation(annotation_id)
if deleted:
return jsonify({
'success': True,
'message': 'Annotation deleted successfully'
})
else:
return jsonify({
'success': False,
'error': {
'code': 'ANNOTATION_NOT_FOUND',
'message': f'Annotation {annotation_id} not found'
}
})
except Exception as e:
logger.error(f"Error deleting annotation: {e}", exc_info=True)
return jsonify({
'success': False,
'error': {
'code': 'DELETE_ANNOTATION_ERROR',
'message': str(e)
}
})
@self.server.route('/api/clear-all-annotations', methods=['POST'])
def clear_all_annotations():
"""Clear all annotations"""
try:
data = request.get_json() or {}
symbol = data.get('symbol', None)
# Use the efficient clear_all_annotations method
deleted_count = self.annotation_manager.clear_all_annotations(symbol=symbol)
if deleted_count == 0:
return jsonify({
'success': True,
'deleted_count': 0,
'message': 'No annotations to clear'
})
logger.info(f"Cleared {deleted_count} annotations" + (f" for symbol {symbol}" if symbol else ""))
return jsonify({
'success': True,
'deleted_count': deleted_count,
'message': f'Cleared {deleted_count} annotations'
})
except Exception as e:
logger.error(f"Error clearing all annotations: {e}")
import traceback
logger.error(traceback.format_exc())
return jsonify({
'success': False,
'error': {
'code': 'CLEAR_ALL_ANNOTATIONS_ERROR',
'message': str(e)
}
})
@self.server.route('/api/refresh-data', methods=['POST'])
def refresh_data():
"""Refresh chart data from data provider"""
try:
data = request.get_json()
symbol = data.get('symbol', 'ETH/USDT')
timeframes = data.get('timeframes', ['1s', '1m', '1h', '1d'])
logger.info(f"Refreshing data for {symbol} with timeframes: {timeframes}")
# Force refresh data from data provider
chart_data = {}
if self.data_provider:
for timeframe in timeframes:
try:
# Force refresh by setting refresh=True
df = self.data_provider.get_historical_data(
symbol=symbol,
timeframe=timeframe,
limit=1000,
refresh=True
)
if df is not None and not df.empty:
# Get pivot markers for this timeframe
pivot_markers = self._get_pivot_markers_for_timeframe(symbol, timeframe, df)
chart_data[timeframe] = {
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': df['open'].tolist(),
'high': df['high'].tolist(),
'low': df['low'].tolist(),
'close': df['close'].tolist(),
'volume': df['volume'].tolist(),
'pivot_markers': pivot_markers # Optional: only present if pivots exist
}
logger.info(f"Refreshed {timeframe}: {len(df)} candles")
else:
logger.warning(f"No data available for {timeframe}")
except Exception as e:
logger.error(f"Error refreshing {timeframe} data: {e}")
# Get pivot bounds for the symbol
pivot_bounds = None
if self.data_provider:
try:
pivot_bounds = self.data_provider.get_pivot_bounds(symbol)
if pivot_bounds:
logger.info(f"Found pivot bounds for {symbol}: {len(pivot_bounds.pivot_support_levels)} support, {len(pivot_bounds.pivot_resistance_levels)} resistance")
except Exception as e:
logger.error(f"Error getting pivot bounds: {e}")
return jsonify({
'success': True,
'chart_data': chart_data,
'pivot_bounds': {
'support_levels': pivot_bounds.pivot_support_levels if pivot_bounds else [],
'resistance_levels': pivot_bounds.pivot_resistance_levels if pivot_bounds else [],
'price_range': {
'min': pivot_bounds.price_min if pivot_bounds else None,
'max': pivot_bounds.price_max if pivot_bounds else None
},
'volume_range': {
'min': pivot_bounds.volume_min if pivot_bounds else None,
'max': pivot_bounds.volume_max if pivot_bounds else None
},
'timeframe': '1m', # Pivot bounds are calculated from 1m data
'period': '30 days', # Monthly data
'total_levels': len(pivot_bounds.pivot_support_levels) + len(pivot_bounds.pivot_resistance_levels) if pivot_bounds else 0
} if pivot_bounds else None,
'message': f'Refreshed data for {symbol}'
})
except Exception as e:
logger.error(f"Error refreshing data: {e}")
return jsonify({
'success': False,
'error': {
'code': 'REFRESH_DATA_ERROR',
'message': str(e)
}
})
@self.server.route('/api/generate-test-case', methods=['POST'])
def generate_test_case():
"""Generate test case from annotation"""
try:
data = request.get_json()
annotation_id = data['annotation_id']
# Get annotation
annotations = self.annotation_manager.get_annotations()
annotation = next((a for a in annotations
if (a.annotation_id if hasattr(a, 'annotation_id')
else a.get('annotation_id')) == annotation_id), None)
if not annotation:
return jsonify({
'success': False,
'error': {
'code': 'ANNOTATION_NOT_FOUND',
'message': 'Annotation not found'
}
})
# Generate test case with market context
test_case = self.annotation_manager.generate_test_case(
annotation,
data_provider=self.data_provider
)
return jsonify({
'success': True,
'test_case': test_case
})
except Exception as e:
logger.error(f"Error generating test case: {e}")
return jsonify({
'success': False,
'error': {
'code': 'GENERATE_TESTCASE_ERROR',
'message': str(e)
}
})
@self.server.route('/api/get-annotations', methods=['POST'])
def get_annotations_api():
"""Get annotations filtered by symbol"""
try:
data = request.get_json()
symbol = data.get('symbol', 'ETH/USDT')
# Get annotations for this symbol
annotations = self.annotation_manager.get_annotations(symbol=symbol)
# Convert to serializable format
annotations_data = []
for ann in annotations:
if hasattr(ann, '__dict__'):
ann_dict = ann.__dict__
else:
ann_dict = ann
annotations_data.append({
'annotation_id': ann_dict.get('annotation_id'),
'symbol': ann_dict.get('symbol'),
'timeframe': ann_dict.get('timeframe'),
'entry': ann_dict.get('entry'),
'exit': ann_dict.get('exit'),
'direction': ann_dict.get('direction'),
'profit_loss_pct': ann_dict.get('profit_loss_pct'),
'notes': ann_dict.get('notes', ''),
'created_at': ann_dict.get('created_at')
})
logger.info(f"Returning {len(annotations_data)} annotations for {symbol}")
return jsonify({
'success': True,
'annotations': annotations_data,
'symbol': symbol,
'count': len(annotations_data)
})
except Exception as e:
logger.error(f"Error getting annotations: {e}")
return jsonify({
'success': False,
'error': str(e)
})
@self.server.route('/api/export-annotations', methods=['POST'])
def export_annotations():
"""Export annotations to file"""
try:
data = request.get_json()
symbol = data.get('symbol')
format_type = data.get('format', 'json')
# Get annotations
annotations = self.annotation_manager.get_annotations(symbol=symbol)
# Export to file
output_path = self.annotation_manager.export_annotations(
annotations=annotations,
format_type=format_type
)
return send_file(output_path, as_attachment=True)
except Exception as e:
logger.error(f"Error exporting annotations: {e}")
return jsonify({
'success': False,
'error': {
'code': 'EXPORT_ERROR',
'message': str(e)
}
})
@self.server.route('/api/train-model', methods=['POST'])
def train_model():
"""Start model training with annotations"""
try:
if not self.training_adapter:
return jsonify({
'success': False,
'error': {
'code': 'TRAINING_UNAVAILABLE',
'message': 'Real training adapter not available'
}
})
data = request.get_json()
model_name = data['model_name']
annotation_ids = data.get('annotation_ids', [])
# CRITICAL: Get current symbol to filter annotations
current_symbol = data.get('symbol', 'ETH/USDT')
# If no specific annotations provided, use all for current symbol
if not annotation_ids:
annotations = self.annotation_manager.get_annotations(symbol=current_symbol)
annotation_ids = [
a.annotation_id if hasattr(a, 'annotation_id') else a.get('annotation_id')
for a in annotations
]
logger.info(f"Using all {len(annotation_ids)} annotations for {current_symbol}")
# Load test cases from disk (they were auto-generated when annotations were saved)
# Filter by current symbol to avoid cross-symbol training
all_test_cases = self.annotation_manager.get_all_test_cases(symbol=current_symbol)
# Filter to selected annotations
test_cases = [
tc for tc in all_test_cases
if tc['test_case_id'].replace('annotation_', '') in annotation_ids
]
if not test_cases:
return jsonify({
'success': False,
'error': {
'code': 'NO_TEST_CASES',
'message': f'No test cases found for {len(annotation_ids)} annotations'
}
})
logger.info(f"Starting REAL training with {len(test_cases)} test cases for model {model_name}")
# Start REAL training (NO SIMULATION!)
training_id = self.training_adapter.start_training(
model_name=model_name,
test_cases=test_cases
)
return jsonify({
'success': True,
'training_id': training_id,
'test_cases_count': len(test_cases)
})
except Exception as e:
logger.error(f"Error starting training: {e}")
return jsonify({
'success': False,
'error': {
'code': 'TRAINING_ERROR',
'message': str(e)
}
})
@self.server.route('/api/training-progress', methods=['POST'])
def get_training_progress():
"""Get training progress"""
try:
if not self.training_adapter:
return jsonify({
'success': False,
'error': {
'code': 'TRAINING_UNAVAILABLE',
'message': 'Real training adapter not available'
}
})
data = request.get_json()
training_id = data['training_id']
progress = self.training_adapter.get_training_progress(training_id)
return jsonify({
'success': True,
'progress': progress
})
except Exception as e:
logger.error(f"Error getting training progress: {e}")
return jsonify({
'success': False,
'error': {
'code': 'PROGRESS_ERROR',
'message': str(e)
}
})
# Backtest API Endpoints
@self.server.route('/api/backtest', methods=['POST'])
def start_backtest():
"""Start backtest on visible chart data"""
try:
data = request.get_json()
model_name = data['model_name']
symbol = data['symbol']
timeframe = data['timeframe']
start_time = data.get('start_time')
end_time = data.get('end_time')
# Get the loaded model
if model_name not in self.loaded_models:
return jsonify({
'success': False,
'error': f'Model {model_name} not loaded. Please load it first.'
})
model = self.loaded_models[model_name]
# Generate backtest ID
backtest_id = str(uuid.uuid4())
# Start backtest in background
self.backtest_runner.start_backtest(
backtest_id=backtest_id,
model=model,
data_provider=self.data_provider,
symbol=symbol,
timeframe=timeframe,
start_time=start_time,
end_time=end_time
)
# Get initial state
progress = self.backtest_runner.get_progress(backtest_id)
return jsonify({
'success': True,
'backtest_id': backtest_id,
'total_candles': progress.get('total_candles', 0)
})
except Exception as e:
logger.error(f"Error starting backtest: {e}", exc_info=True)
return jsonify({
'success': False,
'error': str(e)
})
@self.server.route('/api/backtest/progress/<backtest_id>', methods=['GET'])
def get_backtest_progress(backtest_id):
"""Get backtest progress"""
try:
progress = self.backtest_runner.get_progress(backtest_id)
return jsonify(progress)
except Exception as e:
logger.error(f"Error getting backtest progress: {e}")
return jsonify({
'success': False,
'error': str(e)
})
@self.server.route('/api/backtest/stop', methods=['POST'])
def stop_backtest():
"""Stop running backtest"""
try:
data = request.get_json()
backtest_id = data['backtest_id']
self.backtest_runner.stop_backtest(backtest_id)
return jsonify({
'success': True
})
except Exception as e:
logger.error(f"Error stopping backtest: {e}")
return jsonify({
'success': False,
'error': str(e)
})
@self.server.route('/api/active-training', methods=['GET'])
def get_active_training():
"""
Get currently active training session (if any)
Allows UI to resume tracking after page reload or across multiple clients
"""
try:
if not self.training_adapter:
return jsonify({
'success': False,
'active': False,
'error': {
'code': 'TRAINING_UNAVAILABLE',
'message': 'Real training adapter not available'
}
})
active_session = self.training_adapter.get_active_training_session()
if active_session:
return jsonify({
'success': True,
'active': True,
'session': active_session
})
else:
return jsonify({
'success': True,
'active': False
})
except Exception as e:
logger.error(f"Error getting active training: {e}")
return jsonify({
'success': False,
'active': False,
'error': {
'code': 'ACTIVE_TRAINING_ERROR',
'message': str(e)
}
})
# Live Training API Endpoints
@self.server.route('/api/live-training/start', methods=['POST'])
def start_live_training():
"""Start live inference and training mode"""
try:
if not self.orchestrator:
return jsonify({
'success': False,
'error': 'Orchestrator not available'
}), 500
if self.orchestrator.start_live_training():
return jsonify({
'success': True,
'status': 'started',
'message': 'Live training mode started'
})
else:
return jsonify({
'success': False,
'error': 'Failed to start live training'
}), 500
except Exception as e:
logger.error(f"Error starting live training: {e}")
return jsonify({
'success': False,
'error': str(e)
}), 500
@self.server.route('/api/live-training/stop', methods=['POST'])
def stop_live_training():
"""Stop live inference and training mode"""
try:
if not self.orchestrator:
return jsonify({
'success': False,
'error': 'Orchestrator not available'
}), 500
if self.orchestrator.stop_live_training():
return jsonify({
'success': True,
'status': 'stopped',
'message': 'Live training mode stopped'
})
else:
return jsonify({
'success': False,
'error': 'Failed to stop live training'
}), 500
except Exception as e:
logger.error(f"Error stopping live training: {e}")
return jsonify({
'success': False,
'error': str(e)
}), 500
@self.server.route('/api/live-training/status', methods=['GET'])
def get_live_training_status():
"""Get live training status and statistics"""
try:
if not self.orchestrator:
return jsonify({
'success': False,
'active': False,
'error': 'Orchestrator not available'
})
is_active = self.orchestrator.is_live_training_active()
stats = self.orchestrator.get_live_training_stats() if is_active else {}
return jsonify({
'success': True,
'active': is_active,
'stats': stats
})
except Exception as e:
logger.error(f"Error getting live training status: {e}")
return jsonify({
'success': False,
'active': False,
'error': str(e)
})
@self.server.route('/api/available-models', methods=['GET'])
def get_available_models():
"""Get list of available models with their load status"""
try:
# Use self.available_models which is a simple list of strings
# Don't call training_adapter.get_available_models() as it may return objects
# Build model state dict with checkpoint info
logger.info(f"Building model states for {len(self.available_models)} models: {self.available_models}")
model_states = []
for model_name in self.available_models:
is_loaded = model_name in self.loaded_models
# Get checkpoint info (even for unloaded models)
checkpoint_info = None
# If loaded, get from orchestrator
if is_loaded and self.orchestrator:
if model_name == 'Transformer' and hasattr(self.orchestrator, 'transformer_checkpoint_info'):
cp_info = self.orchestrator.transformer_checkpoint_info
if cp_info and cp_info.get('status') == 'loaded':
checkpoint_info = {
'filename': cp_info.get('filename', 'unknown'),
'epoch': cp_info.get('epoch', 0),
'loss': cp_info.get('loss', 0.0),
'accuracy': cp_info.get('accuracy', 0.0),
'loaded_at': cp_info.get('loaded_at', ''),
'source': 'loaded'
}
# If not loaded, try to read best checkpoint from disk (filename parsing only)
if not checkpoint_info:
try:
cp_info = self._get_best_checkpoint_info(model_name)
if cp_info:
checkpoint_info = cp_info
checkpoint_info['source'] = 'disk'
except Exception as e:
logger.warning(f"Could not read checkpoint for {model_name}: {e}")
# Continue without checkpoint info - not critical
model_states.append({
'name': model_name,
'loaded': is_loaded,
'can_train': is_loaded,
'can_infer': is_loaded,
'checkpoint': checkpoint_info # Checkpoint metadata (loaded or from disk)
})
logger.info(f"Returning {len(model_states)} model states")
return jsonify({
'success': True,
'models': model_states,
'loaded_count': len(self.loaded_models),
'available_count': len(self.available_models)
})
except Exception as e:
logger.error(f"Error getting available models: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
return jsonify({
'success': False,
'error': {
'code': 'MODEL_LIST_ERROR',
'message': str(e)
}
})
@self.server.route('/api/load-model', methods=['POST'])
def load_model():
"""Load a specific model on demand"""
try:
data = request.get_json()
model_name = data.get('model_name')
if not model_name:
return jsonify({
'success': False,
'error': 'model_name is required'
})
# Load the model
result = self._load_model_lazy(model_name)
return jsonify(result)
except Exception as e:
logger.error(f"Error in load_model endpoint: {e}")
return jsonify({
'success': False,
'error': str(e)
})
@self.server.route('/api/realtime-inference/start', methods=['POST'])
def start_realtime_inference():
"""Start real-time inference mode with optional live training on L2 pivots"""
try:
data = request.get_json()
model_name = data.get('model_name')
symbol = data.get('symbol', 'ETH/USDT')
enable_live_training = data.get('enable_live_training', True) # Default: enabled
if not self.training_adapter:
return jsonify({
'success': False,
'error': {
'code': 'TRAINING_UNAVAILABLE',
'message': 'Real training adapter not available'
}
})
# Start real-time inference with optional live training
inference_id = self.training_adapter.start_realtime_inference(
model_name=model_name,
symbol=symbol,
data_provider=self.data_provider,
enable_live_training=enable_live_training
)
return jsonify({
'success': True,
'inference_id': inference_id,
'live_training_enabled': enable_live_training
})
except Exception as e:
logger.error(f"Error starting real-time inference: {e}")
return jsonify({
'success': False,
'error': {
'code': 'INFERENCE_START_ERROR',
'message': str(e)
}
})
@self.server.route('/api/realtime-inference/stop', methods=['POST'])
def stop_realtime_inference():
"""Stop real-time inference mode"""
try:
data = request.get_json()
inference_id = data.get('inference_id')
if not self.training_adapter:
return jsonify({
'success': False,
'error': {
'code': 'TRAINING_UNAVAILABLE',
'message': 'Real training adapter not available'
}
})
self.training_adapter.stop_realtime_inference(inference_id)
return jsonify({
'success': True
})
except Exception as e:
logger.error(f"Error stopping real-time inference: {e}")
return jsonify({
'success': False,
'error': {
'code': 'INFERENCE_STOP_ERROR',
'message': str(e)
}
})
@self.server.route('/api/realtime-inference/signals', methods=['GET'])
def get_realtime_signals():
"""Get latest real-time inference signals"""
try:
if not self.training_adapter:
return jsonify({
'success': False,
'error': {
'code': 'TRAINING_UNAVAILABLE',
'message': 'Real training adapter not available'
}
})
signals = self.training_adapter.get_latest_signals()
return jsonify({
'success': True,
'signals': signals
})
except Exception as e:
logger.error(f"Error getting signals: {e}")
return jsonify({
'success': False,
'error': {
'code': 'SIGNALS_ERROR',
'message': str(e)
}
})
# WebSocket event handlers (if SocketIO is available)
if self.has_socketio:
self._setup_websocket_handlers()
def _setup_websocket_handlers(self):
"""Setup WebSocket event handlers for real-time updates"""
if not self.has_socketio:
return
@self.socketio.on('connect')
def handle_connect():
"""Handle client connection"""
logger.info(f"WebSocket client connected")
from flask_socketio import emit
emit('connection_response', {'status': 'connected', 'message': 'Connected to ANNOTATE live updates'})
@self.socketio.on('disconnect')
def handle_disconnect():
"""Handle client disconnection"""
logger.info(f"WebSocket client disconnected")
@self.socketio.on('subscribe_live_updates')
def handle_subscribe(data):
"""Subscribe to live chart and prediction updates"""
from flask_socketio import emit, join_room
symbol = data.get('symbol', 'ETH/USDT')
timeframe = data.get('timeframe', '1s')
room = f"{symbol}_{timeframe}"
join_room(room)
logger.info(f"Client subscribed to live updates: {room}")
emit('subscription_confirmed', {'room': room, 'symbol': symbol, 'timeframe': timeframe})
# Start live update thread if not already running
if not hasattr(self, '_live_update_thread') or not self._live_update_thread.is_alive():
self._start_live_update_thread()
@self.socketio.on('request_prediction')
def handle_prediction_request(data):
"""Handle manual prediction request"""
from flask_socketio import emit
try:
symbol = data.get('symbol', 'ETH/USDT')
timeframe = data.get('timeframe', '1s')
prediction_steps = data.get('prediction_steps', 1)
# Get prediction from model
prediction = self._get_live_prediction(symbol, timeframe, prediction_steps)
emit('prediction_update', prediction)
except Exception as e:
logger.error(f"Error handling prediction request: {e}")
emit('prediction_error', {'error': str(e)})
def _start_live_update_thread(self):
"""Start background thread for live updates"""
import threading
def live_update_worker():
"""Background worker for live updates"""
import time
from flask_socketio import emit
logger.info("Live update thread started")
while True:
try:
# Get active rooms (symbol_timeframe combinations)
# For now, update all subscribed clients every second
# Get latest chart data
if self.data_provider:
for symbol in ['ETH/USDT', 'BTC/USDT']: # TODO: Get from active subscriptions
for timeframe in ['1s', '1m']:
room = f"{symbol}_{timeframe}"
# Get latest candle
try:
candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=1)
if candles and len(candles) > 0:
latest_candle = candles[-1]
# Emit chart update
self.socketio.emit('chart_update', {
'symbol': symbol,
'timeframe': timeframe,
'candle': {
'timestamp': latest_candle.get('timestamp'),
'open': latest_candle.get('open'),
'high': latest_candle.get('high'),
'low': latest_candle.get('low'),
'close': latest_candle.get('close'),
'volume': latest_candle.get('volume')
}
}, room=room)
# Get prediction if model is loaded
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer'):
prediction = self._get_live_prediction(symbol, timeframe, 1)
if prediction:
self.socketio.emit('prediction_update', prediction, room=room)
except Exception as e:
logger.debug(f"Error getting data for {symbol} {timeframe}: {e}")
time.sleep(1) # Update every second
except Exception as e:
logger.error(f"Error in live update thread: {e}")
time.sleep(5) # Wait longer on error
self._live_update_thread = threading.Thread(target=live_update_worker, daemon=True)
self._live_update_thread.start()
def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1):
"""Get live prediction from model"""
try:
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
return None
# Get recent candles for prediction
candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=200)
if not candles or len(candles) < 200:
return None
# TODO: Implement actual prediction logic
# For now, return placeholder
import random
return {
'symbol': symbol,
'timeframe': timeframe,
'timestamp': datetime.now().isoformat(),
'action': random.choice(['BUY', 'SELL', 'HOLD']),
'confidence': random.uniform(0.6, 0.95),
'predicted_price': candles[-1].get('close', 0) * (1 + random.uniform(-0.01, 0.01)),
'prediction_steps': prediction_steps
}
except Exception as e:
logger.error(f"Error getting live prediction: {e}")
return None
def run(self, host='127.0.0.1', port=8051, debug=False):
"""Run the application"""
logger.info(f"Starting Annotation Dashboard on http://{host}:{port}")
if self.has_socketio:
logger.info("Running with WebSocket support (SocketIO)")
self.socketio.run(self.server, host=host, port=port, debug=debug, allow_unsafe_werkzeug=True)
else:
logger.warning("Running without WebSocket support - install flask-socketio for live updates")
self.server.run(host=host, port=port, debug=debug)
def main():
"""Main entry point"""
logger.info("=" * 80)
logger.info("ANNOTATE Application Starting")
logger.info(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
logger.info("=" * 80)
dashboard = AnnotationDashboard()
dashboard.run(debug=True)
if __name__ == '__main__':
main()