3555 lines
161 KiB
Python
3555 lines
161 KiB
Python
"""
|
|
Manual Trade Annotation UI - Main Application
|
|
|
|
A web-based interface for manually marking profitable buy/sell signals on historical
|
|
market data to generate training test cases for machine learning models.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add parent directory to path for imports
|
|
parent_dir = Path(__file__).parent.parent.parent
|
|
sys.path.insert(0, str(parent_dir))
|
|
|
|
from flask import Flask, render_template, request, jsonify, send_file
|
|
from dash import Dash, html
|
|
import logging
|
|
from datetime import datetime, timezone, timedelta
|
|
from typing import Optional, Dict, List, Any, Tuple
|
|
import json
|
|
import pandas as pd
|
|
import numpy as np
|
|
import threading
|
|
import uuid
|
|
import time
|
|
import torch
|
|
from utils.logging_config import get_channel_logger, LogChannel
|
|
|
|
# Import core components from main system
|
|
try:
|
|
from core.data_provider import DataProvider
|
|
from core.orchestrator import TradingOrchestrator
|
|
from core.config import get_config
|
|
from core.williams_market_structure import WilliamsMarketStructure
|
|
except ImportError as e:
|
|
print(f"Warning: Could not import main system components: {e}")
|
|
print("Running in standalone mode with limited functionality")
|
|
DataProvider = None
|
|
WilliamsMarketStructure = None
|
|
TradingOrchestrator = None
|
|
get_config = lambda: {}
|
|
|
|
# Import ANNOTATE modules
|
|
annotate_dir = Path(__file__).parent.parent
|
|
sys.path.insert(0, str(annotate_dir))
|
|
|
|
try:
|
|
from core.annotation_manager import AnnotationManager
|
|
from core.real_training_adapter import RealTrainingAdapter
|
|
# Using main DataProvider directly instead of duplicate data_loader
|
|
except ImportError:
|
|
# Try alternative import path
|
|
import importlib.util
|
|
|
|
# Load annotation_manager
|
|
ann_spec = importlib.util.spec_from_file_location(
|
|
"annotation_manager",
|
|
annotate_dir / "core" / "annotation_manager.py"
|
|
)
|
|
ann_module = importlib.util.module_from_spec(ann_spec)
|
|
ann_spec.loader.exec_module(ann_module)
|
|
AnnotationManager = ann_module.AnnotationManager
|
|
|
|
# Load real_training_adapter (NO SIMULATION!)
|
|
train_spec = importlib.util.spec_from_file_location(
|
|
"real_training_adapter",
|
|
annotate_dir / "core" / "real_training_adapter.py"
|
|
)
|
|
train_module = importlib.util.module_from_spec(train_spec)
|
|
train_spec.loader.exec_module(train_module)
|
|
RealTrainingAdapter = train_module.RealTrainingAdapter
|
|
|
|
# Using main DataProvider directly - no need for duplicate data_loader
|
|
HistoricalDataLoader = None
|
|
TimeRangeManager = None
|
|
|
|
# Setup logging - configure before any logging occurs
|
|
log_dir = Path(__file__).parent.parent / 'logs'
|
|
log_dir.mkdir(exist_ok=True)
|
|
log_file = log_dir / 'annotate_app.log'
|
|
|
|
# Configure logging to both file and console
|
|
# File mode 'w' truncates the file on each run
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler(log_file, mode='w'), # Truncate on each run
|
|
logging.StreamHandler(sys.stdout) # Also print to console
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
logger.info(f"Logging to: {log_file}")
|
|
|
|
# Create channel-specific loggers
|
|
pivot_logger = get_channel_logger(__name__, LogChannel.PIVOTS)
|
|
api_logger = get_channel_logger(__name__, LogChannel.API)
|
|
webui_logger = get_channel_logger(__name__, LogChannel.WEBUI)
|
|
|
|
|
|
class BacktestRunner:
|
|
"""Runs backtest candle-by-candle with model predictions and tracks PnL"""
|
|
|
|
def __init__(self):
|
|
self.active_backtests = {} # backtest_id -> state
|
|
self.lock = threading.Lock()
|
|
|
|
def start_backtest(self, backtest_id: str, model, data_provider, symbol: str, timeframe: str,
|
|
orchestrator=None, start_time: Optional[str] = None, end_time: Optional[str] = None):
|
|
"""Start backtest in background thread"""
|
|
|
|
# Initialize backtest state
|
|
state = {
|
|
'status': 'running',
|
|
'candles_processed': 0,
|
|
'total_candles': 0,
|
|
'pnl': 0.0,
|
|
'total_trades': 0,
|
|
'wins': 0,
|
|
'losses': 0,
|
|
'new_predictions': [],
|
|
'position': None, # {'type': 'long/short', 'entry_price': float, 'entry_time': str}
|
|
'error': None,
|
|
'stop_requested': False,
|
|
'orchestrator': orchestrator,
|
|
'symbol': symbol
|
|
}
|
|
|
|
# Clear previous predictions from orchestrator
|
|
if orchestrator and hasattr(orchestrator, 'recent_transformer_predictions'):
|
|
if symbol in orchestrator.recent_transformer_predictions:
|
|
orchestrator.recent_transformer_predictions[symbol].clear()
|
|
if symbol in orchestrator.recent_cnn_predictions:
|
|
orchestrator.recent_cnn_predictions[symbol].clear()
|
|
if symbol in orchestrator.recent_dqn_predictions:
|
|
orchestrator.recent_dqn_predictions[symbol].clear()
|
|
logger.info(f"Cleared previous predictions for backtest on {symbol}")
|
|
|
|
with self.lock:
|
|
self.active_backtests[backtest_id] = state
|
|
|
|
# Run backtest in background thread
|
|
thread = threading.Thread(
|
|
target=self._run_backtest,
|
|
args=(backtest_id, model, data_provider, symbol, timeframe, orchestrator, start_time, end_time)
|
|
)
|
|
thread.daemon = True
|
|
thread.start()
|
|
|
|
def _run_backtest(self, backtest_id: str, model, data_provider, symbol: str, timeframe: str,
|
|
orchestrator=None, start_time: Optional[str] = None, end_time: Optional[str] = None):
|
|
"""Execute backtest candle-by-candle"""
|
|
try:
|
|
state = self.active_backtests[backtest_id]
|
|
|
|
# Get historical data
|
|
logger.info(f"Backtest {backtest_id}: Fetching data for {symbol} {timeframe}")
|
|
|
|
# Get candles for the time range
|
|
if start_time and end_time:
|
|
# Parse time range and fetch data
|
|
df = data_provider.get_historical_data(
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
limit=1000 # Max candles
|
|
)
|
|
else:
|
|
# Use last 500 candles
|
|
df = data_provider.get_historical_data(
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
limit=500
|
|
)
|
|
|
|
if df is None or df.empty:
|
|
state['status'] = 'error'
|
|
state['error'] = 'No data available'
|
|
return
|
|
|
|
logger.info(f"Backtest {backtest_id}: Processing {len(df)} candles")
|
|
state['total_candles'] = len(df)
|
|
|
|
# Prepare for inference
|
|
model.eval()
|
|
|
|
# IMPORTANT: Use CPU for backtest to avoid ROCm/HIP compatibility issues
|
|
# GPU inference has kernel compatibility problems with some model architectures
|
|
device = torch.device('cpu')
|
|
model.to(device)
|
|
logger.info(f"Backtest {backtest_id}: Using CPU for stable inference (avoiding ROCm/HIP issues)")
|
|
|
|
# Need at least 200 candles for context
|
|
min_context = 200
|
|
|
|
# Process candles one by one
|
|
for i in range(min_context, len(df)):
|
|
if state['stop_requested']:
|
|
state['status'] = 'stopped'
|
|
break
|
|
|
|
# Get context (last 200 candles)
|
|
context = df.iloc[i-200:i]
|
|
current_candle = df.iloc[i]
|
|
current_time = current_candle.name
|
|
current_price = float(current_candle['close'])
|
|
|
|
# Make prediction
|
|
prediction = self._make_prediction(model, device, context, symbol, timeframe)
|
|
|
|
if prediction:
|
|
# Store prediction for display
|
|
pred_data = {
|
|
'timestamp': str(current_time),
|
|
'price': current_price,
|
|
'action': prediction['action'],
|
|
'confidence': prediction['confidence'],
|
|
'timeframe': timeframe,
|
|
'current_price': current_price
|
|
}
|
|
state['new_predictions'].append(pred_data)
|
|
|
|
# Store in orchestrator for visualization
|
|
if orchestrator and hasattr(orchestrator, 'store_transformer_prediction'):
|
|
# Determine model type from model class name
|
|
model_type = model.__class__.__name__.lower()
|
|
logger.debug(f"Backtest: Storing prediction for model type: {model_type}")
|
|
|
|
# Store in appropriate prediction collection
|
|
if 'transformer' in model_type:
|
|
orchestrator.store_transformer_prediction(symbol, {
|
|
'timestamp': current_time,
|
|
'current_price': current_price,
|
|
'predicted_price': current_price * (1.01 if prediction['action'] == 'BUY' else 0.99),
|
|
'price_change': 1.0 if prediction['action'] == 'BUY' else -1.0,
|
|
'confidence': prediction['confidence'],
|
|
'action': prediction['action'],
|
|
'horizon_minutes': 10
|
|
})
|
|
logger.debug(f"Backtest: Stored transformer prediction: {prediction['action']} @ {current_price}")
|
|
elif 'cnn' in model_type:
|
|
if hasattr(orchestrator, 'recent_cnn_predictions'):
|
|
if symbol not in orchestrator.recent_cnn_predictions:
|
|
from collections import deque
|
|
orchestrator.recent_cnn_predictions[symbol] = deque(maxlen=50)
|
|
orchestrator.recent_cnn_predictions[symbol].append({
|
|
'timestamp': current_time,
|
|
'current_price': current_price,
|
|
'predicted_price': current_price * (1.01 if prediction['action'] == 'BUY' else 0.99),
|
|
'confidence': prediction['confidence'],
|
|
'direction': 2 if prediction['action'] == 'BUY' else 0
|
|
})
|
|
elif 'dqn' in model_type or 'rl' in model_type:
|
|
if hasattr(orchestrator, 'recent_dqn_predictions'):
|
|
if symbol not in orchestrator.recent_dqn_predictions:
|
|
from collections import deque
|
|
orchestrator.recent_dqn_predictions[symbol] = deque(maxlen=100)
|
|
orchestrator.recent_dqn_predictions[symbol].append({
|
|
'timestamp': current_time,
|
|
'current_price': current_price,
|
|
'action': prediction['action'],
|
|
'confidence': prediction['confidence']
|
|
})
|
|
|
|
# Execute trade logic
|
|
self._execute_trade_logic(state, prediction, current_price, current_time)
|
|
|
|
# Update progress
|
|
state['candles_processed'] = i - min_context + 1
|
|
|
|
# Simulate real-time (optional, remove for faster backtest)
|
|
# time.sleep(0.01) # 10ms per candle
|
|
|
|
# Close any open position at end
|
|
if state['position']:
|
|
self._close_position(state, current_price, 'backtest_end')
|
|
|
|
# Calculate final stats
|
|
total_trades = state['total_trades']
|
|
wins = state['wins']
|
|
state['win_rate'] = wins / total_trades if total_trades > 0 else 0
|
|
|
|
state['status'] = 'complete'
|
|
logger.info(f"Backtest {backtest_id}: Complete. PnL=${state['pnl']:.2f}, Trades={total_trades}, Win Rate={state['win_rate']:.1%}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Backtest {backtest_id} error: {e}", exc_info=True)
|
|
state['status'] = 'error'
|
|
state['error'] = str(e)
|
|
|
|
def _make_prediction(self, model, device, context_df, symbol, timeframe):
|
|
"""Make model prediction on context data"""
|
|
try:
|
|
# Convert context to model input format
|
|
# Extract OHLCV data
|
|
candles = context_df[['open', 'high', 'low', 'close', 'volume']].values
|
|
|
|
# Normalize
|
|
candles_normalized = candles.copy()
|
|
price_data = candles[:, :4]
|
|
volume_data = candles[:, 4:5]
|
|
|
|
price_min = price_data.min()
|
|
price_max = price_data.max()
|
|
if price_max > price_min:
|
|
candles_normalized[:, :4] = (price_data - price_min) / (price_max - price_min)
|
|
|
|
volume_min = volume_data.min()
|
|
volume_max = volume_data.max()
|
|
if volume_max > volume_min:
|
|
candles_normalized[:, 4:5] = (volume_data - volume_min) / (volume_max - volume_min)
|
|
|
|
# Convert to tensor [1, 200, 5]
|
|
# Try GPU first, fallback to CPU if GPU fails
|
|
try:
|
|
price_tensor = torch.tensor(candles_normalized, dtype=torch.float32).unsqueeze(0).to(device)
|
|
tech_data = torch.zeros(1, 40, dtype=torch.float32).to(device)
|
|
market_data = torch.zeros(1, 30, dtype=torch.float32).to(device)
|
|
use_cpu = False
|
|
except Exception as gpu_error:
|
|
logger.warning(f"GPU tensor creation failed, using CPU: {gpu_error}")
|
|
device = torch.device('cpu')
|
|
model.to(device)
|
|
price_tensor = torch.tensor(candles_normalized, dtype=torch.float32).unsqueeze(0)
|
|
tech_data = torch.zeros(1, 40, dtype=torch.float32)
|
|
market_data = torch.zeros(1, 30, dtype=torch.float32)
|
|
use_cpu = True
|
|
|
|
# Make prediction
|
|
with torch.no_grad():
|
|
try:
|
|
outputs = model(
|
|
price_data_1m=price_tensor if timeframe == '1m' else None,
|
|
price_data_1s=price_tensor if timeframe == '1s' else None,
|
|
price_data_1h=price_tensor if timeframe == '1h' else None,
|
|
price_data_1d=price_tensor if timeframe == '1d' else None,
|
|
tech_data=tech_data,
|
|
market_data=market_data
|
|
)
|
|
except RuntimeError as model_error:
|
|
# GPU inference failed, retry on CPU
|
|
if not use_cpu and 'HIP' in str(model_error):
|
|
logger.warning(f"GPU inference failed, retrying on CPU: {model_error}")
|
|
device = torch.device('cpu')
|
|
model.to(device)
|
|
price_tensor = price_tensor.cpu()
|
|
tech_data = tech_data.cpu()
|
|
market_data = market_data.cpu()
|
|
|
|
outputs = model(
|
|
price_data_1m=price_tensor if timeframe == '1m' else None,
|
|
price_data_1s=price_tensor if timeframe == '1s' else None,
|
|
price_data_1h=price_tensor if timeframe == '1h' else None,
|
|
price_data_1d=price_tensor if timeframe == '1d' else None,
|
|
tech_data=tech_data,
|
|
market_data=market_data
|
|
)
|
|
else:
|
|
raise
|
|
|
|
# Get action prediction
|
|
action_probs = outputs.get('action_probs', outputs.get('trend_probs'))
|
|
if action_probs is not None:
|
|
action_idx = torch.argmax(action_probs, dim=-1).item()
|
|
confidence = action_probs[0, action_idx].item()
|
|
|
|
# Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL)
|
|
actions = ['HOLD', 'BUY', 'SELL']
|
|
if action_idx < len(actions):
|
|
action = actions[action_idx]
|
|
else:
|
|
# If 4 actions (model has 4 trend directions), map to 3 actions
|
|
action = 'HOLD' if action_idx == 1 else ('BUY' if action_idx in [2, 3] else 'SELL')
|
|
|
|
return {
|
|
'action': action,
|
|
'confidence': confidence
|
|
}
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Prediction error: {e}", exc_info=True)
|
|
return None
|
|
|
|
def _execute_trade_logic(self, state, prediction, current_price, current_time):
|
|
"""Execute trading logic based on prediction"""
|
|
action = prediction['action']
|
|
confidence = prediction['confidence']
|
|
|
|
# Only trade on high confidence
|
|
if confidence < 0.6:
|
|
return
|
|
|
|
position = state['position']
|
|
|
|
if action == 'BUY' and position is None:
|
|
# Enter long position
|
|
state['position'] = {
|
|
'type': 'long',
|
|
'entry_price': current_price,
|
|
'entry_time': current_time
|
|
}
|
|
logger.debug(f"Backtest: ENTER LONG @ ${current_price}")
|
|
|
|
elif action == 'SELL' and position is None:
|
|
# Enter short position
|
|
state['position'] = {
|
|
'type': 'short',
|
|
'entry_price': current_price,
|
|
'entry_time': current_time
|
|
}
|
|
logger.debug(f"Backtest: ENTER SHORT @ ${current_price}")
|
|
|
|
elif position is not None:
|
|
# Check if should exit
|
|
should_exit = False
|
|
|
|
if position['type'] == 'long' and action == 'SELL':
|
|
should_exit = True
|
|
elif position['type'] == 'short' and action == 'BUY':
|
|
should_exit = True
|
|
|
|
if should_exit:
|
|
self._close_position(state, current_price, 'signal')
|
|
|
|
def _close_position(self, state, exit_price, reason):
|
|
"""Close current position and update PnL"""
|
|
position = state['position']
|
|
if not position:
|
|
return
|
|
|
|
entry_price = position['entry_price']
|
|
|
|
# Calculate PnL
|
|
if position['type'] == 'long':
|
|
pnl = exit_price - entry_price
|
|
else: # short
|
|
pnl = entry_price - exit_price
|
|
|
|
# Update state
|
|
state['pnl'] += pnl
|
|
state['total_trades'] += 1
|
|
|
|
if pnl > 0:
|
|
state['wins'] += 1
|
|
elif pnl < 0:
|
|
state['losses'] += 1
|
|
|
|
logger.debug(f"Backtest: CLOSE {position['type'].upper()} @ ${exit_price:.2f}, PnL=${pnl:.2f} ({reason})")
|
|
|
|
state['position'] = None
|
|
|
|
def get_progress(self, backtest_id: str) -> Dict:
|
|
"""Get backtest progress"""
|
|
with self.lock:
|
|
state = self.active_backtests.get(backtest_id)
|
|
if not state:
|
|
return {'success': False, 'error': 'Backtest not found'}
|
|
|
|
# Get and clear new predictions (they'll be sent to frontend)
|
|
new_predictions = state['new_predictions']
|
|
state['new_predictions'] = []
|
|
|
|
return {
|
|
'success': True,
|
|
'status': state['status'],
|
|
'candles_processed': state['candles_processed'],
|
|
'total_candles': state['total_candles'],
|
|
'pnl': state['pnl'],
|
|
'total_trades': state['total_trades'],
|
|
'wins': state['wins'],
|
|
'losses': state['losses'],
|
|
'win_rate': state['wins'] / state['total_trades'] if state['total_trades'] > 0 else 0,
|
|
'new_predictions': new_predictions,
|
|
'error': state['error']
|
|
}
|
|
|
|
def stop_backtest(self, backtest_id: str):
|
|
"""Request backtest to stop"""
|
|
with self.lock:
|
|
state = self.active_backtests.get(backtest_id)
|
|
if state:
|
|
state['stop_requested'] = True
|
|
|
|
|
|
class TrainingStrategyManager:
|
|
"""
|
|
Manages training strategies and decisions - Separates business logic from model interface
|
|
|
|
Training Modes:
|
|
- 'none': No training (inference only)
|
|
- 'every_candle': Train on every completed candle
|
|
- 'pivots_only': Train only on pivot points (BUY at L pivots, SELL at H pivots)
|
|
- 'manual': Training triggered manually by user button
|
|
"""
|
|
|
|
def __init__(self, data_provider, training_adapter):
|
|
self.data_provider = data_provider
|
|
self.training_adapter = training_adapter
|
|
self.mode = 'none' # Default: no training
|
|
self.dashboard = None # Set by dashboard after initialization
|
|
|
|
# Statistics tracking
|
|
self.stats = {
|
|
'total_trained': 0,
|
|
'by_action': {'BUY': 0, 'SELL': 0, 'HOLD': 0},
|
|
'profitable': 0
|
|
}
|
|
|
|
def should_train_on_candle(self, symbol: str, timeframe: str, candle_timestamp, pivot_markers: Dict = None) -> Tuple[bool, Optional[Dict]]:
|
|
"""
|
|
Decide if we should train on this candle based on current mode
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
timeframe: Candle timeframe
|
|
candle_timestamp: Timestamp of the candle
|
|
pivot_markers: Dict of pivot markers (timestamp -> pivot data)
|
|
|
|
Returns:
|
|
Tuple of (should_train: bool, action_data: Optional[Dict])
|
|
action_data contains: {'action': 'BUY'/'SELL'/'HOLD', 'pivot_level': int, 'pivot_strength': float}
|
|
"""
|
|
if self.mode == 'none':
|
|
return False, None
|
|
|
|
elif self.mode == 'every_candle':
|
|
# Train on every candle - determine action from price movement or pivots
|
|
action_data = self._get_action_for_candle(symbol, timeframe, candle_timestamp, pivot_markers)
|
|
return True, action_data
|
|
|
|
elif self.mode == 'pivots_only':
|
|
# Train only on pivot candles
|
|
return self._is_pivot_candle(candle_timestamp, pivot_markers)
|
|
|
|
elif self.mode == 'manual':
|
|
# Manual training - don't auto-train
|
|
return False, None
|
|
|
|
return False, None
|
|
|
|
def _get_action_for_candle(self, symbol: str, timeframe: str, candle_timestamp, pivot_markers: Dict = None) -> Dict:
|
|
"""
|
|
Determine action for any candle (pivot or non-pivot)
|
|
For pivot candles: BUY at L, SELL at H
|
|
For non-pivot candles: Use price movement thresholds
|
|
"""
|
|
# First check if it's a pivot candle
|
|
is_pivot, pivot_action = self._is_pivot_candle(candle_timestamp, pivot_markers)
|
|
if is_pivot and pivot_action:
|
|
return pivot_action
|
|
|
|
# Not a pivot - use price movement based logic
|
|
# Get recent candles to determine trend
|
|
df = self.data_provider.get_historical_data(symbol, timeframe, limit=5)
|
|
if df is None or len(df) < 3:
|
|
return {'action': 'HOLD', 'reason': 'insufficient_data'}
|
|
|
|
# Simple momentum: if price going up, BUY, if going down, SELL
|
|
recent_change = (df.iloc[-1]['close'] - df.iloc[-3]['close']) / df.iloc[-3]['close']
|
|
|
|
if recent_change > 0.0005: # 0.05% up
|
|
action = 'BUY'
|
|
elif recent_change < -0.0005: # 0.05% down
|
|
action = 'SELL'
|
|
else:
|
|
action = 'HOLD'
|
|
|
|
return {
|
|
'action': action,
|
|
'reason': 'price_movement',
|
|
'change_pct': recent_change * 100
|
|
}
|
|
|
|
def _is_pivot_candle(self, timestamp, pivot_markers: Dict = None) -> Tuple[bool, Optional[Dict]]:
|
|
"""
|
|
Check if candle is a pivot point and return action
|
|
|
|
Returns:
|
|
Tuple of (is_pivot: bool, action_data: Optional[Dict])
|
|
"""
|
|
if not pivot_markers:
|
|
return False, None
|
|
|
|
candle_timestamp = str(timestamp)
|
|
candle_pivots = pivot_markers.get(candle_timestamp, {})
|
|
|
|
if not candle_pivots:
|
|
return False, None
|
|
|
|
# BUY at L pivots (lows - support levels)
|
|
if 'lows' in candle_pivots and len(candle_pivots['lows']) > 0:
|
|
best_low = max(candle_pivots['lows'], key=lambda p: p.get('level', 0))
|
|
pivot_level = best_low.get('level', 1)
|
|
pivot_strength = best_low.get('strength', 0.5)
|
|
|
|
logger.info(f"L{pivot_level}L pivot detected @ {timestamp}, strength={pivot_strength:.2f} → BUY signal")
|
|
|
|
return True, {
|
|
'action': 'BUY',
|
|
'pivot_level': pivot_level,
|
|
'pivot_strength': pivot_strength,
|
|
'reason': 'low_pivot'
|
|
}
|
|
|
|
# SELL at H pivots (highs - resistance levels)
|
|
elif 'highs' in candle_pivots and len(candle_pivots['highs']) > 0:
|
|
best_high = max(candle_pivots['highs'], key=lambda p: p.get('level', 0))
|
|
pivot_level = best_high.get('level', 1)
|
|
pivot_strength = best_high.get('strength', 0.5)
|
|
|
|
logger.info(f"L{pivot_level}H pivot detected @ {timestamp}, strength={pivot_strength:.2f} → SELL signal")
|
|
|
|
return True, {
|
|
'action': 'SELL',
|
|
'pivot_level': pivot_level,
|
|
'pivot_strength': pivot_strength,
|
|
'reason': 'high_pivot'
|
|
}
|
|
|
|
return False, None
|
|
|
|
def train_manually(self, symbol: str, timeframe: str, action: str) -> Dict:
|
|
"""
|
|
Manually trigger training with specified action
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
timeframe: Timeframe
|
|
action: Action to train ('BUY', 'SELL', or 'HOLD')
|
|
|
|
Returns:
|
|
Training result dict with metrics
|
|
"""
|
|
logger.info(f"Manual training triggered: {action} on {symbol} {timeframe}")
|
|
|
|
# Create action data
|
|
action_data = {
|
|
'action': action,
|
|
'reason': 'manual_trigger'
|
|
}
|
|
|
|
# Update stats
|
|
self.stats['total_trained'] += 1
|
|
self.stats['by_action'][action] = self.stats['by_action'].get(action, 0) + 1
|
|
|
|
return {
|
|
'success': True,
|
|
'action': action,
|
|
'triggered_by': 'manual'
|
|
}
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""Get training statistics"""
|
|
total = self.stats['total_trained']
|
|
if total == 0:
|
|
return {
|
|
'total_trained': 0,
|
|
'by_action': {'BUY': '0%', 'SELL': '0%', 'HOLD': '0%'},
|
|
'mode': self.mode
|
|
}
|
|
|
|
return {
|
|
'total_trained': total,
|
|
'by_action': {
|
|
'BUY': f"{(self.stats['by_action'].get('BUY', 0) / total * 100):.1f}%",
|
|
'SELL': f"{(self.stats['by_action'].get('SELL', 0) / total * 100):.1f}%",
|
|
'HOLD': f"{(self.stats['by_action'].get('HOLD', 0) / total * 100):.1f}%"
|
|
},
|
|
'mode': self.mode
|
|
}
|
|
|
|
|
|
class AnnotationDashboard:
|
|
"""Main annotation dashboard application"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the dashboard"""
|
|
# Load configuration
|
|
try:
|
|
# Always try YAML loading first since get_config might not work in standalone mode
|
|
import yaml
|
|
with open('config.yaml', 'r') as f:
|
|
self.config = yaml.safe_load(f)
|
|
logger.info(f"Loaded config via YAML: {len(self.config)} keys")
|
|
except Exception as e:
|
|
logger.warning(f"Could not load config via YAML: {e}")
|
|
try:
|
|
# Fallback to get_config if available
|
|
if get_config:
|
|
self.config = get_config()
|
|
logger.info(f"Loaded config via get_config: {len(self.config)} keys")
|
|
else:
|
|
raise Exception("get_config not available")
|
|
except Exception as e2:
|
|
logger.warning(f"Could not load config via get_config: {e2}")
|
|
# Final fallback config with SOL/USDT
|
|
self.config = {
|
|
'symbols': ['ETH/USDT', 'BTC/USDT', 'SOL/USDT'],
|
|
'timeframes': ['1s', '1m', '1h', '1d']
|
|
}
|
|
logger.info("Using fallback config")
|
|
|
|
# Initialize Flask app
|
|
self.server = Flask(
|
|
__name__,
|
|
template_folder='templates',
|
|
static_folder='static'
|
|
)
|
|
|
|
# WebSocket support removed - using HTTP polling only
|
|
self.socketio = None
|
|
self.has_socketio = False
|
|
|
|
# Suppress werkzeug request logs (reduce noise from polling endpoints)
|
|
werkzeug_logger = logging.getLogger('werkzeug')
|
|
werkzeug_logger.setLevel(logging.WARNING) # Only show warnings and errors, not INFO
|
|
|
|
# Initialize Dash app (optional component)
|
|
self.app = Dash(
|
|
__name__,
|
|
server=self.server,
|
|
url_base_pathname='/dash/',
|
|
external_stylesheets=[
|
|
'https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css',
|
|
'https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css'
|
|
]
|
|
)
|
|
|
|
# Set a simple Dash layout to avoid NoLayoutException
|
|
self.app.layout = html.Div([
|
|
html.H1("ANNOTATE Dashboard", className="text-center mb-4"),
|
|
html.Div([
|
|
html.P("This is the Dash component of the ANNOTATE system."),
|
|
html.P("The main interface is available at the Flask routes."),
|
|
html.A("Go to Main Interface", href="/", className="btn btn-primary")
|
|
], className="container")
|
|
])
|
|
|
|
# Initialize core components (skip initial load for fast startup)
|
|
try:
|
|
if DataProvider:
|
|
config = get_config()
|
|
self.data_provider = DataProvider(skip_initial_load=True)
|
|
logger.info("DataProvider initialized successfully")
|
|
else:
|
|
self.data_provider = None
|
|
logger.warning("DataProvider class not available")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize DataProvider: {e}")
|
|
self.data_provider = None
|
|
|
|
# Enable unified storage for real-time data access
|
|
if self.data_provider:
|
|
self._enable_unified_storage_async()
|
|
|
|
# ANNOTATE doesn't need orchestrator immediately - lazy load on demand
|
|
self.orchestrator = None
|
|
self.models_loading = False
|
|
self.available_models = ['DQN', 'CNN', 'Transformer'] # Models that CAN be loaded
|
|
self.loaded_models = {} # Models that ARE loaded: {name: model_instance}
|
|
|
|
# Initialize ANNOTATE components
|
|
self.annotation_manager = AnnotationManager()
|
|
# Use REAL training adapter - NO SIMULATION!
|
|
self.training_adapter = RealTrainingAdapter(None, self.data_provider)
|
|
# Initialize training strategy manager (controls training decisions)
|
|
self.training_strategy = TrainingStrategyManager(self.data_provider, self.training_adapter)
|
|
self.training_strategy.dashboard = self
|
|
# WebSocket removed - using HTTP polling only
|
|
# Backtest runner for replaying visible chart with predictions
|
|
self.backtest_runner = BacktestRunner()
|
|
|
|
# NOTE: Prediction caching is now handled by InferenceFrameReference system
|
|
# See ANNOTATE/core/inference_training_system.py for the unified implementation
|
|
|
|
# Check if we should auto-load a model at startup
|
|
auto_load_model = os.getenv('AUTO_LOAD_MODEL', 'Transformer') # Default: Transformer
|
|
|
|
if auto_load_model and auto_load_model.lower() != 'none':
|
|
logger.info(f"Auto-loading model: {auto_load_model}")
|
|
self._auto_load_model(auto_load_model)
|
|
else:
|
|
logger.info("Auto-load disabled. Models available for lazy loading: " + ", ".join(self.available_models))
|
|
|
|
# Use main DataProvider directly instead of duplicate data_loader
|
|
self.data_loader = None # Deprecated - using data_provider directly
|
|
self.time_range_manager = None # Deprecated
|
|
|
|
# Setup routes
|
|
self._setup_routes()
|
|
|
|
# Start background data refresh after startup
|
|
if self.data_provider:
|
|
self._start_background_data_refresh()
|
|
|
|
logger.info("Annotation Dashboard initialized")
|
|
|
|
def _auto_load_model(self, model_name: str):
|
|
"""
|
|
Auto-load a model at startup in background thread
|
|
|
|
Args:
|
|
model_name: Name of model to load (DQN, CNN, or Transformer)
|
|
"""
|
|
def load_in_background():
|
|
try:
|
|
logger.info(f"Starting auto-load for {model_name}...")
|
|
|
|
# Initialize orchestrator if not already done
|
|
if not self.orchestrator:
|
|
logger.info("Initializing TradingOrchestrator...")
|
|
self.orchestrator = TradingOrchestrator(
|
|
data_provider=self.data_provider
|
|
)
|
|
self.training_adapter.orchestrator = self.orchestrator
|
|
logger.info("TradingOrchestrator initialized")
|
|
|
|
# Initialize TradingExecutor for trade execution
|
|
try:
|
|
from core.trading_executor import TradingExecutor
|
|
self.trading_executor = TradingExecutor()
|
|
self.orchestrator.set_trading_executor(self.trading_executor)
|
|
logger.info("TradingExecutor initialized and connected to orchestrator")
|
|
|
|
# Start continuous trading loop
|
|
import asyncio
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
if loop.is_running():
|
|
# Create task in existing loop
|
|
loop.create_task(self.orchestrator.start_continuous_trading())
|
|
logger.info("Continuous trading loop started")
|
|
else:
|
|
logger.warning("No running event loop - trading loop will start when loop is available")
|
|
except RuntimeError:
|
|
logger.warning("No event loop available - trading loop will start when loop is available")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Could not initialize TradingExecutor: {e}")
|
|
self.trading_executor = None
|
|
|
|
# Check if the specific model is already initialized
|
|
if model_name == 'Transformer':
|
|
logger.info("Checking Transformer model...")
|
|
if hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer:
|
|
self.loaded_models['Transformer'] = self.orchestrator.primary_transformer
|
|
logger.info("Transformer model loaded successfully")
|
|
else:
|
|
logger.warning("Transformer model not initialized in orchestrator")
|
|
return
|
|
|
|
elif model_name == 'CNN':
|
|
logger.info("Checking CNN model...")
|
|
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
|
self.loaded_models['CNN'] = self.orchestrator.cnn_model
|
|
logger.info("CNN model loaded successfully")
|
|
else:
|
|
logger.warning("CNN model not initialized in orchestrator")
|
|
return
|
|
|
|
elif model_name == 'DQN':
|
|
logger.info("Checking DQN model...")
|
|
if self.orchestrator.rl_agent:
|
|
self.loaded_models['DQN'] = self.orchestrator.rl_agent
|
|
logger.info("DQN model loaded successfully")
|
|
else:
|
|
logger.warning("DQN model not initialized in orchestrator")
|
|
return
|
|
|
|
else:
|
|
logger.warning(f"Unknown model name: {model_name}")
|
|
return
|
|
|
|
self.models_loading = False
|
|
logger.info(f"{model_name} model ready for inference and training")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error auto-loading {model_name} model: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
self.models_loading = False
|
|
|
|
# Start loading in background thread
|
|
self.models_loading = True
|
|
thread = threading.Thread(target=load_in_background, daemon=True)
|
|
thread.start()
|
|
|
|
def _get_best_checkpoint_info(self, model_name: str) -> Optional[Dict]:
|
|
"""
|
|
Get best checkpoint info for a model without loading it
|
|
First tries database, then falls back to filename parsing
|
|
|
|
Args:
|
|
model_name: Name of the model
|
|
|
|
Returns:
|
|
Dict with checkpoint info or None if no checkpoint found
|
|
"""
|
|
try:
|
|
# Try to get from database first (has full metadata)
|
|
try:
|
|
from utils.database_manager import get_database_manager
|
|
db_manager = get_database_manager()
|
|
|
|
# Get active checkpoint for this model
|
|
with db_manager._get_connection() as conn:
|
|
cursor = conn.execute("""
|
|
SELECT checkpoint_id, performance_metrics, timestamp, file_path
|
|
FROM checkpoint_metadata
|
|
WHERE model_name = ? AND is_active = TRUE
|
|
ORDER BY timestamp DESC
|
|
LIMIT 1
|
|
""", (model_name.lower(),))
|
|
|
|
row = cursor.fetchone()
|
|
if row:
|
|
import json
|
|
checkpoint_id, metrics_json, timestamp, file_path = row
|
|
metrics = json.loads(metrics_json) if metrics_json else {}
|
|
|
|
checkpoint_info = {
|
|
'filename': os.path.basename(file_path) if file_path else checkpoint_id,
|
|
'epoch': metrics.get('epoch', 0),
|
|
'loss': metrics.get('loss'),
|
|
'accuracy': metrics.get('accuracy'),
|
|
'source': 'database'
|
|
}
|
|
|
|
logger.info(f"Loaded checkpoint info from database for {model_name}: E{checkpoint_info['epoch']}, Loss={checkpoint_info['loss']}, Acc={checkpoint_info['accuracy']}")
|
|
return checkpoint_info
|
|
except Exception as db_error:
|
|
logger.debug(f"Could not load from database: {db_error}")
|
|
|
|
# Fallback to filename parsing
|
|
import glob
|
|
import re
|
|
|
|
# Map model names to checkpoint directories
|
|
checkpoint_dirs = {
|
|
'Transformer': 'models/checkpoints/transformer',
|
|
'CNN': 'models/checkpoints/enhanced_cnn',
|
|
'DQN': 'models/checkpoints/dqn_agent'
|
|
}
|
|
|
|
checkpoint_dir = checkpoint_dirs.get(model_name)
|
|
if not checkpoint_dir:
|
|
return None
|
|
|
|
if not os.path.exists(checkpoint_dir):
|
|
logger.debug(f"Checkpoint directory not found: {checkpoint_dir}")
|
|
return None
|
|
|
|
# Find all checkpoint files
|
|
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, '*.pt'))
|
|
if not checkpoint_files:
|
|
logger.debug(f"No checkpoint files found in {checkpoint_dir}")
|
|
return None
|
|
|
|
logger.debug(f"Found {len(checkpoint_files)} checkpoints for {model_name}")
|
|
|
|
# Parse filenames to extract epoch info
|
|
# Format: transformer_epoch5_20251110_123620.pt
|
|
best_checkpoint = None
|
|
best_epoch = -1
|
|
|
|
for cp_file in checkpoint_files:
|
|
try:
|
|
filename = os.path.basename(cp_file)
|
|
|
|
# Extract epoch number from filename
|
|
match = re.search(r'epoch(\d+)', filename, re.IGNORECASE)
|
|
if match:
|
|
epoch = int(match.group(1))
|
|
if epoch > best_epoch:
|
|
best_epoch = epoch
|
|
best_checkpoint = {
|
|
'filename': filename,
|
|
'epoch': epoch,
|
|
'loss': None, # Can't get without loading
|
|
'accuracy': None, # Can't get without loading
|
|
'source': 'filename'
|
|
}
|
|
logger.debug(f"Found checkpoint: {filename}, epoch {epoch}")
|
|
except Exception as e:
|
|
logger.debug(f"Could not parse checkpoint {cp_file}: {e}")
|
|
continue
|
|
|
|
if best_checkpoint:
|
|
logger.info(f"Best checkpoint for {model_name}: {best_checkpoint['filename']} (E{best_checkpoint['epoch']})")
|
|
|
|
return best_checkpoint
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting checkpoint info for {model_name}: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return None
|
|
|
|
def _load_model_lazy(self, model_name: str) -> dict:
|
|
"""
|
|
Lazy load a specific model on demand
|
|
|
|
Args:
|
|
model_name: Name of model to load ('DQN', 'CNN', 'Transformer')
|
|
|
|
Returns:
|
|
dict: Result with success status and message
|
|
"""
|
|
try:
|
|
# Check if already loaded
|
|
if model_name in self.loaded_models:
|
|
return {
|
|
'success': True,
|
|
'message': f'{model_name} already loaded',
|
|
'already_loaded': True
|
|
}
|
|
|
|
# Check if model is available
|
|
if model_name not in self.available_models:
|
|
return {
|
|
'success': False,
|
|
'error': f'{model_name} is not in available models list'
|
|
}
|
|
|
|
logger.info(f"Loading {model_name} model...")
|
|
|
|
# Initialize orchestrator if not already done
|
|
if not self.orchestrator:
|
|
if not TradingOrchestrator:
|
|
return {
|
|
'success': False,
|
|
'error': 'TradingOrchestrator class not available'
|
|
}
|
|
|
|
logger.info("Creating TradingOrchestrator instance...")
|
|
self.orchestrator = TradingOrchestrator(
|
|
data_provider=self.data_provider,
|
|
enhanced_rl_training=True
|
|
)
|
|
logger.info("Orchestrator created")
|
|
|
|
# Initialize TradingExecutor for trade execution
|
|
try:
|
|
from core.trading_executor import TradingExecutor
|
|
self.trading_executor = TradingExecutor()
|
|
self.orchestrator.set_trading_executor(self.trading_executor)
|
|
logger.info("TradingExecutor initialized and connected to orchestrator")
|
|
|
|
# Start continuous trading loop
|
|
import asyncio
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
if loop.is_running():
|
|
# Create task in existing loop
|
|
loop.create_task(self.orchestrator.start_continuous_trading())
|
|
logger.info("Continuous trading loop started")
|
|
else:
|
|
logger.warning("No running event loop - trading loop will start when loop is available")
|
|
except RuntimeError:
|
|
logger.warning("No event loop available - trading loop will start when loop is available")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Could not initialize TradingExecutor: {e}")
|
|
self.trading_executor = None
|
|
|
|
# Update training adapter
|
|
self.training_adapter.orchestrator = self.orchestrator
|
|
|
|
# Load specific model
|
|
if model_name == 'DQN':
|
|
if not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
|
|
# Initialize RL agent
|
|
self.orchestrator._initialize_rl_agent()
|
|
self.loaded_models['DQN'] = self.orchestrator.rl_agent
|
|
|
|
elif model_name == 'CNN':
|
|
if not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
|
|
# Initialize CNN model
|
|
self.orchestrator._initialize_cnn_model()
|
|
self.loaded_models['CNN'] = self.orchestrator.cnn_model
|
|
|
|
elif model_name == 'Transformer':
|
|
if not hasattr(self.orchestrator, 'primary_transformer') or not self.orchestrator.primary_transformer:
|
|
# Initialize Transformer model
|
|
self.orchestrator._initialize_transformer_model()
|
|
self.loaded_models['Transformer'] = self.orchestrator.primary_transformer
|
|
|
|
else:
|
|
return {
|
|
'success': False,
|
|
'error': f'Unknown model: {model_name}'
|
|
}
|
|
|
|
logger.info(f"{model_name} model loaded successfully")
|
|
|
|
return {
|
|
'success': True,
|
|
'message': f'{model_name} loaded successfully',
|
|
'loaded_models': list(self.loaded_models.keys())
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error loading {model_name}: {e}")
|
|
import traceback
|
|
logger.error(f"Traceback:\n{traceback.format_exc()}")
|
|
return {
|
|
'success': False,
|
|
'error': str(e)
|
|
}
|
|
|
|
def _enable_unified_storage_async(self):
|
|
"""Enable unified storage system in background thread"""
|
|
def enable_storage():
|
|
try:
|
|
import asyncio
|
|
import threading
|
|
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
# Enable unified storage
|
|
success = loop.run_until_complete(
|
|
self.data_provider.enable_unified_storage()
|
|
)
|
|
|
|
if success:
|
|
logger.info(" ANNOTATE: Unified storage enabled for real-time data")
|
|
|
|
# Get statistics
|
|
stats = self.data_provider.get_unified_storage_stats()
|
|
if stats.get('initialized'):
|
|
logger.info(" Real-time data access: <10ms")
|
|
logger.info(" Historical data access: <100ms")
|
|
logger.info(" Annotation data: Available at any timestamp")
|
|
else:
|
|
logger.warning(" ANNOTATE: Unified storage not available, using cached data only")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"ANNOTATE: Could not enable unified storage: {e}")
|
|
logger.info("ANNOTATE: Continuing with cached data access")
|
|
|
|
# Start in background thread
|
|
import threading
|
|
storage_thread = threading.Thread(target=enable_storage, daemon=True)
|
|
storage_thread.start()
|
|
|
|
def _start_background_data_refresh(self):
|
|
"""Start background task to refresh recent data after startup - ONCE ONLY"""
|
|
def refresh_recent_data():
|
|
try:
|
|
import time
|
|
# Wait for app to fully start
|
|
time.sleep(5)
|
|
|
|
logger.info(" Starting one-time background data refresh (fetching only recent missing data)")
|
|
|
|
# Disable startup mode to fetch fresh data
|
|
if self.data_provider:
|
|
self.data_provider.disable_startup_mode()
|
|
|
|
# Use the new on-demand refresh method
|
|
logger.info("Using on-demand refresh for recent data")
|
|
self.data_provider.refresh_data_on_demand()
|
|
|
|
logger.info(" One-time background data refresh completed")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in background data refresh: {e}")
|
|
|
|
# Start refresh in background thread
|
|
import threading
|
|
refresh_thread = threading.Thread(target=refresh_recent_data, daemon=True)
|
|
refresh_thread.start()
|
|
logger.info("One-time background data refresh scheduled")
|
|
|
|
def _get_pivot_markers_for_timeframe(self, symbol: str, timeframe: str, df: pd.DataFrame) -> dict:
|
|
"""
|
|
Get pivot markers for a specific timeframe using WilliamsMarketStructure directly
|
|
Returns dict with all pivot points and identifies which are the last high/low per level
|
|
"""
|
|
try:
|
|
if WilliamsMarketStructure is None:
|
|
logger.warning("WilliamsMarketStructure not available")
|
|
return {}
|
|
|
|
if df is None or len(df) < 10:
|
|
logger.warning(f"Insufficient data for pivot calculation: {len(df) if df is not None else 0} bars")
|
|
return {}
|
|
|
|
# Convert DataFrame to numpy array format expected by Williams Market Structure
|
|
ohlcv_array = df[['open', 'high', 'low', 'close', 'volume']].copy()
|
|
|
|
# Add timestamp as first column (convert to milliseconds)
|
|
timestamps = df.index.astype(np.int64) // 10**6 # pandas index is ns -> convert to ms
|
|
ohlcv_array.insert(0, 'timestamp', timestamps)
|
|
ohlcv_array = ohlcv_array.to_numpy()
|
|
|
|
# Initialize Williams Market Structure with default distance
|
|
# We'll override it in the calculation call
|
|
williams = WilliamsMarketStructure(min_pivot_distance=1)
|
|
|
|
# Calculate recursive pivot points with min_pivot_distance=2
|
|
# This ensures 5 candles per pivot (tip + 2 prev + 2 next)
|
|
pivot_levels = williams.calculate_recursive_pivot_points(
|
|
ohlcv_array,
|
|
min_pivot_distance=2
|
|
)
|
|
|
|
if not pivot_levels:
|
|
logger.debug(f"No pivot levels found for {symbol} {timeframe}")
|
|
return {}
|
|
|
|
# Build a map of timestamp -> pivot info
|
|
# Also track last high/low per level for drawing horizontal lines
|
|
pivot_map = {}
|
|
last_pivots = {} # {level: {'high': (ts_str, idx), 'low': (ts_str, idx)}}
|
|
|
|
# For each level (1-5), collect ALL pivot points
|
|
for level_num, trend_level in pivot_levels.items():
|
|
if not hasattr(trend_level, 'pivot_points') or not trend_level.pivot_points:
|
|
continue
|
|
|
|
last_pivots[level_num] = {'high': None, 'low': None}
|
|
|
|
# Add ALL pivot points to the map
|
|
for pivot in trend_level.pivot_points:
|
|
ts_str = self._format_timestamp_utc(pivot.timestamp)
|
|
|
|
if ts_str not in pivot_map:
|
|
pivot_map[ts_str] = {'highs': [], 'lows': []}
|
|
|
|
pivot_info = {
|
|
'level': level_num,
|
|
'price': pivot.price,
|
|
'strength': pivot.strength,
|
|
'is_last': False # Will be updated below
|
|
}
|
|
|
|
if pivot.pivot_type == 'high':
|
|
pivot_map[ts_str]['highs'].append(pivot_info)
|
|
last_pivots[level_num]['high'] = (ts_str, len(pivot_map[ts_str]['highs']) - 1)
|
|
elif pivot.pivot_type == 'low':
|
|
pivot_map[ts_str]['lows'].append(pivot_info)
|
|
last_pivots[level_num]['low'] = (ts_str, len(pivot_map[ts_str]['lows']) - 1)
|
|
|
|
# Mark the last high and last low for each level
|
|
for level_num, last_info in last_pivots.items():
|
|
if last_info['high']:
|
|
ts_str, idx = last_info['high']
|
|
pivot_map[ts_str]['highs'][idx]['is_last'] = True
|
|
if last_info['low']:
|
|
ts_str, idx = last_info['low']
|
|
pivot_map[ts_str]['lows'][idx]['is_last'] = True
|
|
|
|
pivot_logger.info(f"Found {len(pivot_map)} pivot candles for {symbol} {timeframe} (from {len(df)} candles)")
|
|
return pivot_map
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting pivot markers for {timeframe}: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return {}
|
|
|
|
def _format_timestamp_utc(self, ts):
|
|
"""
|
|
Format timestamp in ISO format with UTC indicator ('Z' suffix)
|
|
This ensures frontend JavaScript parses it as UTC, not local time
|
|
|
|
Args:
|
|
ts: pandas Timestamp or datetime object
|
|
|
|
Returns:
|
|
str: ISO format timestamp with 'Z' suffix (e.g., '2025-12-08T21:00:00Z')
|
|
"""
|
|
try:
|
|
# Ensure timestamp is UTC
|
|
if hasattr(ts, 'tz'):
|
|
if ts.tz is not None:
|
|
ts_utc = ts.tz_convert('UTC') if hasattr(ts, 'tz_convert') else ts
|
|
else:
|
|
try:
|
|
ts_utc = ts.tz_localize('UTC') if hasattr(ts, 'tz_localize') else ts
|
|
except:
|
|
ts_utc = ts
|
|
else:
|
|
ts_utc = ts
|
|
|
|
# Format as ISO with 'Z' for UTC
|
|
if hasattr(ts_utc, 'strftime'):
|
|
return ts_utc.strftime('%Y-%m-%dT%H:%M:%SZ')
|
|
else:
|
|
return str(ts_utc)
|
|
except Exception as e:
|
|
logger.debug(f"Error formatting timestamp: {e}")
|
|
return str(ts)
|
|
|
|
def _format_timestamps_utc(self, timestamp_series):
|
|
"""
|
|
Format a series of timestamps in ISO format with UTC indicator
|
|
|
|
Args:
|
|
timestamp_series: pandas Index or Series with timestamps
|
|
|
|
Returns:
|
|
list: List of ISO format timestamps with 'Z' suffix
|
|
"""
|
|
return [self._format_timestamp_utc(ts) for ts in timestamp_series]
|
|
|
|
def _setup_routes(self):
|
|
"""Setup Flask routes"""
|
|
|
|
@self.server.route('/favicon.ico')
|
|
def favicon():
|
|
"""Serve favicon to prevent 404 errors"""
|
|
from flask import Response
|
|
# Return a simple 1x1 transparent pixel as favicon
|
|
favicon_data = b'\x00\x00\x01\x00\x01\x00\x10\x10\x00\x00\x01\x00\x20\x00\x68\x04\x00\x00\x16\x00\x00\x00\x28\x00\x00\x00\x10\x00\x00\x00\x20\x00\x00\x00\x01\x00\x20\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
|
|
return Response(favicon_data, mimetype='image/x-icon')
|
|
|
|
@self.server.route('/')
|
|
def index():
|
|
"""Main dashboard page - loads existing annotations"""
|
|
try:
|
|
# Get symbols and timeframes from config
|
|
symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT'])
|
|
timeframes = self.config.get('timeframes', ['1s', '1m', '1h', '1d'])
|
|
current_symbol = symbols[0] if symbols else 'ETH/USDT'
|
|
|
|
# Get annotations filtered by current symbol
|
|
annotations = self.annotation_manager.get_annotations(symbol=current_symbol)
|
|
|
|
# Convert to serializable format
|
|
annotations_data = []
|
|
for ann in annotations:
|
|
if hasattr(ann, '__dict__'):
|
|
ann_dict = ann.__dict__
|
|
else:
|
|
ann_dict = ann
|
|
|
|
# Ensure all fields are JSON serializable
|
|
annotations_data.append({
|
|
'annotation_id': ann_dict.get('annotation_id'),
|
|
'symbol': ann_dict.get('symbol'),
|
|
'timeframe': ann_dict.get('timeframe'),
|
|
'entry': ann_dict.get('entry'),
|
|
'exit': ann_dict.get('exit'),
|
|
'direction': ann_dict.get('direction'),
|
|
'profit_loss_pct': ann_dict.get('profit_loss_pct'),
|
|
'notes': ann_dict.get('notes', ''),
|
|
'created_at': ann_dict.get('created_at')
|
|
})
|
|
|
|
logger.info(f"Loading dashboard with {len(annotations_data)} annotations for {current_symbol}")
|
|
|
|
# Prepare template data
|
|
template_data = {
|
|
'current_symbol': current_symbol,
|
|
'symbols': symbols,
|
|
'timeframes': timeframes,
|
|
'annotations': annotations_data
|
|
}
|
|
|
|
return render_template('annotation_dashboard.html', **template_data)
|
|
except Exception as e:
|
|
logger.error(f"Error rendering main page: {e}")
|
|
# Fallback simple HTML page
|
|
return f"""
|
|
<html>
|
|
<head>
|
|
<title>ANNOTATE - Manual Trade Annotation UI</title>
|
|
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
|
|
</head>
|
|
<body>
|
|
<div class="container mt-5">
|
|
<h1 class="text-center">📝 ANNOTATE - Manual Trade Annotation UI</h1>
|
|
<div class="alert alert-info">
|
|
<h4>System Status</h4>
|
|
<p> Annotation Manager: Active</p>
|
|
<p> Data Provider: {'Available' if self.data_provider else 'Not Available (Standalone Mode)'}</p>
|
|
<p> Trading Orchestrator: {'Available' if self.orchestrator else 'Not Available (Standalone Mode)'}</p>
|
|
</div>
|
|
<div class="row">
|
|
<div class="col-md-6">
|
|
<h3>Available Features</h3>
|
|
<ul>
|
|
<li>Manual trade annotation</li>
|
|
<li>Test case generation</li>
|
|
<li>Annotation export</li>
|
|
<li>Real model training</li>
|
|
</ul>
|
|
</div>
|
|
<div class="col-md-6">
|
|
<h3>API Endpoints</h3>
|
|
<ul>
|
|
<li><code>POST /api/chart-data</code> - Get chart data</li>
|
|
<li><code>POST /api/save-annotation</code> - Save annotation</li>
|
|
<li><code>POST /api/delete-annotation</code> - Delete annotation</li>
|
|
<li><code>POST /api/generate-test-case</code> - Generate test case</li>
|
|
<li><code>POST /api/export-annotations</code> - Export annotations</li>
|
|
</ul>
|
|
</div>
|
|
</div>
|
|
<div class="text-center mt-4">
|
|
<a href="/dash/" class="btn btn-primary">Go to Dash Interface</a>
|
|
</div>
|
|
</div>
|
|
</body>
|
|
</html>
|
|
"""
|
|
|
|
@self.server.route('/api/recalculate-pivots', methods=['POST'])
|
|
def recalculate_pivots():
|
|
"""Recalculate pivot points for merged data using cached data from data_loader"""
|
|
try:
|
|
data = request.get_json()
|
|
symbol = data.get('symbol', 'ETH/USDT')
|
|
timeframe = data.get('timeframe')
|
|
# We don't use timestamps/ohlcv from frontend anymore, we use our own consistent data source
|
|
|
|
if not timeframe:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {'code': 'INVALID_REQUEST', 'message': 'Missing timeframe'}
|
|
})
|
|
|
|
pivot_logger.info(f"Recalculating pivots for {symbol} {timeframe} using backend data")
|
|
|
|
if not self.data_provider:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {'code': 'DATA_PROVIDER_UNAVAILABLE', 'message': 'Data provider not available'}
|
|
})
|
|
|
|
# Fetch latest data from data_provider for pivot calculation
|
|
df = self.data_provider.get_data_for_annotation(
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
limit=2500, # Enough for context
|
|
direction='latest'
|
|
)
|
|
|
|
if df is None or df.empty:
|
|
logger.warning(f"No data found for {symbol} {timeframe} to recalculate pivots")
|
|
return jsonify({
|
|
'success': True,
|
|
'pivot_markers': {}
|
|
})
|
|
|
|
# Recalculate pivot markers
|
|
pivot_markers = self._get_pivot_markers_for_timeframe(symbol, timeframe, df)
|
|
|
|
pivot_logger.info(f"Recalculated {len(pivot_markers)} pivot candles")
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'pivot_markers': pivot_markers
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error recalculating pivots: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {'code': 'RECALC_ERROR', 'message': str(e)}
|
|
})
|
|
|
|
@self.server.route('/api/chart-data', methods=['GET'])
|
|
def get_chart_data_get():
|
|
"""GET endpoint for chart data (used by initial chart load)"""
|
|
try:
|
|
symbol = request.args.get('symbol', 'ETH/USDT')
|
|
timeframe = request.args.get('timeframe', '1m')
|
|
limit = int(request.args.get('limit', 2500))
|
|
|
|
webui_logger.info(f"Chart data GET request: {symbol} {timeframe} limit={limit}")
|
|
|
|
if not self.data_provider:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {'code': 'DATA_PROVIDER_UNAVAILABLE', 'message': 'Data provider not available'}
|
|
})
|
|
|
|
# Fetch data using main data provider
|
|
df = self.data_provider.get_data_for_annotation(
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
limit=limit,
|
|
direction='latest'
|
|
)
|
|
|
|
if df is not None and not df.empty:
|
|
webui_logger.info(f" {timeframe}: {len(df)} candles")
|
|
|
|
# Get pivot points
|
|
pivot_markers = {}
|
|
if len(df) >= 50:
|
|
pivot_markers = self._get_pivot_markers_for_timeframe(symbol, timeframe, df)
|
|
|
|
chart_data = {
|
|
timeframe: {
|
|
'timestamps': self._format_timestamps_utc(df.index),
|
|
'open': df['open'].tolist(),
|
|
'high': df['high'].tolist(),
|
|
'low': df['low'].tolist(),
|
|
'close': df['close'].tolist(),
|
|
'volume': df['volume'].tolist(),
|
|
'pivot_markers': pivot_markers
|
|
}
|
|
}
|
|
|
|
return jsonify({'success': True, 'data': chart_data})
|
|
else:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {'code': 'NO_DATA', 'message': f'No data available for {symbol} {timeframe}'}
|
|
})
|
|
|
|
except Exception as e:
|
|
webui_logger.error(f"Error in chart-data GET: {e}")
|
|
return jsonify({'success': False, 'error': {'code': 'ERROR', 'message': str(e)}})
|
|
|
|
@self.server.route('/api/chart-data', methods=['POST'])
|
|
def get_chart_data():
|
|
"""Get chart data for specified symbol and timeframes with infinite scroll support"""
|
|
try:
|
|
data = request.get_json()
|
|
symbol = data.get('symbol', 'ETH/USDT')
|
|
timeframes = data.get('timeframes', ['1s', '1m', '1h', '1d'])
|
|
start_time_str = data.get('start_time')
|
|
end_time_str = data.get('end_time')
|
|
limit = data.get('limit', 2500) # Default 2500 candles for training
|
|
direction = data.get('direction', 'latest') # 'latest', 'before', or 'after'
|
|
|
|
webui_logger.info(f"Chart data request: {symbol} {timeframes} direction={direction} limit={limit}")
|
|
if start_time_str:
|
|
webui_logger.info(f" start_time: {start_time_str}")
|
|
if end_time_str:
|
|
webui_logger.info(f" end_time: {end_time_str}")
|
|
|
|
if not self.data_provider:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'DATA_PROVIDER_UNAVAILABLE',
|
|
'message': 'Data provider not available'
|
|
}
|
|
})
|
|
|
|
# Parse time strings if provided
|
|
start_time = datetime.fromisoformat(start_time_str.replace('Z', '+00:00')) if start_time_str else None
|
|
end_time = datetime.fromisoformat(end_time_str.replace('Z', '+00:00')) if end_time_str else None
|
|
|
|
# Fetch data for each timeframe using data provider
|
|
# This will automatically:
|
|
# 1. Check DuckDB first
|
|
# 2. Fetch from API if not in cache
|
|
# 3. Store in DuckDB for future use
|
|
chart_data = {}
|
|
for timeframe in timeframes:
|
|
df = self.data_provider.get_data_for_annotation(
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
limit=limit,
|
|
direction=direction
|
|
)
|
|
|
|
if df is not None and not df.empty:
|
|
webui_logger.info(f" {timeframe}: {len(df)} candles ({df.index[0]} to {df.index[-1]})")
|
|
|
|
# Get pivot points for this timeframe (only if we have enough context)
|
|
pivot_markers = {}
|
|
if len(df) >= 50:
|
|
pivot_markers = self._get_pivot_markers_for_timeframe(symbol, timeframe, df)
|
|
|
|
# Convert to format suitable for Plotly
|
|
chart_data[timeframe] = {
|
|
'timestamps': self._format_timestamps_utc(df.index),
|
|
'open': df['open'].tolist(),
|
|
'high': df['high'].tolist(),
|
|
'low': df['low'].tolist(),
|
|
'close': df['close'].tolist(),
|
|
'volume': df['volume'].tolist(),
|
|
'pivot_markers': pivot_markers # Optional: only present if pivots exist
|
|
}
|
|
else:
|
|
logger.warning(f" {timeframe}: No data returned")
|
|
|
|
# Get pivot bounds for the symbol
|
|
pivot_bounds = None
|
|
if self.data_provider:
|
|
try:
|
|
pivot_bounds = self.data_provider.get_pivot_bounds(symbol)
|
|
if pivot_bounds:
|
|
logger.info(f"Found pivot bounds for {symbol}: {len(pivot_bounds.pivot_support_levels)} support, {len(pivot_bounds.pivot_resistance_levels)} resistance")
|
|
except Exception as e:
|
|
logger.error(f"Error getting pivot bounds: {e}")
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'chart_data': chart_data,
|
|
'pivot_bounds': {
|
|
'support_levels': pivot_bounds.pivot_support_levels if pivot_bounds else [],
|
|
'resistance_levels': pivot_bounds.pivot_resistance_levels if pivot_bounds else [],
|
|
'price_range': {
|
|
'min': pivot_bounds.price_min if pivot_bounds else None,
|
|
'max': pivot_bounds.price_max if pivot_bounds else None
|
|
},
|
|
'volume_range': {
|
|
'min': pivot_bounds.volume_min if pivot_bounds else None,
|
|
'max': pivot_bounds.volume_max if pivot_bounds else None
|
|
},
|
|
'timeframe': '1m', # Pivot bounds are calculated from 1m data
|
|
'period': '30 days', # Monthly data
|
|
'total_levels': len(pivot_bounds.pivot_support_levels) + len(pivot_bounds.pivot_resistance_levels) if pivot_bounds else 0
|
|
} if pivot_bounds else None
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching chart data: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'CHART_DATA_ERROR',
|
|
'message': str(e)
|
|
}
|
|
})
|
|
|
|
@self.server.route('/api/save-annotation', methods=['POST'])
|
|
def save_annotation():
|
|
"""Save a new annotation with full market context"""
|
|
try:
|
|
data = request.get_json()
|
|
|
|
# Capture market state at entry and exit times using data provider
|
|
entry_market_state = {}
|
|
exit_market_state = {}
|
|
|
|
if self.data_provider:
|
|
try:
|
|
# Parse timestamps
|
|
entry_time = datetime.fromisoformat(data['entry']['timestamp'].replace('Z', '+00:00'))
|
|
exit_time = datetime.fromisoformat(data['exit']['timestamp'].replace('Z', '+00:00'))
|
|
|
|
# Use the new data provider method to get market state at entry time
|
|
entry_market_state = self.data_provider.get_market_state_at_time(
|
|
symbol=data['symbol'],
|
|
timestamp=entry_time,
|
|
context_window_minutes=5
|
|
)
|
|
|
|
# Use the new data provider method to get market state at exit time
|
|
exit_market_state = self.data_provider.get_market_state_at_time(
|
|
symbol=data['symbol'],
|
|
timestamp=exit_time,
|
|
context_window_minutes=5
|
|
)
|
|
|
|
logger.info(f"Captured market state: {len(entry_market_state)} timeframes at entry, {len(exit_market_state)} at exit")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error capturing market state: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
# Create annotation with market context
|
|
annotation = self.annotation_manager.create_annotation(
|
|
entry_point=data['entry'],
|
|
exit_point=data['exit'],
|
|
symbol=data['symbol'],
|
|
timeframe=data['timeframe'],
|
|
entry_market_state=entry_market_state,
|
|
exit_market_state=exit_market_state
|
|
)
|
|
|
|
# Collect market snapshots for SQLite storage
|
|
market_snapshots = {}
|
|
if self.data_provider:
|
|
try:
|
|
# Get OHLCV data for all timeframes around the annotation time
|
|
entry_time = datetime.fromisoformat(data['entry']['timestamp'].replace('Z', '+00:00'))
|
|
exit_time = datetime.fromisoformat(data['exit']['timestamp'].replace('Z', '+00:00'))
|
|
|
|
# Get data from 5 minutes before entry to 5 minutes after exit
|
|
start_time = entry_time - timedelta(minutes=5)
|
|
end_time = exit_time + timedelta(minutes=5)
|
|
|
|
for timeframe in ['1s', '1m', '1h', '1d']:
|
|
df = self.data_provider.get_data_for_annotation(
|
|
symbol=data['symbol'],
|
|
timeframe=timeframe,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
limit=1500
|
|
)
|
|
if df is not None and not df.empty:
|
|
market_snapshots[timeframe] = df
|
|
|
|
logger.info(f"Collected {len(market_snapshots)} timeframes for annotation storage")
|
|
except Exception as e:
|
|
logger.error(f"Error collecting market snapshots: {e}")
|
|
|
|
# Save annotation with market snapshots
|
|
self.annotation_manager.save_annotation(
|
|
annotation=annotation,
|
|
market_snapshots=market_snapshots
|
|
)
|
|
|
|
# Automatically generate test case with ±5min data
|
|
try:
|
|
test_case = self.annotation_manager.generate_test_case(
|
|
annotation,
|
|
data_provider=self.data_provider,
|
|
auto_save=True
|
|
)
|
|
|
|
# Log test case details
|
|
market_state = test_case.get('market_state', {})
|
|
timeframes_with_data = [k for k in market_state.keys() if k.startswith('ohlcv_')]
|
|
logger.info(f"Auto-generated test case: {test_case['test_case_id']}")
|
|
logger.info(f" Timeframes: {timeframes_with_data}")
|
|
for tf_key in timeframes_with_data:
|
|
candle_count = len(market_state[tf_key].get('timestamps', []))
|
|
logger.info(f" {tf_key}: {candle_count} candles")
|
|
|
|
if 'training_labels' in market_state:
|
|
logger.info(f" Training labels: {len(market_state['training_labels'].get('labels_1m', []))} labels")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to auto-generate test case: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'annotation': annotation.__dict__ if hasattr(annotation, '__dict__') else annotation
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving annotation: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'SAVE_ANNOTATION_ERROR',
|
|
'message': str(e)
|
|
}
|
|
})
|
|
|
|
@self.server.route('/api/delete-annotation', methods=['POST'])
|
|
def delete_annotation():
|
|
"""Delete an annotation"""
|
|
try:
|
|
data = request.get_json()
|
|
annotation_id = data['annotation_id']
|
|
|
|
# Delete annotation and check if it was found
|
|
deleted = self.annotation_manager.delete_annotation(annotation_id)
|
|
|
|
if deleted:
|
|
return jsonify({
|
|
'success': True,
|
|
'message': 'Annotation deleted successfully'
|
|
})
|
|
else:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'ANNOTATION_NOT_FOUND',
|
|
'message': f'Annotation {annotation_id} not found'
|
|
}
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting annotation: {e}", exc_info=True)
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'DELETE_ANNOTATION_ERROR',
|
|
'message': str(e)
|
|
}
|
|
})
|
|
|
|
@self.server.route('/api/clear-all-annotations', methods=['POST'])
|
|
def clear_all_annotations():
|
|
"""Clear all annotations"""
|
|
try:
|
|
data = request.get_json() or {}
|
|
symbol = data.get('symbol', None)
|
|
|
|
# Use the efficient clear_all_annotations method
|
|
deleted_count = self.annotation_manager.clear_all_annotations(symbol=symbol)
|
|
|
|
if deleted_count == 0:
|
|
return jsonify({
|
|
'success': True,
|
|
'deleted_count': 0,
|
|
'message': 'No annotations to clear'
|
|
})
|
|
|
|
logger.info(f"Cleared {deleted_count} annotations" + (f" for symbol {symbol}" if symbol else ""))
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'deleted_count': deleted_count,
|
|
'message': f'Cleared {deleted_count} annotations'
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error clearing all annotations: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'CLEAR_ALL_ANNOTATIONS_ERROR',
|
|
'message': str(e)
|
|
}
|
|
})
|
|
|
|
@self.server.route('/api/refresh-data', methods=['POST'])
|
|
def refresh_data():
|
|
"""Refresh chart data from data provider"""
|
|
try:
|
|
data = request.get_json()
|
|
symbol = data.get('symbol', 'ETH/USDT')
|
|
timeframes = data.get('timeframes', ['1s', '1m', '1h', '1d'])
|
|
|
|
logger.info(f"Refreshing data for {symbol} with timeframes: {timeframes}")
|
|
|
|
# Force refresh data from data provider
|
|
chart_data = {}
|
|
|
|
if self.data_provider:
|
|
for timeframe in timeframes:
|
|
try:
|
|
# Force refresh by setting refresh=True
|
|
df = self.data_provider.get_historical_data(
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
limit=1000,
|
|
refresh=True
|
|
)
|
|
|
|
if df is not None and not df.empty:
|
|
# Get pivot markers for this timeframe
|
|
pivot_markers = self._get_pivot_markers_for_timeframe(symbol, timeframe, df)
|
|
|
|
# CRITICAL FIX: Format timestamps in ISO format with UTC indicator
|
|
# This ensures frontend parses them as UTC, not local time
|
|
timestamps = []
|
|
for ts in df.index:
|
|
# Ensure timestamp is UTC
|
|
if hasattr(ts, 'tz'):
|
|
if ts.tz is not None:
|
|
ts_utc = ts.tz_convert('UTC') if hasattr(ts, 'tz_convert') else ts
|
|
else:
|
|
try:
|
|
ts_utc = ts.tz_localize('UTC') if hasattr(ts, 'tz_localize') else ts
|
|
except:
|
|
ts_utc = ts
|
|
else:
|
|
ts_utc = ts
|
|
|
|
# Format as ISO with 'Z' for UTC: 'YYYY-MM-DDTHH:MM:SSZ'
|
|
# Plotly handles ISO format correctly
|
|
if hasattr(ts_utc, 'strftime'):
|
|
timestamps.append(ts_utc.strftime('%Y-%m-%dT%H:%M:%SZ'))
|
|
else:
|
|
timestamps.append(str(ts_utc))
|
|
|
|
chart_data[timeframe] = {
|
|
'timestamps': timestamps,
|
|
'open': df['open'].tolist(),
|
|
'high': df['high'].tolist(),
|
|
'low': df['low'].tolist(),
|
|
'close': df['close'].tolist(),
|
|
'volume': df['volume'].tolist(),
|
|
'pivot_markers': pivot_markers # Optional: only present if pivots exist
|
|
}
|
|
logger.info(f"Refreshed {timeframe}: {len(df)} candles")
|
|
else:
|
|
logger.warning(f"No data available for {timeframe}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error refreshing {timeframe} data: {e}")
|
|
|
|
# Get pivot bounds for the symbol
|
|
pivot_bounds = None
|
|
if self.data_provider:
|
|
try:
|
|
pivot_bounds = self.data_provider.get_pivot_bounds(symbol)
|
|
if pivot_bounds:
|
|
logger.info(f"Found pivot bounds for {symbol}: {len(pivot_bounds.pivot_support_levels)} support, {len(pivot_bounds.pivot_resistance_levels)} resistance")
|
|
except Exception as e:
|
|
logger.error(f"Error getting pivot bounds: {e}")
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'chart_data': chart_data,
|
|
'pivot_bounds': {
|
|
'support_levels': pivot_bounds.pivot_support_levels if pivot_bounds else [],
|
|
'resistance_levels': pivot_bounds.pivot_resistance_levels if pivot_bounds else [],
|
|
'price_range': {
|
|
'min': pivot_bounds.price_min if pivot_bounds else None,
|
|
'max': pivot_bounds.price_max if pivot_bounds else None
|
|
},
|
|
'volume_range': {
|
|
'min': pivot_bounds.volume_min if pivot_bounds else None,
|
|
'max': pivot_bounds.volume_max if pivot_bounds else None
|
|
},
|
|
'timeframe': '1m', # Pivot bounds are calculated from 1m data
|
|
'period': '30 days', # Monthly data
|
|
'total_levels': len(pivot_bounds.pivot_support_levels) + len(pivot_bounds.pivot_resistance_levels) if pivot_bounds else 0
|
|
} if pivot_bounds else None,
|
|
'message': f'Refreshed data for {symbol}'
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error refreshing data: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'REFRESH_DATA_ERROR',
|
|
'message': str(e)
|
|
}
|
|
})
|
|
|
|
@self.server.route('/api/generate-test-case', methods=['POST'])
|
|
def generate_test_case():
|
|
"""Generate test case from annotation"""
|
|
try:
|
|
data = request.get_json()
|
|
annotation_id = data['annotation_id']
|
|
|
|
# Get annotation
|
|
annotations = self.annotation_manager.get_annotations()
|
|
annotation = next((a for a in annotations
|
|
if (a.annotation_id if hasattr(a, 'annotation_id')
|
|
else a.get('annotation_id')) == annotation_id), None)
|
|
|
|
if not annotation:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'ANNOTATION_NOT_FOUND',
|
|
'message': 'Annotation not found'
|
|
}
|
|
})
|
|
|
|
# Generate test case with market context
|
|
test_case = self.annotation_manager.generate_test_case(
|
|
annotation,
|
|
data_provider=self.data_provider
|
|
)
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'test_case': test_case
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating test case: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'GENERATE_TESTCASE_ERROR',
|
|
'message': str(e)
|
|
}
|
|
})
|
|
|
|
@self.server.route('/api/get-annotations', methods=['POST'])
|
|
def get_annotations_api():
|
|
"""Get annotations filtered by symbol"""
|
|
try:
|
|
data = request.get_json()
|
|
symbol = data.get('symbol', 'ETH/USDT')
|
|
|
|
# Get annotations for this symbol
|
|
annotations = self.annotation_manager.get_annotations(symbol=symbol)
|
|
|
|
# Convert to serializable format
|
|
annotations_data = []
|
|
for ann in annotations:
|
|
if hasattr(ann, '__dict__'):
|
|
ann_dict = ann.__dict__
|
|
else:
|
|
ann_dict = ann
|
|
|
|
annotations_data.append({
|
|
'annotation_id': ann_dict.get('annotation_id'),
|
|
'symbol': ann_dict.get('symbol'),
|
|
'timeframe': ann_dict.get('timeframe'),
|
|
'entry': ann_dict.get('entry'),
|
|
'exit': ann_dict.get('exit'),
|
|
'direction': ann_dict.get('direction'),
|
|
'profit_loss_pct': ann_dict.get('profit_loss_pct'),
|
|
'notes': ann_dict.get('notes', ''),
|
|
'created_at': ann_dict.get('created_at')
|
|
})
|
|
|
|
logger.info(f"Returning {len(annotations_data)} annotations for {symbol}")
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'annotations': annotations_data,
|
|
'symbol': symbol,
|
|
'count': len(annotations_data)
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting annotations: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
@self.server.route('/api/export-annotations', methods=['POST'])
|
|
def export_annotations():
|
|
"""Export annotations to file"""
|
|
try:
|
|
data = request.get_json()
|
|
symbol = data.get('symbol')
|
|
format_type = data.get('format', 'json')
|
|
|
|
# Get annotations
|
|
annotations = self.annotation_manager.get_annotations(symbol=symbol)
|
|
|
|
# Export to file
|
|
output_path = self.annotation_manager.export_annotations(
|
|
annotations=annotations,
|
|
format_type=format_type
|
|
)
|
|
|
|
return send_file(output_path, as_attachment=True)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error exporting annotations: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'EXPORT_ERROR',
|
|
'message': str(e)
|
|
}
|
|
})
|
|
|
|
@self.server.route('/api/train-model', methods=['POST'])
|
|
def train_model():
|
|
"""Start model training with annotations"""
|
|
try:
|
|
if not self.training_adapter:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'TRAINING_UNAVAILABLE',
|
|
'message': 'Real training adapter not available'
|
|
}
|
|
})
|
|
|
|
data = request.get_json()
|
|
model_name = data['model_name']
|
|
annotation_ids = data.get('annotation_ids', [])
|
|
|
|
# CRITICAL: Get current symbol to filter annotations
|
|
current_symbol = data.get('symbol', 'ETH/USDT')
|
|
|
|
# Get primary timeframe for display (optional)
|
|
timeframe = data.get('timeframe', '1m')
|
|
|
|
# If no specific annotations provided, use all for current symbol
|
|
if not annotation_ids:
|
|
annotations = self.annotation_manager.get_annotations(symbol=current_symbol)
|
|
annotation_ids = [
|
|
a.annotation_id if hasattr(a, 'annotation_id') else a.get('annotation_id')
|
|
for a in annotations
|
|
]
|
|
logger.info(f"Using all {len(annotation_ids)} annotations for {current_symbol}")
|
|
|
|
# Load test cases from disk (they were auto-generated when annotations were saved)
|
|
# Filter by current symbol to avoid cross-symbol training
|
|
all_test_cases = self.annotation_manager.get_all_test_cases(symbol=current_symbol)
|
|
|
|
# Filter to selected annotations
|
|
test_cases = [
|
|
tc for tc in all_test_cases
|
|
if tc['test_case_id'].replace('annotation_', '') in annotation_ids
|
|
]
|
|
|
|
if not test_cases:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'NO_TEST_CASES',
|
|
'message': f'No test cases found for {len(annotation_ids)} annotations'
|
|
}
|
|
})
|
|
|
|
logger.info(f"Starting REAL training with {len(test_cases)} test cases ({len(annotation_ids)} annotations) for model {model_name} on {timeframe}")
|
|
|
|
# Start REAL training (NO SIMULATION!)
|
|
training_id = self.training_adapter.start_training(
|
|
model_name=model_name,
|
|
test_cases=test_cases,
|
|
annotation_count=len(annotation_ids),
|
|
timeframe=timeframe
|
|
)
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'training_id': training_id,
|
|
'test_cases_count': len(test_cases)
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error starting training: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'TRAINING_ERROR',
|
|
'message': str(e)
|
|
}
|
|
})
|
|
|
|
@self.server.route('/api/training-progress', methods=['POST'])
|
|
def get_training_progress():
|
|
"""Get training progress"""
|
|
try:
|
|
if not self.training_adapter:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'TRAINING_UNAVAILABLE',
|
|
'message': 'Real training adapter not available'
|
|
}
|
|
})
|
|
|
|
data = request.get_json()
|
|
training_id = data['training_id']
|
|
|
|
progress = self.training_adapter.get_training_progress(training_id)
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'progress': progress
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting training progress: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'PROGRESS_ERROR',
|
|
'message': str(e)
|
|
}
|
|
})
|
|
|
|
# Backtest API Endpoints
|
|
@self.server.route('/api/backtest', methods=['POST'])
|
|
def start_backtest():
|
|
"""Start backtest on visible chart data"""
|
|
try:
|
|
data = request.get_json()
|
|
model_name = data['model_name']
|
|
symbol = data['symbol']
|
|
timeframe = data['timeframe']
|
|
start_time = data.get('start_time')
|
|
end_time = data.get('end_time')
|
|
|
|
# Get the loaded model
|
|
if model_name not in self.loaded_models:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': f'Model {model_name} not loaded. Please load it first.'
|
|
})
|
|
|
|
model = self.loaded_models[model_name]
|
|
|
|
# Generate backtest ID
|
|
backtest_id = str(uuid.uuid4())
|
|
|
|
# Start backtest in background
|
|
self.backtest_runner.start_backtest(
|
|
backtest_id=backtest_id,
|
|
model=model,
|
|
data_provider=self.data_provider,
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
orchestrator=self.orchestrator,
|
|
start_time=start_time,
|
|
end_time=end_time
|
|
)
|
|
|
|
# Get initial state
|
|
progress = self.backtest_runner.get_progress(backtest_id)
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'backtest_id': backtest_id,
|
|
'total_candles': progress.get('total_candles', 0)
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error starting backtest: {e}", exc_info=True)
|
|
return jsonify({
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
@self.server.route('/api/backtest/progress/<backtest_id>', methods=['GET'])
|
|
def get_backtest_progress(backtest_id):
|
|
"""Get backtest progress"""
|
|
try:
|
|
progress = self.backtest_runner.get_progress(backtest_id)
|
|
return jsonify(progress)
|
|
except Exception as e:
|
|
logger.error(f"Error getting backtest progress: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
@self.server.route('/api/backtest/stop', methods=['POST'])
|
|
def stop_backtest():
|
|
"""Stop running backtest"""
|
|
try:
|
|
data = request.get_json()
|
|
backtest_id = data['backtest_id']
|
|
|
|
self.backtest_runner.stop_backtest(backtest_id)
|
|
|
|
return jsonify({
|
|
'success': True
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Error stopping backtest: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
@self.server.route('/api/active-training', methods=['GET'])
|
|
def get_active_training():
|
|
"""
|
|
Get currently active training session (if any)
|
|
Allows UI to resume tracking after page reload or across multiple clients
|
|
"""
|
|
try:
|
|
if not self.training_adapter:
|
|
return jsonify({
|
|
'success': False,
|
|
'active': False,
|
|
'error': {
|
|
'code': 'TRAINING_UNAVAILABLE',
|
|
'message': 'Real training adapter not available'
|
|
}
|
|
})
|
|
|
|
active_session = self.training_adapter.get_active_training_session()
|
|
|
|
if active_session:
|
|
return jsonify({
|
|
'success': True,
|
|
'active': True,
|
|
'session': active_session
|
|
})
|
|
else:
|
|
return jsonify({
|
|
'success': True,
|
|
'active': False
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting active training: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'active': False,
|
|
'error': {
|
|
'code': 'ACTIVE_TRAINING_ERROR',
|
|
'message': str(e)
|
|
}
|
|
})
|
|
|
|
# Live Training API Endpoints
|
|
@self.server.route('/api/live-training/start', methods=['POST'])
|
|
def start_live_training():
|
|
"""Start live inference and training mode"""
|
|
try:
|
|
if not self.orchestrator:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': 'Orchestrator not available'
|
|
}), 500
|
|
|
|
if self.orchestrator.start_live_training():
|
|
return jsonify({
|
|
'success': True,
|
|
'status': 'started',
|
|
'message': 'Live training mode started'
|
|
})
|
|
else:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': 'Failed to start live training'
|
|
}), 500
|
|
except Exception as e:
|
|
logger.error(f"Error starting live training: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': str(e)
|
|
}), 500
|
|
|
|
@self.server.route('/api/live-training/stop', methods=['POST'])
|
|
def stop_live_training():
|
|
"""Stop live inference and training mode"""
|
|
try:
|
|
if not self.orchestrator:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': 'Orchestrator not available'
|
|
}), 500
|
|
|
|
if self.orchestrator.stop_live_training():
|
|
return jsonify({
|
|
'success': True,
|
|
'status': 'stopped',
|
|
'message': 'Live training mode stopped'
|
|
})
|
|
else:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': 'Failed to stop live training'
|
|
}), 500
|
|
except Exception as e:
|
|
logger.error(f"Error stopping live training: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': str(e)
|
|
}), 500
|
|
|
|
@self.server.route('/api/live-training/status', methods=['GET'])
|
|
def get_live_training_status():
|
|
"""Get live training status and statistics"""
|
|
try:
|
|
if not self.orchestrator:
|
|
return jsonify({
|
|
'success': False,
|
|
'active': False,
|
|
'error': 'Orchestrator not available'
|
|
})
|
|
|
|
is_active = self.orchestrator.is_live_training_active()
|
|
stats = self.orchestrator.get_live_training_stats() if is_active else {}
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'active': is_active,
|
|
'stats': stats
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Error getting live training status: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'active': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
@self.server.route('/api/available-models', methods=['GET'])
|
|
def get_available_models():
|
|
"""Get list of available models with their load status"""
|
|
try:
|
|
# Ensure self.available_models is a list
|
|
if not isinstance(self.available_models, list):
|
|
logger.warning(f"self.available_models is not a list: {type(self.available_models)}. Resetting to default.")
|
|
self.available_models = ['Transformer', 'COB_RL', 'CNN', 'DQN']
|
|
|
|
# Ensure self.loaded_models exists (it's a dict)
|
|
if not hasattr(self, 'loaded_models'):
|
|
self.loaded_models = {}
|
|
|
|
# Build model state dict with checkpoint info
|
|
logger.info(f"Building model states for {len(self.available_models)} models: {self.available_models}")
|
|
logger.info(f"Currently loaded models: {list(self.loaded_models.keys())}")
|
|
model_states = []
|
|
for model_name in self.available_models:
|
|
# Check if model is in loaded_models dict
|
|
is_loaded = model_name in self.loaded_models and self.loaded_models[model_name] is not None
|
|
|
|
# Get checkpoint info (even for unloaded models)
|
|
checkpoint_info = None
|
|
|
|
# If loaded, get from orchestrator
|
|
if is_loaded and self.orchestrator:
|
|
checkpoint_attr = f"{model_name.lower()}_checkpoint_info"
|
|
|
|
if hasattr(self.orchestrator, checkpoint_attr):
|
|
cp_info = getattr(self.orchestrator, checkpoint_attr)
|
|
if cp_info and cp_info.get('status') == 'loaded':
|
|
checkpoint_info = {
|
|
'filename': cp_info.get('filename', 'unknown'),
|
|
'epoch': cp_info.get('epoch', 0),
|
|
'loss': cp_info.get('loss', 0.0),
|
|
'accuracy': cp_info.get('accuracy', 0.0),
|
|
'loaded_at': cp_info.get('loaded_at', ''),
|
|
'source': 'loaded'
|
|
}
|
|
|
|
# If not loaded, try to read best checkpoint from disk (filename parsing only)
|
|
if not checkpoint_info:
|
|
try:
|
|
cp_info = self._get_best_checkpoint_info(model_name)
|
|
if cp_info:
|
|
checkpoint_info = cp_info
|
|
checkpoint_info['source'] = 'disk'
|
|
except Exception as e:
|
|
logger.warning(f"Could not read checkpoint for {model_name}: {e}")
|
|
# Continue without checkpoint info - not critical
|
|
|
|
model_states.append({
|
|
'name': model_name,
|
|
'loaded': is_loaded,
|
|
'can_train': is_loaded,
|
|
'can_infer': is_loaded,
|
|
'checkpoint': checkpoint_info # Checkpoint metadata (loaded or from disk)
|
|
})
|
|
|
|
logger.info(f"Returning {len(model_states)} model states")
|
|
return jsonify({
|
|
'success': True,
|
|
'models': model_states,
|
|
'loaded_count': len(self.loaded_models),
|
|
'available_count': len(self.available_models)
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting available models: {e}")
|
|
import traceback
|
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
|
# Return a fallback list so the UI doesn't hang
|
|
return jsonify({
|
|
'success': True,
|
|
'models': [
|
|
{'name': 'Transformer', 'loaded': False, 'can_train': False, 'can_infer': False},
|
|
{'name': 'COB_RL', 'loaded': False, 'can_train': False, 'can_infer': False}
|
|
],
|
|
'loaded_count': 0,
|
|
'available_count': 2,
|
|
'error': str(e)
|
|
})
|
|
|
|
@self.server.route('/api/load-model', methods=['POST'])
|
|
def load_model():
|
|
"""Load a specific model on demand"""
|
|
try:
|
|
data = request.get_json()
|
|
model_name = data.get('model_name')
|
|
|
|
if not model_name:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': 'model_name is required'
|
|
})
|
|
|
|
# Load the model
|
|
result = self._load_model_lazy(model_name)
|
|
|
|
return jsonify(result)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in load_model endpoint: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
@self.server.route('/api/realtime-inference/start', methods=['POST'])
|
|
def start_realtime_inference():
|
|
"""Start real-time inference mode with configurable training strategy"""
|
|
try:
|
|
data = request.get_json()
|
|
model_name = data.get('model_name')
|
|
symbol = data.get('symbol', 'ETH/USDT')
|
|
timeframe = data.get('timeframe', '1m')
|
|
|
|
# New unified training_mode parameter
|
|
training_mode = data.get('training_mode', 'none') # 'none', 'every_candle', 'pivots_only', 'manual'
|
|
|
|
# Backward compatibility with old parameters
|
|
if 'enable_live_training' in data or 'train_every_candle' in data:
|
|
enable_live_training = data.get('enable_live_training', False)
|
|
train_every_candle = data.get('train_every_candle', False)
|
|
training_mode = 'every_candle' if train_every_candle else ('pivots_only' if enable_live_training else 'none')
|
|
|
|
if not self.training_adapter:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'TRAINING_UNAVAILABLE',
|
|
'message': 'Real training adapter not available'
|
|
}
|
|
})
|
|
|
|
# Set training mode on strategy manager
|
|
self.training_strategy.mode = training_mode
|
|
logger.info(f"Training strategy mode set to: {training_mode}")
|
|
|
|
# Start real-time inference - pass strategy manager for training decisions
|
|
inference_id = self.training_adapter.start_realtime_inference(
|
|
model_name=model_name,
|
|
symbol=symbol,
|
|
data_provider=self.data_provider,
|
|
enable_live_training=(training_mode != 'none'),
|
|
train_every_candle=(training_mode == 'every_candle'),
|
|
timeframe=timeframe,
|
|
training_strategy=self.training_strategy # Pass strategy manager
|
|
)
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'inference_id': inference_id,
|
|
'training_mode': training_mode,
|
|
'timeframe': timeframe
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error starting real-time inference: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'INFERENCE_START_ERROR',
|
|
'message': str(e)
|
|
}
|
|
})
|
|
|
|
@self.server.route('/api/realtime-inference/stop', methods=['POST'])
|
|
def stop_realtime_inference():
|
|
"""Stop real-time inference mode"""
|
|
try:
|
|
data = request.get_json()
|
|
inference_id = data.get('inference_id')
|
|
|
|
if not self.training_adapter:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'TRAINING_UNAVAILABLE',
|
|
'message': 'Real training adapter not available'
|
|
}
|
|
})
|
|
|
|
self.training_adapter.stop_realtime_inference(inference_id)
|
|
|
|
return jsonify({
|
|
'success': True
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error stopping real-time inference: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'INFERENCE_STOP_ERROR',
|
|
'message': str(e)
|
|
}
|
|
})
|
|
|
|
@self.server.route('/api/live-updates', methods=['GET', 'POST'])
|
|
def get_live_updates():
|
|
"""Get live chart and prediction updates (polling endpoint)"""
|
|
try:
|
|
# Support both GET and POST
|
|
if request.method == 'POST':
|
|
data = request.get_json() or {}
|
|
else:
|
|
data = {}
|
|
symbol = data.get('symbol', request.args.get('symbol', 'ETH/USDT'))
|
|
timeframe = data.get('timeframe', request.args.get('timeframe', '1m'))
|
|
|
|
response = {
|
|
'success': True,
|
|
'chart_update': None,
|
|
'prediction': None
|
|
}
|
|
|
|
# Get latest candle for the requested timeframe using data_provider
|
|
if self.data_provider:
|
|
try:
|
|
# Get latest candle from data_provider (includes real-time data)
|
|
df = self.data_provider.get_data_for_annotation(symbol, timeframe, limit=2, direction='latest')
|
|
if df is not None and not df.empty:
|
|
latest_candle = df.iloc[-1]
|
|
|
|
# Format timestamp as ISO string (ensure UTC format for frontend)
|
|
timestamp = latest_candle.name
|
|
if hasattr(timestamp, 'isoformat'):
|
|
# If timezone-aware, convert to UTC ISO string
|
|
if timestamp.tzinfo is not None:
|
|
timestamp_str = timestamp.astimezone(timezone.utc).isoformat()
|
|
else:
|
|
# Assume UTC if no timezone info
|
|
timestamp_str = timestamp.isoformat() + 'Z'
|
|
else:
|
|
timestamp_str = str(timestamp)
|
|
|
|
# Determine if candle is confirmed (we have 2 candles, so previous is confirmed)
|
|
is_confirmed = len(df) >= 2
|
|
|
|
response['chart_update'] = {
|
|
'symbol': symbol,
|
|
'timeframe': timeframe,
|
|
'candle': {
|
|
'timestamp': timestamp_str,
|
|
'open': float(latest_candle['open']),
|
|
'high': float(latest_candle['high']),
|
|
'low': float(latest_candle['low']),
|
|
'close': float(latest_candle['close']),
|
|
'volume': float(latest_candle['volume'])
|
|
},
|
|
'is_confirmed': is_confirmed
|
|
}
|
|
except Exception as e:
|
|
logger.debug(f"Error getting latest candle from data_provider: {e}", exc_info=True)
|
|
else:
|
|
logger.debug("Data provider not available for live updates")
|
|
|
|
# Get latest model predictions
|
|
if self.orchestrator:
|
|
try:
|
|
# Get latest predictions from orchestrator
|
|
predictions = {}
|
|
|
|
# DQN predictions
|
|
if hasattr(self.orchestrator, 'recent_dqn_predictions') and symbol in self.orchestrator.recent_dqn_predictions:
|
|
dqn_preds = list(self.orchestrator.recent_dqn_predictions[symbol])
|
|
if dqn_preds:
|
|
predictions['dqn'] = dqn_preds[-1]
|
|
|
|
# CNN predictions
|
|
if hasattr(self.orchestrator, 'recent_cnn_predictions') and symbol in self.orchestrator.recent_cnn_predictions:
|
|
cnn_preds = list(self.orchestrator.recent_cnn_predictions[symbol])
|
|
if cnn_preds:
|
|
predictions['cnn'] = cnn_preds[-1]
|
|
|
|
# Transformer predictions with next_candles for ghost candles
|
|
# First check if there are stored predictions from the inference loop
|
|
if hasattr(self.orchestrator, 'recent_transformer_predictions') and symbol in self.orchestrator.recent_transformer_predictions:
|
|
transformer_preds = list(self.orchestrator.recent_transformer_predictions[symbol])
|
|
if transformer_preds:
|
|
# Convert any remaining tensors to Python types before JSON serialization
|
|
transformer_pred = transformer_preds[-1].copy()
|
|
|
|
# CRITICAL: Log prediction structure to debug missing predicted_candle
|
|
logger.debug(f"Transformer prediction keys: {list(transformer_pred.keys())}")
|
|
if 'predicted_candle' in transformer_pred:
|
|
logger.debug(f"predicted_candle timeframes: {list(transformer_pred['predicted_candle'].keys()) if isinstance(transformer_pred['predicted_candle'], dict) else 'not a dict'}")
|
|
|
|
predictions['transformer'] = self._serialize_prediction(transformer_pred)
|
|
|
|
# Verify predicted_candle is preserved after serialization
|
|
if 'predicted_candle' not in predictions['transformer'] and 'predicted_candle' in transformer_pred:
|
|
logger.warning("predicted_candle was lost during serialization!")
|
|
|
|
if predictions:
|
|
response['prediction'] = predictions
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error getting predictions: {e}")
|
|
import traceback
|
|
logger.debug(traceback.format_exc())
|
|
|
|
return jsonify(response)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in live updates: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
@self.server.route('/api/live-updates-batch', methods=['POST'])
|
|
def get_live_updates_batch():
|
|
"""Get live chart and prediction updates for multiple timeframes (optimized batch endpoint)"""
|
|
try:
|
|
data = request.get_json() or {}
|
|
symbol = data.get('symbol', 'ETH/USDT')
|
|
timeframes = data.get('timeframes', ['1m'])
|
|
|
|
response = {
|
|
'success': True,
|
|
'server_time': datetime.now(timezone.utc).isoformat(), # Add server timestamp to detect stale data
|
|
'chart_updates': {}, # Dict of timeframe -> chart_update
|
|
'prediction': None # Single prediction for all timeframes
|
|
}
|
|
|
|
# Get latest candle for each requested timeframe
|
|
if self.data_provider:
|
|
for timeframe in timeframes:
|
|
try:
|
|
df = self.data_provider.get_data_for_annotation(symbol, timeframe, limit=2, direction='latest')
|
|
if df is not None and not df.empty:
|
|
latest_candle = df.iloc[-1]
|
|
|
|
# Format timestamp as ISO string
|
|
timestamp = latest_candle.name
|
|
if hasattr(timestamp, 'isoformat'):
|
|
if timestamp.tzinfo is not None:
|
|
timestamp_str = timestamp.astimezone(timezone.utc).isoformat()
|
|
else:
|
|
timestamp_str = timestamp.isoformat() + 'Z'
|
|
else:
|
|
timestamp_str = str(timestamp)
|
|
|
|
is_confirmed = len(df) >= 2
|
|
|
|
response['chart_updates'][timeframe] = {
|
|
'symbol': symbol,
|
|
'timeframe': timeframe,
|
|
'candle': {
|
|
'timestamp': timestamp_str,
|
|
'open': float(latest_candle['open']),
|
|
'high': float(latest_candle['high']),
|
|
'low': float(latest_candle['low']),
|
|
'close': float(latest_candle['close']),
|
|
'volume': float(latest_candle['volume'])
|
|
},
|
|
'is_confirmed': is_confirmed
|
|
}
|
|
except Exception as e:
|
|
logger.debug(f"Error getting candle for {timeframe}: {e}")
|
|
|
|
# Get latest model predictions (same for all timeframes)
|
|
if self.orchestrator:
|
|
try:
|
|
predictions = {}
|
|
|
|
# DQN predictions
|
|
if hasattr(self.orchestrator, 'recent_dqn_predictions') and symbol in self.orchestrator.recent_dqn_predictions:
|
|
dqn_preds = list(self.orchestrator.recent_dqn_predictions[symbol])
|
|
if dqn_preds:
|
|
predictions['dqn'] = dqn_preds[-1]
|
|
|
|
# CNN predictions
|
|
if hasattr(self.orchestrator, 'recent_cnn_predictions') and symbol in self.orchestrator.recent_cnn_predictions:
|
|
cnn_preds = list(self.orchestrator.recent_cnn_predictions[symbol])
|
|
if cnn_preds:
|
|
predictions['cnn'] = cnn_preds[-1]
|
|
|
|
# Transformer predictions
|
|
if hasattr(self.orchestrator, 'recent_transformer_predictions') and symbol in self.orchestrator.recent_transformer_predictions:
|
|
transformer_preds = list(self.orchestrator.recent_transformer_predictions[symbol])
|
|
if transformer_preds:
|
|
transformer_pred = transformer_preds[-1].copy()
|
|
predictions['transformer'] = self._serialize_prediction(transformer_pred)
|
|
|
|
if predictions:
|
|
response['prediction'] = predictions
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error getting predictions: {e}")
|
|
|
|
return jsonify(response)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in batch live updates: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
@self.server.route('/api/realtime-inference/signals', methods=['GET'])
|
|
def get_realtime_signals():
|
|
"""Get latest real-time inference signals"""
|
|
try:
|
|
if not self.training_adapter:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'TRAINING_UNAVAILABLE',
|
|
'message': 'Real training adapter not available'
|
|
}
|
|
})
|
|
|
|
signals = self.training_adapter.get_latest_signals()
|
|
|
|
# Get metrics from active inference sessions or orchestrator
|
|
metrics = {'accuracy': 0.0, 'loss': 0.0}
|
|
|
|
# Try to get metrics from orchestrator first (most recent)
|
|
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer_trainer'):
|
|
trainer = self.orchestrator.primary_transformer_trainer
|
|
if trainer and hasattr(trainer, 'training_history'):
|
|
history = trainer.training_history
|
|
if history.get('train_accuracy'):
|
|
metrics['accuracy'] = history['train_accuracy'][-1] if history['train_accuracy'] else 0.0
|
|
if history.get('train_loss'):
|
|
metrics['loss'] = history['train_loss'][-1] if history['train_loss'] else 0.0
|
|
|
|
# Fallback to inference session metrics
|
|
if metrics['accuracy'] == 0.0 and metrics['loss'] == 0.0:
|
|
if hasattr(self.training_adapter, 'inference_sessions'):
|
|
for session in self.training_adapter.inference_sessions.values():
|
|
if 'metrics' in session and session['metrics']:
|
|
metrics = session['metrics'].copy()
|
|
break
|
|
|
|
# CRITICAL FIX: Include position state and session metrics for UI state restoration
|
|
position_state = None
|
|
session_metrics = None
|
|
|
|
# Get position state and session metrics from orchestrator if available
|
|
if self.orchestrator and hasattr(self.orchestrator, 'get_position_state'):
|
|
try:
|
|
position_state = self.orchestrator.get_position_state()
|
|
except:
|
|
pass
|
|
|
|
if self.orchestrator and hasattr(self.orchestrator, 'get_session_metrics'):
|
|
try:
|
|
session_metrics = self.orchestrator.get_session_metrics()
|
|
except:
|
|
pass
|
|
|
|
# Add position state and session metrics to metrics dict
|
|
if position_state:
|
|
metrics['position_state'] = position_state
|
|
if session_metrics:
|
|
metrics['session_pnl'] = session_metrics.get('total_pnl', 0.0)
|
|
metrics['session_metrics'] = session_metrics
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'signals': signals,
|
|
'metrics': metrics
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting signals: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': {
|
|
'code': 'SIGNALS_ERROR',
|
|
'message': str(e)
|
|
}
|
|
})
|
|
|
|
@self.server.route('/api/train-validated-prediction', methods=['POST'])
|
|
def train_validated_prediction():
|
|
"""Train model on a validated prediction (online learning)"""
|
|
try:
|
|
data = request.get_json()
|
|
|
|
timeframe = data.get('timeframe')
|
|
timestamp = data.get('timestamp')
|
|
predicted = data.get('predicted')
|
|
actual = data.get('actual')
|
|
errors = data.get('errors')
|
|
direction_correct = data.get('direction_correct')
|
|
accuracy = data.get('accuracy')
|
|
|
|
logger.info(f"[ONLINE LEARNING] Received validation for {timeframe}: accuracy={accuracy:.1f}%, direction={'✓' if direction_correct else '✗'}")
|
|
|
|
# Trigger training and get metrics
|
|
metrics = self._train_on_validated_prediction(
|
|
timeframe, timestamp, predicted, actual,
|
|
errors, direction_correct, accuracy
|
|
)
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'message': 'Training triggered',
|
|
'metrics': metrics or {}
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error training on validated prediction: {e}", exc_info=True)
|
|
return jsonify({
|
|
'success': False,
|
|
'error': str(e)
|
|
}), 500
|
|
|
|
@self.server.route('/api/training-metrics', methods=['GET'])
|
|
def get_training_metrics():
|
|
"""Get current training metrics for display (loss, accuracy, etc.)"""
|
|
try:
|
|
metrics = {
|
|
'loss': 0.0,
|
|
'accuracy': 0.0,
|
|
'steps': 0,
|
|
'recent_history': []
|
|
}
|
|
|
|
# Get metrics from training adapter if available
|
|
if self.training_adapter and hasattr(self.training_adapter, 'realtime_training_metrics'):
|
|
rt_metrics = self.training_adapter.realtime_training_metrics
|
|
metrics['loss'] = rt_metrics.get('last_loss', 0.0)
|
|
metrics['accuracy'] = rt_metrics.get('last_accuracy', 0.0)
|
|
metrics['steps'] = rt_metrics.get('total_steps', 0)
|
|
# Add best checkpoint metrics
|
|
metrics['best_loss'] = rt_metrics.get('best_loss', float('inf'))
|
|
metrics['best_accuracy'] = rt_metrics.get('best_accuracy', 0.0)
|
|
if metrics['best_loss'] == float('inf'):
|
|
metrics['best_loss'] = None
|
|
|
|
# Get incremental training metrics
|
|
if hasattr(self, '_incremental_training_steps'):
|
|
metrics['incremental_steps'] = self._incremental_training_steps
|
|
if hasattr(self, '_training_metrics_history') and self._training_metrics_history:
|
|
# Get last 10 metrics for display
|
|
metrics['recent_history'] = self._training_metrics_history[-10:]
|
|
# Update current metrics from most recent
|
|
latest = self._training_metrics_history[-1]
|
|
metrics['loss'] = latest.get('loss', metrics['loss'])
|
|
metrics['accuracy'] = latest.get('accuracy', metrics['accuracy'])
|
|
|
|
# Get metrics from orchestrator trainer if available
|
|
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer_trainer'):
|
|
trainer = self.orchestrator.primary_transformer_trainer
|
|
if trainer and hasattr(trainer, 'training_history'):
|
|
history = trainer.training_history
|
|
if history.get('train_loss'):
|
|
metrics['loss'] = history['train_loss'][-1] if history['train_loss'] else metrics['loss']
|
|
if history.get('train_accuracy'):
|
|
metrics['accuracy'] = history['train_accuracy'][-1] if history['train_accuracy'] else metrics['accuracy']
|
|
|
|
return jsonify({
|
|
'success': True,
|
|
'metrics': metrics
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting training metrics: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': str(e)
|
|
}), 500
|
|
|
|
|
|
|
|
@self.server.route('/api/realtime-inference/train-manual', methods=['POST'])
|
|
def train_manual():
|
|
"""Manually trigger training on current candle with specified action"""
|
|
try:
|
|
data = request.get_json()
|
|
inference_id = data.get('inference_id')
|
|
action = data.get('action', 'HOLD')
|
|
|
|
if not self.training_adapter:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': 'Training adapter not available'
|
|
})
|
|
|
|
# Get active inference session
|
|
if not hasattr(self.training_adapter, 'inference_sessions'):
|
|
return jsonify({
|
|
'success': False,
|
|
'error': 'No active inference sessions'
|
|
})
|
|
|
|
session = self.training_adapter.inference_sessions.get(inference_id)
|
|
if not session:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': 'Inference session not found'
|
|
})
|
|
|
|
# Set pending action for training
|
|
session['pending_action'] = action
|
|
|
|
# Get session parameters
|
|
symbol = session.get('symbol', 'ETH/USDT')
|
|
timeframe = session.get('timeframe', '1m')
|
|
data_provider = session.get('data_provider')
|
|
|
|
# Call training method
|
|
train_result = self.training_adapter._train_on_new_candle(
|
|
session, symbol, timeframe, data_provider
|
|
)
|
|
|
|
if train_result.get('success'):
|
|
return jsonify({
|
|
'success': True,
|
|
'action': action,
|
|
'metrics': {
|
|
'loss': train_result.get('loss', 0.0),
|
|
'accuracy': train_result.get('accuracy', 0.0),
|
|
'training_steps': train_result.get('training_steps', 0)
|
|
}
|
|
})
|
|
else:
|
|
return jsonify({
|
|
'success': False,
|
|
'error': train_result.get('error', 'Training failed')
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in manual training: {e}")
|
|
return jsonify({
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
# WebSocket removed - using HTTP polling only
|
|
|
|
def _serialize_prediction(self, prediction: Dict) -> Dict:
|
|
"""Convert PyTorch tensors in prediction dict to JSON-serializable Python types"""
|
|
try:
|
|
import torch
|
|
serialized = {}
|
|
for key, value in prediction.items():
|
|
if isinstance(value, torch.Tensor):
|
|
if value.numel() == 1: # Scalar tensor
|
|
serialized[key] = value.item()
|
|
else: # Multi-element tensor
|
|
serialized[key] = value.detach().cpu().tolist()
|
|
elif isinstance(value, dict):
|
|
serialized[key] = self._serialize_prediction(value) # Recursive
|
|
elif isinstance(value, (list, tuple)):
|
|
serialized[key] = [
|
|
v.item() if isinstance(v, torch.Tensor) and v.numel() == 1 else
|
|
(v.detach().cpu().tolist() if isinstance(v, torch.Tensor) else v)
|
|
for v in value
|
|
]
|
|
else:
|
|
serialized[key] = value
|
|
return serialized
|
|
except Exception as e:
|
|
logger.warning(f"Error serializing prediction: {e}")
|
|
# Fallback: return as-is (might fail JSON serialization but won't crash)
|
|
return prediction
|
|
|
|
# WebSocket code removed - using HTTP polling only
|
|
|
|
def _get_live_transformer_prediction(self, symbol: str = 'ETH/USDT'):
|
|
"""
|
|
Generate live transformer prediction with next_candles for ghost candle display
|
|
"""
|
|
try:
|
|
if not self.orchestrator:
|
|
logger.debug("No orchestrator - cannot generate predictions")
|
|
return None
|
|
|
|
if not hasattr(self.orchestrator, 'primary_transformer'):
|
|
logger.debug("Orchestrator has no primary_transformer - enable training first")
|
|
return None
|
|
|
|
transformer = self.orchestrator.primary_transformer
|
|
if not transformer:
|
|
logger.debug("primary_transformer is None - model not loaded yet")
|
|
return None
|
|
|
|
transformer.eval()
|
|
|
|
# Get recent market data
|
|
price_data_1s = self.data_provider.get_ohlcv(symbol, '1s', limit=200) if self.data_provider else None
|
|
price_data_1m = self.data_provider.get_ohlcv(symbol, '1m', limit=150) if self.data_provider else None
|
|
price_data_1h = self.data_provider.get_ohlcv(symbol, '1h', limit=24) if self.data_provider else None
|
|
price_data_1d = self.data_provider.get_ohlcv(symbol, '1d', limit=14) if self.data_provider else None
|
|
btc_data_1m = self.data_provider.get_ohlcv('BTC/USDT', '1m', limit=150) if self.data_provider else None
|
|
|
|
if not price_data_1m or len(price_data_1m) < 10:
|
|
return None
|
|
|
|
import torch
|
|
import numpy as np
|
|
device = next(transformer.parameters()).device
|
|
|
|
def ohlcv_to_tensor(data, limit=None):
|
|
if not data:
|
|
return None
|
|
data = data[-limit:] if limit and len(data) > limit else data
|
|
arr = np.array([[d['open'], d['high'], d['low'], d['close'], d['volume']] for d in data], dtype=np.float32)
|
|
return torch.from_numpy(arr).unsqueeze(0).to(device)
|
|
|
|
inputs = {
|
|
'price_data_1s': ohlcv_to_tensor(price_data_1s, 200),
|
|
'price_data_1m': ohlcv_to_tensor(price_data_1m, 150),
|
|
'price_data_1h': ohlcv_to_tensor(price_data_1h, 24),
|
|
'price_data_1d': ohlcv_to_tensor(price_data_1d, 14),
|
|
'btc_data_1m': ohlcv_to_tensor(btc_data_1m, 150)
|
|
}
|
|
|
|
# Forward pass
|
|
with torch.no_grad():
|
|
outputs = transformer(**inputs)
|
|
|
|
# Extract next_candles
|
|
next_candles = outputs.get('next_candles', {})
|
|
if not next_candles:
|
|
return None
|
|
|
|
# Convert to JSON-serializable format
|
|
predicted_candle = {}
|
|
for tf, candle_tensor in next_candles.items():
|
|
if candle_tensor is not None:
|
|
candle_values = candle_tensor.squeeze(0).cpu().numpy().tolist()
|
|
predicted_candle[tf] = candle_values
|
|
|
|
current_price = price_data_1m[-1]['close']
|
|
predicted_1m_close = predicted_candle.get('1m', [0,0,0,current_price,0])[3]
|
|
price_change = (predicted_1m_close - current_price) / current_price
|
|
|
|
if price_change > 0.001:
|
|
action = 'BUY'
|
|
elif price_change < -0.001:
|
|
action = 'SELL'
|
|
else:
|
|
action = 'HOLD'
|
|
|
|
confidence = 0.7
|
|
if 'confidence' in outputs:
|
|
conf_tensor = outputs['confidence']
|
|
confidence = float(conf_tensor.squeeze(0).cpu().numpy()[0])
|
|
|
|
prediction = {
|
|
'timestamp': datetime.now().isoformat(),
|
|
'symbol': symbol,
|
|
'action': action,
|
|
'confidence': confidence,
|
|
'predicted_price': predicted_1m_close,
|
|
'current_price': current_price,
|
|
'price_change': price_change,
|
|
'predicted_candle': predicted_candle, # This is what frontend needs!
|
|
'primary_timeframe': '1m', # The main timeframe the model is predicting for
|
|
'type': 'transformer_prediction'
|
|
}
|
|
|
|
# Store for tracking
|
|
self.orchestrator.store_transformer_prediction(symbol, prediction)
|
|
|
|
logger.debug(f"Generated transformer prediction with {len(predicted_candle)} timeframes for ghost candles")
|
|
|
|
return prediction
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating live transformer prediction: {e}", exc_info=True)
|
|
return None
|
|
|
|
def _train_on_validated_prediction(self, timeframe: str, timestamp: str, predicted: list,
|
|
actual: list, errors: dict, direction_correct: bool, accuracy: float):
|
|
"""
|
|
Incrementally train model on validated prediction
|
|
|
|
This implements online learning where each validated prediction becomes
|
|
a training sample, with loss weighting based on prediction accuracy.
|
|
|
|
Returns:
|
|
Dict with training metrics (loss, accuracy, steps)
|
|
"""
|
|
try:
|
|
if not self.training_adapter:
|
|
logger.warning("Training adapter not available for incremental training")
|
|
return None
|
|
|
|
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
|
|
logger.warning("Transformer model not available for incremental training")
|
|
return None
|
|
|
|
# Get the transformer trainer
|
|
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
|
if not trainer:
|
|
logger.warning("Transformer trainer not available")
|
|
return None
|
|
|
|
# Calculate sample weight based on accuracy
|
|
# Low accuracy predictions get higher weight (we need to learn from mistakes)
|
|
# High accuracy predictions get lower weight (model already knows this)
|
|
if accuracy < 50:
|
|
sample_weight = 3.0 # Learn hard from bad predictions
|
|
elif accuracy < 70:
|
|
sample_weight = 2.0 # Moderate learning
|
|
elif accuracy < 85:
|
|
sample_weight = 1.0 # Normal learning
|
|
else:
|
|
sample_weight = 0.5 # Light touch-up for good predictions
|
|
|
|
# Also weight by direction correctness
|
|
if not direction_correct:
|
|
sample_weight *= 1.5 # Wrong direction is critical - learn more
|
|
|
|
logger.info(f"[{timeframe}] Incremental training: accuracy={accuracy:.1f}%, weight={sample_weight:.1f}x")
|
|
|
|
# Create training sample from validated prediction
|
|
# We need to fetch the market state at that timestamp
|
|
symbol = 'ETH/USDT' # TODO: Get from active trading pair
|
|
|
|
# Ensure actual_candle has volume (frontend sends [O, H, L, C, V])
|
|
actual_candle = list(actual) if isinstance(actual, (list, tuple)) else actual
|
|
if len(actual_candle) == 4:
|
|
# If only 4 values, add volume from predicted (fallback)
|
|
actual_candle.append(predicted[4] if len(predicted) > 4 else 0.0)
|
|
|
|
training_sample = {
|
|
'symbol': symbol,
|
|
'timestamp': timestamp,
|
|
'predicted_candle': predicted, # [O, H, L, C, V]
|
|
'actual_candle': actual_candle, # [O, H, L, C, V] - ensure 5 values
|
|
'errors': errors,
|
|
'accuracy': accuracy,
|
|
'direction_correct': direction_correct,
|
|
'sample_weight': sample_weight
|
|
}
|
|
|
|
# Get market state at that timestamp
|
|
try:
|
|
market_state = self._fetch_market_state_at_timestamp(symbol, timestamp, timeframe)
|
|
if not market_state or 'timeframes' not in market_state:
|
|
logger.warning(f"Could not fetch market state for {symbol} at {timestamp}")
|
|
return None
|
|
training_sample['market_state'] = market_state
|
|
except Exception as e:
|
|
logger.warning(f"Could not fetch market state: {e}")
|
|
return None
|
|
|
|
# Convert to transformer batch format
|
|
batch = self.training_adapter._convert_prediction_to_batch(training_sample, timeframe)
|
|
if not batch:
|
|
logger.warning("Could not convert validated prediction to training batch")
|
|
return None
|
|
|
|
# Train on this batch with sample weighting
|
|
# CRITICAL: Use training lock to prevent concurrent access
|
|
import torch
|
|
import threading
|
|
|
|
# Try to acquire training lock with timeout
|
|
if hasattr(self.training_adapter, '_training_lock'):
|
|
lock_acquired = self.training_adapter._training_lock.acquire(timeout=5.0)
|
|
if not lock_acquired:
|
|
logger.warning("Could not acquire training lock within 5 seconds - skipping incremental training")
|
|
return None
|
|
else:
|
|
lock_acquired = False
|
|
|
|
try:
|
|
with torch.enable_grad():
|
|
trainer.model.train()
|
|
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
|
|
|
|
if result:
|
|
loss = result.get('total_loss', 0)
|
|
candle_accuracy = result.get('candle_accuracy', 0)
|
|
|
|
logger.info(f"[{timeframe}] ✓ Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}")
|
|
|
|
# Save checkpoint periodically (every 10 incremental steps)
|
|
if not hasattr(self, '_incremental_training_steps'):
|
|
self._incremental_training_steps = 0
|
|
|
|
self._incremental_training_steps += 1
|
|
|
|
# Track metrics for display
|
|
if not hasattr(self, '_training_metrics_history'):
|
|
self._training_metrics_history = []
|
|
|
|
self._training_metrics_history.append({
|
|
'step': self._incremental_training_steps,
|
|
'loss': loss,
|
|
'accuracy': candle_accuracy,
|
|
'timeframe': timeframe,
|
|
'timestamp': timestamp
|
|
})
|
|
|
|
# Keep only last 100 metrics
|
|
if len(self._training_metrics_history) > 100:
|
|
self._training_metrics_history.pop(0)
|
|
|
|
if self._incremental_training_steps % 10 == 0:
|
|
logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps")
|
|
trainer.save_checkpoint(
|
|
filepath=None, # Auto-generate path
|
|
metadata={
|
|
'training_type': 'incremental_online',
|
|
'steps': self._incremental_training_steps,
|
|
'last_accuracy': accuracy
|
|
}
|
|
)
|
|
|
|
# Return metrics for display
|
|
return {
|
|
'loss': loss,
|
|
'accuracy': candle_accuracy,
|
|
'steps': self._incremental_training_steps,
|
|
'sample_weight': sample_weight
|
|
}
|
|
else:
|
|
logger.warning("Training step returned no result")
|
|
return None
|
|
finally:
|
|
# CRITICAL: Always release the lock
|
|
if lock_acquired and hasattr(self.training_adapter, '_training_lock'):
|
|
self.training_adapter._training_lock.release()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in incremental training: {e}", exc_info=True)
|
|
# Ensure lock is released even on error
|
|
if 'lock_acquired' in locals() and lock_acquired and hasattr(self.training_adapter, '_training_lock'):
|
|
try:
|
|
self.training_adapter._training_lock.release()
|
|
except:
|
|
pass
|
|
return None
|
|
|
|
def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict:
|
|
"""Fetch market state at a specific timestamp for training"""
|
|
try:
|
|
from datetime import datetime, timezone
|
|
import pandas as pd
|
|
|
|
# Parse timestamp - ensure it's timezone-aware
|
|
if isinstance(timestamp, str):
|
|
ts = pd.Timestamp(timestamp)
|
|
if ts.tz is None:
|
|
ts = ts.tz_localize('UTC')
|
|
else:
|
|
ts = pd.Timestamp(timestamp)
|
|
if ts.tz is None:
|
|
ts = ts.tz_localize('UTC')
|
|
|
|
# Use data provider's method to get market state at that time
|
|
# This ensures we get the proper format with all required timeframes
|
|
if self.data_provider and hasattr(self.data_provider, 'get_market_state_at_time'):
|
|
try:
|
|
# Convert to datetime if needed
|
|
if isinstance(ts, pd.Timestamp):
|
|
dt = ts.to_pydatetime()
|
|
else:
|
|
dt = ts
|
|
|
|
# Get market state with context window (need enough candles for training)
|
|
market_state = self.data_provider.get_market_state_at_time(
|
|
symbol=symbol,
|
|
timestamp=dt,
|
|
context_window_minutes=600 # Get 600 minutes of context for 1m candles
|
|
)
|
|
|
|
if market_state and 'timeframes' in market_state:
|
|
logger.debug(f"Fetched market state with {len(market_state.get('timeframes', {}))} timeframes")
|
|
return market_state
|
|
else:
|
|
logger.warning("Market state returned empty or invalid format")
|
|
except Exception as e:
|
|
logger.warning(f"Could not use data provider method: {e}")
|
|
|
|
# Fallback: manually fetch data for each timeframe
|
|
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
|
|
|
|
# REQUIRED timeframes for transformer: 1m, 1h, 1d (1s is optional)
|
|
# Need at least 50 candles, preferably 600
|
|
required_tfs = ['1m', '1h', '1d']
|
|
optional_tfs = ['1s']
|
|
|
|
for tf in required_tfs + optional_tfs:
|
|
try:
|
|
# Fetch enough candles (600 for training, but accept less)
|
|
df = None
|
|
if self.data_provider:
|
|
df = self.data_provider.get_data_for_annotation(
|
|
symbol=symbol,
|
|
timeframe=tf,
|
|
end_time=dt,
|
|
limit=600,
|
|
direction='before'
|
|
)
|
|
|
|
# Fallback to regular historical data if annotation method fails
|
|
if df is None or df.empty:
|
|
if self.data_provider:
|
|
df = self.data_provider.get_historical_data(symbol, tf, limit=600, refresh=False)
|
|
|
|
if df is not None and not df.empty:
|
|
# Filter to data before the target timestamp
|
|
df_before = df[df.index < ts]
|
|
if df_before.empty:
|
|
# If no data before timestamp, use all available data
|
|
df_before = df
|
|
|
|
# Take last 600 candles (or all if less)
|
|
recent = df_before.tail(600)
|
|
|
|
if len(recent) >= 50: # Minimum required
|
|
market_state['timeframes'][tf] = {
|
|
'open': recent['open'].tolist(),
|
|
'high': recent['high'].tolist(),
|
|
'low': recent['low'].tolist(),
|
|
'close': recent['close'].tolist(),
|
|
'volume': recent['volume'].tolist()
|
|
}
|
|
logger.debug(f"Fetched {len(recent)} candles for {tf} timeframe")
|
|
else:
|
|
if tf in required_tfs:
|
|
logger.warning(f"Required timeframe {tf} has only {len(recent)} candles (need at least 50)")
|
|
else:
|
|
logger.debug(f"Optional timeframe {tf} has only {len(recent)} candles, skipping")
|
|
except Exception as e:
|
|
logger.warning(f"Could not fetch {tf} data: {e}")
|
|
|
|
# Validate we have required timeframes
|
|
missing_required = [tf for tf in required_tfs if tf not in market_state['timeframes']]
|
|
if missing_required:
|
|
logger.warning(f"Missing required timeframes: {missing_required}")
|
|
return {}
|
|
|
|
return market_state
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching market state: {e}", exc_info=True)
|
|
return {}
|
|
|
|
def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1):
|
|
"""
|
|
Get live prediction from model using trainer inference
|
|
|
|
Caches inference data (inputs/outputs) for later training when actual candle arrives.
|
|
This allows us to:
|
|
1. Compare predicted vs actual candle values
|
|
2. Calculate loss
|
|
3. Do backpropagation with correct outputs
|
|
|
|
Returns:
|
|
Dict with prediction results including predicted_candle for ghost candle display
|
|
"""
|
|
try:
|
|
if not self.orchestrator:
|
|
return None
|
|
|
|
# Get trainer from orchestrator
|
|
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
|
if not trainer or not trainer.model:
|
|
logger.debug("No transformer trainer available for live prediction")
|
|
return None
|
|
|
|
# Get market data using training adapter's method (reuses existing logic)
|
|
if not hasattr(self.training_adapter, '_get_realtime_market_data'):
|
|
logger.warning("Training adapter missing _get_realtime_market_data method")
|
|
return None
|
|
|
|
market_data, norm_params = self.training_adapter._get_realtime_market_data(symbol, self.data_provider)
|
|
if not market_data:
|
|
logger.debug(f"No market data available for {symbol} {timeframe}")
|
|
return None
|
|
|
|
# Make prediction with model
|
|
import torch
|
|
timestamp = datetime.now(timezone.utc)
|
|
|
|
with torch.no_grad():
|
|
trainer.model.eval()
|
|
outputs = trainer.model(**market_data)
|
|
|
|
# Extract action prediction
|
|
action_probs = outputs.get('action_probs')
|
|
if action_probs is None:
|
|
logger.debug("No action_probs in model output")
|
|
return None
|
|
|
|
action_idx = torch.argmax(action_probs, dim=-1).item()
|
|
confidence = action_probs[0, action_idx].item()
|
|
|
|
# Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL)
|
|
actions = ['HOLD', 'BUY', 'SELL']
|
|
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
|
|
|
|
# Extract predicted candles and denormalize
|
|
predicted_candles_raw = {}
|
|
if 'next_candles' in outputs:
|
|
for tf, tensor in outputs['next_candles'].items():
|
|
predicted_candles_raw[tf] = tensor.detach().cpu().numpy().tolist()
|
|
|
|
# Denormalize predicted candles
|
|
predicted_candles_denorm = {}
|
|
if predicted_candles_raw and norm_params:
|
|
for tf, raw_candle in predicted_candles_raw.items():
|
|
if tf in norm_params:
|
|
params = norm_params[tf]
|
|
price_min = params['price_min']
|
|
price_max = params['price_max']
|
|
vol_min = params['volume_min']
|
|
vol_max = params['volume_max']
|
|
|
|
# raw_candle is [1, 5] list
|
|
candle_values = raw_candle[0]
|
|
|
|
denorm_candle = [
|
|
candle_values[0] * (price_max - price_min) + price_min, # Open
|
|
candle_values[1] * (price_max - price_min) + price_min, # High
|
|
candle_values[2] * (price_max - price_min) + price_min, # Low
|
|
candle_values[3] * (price_max - price_min) + price_min, # Close
|
|
candle_values[4] * (vol_max - vol_min) + vol_min # Volume
|
|
]
|
|
predicted_candles_denorm[tf] = denorm_candle
|
|
|
|
# Get predicted price from candle close
|
|
predicted_price = None
|
|
if timeframe in predicted_candles_denorm:
|
|
predicted_price = predicted_candles_denorm[timeframe][3] # Close
|
|
elif '1m' in predicted_candles_denorm:
|
|
predicted_price = predicted_candles_denorm['1m'][3]
|
|
elif '1s' in predicted_candles_denorm:
|
|
predicted_price = predicted_candles_denorm['1s'][3]
|
|
|
|
# NOTE: Caching is now handled by InferenceFrameReference system in real_training_adapter
|
|
# This provides more efficient reference-based storage without copying 600 candles
|
|
|
|
# Return prediction result (same format as before for compatibility)
|
|
return {
|
|
'symbol': symbol,
|
|
'timeframe': timeframe,
|
|
'timestamp': timestamp.isoformat(),
|
|
'action': action,
|
|
'confidence': confidence,
|
|
'predicted_price': predicted_price,
|
|
'predicted_candle': predicted_candles_denorm,
|
|
'prediction_steps': prediction_steps
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting live prediction: {e}")
|
|
import traceback
|
|
logger.debug(traceback.format_exc())
|
|
return None
|
|
|
|
# REMOVED: Unused prediction caching methods
|
|
# Now using InferenceFrameReference system for unified prediction storage and training
|
|
|
|
def run(self, host='0.0.0.0', port=8051, debug=False):
|
|
"""Run the application - binds to all interfaces by default"""
|
|
logger.info(f"Starting Annotation Dashboard on http://{host}:{port}")
|
|
logger.info(f"Access locally at: http://localhost:{port}")
|
|
logger.info(f"Access from network at: http://<your-ip>:{port}")
|
|
|
|
# WebSocket removed - using HTTP polling only
|
|
# Start Flask server
|
|
self.server.run(host=host, port=port, debug=debug, use_reloader=False)
|
|
|
|
|
|
|
|
def main():
|
|
"""Main entry point"""
|
|
logger.info("=" * 80)
|
|
logger.info("ANNOTATE Application Starting")
|
|
logger.info(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
logger.info("=" * 80)
|
|
|
|
# Print logging channel configuration
|
|
from utils.logging_config import print_channel_status
|
|
print_channel_status()
|
|
|
|
dashboard = AnnotationDashboard()
|
|
dashboard.run(debug=True)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|