Files
gogo2/ANNOTATE/core/real_training_adapter.py
Dobromir Popov fadfa8c741 wip
2025-12-10 00:45:41 +02:00

5062 lines
257 KiB
Python

"""
Real Training Adapter for ANNOTATE System
This adapter connects the ANNOTATE annotation system to the REAL training implementations.
NO SIMULATION - Uses actual model training from NN/training and core modules.
Integrates with:
- NN/training/enhanced_realtime_training.py
- NN/training/model_manager.py
- core/unified_training_manager.py
- core/orchestrator.py
"""
import logging
import uuid
import time
import threading
import os
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
import torch
import numpy as np
import pandas as pd
try:
import pytz
except ImportError:
pytz = None
logger = logging.getLogger(__name__)
def parse_timestamp_to_utc(timestamp_str) -> datetime:
"""
Unified timestamp parser that handles all formats and ensures UTC timezone.
Handles:
- pandas Timestamp objects
- datetime objects
- ISO format with timezone: '2025-10-27T14:00:00+00:00'
- ISO format with Z: '2025-10-27T14:00:00Z'
- Space-separated with seconds: '2025-10-27 14:00:00'
- Space-separated without seconds: '2025-10-27 14:00'
Args:
timestamp_str: Timestamp string, pandas Timestamp, or datetime object
Returns:
Timezone-aware datetime object in UTC
Raises:
ValueError: If timestamp cannot be parsed
"""
# Handle pandas Timestamp objects
if hasattr(timestamp_str, 'to_pydatetime'):
dt = timestamp_str.to_pydatetime()
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
# Handle datetime objects directly
if isinstance(timestamp_str, datetime):
if timestamp_str.tzinfo is None:
return timestamp_str.replace(tzinfo=timezone.utc)
return timestamp_str
# Convert to string if not already
if not isinstance(timestamp_str, str):
timestamp_str = str(timestamp_str)
if not timestamp_str:
raise ValueError("Empty timestamp string")
# Try ISO format first (handles T separator and timezone info)
if 'T' in timestamp_str or '+' in timestamp_str:
try:
# Handle 'Z' suffix (Zulu time = UTC)
if timestamp_str.endswith('Z'):
timestamp_str = timestamp_str[:-1] + '+00:00'
return datetime.fromisoformat(timestamp_str)
except ValueError:
pass
# Try space-separated formats
# Replace space with T for fromisoformat compatibility
if ' ' in timestamp_str:
try:
# Try parsing with fromisoformat after converting space to T
dt = datetime.fromisoformat(timestamp_str.replace(' ', 'T'))
# Make timezone-aware if naive
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
except ValueError:
pass
# Try explicit format parsing as fallback
formats = [
'%Y-%m-%d %H:%M:%S', # With seconds
'%Y-%m-%d %H:%M', # Without seconds
'%Y-%m-%dT%H:%M:%S', # ISO without timezone
'%Y-%m-%dT%H:%M', # ISO without seconds or timezone
]
for fmt in formats:
try:
dt = datetime.strptime(timestamp_str, fmt)
# Make timezone-aware
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
except ValueError:
continue
# If all parsing attempts fail
raise ValueError(f"Could not parse timestamp: '{timestamp_str}'")
@dataclass
class TrainingSession:
"""Real training session tracking"""
training_id: str
model_name: str
test_cases_count: int
status: str # 'running', 'completed', 'failed'
current_epoch: int
total_epochs: int
current_loss: float
start_time: float
duration_seconds: Optional[float] = None
final_loss: Optional[float] = None
accuracy: Optional[float] = None
error: Optional[str] = None
gpu_utilization: Optional[float] = None # GPU utilization percentage
cpu_utilization: Optional[float] = None # CPU utilization percentage
annotation_count: Optional[int] = None # Number of annotations used
timeframe: Optional[str] = None # Primary timeframe (e.g., '1m', '5m')
class RealTrainingAdapter:
"""
Adapter for REAL model training using annotations.
This class bridges the ANNOTATE system with the actual training implementations.
NO SIMULATION CODE - All training is real.
"""
def __init__(self, orchestrator=None, data_provider=None):
"""
Initialize with real orchestrator and data provider
Args:
orchestrator: TradingOrchestrator instance with real models
data_provider: DataProvider for fetching real market data
"""
self.orchestrator = orchestrator
self.data_provider = data_provider
self.training_sessions: Dict[str, TrainingSession] = {}
# CRITICAL: Training lock to prevent concurrent model access
# Multiple threads (batch training + per-candle training) can corrupt
# the computation graph if they access the model simultaneously
# Use RLock (reentrant lock) to allow same thread to acquire multiple times
import threading
self._training_lock = threading.RLock()
# Track which thread currently holds the training lock (for debugging)
self._training_lock_holder = None
# Use orchestrator's inference training coordinator (if available)
# This reduces duplication and centralizes coordination logic
if orchestrator and hasattr(orchestrator, 'inference_training_coordinator'):
self.training_coordinator = orchestrator.inference_training_coordinator
if self.training_coordinator:
# Subscribe to training events
self._subscribe_to_training_events()
else:
self.training_coordinator = None
# Real-time training tracking
self.realtime_training_metrics = {
'total_steps': 0,
'total_loss': 0.0,
'total_accuracy': 0.0,
'best_loss': float('inf'),
'best_accuracy': 0.0,
'last_checkpoint_step': 0,
'checkpoint_frequency': 100, # Save every N steps
'losses': [], # Rolling window
'accuracies': [] # Rolling window
}
# Import real training systems
self._import_training_systems()
# Load best realtime checkpoint if available
self._load_best_realtime_checkpoint()
logger.info("RealTrainingAdapter initialized - NO SIMULATION, REAL TRAINING ONLY")
# Implement TrainingEventSubscriber interface
def on_candle_completion(self, event, inference_ref):
"""
Called when a candle completes - train on stored inference frame with actual result.
This uses the reference-based system: inference data is retrieved from DuckDB
using the reference, not copied.
"""
if not inference_ref or not self.training_coordinator:
return
try:
# Retrieve inference data from DuckDB using reference
model_inputs = self.training_coordinator.get_inference_data(inference_ref)
if not model_inputs:
logger.warning(f"Could not retrieve inference data for {inference_ref.inference_id}")
return
# Create training batch with actual candle
batch = self._create_training_batch_from_inference(
model_inputs, event.ohlcv, inference_ref
)
if not batch:
return
# Train model (backprop for Transformer)
self._train_on_inference_batch(batch, inference_ref)
except Exception as e:
logger.error(f"Error in candle completion training: {e}", exc_info=True)
def on_pivot_event(self, event, inference_refs):
"""
Called when a pivot point is detected - train on matching inference frames.
This handles event-based training where we don't know when the pivot will occur.
"""
if not inference_refs or not self.training_coordinator:
return
try:
for inference_ref in inference_refs:
# Retrieve inference data
model_inputs = self.training_coordinator.get_inference_data(inference_ref)
if not model_inputs:
continue
# Create training batch with pivot result
batch = self._create_pivot_training_batch(model_inputs, event, inference_ref)
if not batch:
continue
# Train model
self._train_on_inference_batch(batch, inference_ref)
except Exception as e:
logger.error(f"Error in pivot event training: {e}", exc_info=True)
def _create_training_batch_from_inference(self, model_inputs: Dict, actual_ohlcv: Dict,
inference_ref) -> Optional[Dict]:
"""Create training batch from inference inputs and actual candle result"""
try:
import torch
# Copy model inputs
batch = {k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in model_inputs.items()}
# Get device
device = next(iter(batch.values())).device if batch else torch.device('cpu')
# Normalize actual candle using stored params
timeframe = inference_ref.timeframe
if timeframe in inference_ref.norm_params:
params = inference_ref.norm_params[timeframe]
price_min = params['price_min']
price_max = params['price_max']
vol_min = params['volume_min']
vol_max = params['volume_max']
# Normalize actual OHLCV
normalized_candle = [
(actual_ohlcv['open'] - price_min) / (price_max - price_min),
(actual_ohlcv['high'] - price_min) / (price_max - price_min),
(actual_ohlcv['low'] - price_min) / (price_max - price_min),
(actual_ohlcv['close'] - price_min) / (price_max - price_min),
(actual_ohlcv['volume'] - vol_min) / (vol_max - vol_min) if vol_max > vol_min else 0.0
]
# Add target candle to batch
target_key = f'future_candle_{timeframe}'
batch[target_key] = torch.tensor([normalized_candle], dtype=torch.float32, device=device)
# Add action target (determine from price movement)
price_change = (actual_ohlcv['close'] - actual_ohlcv['open']) / actual_ohlcv['open']
if price_change > 0.0005: # 0.05% up
action = 1 # BUY
elif price_change < -0.0005: # 0.05% down
action = 2 # SELL
else:
action = 0 # HOLD
batch['actions'] = torch.tensor([[action]], dtype=torch.long, device=device)
return batch
return None
except Exception as e:
logger.error(f"Error creating training batch from inference: {e}", exc_info=True)
return None
def _create_pivot_training_batch(self, model_inputs: Dict, pivot_event, inference_ref) -> Optional[Dict]:
"""Create training batch from inference inputs and pivot event"""
try:
import torch
# Copy model inputs
batch = {k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in model_inputs.items()}
# Get device
device = next(iter(batch.values())).device if batch else torch.device('cpu')
# Determine action from pivot type
# L2L, L3L, etc. -> BUY (support levels)
# L2H, L3H, etc. -> SELL (resistance levels)
if pivot_event.pivot_type.endswith('L'):
action = 1 # BUY
elif pivot_event.pivot_type.endswith('H'):
action = 2 # SELL
else:
action = 0 # HOLD
batch['actions'] = torch.tensor([[action]], dtype=torch.long, device=device)
# For pivot training, we don't have a target candle, so we use the pivot price
# as a reference point for training
# This is a simplified approach - could be enhanced with pivot-based targets
return batch
except Exception as e:
logger.error(f"Error creating pivot training batch: {e}", exc_info=True)
return None
def _train_on_inference_batch(self, batch: Dict, inference_ref) -> None:
"""Train model on inference batch (uses stored inference frame)"""
try:
if not self.orchestrator:
return
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if not trainer:
return
# Train with lock protection
import torch
with self._training_lock:
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False)
if result:
loss = result.get('total_loss', 0)
accuracy = result.get('accuracy', 0)
# Update metrics
self.realtime_training_metrics['total_steps'] += 1
self.realtime_training_metrics['total_loss'] += loss
self.realtime_training_metrics['total_accuracy'] += accuracy
self.realtime_training_metrics['losses'].append(loss)
self.realtime_training_metrics['accuracies'].append(accuracy)
if len(self.realtime_training_metrics['losses']) > 100:
self.realtime_training_metrics['losses'].pop(0)
self.realtime_training_metrics['accuracies'].pop(0)
logger.info(f"Trained on inference frame {inference_ref.inference_id}: Loss={loss:.4f}, Acc={accuracy:.2%}")
except Exception as e:
logger.error(f"Error training on inference batch: {e}", exc_info=True)
def _register_inference_frame(self, session: Dict, symbol: str, timeframe: str,
prediction: Dict, data_provider, norm_params: Dict = None) -> None:
"""
Register inference frame reference with coordinator.
Stores reference (timestamp range) instead of copying 600 candles.
This method stores norm_params in the reference for efficient retrieval.
When training is triggered, data is retrieved from DuckDB using the reference.
Args:
session: Inference session
symbol: Trading symbol
timeframe: Timeframe
prediction: Prediction dict from model
data_provider: Data provider instance
norm_params: Normalization parameters (optional, will be calculated if not provided)
"""
if not self.training_coordinator:
return
try:
from ANNOTATE.core.inference_training_system import InferenceFrameReference
from datetime import datetime, timezone, timedelta
import uuid
# Get current time and calculate data range
current_time = datetime.now(timezone.utc)
data_range_end = current_time
# Calculate start time for 600 candles (approximate)
timeframe_seconds = {'1s': 1, '1m': 60, '5m': 300, '15m': 900, '1h': 3600, '1d': 86400}.get(timeframe, 60)
data_range_start = current_time - timedelta(seconds=600 * timeframe_seconds)
# Use provided norm_params or calculate if not available
if not norm_params:
norm_params = {}
# Calculate target timestamp (next candle close time)
# For 1m timeframe, next candle closes in 1 minute
target_timestamp = current_time + timedelta(seconds=timeframe_seconds)
# Create inference frame reference
inference_ref = InferenceFrameReference(
inference_id=str(uuid.uuid4()),
symbol=symbol,
timeframe=timeframe,
prediction_timestamp=current_time,
target_timestamp=target_timestamp,
data_range_start=data_range_start,
data_range_end=data_range_end,
norm_params=norm_params, # Stored for efficient retrieval
predicted_action=prediction.get('action'),
predicted_candle=prediction.get('predicted_candle'),
confidence=prediction.get('confidence', 0.0)
)
# Register with coordinator
self.training_coordinator.register_inference_frame(inference_ref)
logger.debug(f"Registered inference frame: {inference_ref.inference_id} for {symbol} {timeframe} (target: {target_timestamp})")
except Exception as e:
logger.warning(f"Could not register inference frame: {e}", exc_info=True)
def _subscribe_to_training_events(self):
"""Subscribe to training events via orchestrator's coordinator"""
if not self.training_coordinator:
return
try:
# Subscribe to candle completion for primary symbol/timeframe
primary_symbol = getattr(self.orchestrator, 'symbol', 'ETH/USDT')
primary_timeframe = '1m' # Default timeframe
self.training_coordinator.subscribe_to_candle_completion(
self, symbol=primary_symbol, timeframe=primary_timeframe
)
# Subscribe to pivot events (L2L, L2H, L3L, L3H)
self.training_coordinator.subscribe_to_pivot_events(
self, symbol=primary_symbol, timeframe=primary_timeframe,
pivot_types=['L2L', 'L2H', 'L3L', 'L3H']
)
logger.info(f"Subscribed to training events: {primary_symbol} {primary_timeframe}")
except Exception as e:
logger.warning(f"Could not subscribe to training events: {e}")
def _import_training_systems(self):
"""Import real training system implementations"""
try:
from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem
self.enhanced_training_available = True
logger.info("EnhancedRealtimeTrainingSystem available")
except ImportError as e:
self.enhanced_training_available = False
logger.warning(f"EnhancedRealtimeTrainingSystem not available: {e}")
try:
from NN.training.model_manager import ModelManager
self.model_manager_available = True
logger.info("ModelManager available")
except ImportError as e:
self.model_manager_available = False
logger.warning(f"ModelManager not available: {e}")
try:
from core.enhanced_rl_training_adapter import EnhancedRLTrainingAdapter
self.enhanced_rl_adapter_available = True
logger.info("EnhancedRLTrainingAdapter available")
except ImportError as e:
self.enhanced_rl_adapter_available = False
logger.warning(f"EnhancedRLTrainingAdapter not available: {e}")
def get_available_models(self) -> List[str]:
"""Get list of available models from orchestrator"""
if not self.orchestrator:
logger.error("Orchestrator not available")
return []
available = []
# Check which models are actually loaded in orchestrator
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
available.append("CNN")
if hasattr(self.orchestrator, 'rl_agent') and self.orchestrator.rl_agent:
available.append("DQN")
if hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer:
available.append("Transformer")
if hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent:
available.append("COB")
if hasattr(self.orchestrator, 'extrema_trainer') and self.orchestrator.extrema_trainer:
available.append("Extrema")
logger.info(f"Available models for training: {available}")
return available
def start_training(self, model_name: str, test_cases: List[Dict],
annotation_count: Optional[int] = None,
timeframe: Optional[str] = None) -> str:
"""
Start REAL training session with test cases
Args:
model_name: Name of model to train (CNN, DQN, Transformer, COB, Extrema)
test_cases: List of test cases from annotations
annotation_count: Number of annotations used (optional)
timeframe: Primary timeframe for training (optional, e.g., '1m', '5m')
Returns:
training_id: Unique ID for this training session
"""
if not self.orchestrator:
raise Exception("Orchestrator not available - cannot train models")
training_id = str(uuid.uuid4())
# Use annotation_count if provided, otherwise use test_cases count
if annotation_count is None:
annotation_count = len(test_cases)
# Create training session
session = TrainingSession(
training_id=training_id,
model_name=model_name,
test_cases_count=len(test_cases),
status='running',
current_epoch=0,
total_epochs=10, # Reasonable for annotation-based training
current_loss=0.0,
start_time=time.time(),
annotation_count=annotation_count,
timeframe=timeframe
)
self.training_sessions[training_id] = session
logger.info(f"Starting REAL training session: {training_id} for {model_name} with {len(test_cases)} test cases")
# Start actual training in background thread
thread = threading.Thread(
target=self._execute_real_training,
args=(training_id, model_name, test_cases),
daemon=True
)
thread.start()
return training_id
def _execute_real_training(self, training_id: str, model_name: str, test_cases: List[Dict]):
"""Execute REAL model training (runs in background thread)"""
session = self.training_sessions[training_id]
try:
logger.info(f"Executing REAL training for {model_name}")
logger.info(f" Training ID: {training_id}")
logger.info(f" Test cases: {len(test_cases)}")
# Clear previous predictions for clean visualization
# Get symbol from first test case
symbol = test_cases[0].get('symbol', 'ETH/USDT') if test_cases else 'ETH/USDT'
if (self.orchestrator and
hasattr(self.orchestrator, 'clear_predictions') and
hasattr(self.orchestrator, 'recent_transformer_predictions')):
self.orchestrator.clear_predictions(symbol)
logger.info(f" Cleared previous predictions for {symbol}")
else:
logger.info(f" Orchestrator not ready, skipping prediction clearing for {symbol}")
# Prepare training data from test cases
training_data = self._prepare_training_data(test_cases)
if not training_data:
raise Exception("No valid training data prepared from test cases")
logger.info(f" Prepared {len(training_data)} training samples")
# Route to appropriate REAL training method
if model_name in ["CNN", "StandardizedCNN"]:
logger.info(" Starting CNN training...")
self._train_cnn_real(session, training_data)
elif model_name == "DQN":
logger.info(" Starting DQN training...")
self._train_dqn_real(session, training_data)
elif model_name == "Transformer":
logger.info(" Starting Transformer training...")
self._train_transformer_real(session, training_data)
elif model_name == "COB":
logger.info(" Starting COB training...")
self._train_cob_real(session, training_data)
elif model_name == "Extrema":
logger.info(" Starting Extrema training...")
self._train_extrema_real(session, training_data)
else:
raise Exception(f"Unknown model type: {model_name}")
# Mark as completed
session.status = 'completed'
session.duration_seconds = time.time() - session.start_time
logger.info(f" REAL training completed: {training_id} in {session.duration_seconds:.2f}s")
logger.info(f" Final loss: {session.final_loss}")
logger.info(f" Accuracy: {session.accuracy}")
except Exception as e:
logger.error(f"REAL training failed: {e}", exc_info=True)
session.status = 'failed'
session.error = str(e)
session.duration_seconds = time.time() - session.start_time
logger.error(f"Training session {training_id} failed after {session.duration_seconds:.2f}s")
def _get_secondary_symbol(self, primary_symbol: str) -> str:
"""
Determine secondary symbol based on primary symbol
Rules:
- ETH/USDT -> BTC/USDT
- SOL/USDT -> BTC/USDT
- BTC/USDT -> ETH/USDT
Args:
primary_symbol: Primary trading symbol
Returns:
Secondary symbol for correlation analysis
"""
if 'BTC' in primary_symbol:
return 'ETH/USDT'
else:
return 'BTC/USDT'
def _fetch_market_state_for_test_case(self, test_case: Dict) -> Dict:
"""
Fetch market state dynamically for a test case from DuckDB storage
This fetches HISTORICAL data at the specific timestamp from the annotation,
not current/latest data.
Args:
test_case: Test case dictionary with timestamp, symbol, etc.
Returns:
Market state dictionary with OHLCV data for all timeframes
"""
try:
if not self.data_provider:
logger.warning("DataProvider not available, cannot fetch market state")
return {}
symbol = test_case.get('symbol', 'ETH/USDT')
timestamp_str = test_case.get('timestamp')
if not timestamp_str:
logger.warning("No timestamp in test case")
return {}
# Parse timestamp using unified parser
try:
timestamp = parse_timestamp_to_utc(timestamp_str)
except Exception as e:
logger.warning(f"Could not parse timestamp '{timestamp_str}': {e}")
return {}
# Get training config
training_config = test_case.get('training_config', {})
# REQUIRED: All 3 timeframes (1m, 1h, 1d) with 600 candles each
timeframes = training_config.get('timeframes', ['1m', '1h', '1d'])
# REQUIRED: 600 candles per timeframe for transformer model
candles_per_timeframe = training_config.get('candles_per_timeframe', 600) # 600 candles per timeframe
# REQUIRED: 1m, 1h, 1d (all with 600 candles each)
# OPTIONAL: 1s (if available, include with 600 candles)
required_timeframes = ['1m', '1h', '1d']
optional_timeframes = ['1s'] # Include if available
# Ensure required timeframes are in the list
missing_tfs = [tf for tf in required_timeframes if tf not in timeframes]
if missing_tfs:
logger.warning(f" Missing required timeframes: {missing_tfs}, adding them...")
timeframes = list(set(timeframes + required_timeframes))
# Note: 1s is optional, don't add it if not present
# Determine secondary symbol based on primary symbol
# ETH/SOL -> BTC, BTC -> ETH
secondary_symbol = self._get_secondary_symbol(symbol)
logger.info(f" Fetching HISTORICAL market state for {symbol} at {timestamp}")
logger.info(f" Primary symbol: {symbol} - Timeframes: {timeframes}")
logger.info(f" Secondary symbol: {secondary_symbol} - Timeframe: 1m")
logger.info(f" Candles per batch: {candles_per_timeframe}")
# Calculate time range based on candles needed
# Use timeframe-specific windows for better efficiency
from datetime import timedelta
# Calculate time window for each timeframe to get required candles
# Add 20% buffer to account for missing candles
buffer_multiplier = 1.2
time_windows = {
'1s': timedelta(seconds=int(candles_per_timeframe * buffer_multiplier)),
'1m': timedelta(minutes=int(candles_per_timeframe * buffer_multiplier)),
'1h': timedelta(hours=int(candles_per_timeframe * buffer_multiplier)),
'1d': timedelta(days=int(candles_per_timeframe * buffer_multiplier))
}
# For historical queries, we want data BEFORE the timestamp
# Use the largest window to ensure we have enough data for all timeframes
max_window = max(time_windows.values())
start_time = timestamp - max_window
end_time = timestamp
# REQUIRED: All required timeframes must have exactly 600 candles (no tolerance for missing data)
min_required_candles = candles_per_timeframe # Must have full 600 candles
required_timeframes = ['1m', '1h', '1d'] # All 3 timeframes are mandatory
optional_timeframes = ['1s'] # Include if available
# Fetch data for primary symbol (all timeframes) and secondary symbol (1m only)
market_state = {
'symbol': symbol,
'timestamp': timestamp_str,
'timeframes': {},
'secondary_symbol': secondary_symbol,
'secondary_timeframes': {}
}
# Try to get data from DuckDB storage first (historical data)
duckdb_storage = None
if hasattr(self.data_provider, 'duckdb_storage'):
duckdb_storage = self.data_provider.duckdb_storage
# Fetch primary symbol data (all timeframes)
# REQUIRED: Must fetch all 3 timeframes (1m, 1h, 1d) with 600 candles each
logger.info(f" Fetching primary symbol data: {symbol}")
logger.info(f" REQUIRED timeframes: {required_timeframes} (each with {candles_per_timeframe} candles)")
fetched_timeframes = {} # Track which timeframes we successfully fetched
for timeframe in timeframes:
# Fetch required timeframes (1m, 1h, 1d) and optional 1s if present
if timeframe not in required_timeframes and timeframe not in optional_timeframes:
continue
df = None
limit = candles_per_timeframe
# Use timeframe-specific window for better efficiency
tf_window = time_windows.get(timeframe, max_window)
tf_start_time = timestamp - tf_window
tf_end_time = timestamp
# Try DuckDB storage first (has historical data)
# Use 'before' direction to get data BEFORE the timestamp
if duckdb_storage:
try:
df = duckdb_storage.get_ohlcv_data(
symbol=symbol,
timeframe=timeframe,
start_time=tf_start_time,
end_time=tf_end_time,
limit=limit,
direction='before' # Get data BEFORE timestamp for historical training
)
if df is not None and not df.empty:
logger.debug(f" {timeframe}: {len(df)} candles from DuckDB (historical, before {timestamp})")
except Exception as e:
logger.debug(f" {timeframe}: DuckDB query failed: {e}")
# If DuckDB doesn't have enough data, try API with proper time range
if df is None or df.empty or len(df) < min_required_candles:
try:
# Try to fetch from API with historical time range
logger.info(f" {timeframe}: DuckDB insufficient ({len(df) if df is not None else 0} candles), fetching from API for timestamp {timestamp}...")
# Fetch historical data from API for the specific time range
api_df = self._fetch_historical_from_api(
symbol=symbol,
timeframe=timeframe,
start_time=tf_start_time,
end_time=tf_end_time,
limit=limit
)
if api_df is not None and not api_df.empty:
# Filter to data before timestamp (historical training needs data BEFORE the event)
try:
api_df = api_df[api_df.index <= tf_end_time]
# Take the most recent candles up to limit
api_df = api_df.tail(limit)
if len(api_df) >= min_required_candles:
df = api_df
logger.info(f" {timeframe}: {len(df)} candles from API (historical range: {tf_start_time} to {tf_end_time})")
# Store in DuckDB for future use
if duckdb_storage:
try:
duckdb_storage.store_ohlcv_data(symbol, timeframe, df)
logger.debug(f" {timeframe}: Stored {len(df)} candles in DuckDB for future use")
except Exception as e:
logger.debug(f" {timeframe}: Could not store in DuckDB: {e}")
else:
logger.warning(f" {timeframe}: API returned only {len(api_df)} candles after filtering (need {min_required_candles})")
except Exception as e:
logger.debug(f" {timeframe}: Could not filter API data: {e}")
# Use as-is if filtering fails
if len(api_df) >= min_required_candles:
df = api_df
logger.info(f" {timeframe}: {len(df)} candles from API (unfiltered)")
else:
logger.warning(f" {timeframe}: API fetch returned no data")
except Exception as e:
logger.warning(f" {timeframe}: API fetch failed: {e}")
import traceback
logger.debug(traceback.format_exc())
# Fallback to replay method
if df is None or df.empty or len(df) < min_required_candles:
try:
replay_data = self.data_provider.get_historical_data_replay(
symbol=symbol,
start_time=tf_start_time,
end_time=tf_end_time,
timeframes=[timeframe]
)
replay_df = replay_data.get(timeframe)
if replay_df is not None and not replay_df.empty and len(replay_df) >= min_required_candles:
df = replay_df
logger.info(f" {timeframe}: {len(df)} candles from replay")
except Exception as e:
logger.debug(f" {timeframe}: Replay failed: {e}")
# Validate data quality before storing
if df is not None and not df.empty:
# Check minimum candle count
if len(df) < min_required_candles:
logger.warning(f" {symbol} {timeframe}: Only {len(df)} candles (need {min_required_candles}), skipping")
continue
# Validate data quality - check for NaN values
if df.isnull().any().any():
logger.warning(f" {symbol} {timeframe}: Contains NaN values, cleaning...")
df = df.dropna()
if len(df) < min_required_candles:
logger.warning(f" {symbol} {timeframe}: After cleaning, only {len(df)} candles, skipping")
continue
# Ensure we have required columns
required_cols = ['open', 'high', 'low', 'close', 'volume']
if not all(col in df.columns for col in required_cols):
logger.warning(f" {symbol} {timeframe}: Missing required columns, skipping")
continue
# Convert to dict format
market_state['timeframes'][timeframe] = {
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': df['open'].tolist(),
'high': df['high'].tolist(),
'low': df['low'].tolist(),
'close': df['close'].tolist(),
'volume': df['volume'].tolist()
}
fetched_timeframes[timeframe] = len(df)
logger.info(f" {symbol} {timeframe}: {len(df)} candles [OK]")
else:
logger.warning(f" {symbol} {timeframe}: No quality data available (need {min_required_candles} candles)")
# CRITICAL: Validate we have all required timeframes (1s is optional, don't check it)
missing_required = [tf for tf in required_timeframes if tf not in fetched_timeframes]
if missing_required:
logger.error(f" FAILED: Missing required timeframes: {missing_required}")
logger.error(f" Fetched: {list(fetched_timeframes.keys())}")
logger.error(f" Cannot proceed without all required timeframes")
return {} # Return empty dict to signal failure
# Log optional timeframes status
if '1s' not in fetched_timeframes:
logger.debug(f" Optional timeframe 1s not available (this is OK - 1s historical data is often unavailable)")
# Fetch secondary symbol data (1m timeframe only)
logger.info(f" Fetching secondary symbol data: {secondary_symbol} (1m)")
secondary_df = None
# Use 1m-specific window
tf_window = time_windows.get('1m', max_window)
tf_start_time = timestamp - tf_window
tf_end_time = timestamp
# Try DuckDB first with 'before' direction
if duckdb_storage:
try:
secondary_df = duckdb_storage.get_ohlcv_data(
symbol=secondary_symbol,
timeframe='1m',
start_time=tf_start_time,
end_time=tf_end_time,
limit=candles_per_timeframe,
direction='before' # Get data BEFORE timestamp
)
if secondary_df is not None and not secondary_df.empty:
logger.debug(f" {secondary_symbol} 1m: {len(secondary_df)} candles from DuckDB")
except Exception as e:
logger.debug(f" {secondary_symbol} 1m: DuckDB query failed: {e}")
# If DuckDB doesn't have enough, try API with historical time range
if secondary_df is None or secondary_df.empty or len(secondary_df) < min_required_candles:
try:
logger.info(f" {secondary_symbol} 1m: DuckDB insufficient ({len(secondary_df) if secondary_df is not None else 0} candles), fetching from API for timestamp {timestamp}...")
# Fetch historical data from API for the specific time range
api_df = self._fetch_historical_from_api(
symbol=secondary_symbol,
timeframe='1m',
start_time=tf_start_time,
end_time=tf_end_time,
limit=candles_per_timeframe
)
if api_df is not None and not api_df.empty:
# Filter to data before timestamp
try:
api_df = api_df[api_df.index <= tf_end_time].tail(candles_per_timeframe)
if len(api_df) >= min_required_candles:
secondary_df = api_df
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles from API (historical range)")
# Store in DuckDB for future use
if duckdb_storage:
try:
duckdb_storage.store_ohlcv_data(secondary_symbol, '1m', secondary_df)
logger.debug(f" {secondary_symbol} 1m: Stored in DuckDB for future use")
except Exception as e:
logger.debug(f" {secondary_symbol} 1m: Could not store in DuckDB: {e}")
except Exception as e:
logger.debug(f" {secondary_symbol} 1m: Could not filter API data: {e}")
if len(api_df) >= min_required_candles:
secondary_df = api_df
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles from API (unfiltered)")
else:
logger.warning(f" {secondary_symbol} 1m: API fetch returned no data")
except Exception as e:
logger.warning(f" {secondary_symbol} 1m: API fetch failed: {e}")
import traceback
logger.debug(traceback.format_exc())
# Fallback to replay
if secondary_df is None or secondary_df.empty or len(secondary_df) < min_required_candles:
try:
replay_data = self.data_provider.get_historical_data_replay(
symbol=secondary_symbol,
start_time=tf_start_time,
end_time=tf_end_time,
timeframes=['1m']
)
replay_df = replay_data.get('1m')
if replay_df is not None and not replay_df.empty and len(replay_df) >= min_required_candles:
secondary_df = replay_df
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles from replay")
except Exception as e:
logger.debug(f" {secondary_symbol} 1m: Replay failed: {e}")
# Validate and store secondary symbol data
if secondary_df is not None and not secondary_df.empty:
if len(secondary_df) < min_required_candles:
logger.warning(f" {secondary_symbol} 1m: Only {len(secondary_df)} candles (need {min_required_candles}), skipping")
elif secondary_df.isnull().any().any():
logger.warning(f" {secondary_symbol} 1m: Contains NaN values, skipping")
elif not all(col in secondary_df.columns for col in ['open', 'high', 'low', 'close', 'volume']):
logger.warning(f" {secondary_symbol} 1m: Missing required columns, skipping")
else:
# Store in the correct structure: secondary_timeframes[symbol][timeframe]
if secondary_symbol not in market_state['secondary_timeframes']:
market_state['secondary_timeframes'][secondary_symbol] = {}
market_state['secondary_timeframes'][secondary_symbol]['1m'] = {
'timestamps': secondary_df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': secondary_df['open'].tolist(),
'high': secondary_df['high'].tolist(),
'low': secondary_df['low'].tolist(),
'close': secondary_df['close'].tolist(),
'volume': secondary_df['volume'].tolist()
}
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles [OK]")
else:
logger.warning(f" {secondary_symbol} 1m: No quality data available (need {min_required_candles} candles)")
# Verify we have data
if market_state['timeframes']:
total_primary = sum(len(tf_data.get('timestamps', [])) for tf_data in market_state['timeframes'].values())
total_secondary = sum(len(tf_data.get('timestamps', [])) for tf_data in market_state['secondary_timeframes'].values())
logger.info(f" [OK] Fetched {len(market_state['timeframes'])} primary timeframes ({total_primary} total candles)")
logger.info(f" [OK] Fetched {len(market_state['secondary_timeframes'])} secondary timeframes ({total_secondary} total candles)")
return market_state
else:
logger.warning(f" No market data fetched for any timeframe")
return {}
except Exception as e:
logger.error(f"Error fetching market state: {e}")
import traceback
logger.error(traceback.format_exc())
return {}
def _prepare_training_data(self, test_cases: List[Dict],
negative_samples_window: int = 15,
training_repetitions: int = 1) -> List[Dict]:
"""
Prepare training data from test cases with negative sampling
Args:
test_cases: List of test cases from annotations
negative_samples_window: Number of candles before/after signal where model should NOT trade
training_repetitions: Number of times to repeat training on each sample
Returns:
List of training samples with positive (trade) and negative (no-trade) examples
"""
training_data = []
logger.info(f"Preparing training data from {len(test_cases)} test cases...")
logger.info(f" Negative sampling: +/-{negative_samples_window} candles around signals")
logger.info(f" Each sample trained once (no artificial repetitions)")
for i, test_case in enumerate(test_cases):
try:
# Extract expected outcome
expected_outcome = test_case.get('expected_outcome', {})
if not expected_outcome:
logger.warning(f" Skipping test case {test_case.get('test_case_id')}: missing expected_outcome")
continue
# Check if market_state is provided, if not, fetch it dynamically
market_state = test_case.get('market_state', {})
if not market_state:
logger.info(f" Fetching market state dynamically for test case {i+1}...")
market_state = self._fetch_market_state_for_test_case(test_case)
if not market_state:
logger.warning(f" Skipping test case {test_case.get('test_case_id')}: could not fetch market state")
continue
logger.debug(f" Test case {i+1}: has_market_state={bool(market_state)}, has_expected_outcome={bool(expected_outcome)}")
# Create ENTRY sample (where model SHOULD enter trade)
entry_sample = {
'market_state': market_state,
'action': test_case.get('action'),
'direction': expected_outcome.get('direction'),
'profit_loss_pct': expected_outcome.get('profit_loss_pct'),
'entry_price': expected_outcome.get('entry_price'),
'exit_price': expected_outcome.get('exit_price'),
'timestamp': test_case.get('timestamp'),
'label': 'ENTRY' # Entry signal
}
training_data.append(entry_sample)
logger.info(f" Test case {i+1}: ENTRY sample - {entry_sample['direction']} @ {entry_sample['entry_price']}")
# Create HOLD samples (every 13 candles while position is open)
# This teaches the model to maintain the position until exit
hold_samples = self._create_hold_samples(
test_case=test_case,
market_state=market_state,
sample_interval=13 # One sample every 13 candles
)
training_data.extend(hold_samples)
if hold_samples:
logger.info(f" Test case {i+1}: Added {len(hold_samples)} HOLD samples (every 13 candles during position)")
# Create EXIT sample (where model SHOULD exit trade)
# Exit info is in expected_outcome, not annotation_metadata
exit_price = expected_outcome.get('exit_price')
if exit_price:
# For now, use same market state (TODO: fetch market state at exit time)
# The model will learn to exit based on profit_loss_pct and position state
exit_sample = {
'market_state': market_state, # Using entry market state as proxy
'action': 'CLOSE',
'direction': expected_outcome.get('direction'),
'profit_loss_pct': expected_outcome.get('profit_loss_pct'),
'entry_price': expected_outcome.get('entry_price'),
'exit_price': exit_price,
'timestamp': test_case.get('timestamp'), # Entry timestamp (exit time not stored separately)
'label': 'EXIT', # Exit signal
'in_position': True # Model is in position when deciding to exit
}
training_data.append(exit_sample)
logger.info(f" Test case {i+1}: EXIT sample @ {exit_price} ({expected_outcome.get('profit_loss_pct', 0):.2f}%)")
# Create NEGATIVE samples (where model should NOT trade)
# 5 candles before entry + 5 candles after exit = 10 NO_TRADE samples per annotation
# This teaches the model to recognize when NOT to enter
negative_samples = self._create_negative_samples(
market_state=market_state,
entry_timestamp=test_case.get('timestamp'),
exit_timestamp=None, # Will be calculated from holding period
holding_period_seconds=expected_outcome.get('holding_period_seconds', 0),
samples_before=5, # 5 candles before entry
samples_after=5 # 5 candles after exit
)
training_data.extend(negative_samples)
if negative_samples:
logger.info(f" Test case {i+1}: Added {len(negative_samples)} NO_TRADE samples (5 before entry + 5 after exit)")
except Exception as e:
logger.error(f" Error preparing test case {i+1}: {e}")
total_entry = sum(1 for s in training_data if s.get('label') == 'ENTRY')
total_hold = sum(1 for s in training_data if s.get('label') == 'HOLD')
total_exit = sum(1 for s in training_data if s.get('label') == 'EXIT')
total_no_trade = sum(1 for s in training_data if s.get('label') == 'NO_TRADE')
logger.info(f" Prepared {len(training_data)} training samples from {len(test_cases)} test cases")
logger.info(f" ENTRY samples: {total_entry}")
logger.info(f" HOLD samples: {total_hold}")
logger.info(f" EXIT samples: {total_exit}")
logger.info(f" NO_TRADE samples: {total_no_trade}")
if total_entry > 0:
logger.info(f" Ratio: 1:{total_no_trade/total_entry:.1f} (entry:no_trade)")
if len(training_data) < len(test_cases):
logger.warning(f" Skipped {len(test_cases) - len(training_data)} test cases due to missing data")
return training_data
def _create_hold_samples(self, test_case: Dict, market_state: Dict, sample_interval: int = 13) -> List[Dict]:
"""
Create HOLD training samples at intervals while position is open
This teaches the model to:
1. Maintain the position (not exit early)
2. Recognize the trade is still valid
3. Wait for the optimal exit point
Args:
test_case: Test case with entry/exit info
market_state: Market state data
sample_interval: Create one sample every N candles (default: 13)
Returns:
List of HOLD training samples
"""
hold_samples = []
try:
from datetime import datetime, timedelta
# Get entry and exit timestamps
entry_timestamp = test_case.get('timestamp')
expected_outcome = test_case.get('expected_outcome', {})
# Calculate exit timestamp from holding period
holding_period_seconds = expected_outcome.get('holding_period_seconds', 0)
if holding_period_seconds == 0:
logger.debug(" No holding period, skipping HOLD samples")
return hold_samples
# Parse entry timestamp using unified parser
try:
entry_time = parse_timestamp_to_utc(entry_timestamp)
except Exception as e:
logger.warning(f"Could not parse entry timestamp '{entry_timestamp}': {e}")
return hold_samples
exit_time = entry_time + timedelta(seconds=holding_period_seconds)
# Get 1m timeframe timestamps
timeframes = market_state.get('timeframes', {})
if '1m' not in timeframes:
return hold_samples
timestamps = timeframes['1m'].get('timestamps', [])
# Find all candles between entry and exit, sample every N candles
candles_in_position = []
for idx, ts_str in enumerate(timestamps):
# Parse timestamp using unified parser
try:
ts = parse_timestamp_to_utc(ts_str)
except Exception as e:
logger.debug(f"Could not parse timestamp '{ts_str}': {e}")
continue
# If this candle is between entry and exit (exclusive)
if entry_time < ts < exit_time:
candles_in_position.append((idx, ts_str, ts))
# Sample every Nth candle (e.g., every 13 candles)
for i in range(0, len(candles_in_position), sample_interval):
idx, ts_str, ts = candles_in_position[i]
# Create market state snapshot at this candle
hold_market_state = self._create_market_state_snapshot(market_state, idx)
# Validate that the snapshot has enough data
has_enough_data = True
for tf, tf_data in hold_market_state.get('timeframes', {}).items():
if len(tf_data.get('close', [])) < 50: # Minimum required
has_enough_data = False
break
if not has_enough_data:
logger.debug(f" Skipping HOLD sample at idx {idx} - insufficient data")
continue
# Calculate current unrealized PnL at this point
entry_price = expected_outcome.get('entry_price', 0)
current_price = timeframes['1m']['close'][idx] if idx < len(timeframes['1m']['close']) else entry_price
direction = expected_outcome.get('direction')
if entry_price > 0 and current_price > 0:
if direction == 'LONG':
current_pnl = (current_price - entry_price) / entry_price * 100
else: # SHORT
current_pnl = (entry_price - current_price) / entry_price * 100
else:
current_pnl = 0.0
hold_sample = {
'market_state': hold_market_state,
'action': 'HOLD',
'direction': direction,
'profit_loss_pct': current_pnl, # Current unrealized PnL
'entry_price': entry_price,
'exit_price': expected_outcome.get('exit_price'),
'timestamp': ts_str,
'label': 'HOLD', # Hold position
'in_position': True # Flag indicating we're in a position
}
hold_samples.append(hold_sample)
logger.debug(f" Created {len(hold_samples)} HOLD samples (every {sample_interval} candles, {len(candles_in_position)} total candles in position)")
except Exception as e:
logger.error(f"Error creating HOLD samples: {e}")
import traceback
logger.error(traceback.format_exc())
return hold_samples
def _create_negative_samples(self, market_state: Dict, entry_timestamp: str,
exit_timestamp: Optional[str], holding_period_seconds: int,
samples_before: int = 5, samples_after: int = 5) -> List[Dict]:
"""
Create negative training samples from candles before entry and after exit
These samples teach the model when NOT to trade - crucial for reducing false signals!
Args:
market_state: Market state with OHLCV data
entry_timestamp: Timestamp of entry signal
exit_timestamp: Timestamp of exit signal (optional, calculated from holding period)
holding_period_seconds: Duration of the trade in seconds
samples_before: Number of candles before entry (default: 5)
samples_after: Number of candles after exit (default: 5)
Returns:
List of negative training samples (NO_TRADE)
"""
negative_samples = []
try:
from datetime import timedelta
# Get timestamps from market state (use 1m timeframe as reference)
timeframes = market_state.get('timeframes', {})
if '1m' not in timeframes:
logger.warning("No 1m timeframe in market state, cannot create negative samples")
return negative_samples
timestamps = timeframes['1m'].get('timestamps', [])
if not timestamps:
return negative_samples
# Parse entry timestamp
try:
entry_time = parse_timestamp_to_utc(entry_timestamp)
except Exception as e:
logger.warning(f"Could not parse entry timestamp '{entry_timestamp}': {e}")
return negative_samples
# Calculate exit time
exit_time = entry_time + timedelta(seconds=holding_period_seconds)
# Find entry and exit indices
entry_index = None
exit_index = None
for idx, ts_str in enumerate(timestamps):
try:
ts = parse_timestamp_to_utc(ts_str)
# Match entry within 1 minute
if entry_index is None and abs((ts - entry_time).total_seconds()) < 60:
entry_index = idx
# Match exit within 1 minute
if exit_index is None and abs((ts - exit_time).total_seconds()) < 60:
exit_index = idx
if entry_index is not None and exit_index is not None:
break
except Exception as e:
continue
if entry_index is None:
logger.debug(f"Could not find entry timestamp in market data - using first candle as entry")
entry_index = 0 # Use first candle if exact match not found
# If exit not found, estimate it
if exit_index is None:
# Estimate: 1 minute per candle
candles_in_trade = int(holding_period_seconds // 60) # Ensure integer
exit_index = min(entry_index + candles_in_trade, len(timestamps) - 1)
logger.debug(f" Estimated exit index: {exit_index} ({candles_in_trade} candles)")
# Create NO_TRADE samples: 5 before entry + 5 after exit
# CRITICAL: Only create samples that have enough historical data (600 candles)
required_candles = 600
negative_indices = []
# 5 candles BEFORE entry (but only if we have enough history)
for offset in range(1, samples_before + 1):
idx = entry_index - offset
# Check if we have enough candles before this index
if idx >= required_candles - 1 and idx < len(timestamps):
negative_indices.append(('before_entry', idx))
# 5 candles AFTER exit (should have enough history since exit is after entry)
for offset in range(1, samples_after + 1):
idx = exit_index + offset
# Check if we have enough candles before this index
if idx >= required_candles - 1 and idx < len(timestamps):
negative_indices.append(('after_exit', idx))
# Create negative samples
for location, idx in negative_indices:
# Create a market state snapshot at this timestamp
negative_market_state = self._create_market_state_snapshot(market_state, idx)
# Validate that the snapshot has enough data
has_enough_data = True
for tf, tf_data in negative_market_state.get('timeframes', {}).items():
if len(tf_data.get('close', [])) < 50: # Minimum required
has_enough_data = False
break
if not has_enough_data:
logger.debug(f" Skipping negative sample at idx {idx} - insufficient data")
continue
negative_sample = {
'market_state': negative_market_state,
'action': 'HOLD', # No action
'direction': 'NONE',
'profit_loss_pct': 0.0,
'entry_price': None,
'exit_price': None,
'timestamp': timestamps[idx],
'label': 'NO_TRADE', # Negative label
'in_position': False # Not in position
}
negative_samples.append(negative_sample)
logger.debug(f" Created {len(negative_samples)} NO_TRADE samples ({samples_before} before entry + {samples_after} after exit)")
except Exception as e:
logger.error(f"Error creating negative samples: {e}")
return negative_samples
def _create_market_state_snapshot(self, market_state: Dict, candle_index: int) -> Dict:
"""
Create a market state snapshot at a specific candle index
This creates a "view" of the market as it was at that specific candle,
which is used for negative sampling. CRITICAL: Ensures 600 candles are available
by taking the last 600 candles BEFORE the target point.
"""
snapshot = {
'symbol': market_state.get('symbol'),
'timestamp': None, # Will be set from the candle
'timeframes': {}
}
# CRITICAL: Training requires 600 candles BEFORE the target point
required_candles = 600
# For each timeframe, create a snapshot with 600 candles BEFORE the candle_index
for tf, tf_data in market_state.get('timeframes', {}).items():
timestamps = tf_data.get('timestamps', [])
if candle_index < len(timestamps):
# Take the last 600 candles BEFORE and INCLUDING this candle
# If we don't have 600 candles, we'll pad later in extraction
start_idx = max(0, candle_index + 1 - required_candles)
end_idx = candle_index + 1
snapshot['timeframes'][tf] = {
'timestamps': timestamps[start_idx:end_idx],
'open': tf_data.get('open', [])[start_idx:end_idx],
'high': tf_data.get('high', [])[start_idx:end_idx],
'low': tf_data.get('low', [])[start_idx:end_idx],
'close': tf_data.get('close', [])[start_idx:end_idx],
'volume': tf_data.get('volume', [])[start_idx:end_idx]
}
if tf == '1m':
snapshot['timestamp'] = timestamps[candle_index]
return snapshot
def _convert_to_cnn_input(self, data: Dict) -> tuple:
"""Convert annotation training data to CNN model input format (x, y tensors)"""
import torch
import numpy as np
try:
market_state = data.get('market_state', {})
timeframes = market_state.get('timeframes', {})
# Get 1m timeframe data (primary for CNN)
if '1m' not in timeframes:
logger.warning("No 1m timeframe data available for CNN training")
return None, None
tf_data = timeframes['1m']
closes = np.array(tf_data.get('close', []), dtype=np.float32)
if len(closes) == 0:
logger.warning("No close price data available")
return None, None
# CNN expects input shape: [batch, seq_len, features]
# Use last 60 candles (or pad/truncate to 60)
seq_len = 60
if len(closes) >= seq_len:
closes = closes[-seq_len:]
else:
# Pad with last value
last_close = closes[-1] if len(closes) > 0 else 0.0
closes = np.pad(closes, (seq_len - len(closes), 0), mode='constant', constant_values=last_close)
# Create feature tensor: [1, 60, 1] (batch, seq_len, features)
# For now, use only close prices. In full implementation, add OHLCV
x = torch.tensor(closes, dtype=torch.float32).unsqueeze(0).unsqueeze(-1) # [1, 60, 1]
# Convert action to target tensor
action = data.get('action', 'HOLD')
direction = data.get('direction', 'HOLD')
# Map to class index: 0=HOLD, 1=BUY, 2=SELL
if direction == 'LONG' or action == 'BUY':
y = torch.tensor([1], dtype=torch.long)
elif direction == 'SHORT' or action == 'SELL':
y = torch.tensor([2], dtype=torch.long)
else:
y = torch.tensor([0], dtype=torch.long)
return x, y
except Exception as e:
logger.error(f"Error converting to CNN input: {e}")
import traceback
logger.error(traceback.format_exc())
return None, None
def _train_cnn_real(self, session: TrainingSession, training_data: List[Dict]):
"""Train CNN model with REAL training loop"""
if not hasattr(self.orchestrator, 'cnn_model') or not self.orchestrator.cnn_model:
raise Exception("CNN model not available in orchestrator")
model = self.orchestrator.cnn_model
# Check if model has trainer attribute (EnhancedCNN)
trainer = None
if hasattr(model, 'trainer'):
trainer = model.trainer
# Use the model's actual training method
if hasattr(model, 'train_on_annotations'):
# If model has annotation-specific training
for epoch in range(session.total_epochs):
loss = model.train_on_annotations(training_data)
session.current_epoch = epoch + 1
session.current_loss = loss if loss else 0.0
logger.info(f"CNN Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}")
elif trainer and hasattr(trainer, 'train_step'):
# Use trainer's train_step method (EnhancedCNN)
logger.info(f"Training CNN using trainer.train_step() with {len(training_data)} samples")
# Convert all samples first
converted_samples = []
for data in training_data:
x, y = self._convert_to_cnn_input(data)
if x is not None and y is not None:
converted_samples.append((x, y))
logger.info(f" Converted {len(converted_samples)} valid samples")
# Group into mini-batches for efficient training
cnn_batch_size = 5 # Small batches for better gradient updates
for epoch in range(session.total_epochs):
epoch_loss = 0.0
num_batches = 0
# Process in mini-batches
for i in range(0, len(converted_samples), cnn_batch_size):
batch_samples = converted_samples[i:i + cnn_batch_size]
# Combine samples into batch
batch_x = torch.cat([x for x, y in batch_samples], dim=0)
batch_y = torch.cat([y for x, y in batch_samples], dim=0)
try:
# Call trainer's train_step with batch
loss_dict = trainer.train_step(batch_x, batch_y)
# Extract loss from dict if it's a dict, otherwise use directly
if isinstance(loss_dict, dict):
loss = loss_dict.get('total_loss', loss_dict.get('main_loss', 0.0))
else:
loss = float(loss_dict) if loss_dict else 0.0
epoch_loss += loss
num_batches += 1
except Exception as e:
logger.error(f"Error in CNN training step: {e}")
import traceback
logger.error(traceback.format_exc())
continue
if num_batches > 0:
session.current_epoch = epoch + 1
session.current_loss = epoch_loss / num_batches
logger.info(f"CNN Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}, Batches: {num_batches}")
else:
logger.warning(f"CNN Epoch {epoch + 1}/{session.total_epochs}: No valid batches processed")
session.current_epoch = epoch + 1
session.current_loss = 0.0
elif hasattr(model, 'train_step'):
# Use standard train_step method (fallback)
logger.warning("Using model.train_step() directly - may not work correctly")
for epoch in range(session.total_epochs):
epoch_loss = 0.0
valid_samples = 0
for data in training_data:
x, y = self._convert_to_cnn_input(data)
if x is None or y is None:
continue
try:
loss = model.train_step(x, y)
epoch_loss += loss if loss else 0.0
valid_samples += 1
except Exception as e:
logger.error(f"Error in CNN training step: {e}")
continue
if valid_samples > 0:
session.current_epoch = epoch + 1
session.current_loss = epoch_loss / valid_samples
logger.info(f"CNN Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}")
else:
raise Exception("CNN model does not have train_on_annotations, trainer.train_step, or train_step method")
session.final_loss = session.current_loss
# Accuracy calculated from actual training metrics, not synthetic
session.accuracy = None # Will be set by training loop if available
def _train_dqn_real(self, session: TrainingSession, training_data: List[Dict]):
"""Train DQN model with REAL training loop"""
if not hasattr(self.orchestrator, 'rl_agent') or not self.orchestrator.rl_agent:
raise Exception("DQN model not available in orchestrator")
agent = self.orchestrator.rl_agent
# Use EnhancedRLTrainingAdapter if available for better reward calculation
if self.enhanced_rl_adapter_available and hasattr(self.orchestrator, 'enhanced_rl_adapter'):
logger.info("Using EnhancedRLTrainingAdapter for DQN training")
# The enhanced adapter will handle training through its async loop
# For now, we'll use the traditional approach but with better state building
# Add experiences to replay buffer
for data in training_data:
# Calculate reward from profit/loss
reward = data['profit_loss_pct'] / 100.0 if data.get('profit_loss_pct') else 0.0
# Add to memory if agent has remember method
if hasattr(agent, 'remember'):
# Try to build proper state representation
state = self._build_state_from_data(data, agent)
action = 1 if data.get('direction') == 'LONG' else 0
agent.remember(state, action, reward, state, True)
# Train with replay
if hasattr(agent, 'replay'):
for epoch in range(session.total_epochs):
loss = agent.replay()
session.current_epoch = epoch + 1
session.current_loss = loss if loss else 0.0
logger.info(f"DQN Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}")
else:
raise Exception("DQN agent does not have replay method")
session.final_loss = session.current_loss
# Accuracy calculated from actual training metrics, not synthetic
session.accuracy = None # Will be set by training loop if available
def _build_state_from_data(self, data: Dict, agent: Any) -> List[float]:
"""Build proper state representation from training data"""
try:
# Try to extract market state features
market_state = data.get('market_state', {})
# Get state size from agent
state_size = agent.state_size if hasattr(agent, 'state_size') else 100
# Build feature vector from market state
features = []
# Add price-based features if available
if 'entry_price' in data:
features.append(float(data['entry_price']))
if 'exit_price' in data:
features.append(float(data['exit_price']))
if 'profit_loss_pct' in data:
features.append(float(data['profit_loss_pct']))
# Pad or truncate to match state size
if len(features) < state_size:
features.extend([0.0] * (state_size - len(features)))
else:
features = features[:state_size]
return features
except Exception as e:
logger.error(f"Error building state from data: {e}")
# Return zero state as fallback
state_size = agent.state_size if hasattr(agent, 'state_size') else 100
return [0.0] * state_size
def _fetch_historical_from_api(self, symbol: str, timeframe: str, start_time: datetime, end_time: datetime, limit: int) -> Optional[pd.DataFrame]:
"""
Fetch historical OHLCV data from exchange APIs for a specific time range
Args:
symbol: Trading symbol (e.g., 'ETH/USDT')
timeframe: Timeframe (e.g., '1m', '1h', '1d')
start_time: Start timestamp (UTC)
end_time: End timestamp (UTC)
limit: Maximum number of candles to fetch
Returns:
DataFrame with OHLCV data or None if fetch fails
"""
import pandas as pd
import requests
import time
from datetime import datetime, timezone
try:
# Handle 1s timeframe specially (not directly supported by most APIs)
if timeframe == '1s':
logger.debug(f"1s timeframe requested - will try to generate from ticks or skip")
# For 1s, we might need to generate from tick data or skip
# This is handled by the data provider's _generate_1s_candles_from_ticks
# For now, return None and let the caller handle it
return None
# Try Binance first (supports historical queries with startTime/endTime)
try:
binance_symbol = symbol.replace('/', '').upper()
# Convert timeframe for Binance
timeframe_map = {
'1m': '1m', '5m': '5m', '15m': '15m', '30m': '30m',
'1h': '1h', '4h': '4h', '1d': '1d'
}
binance_timeframe = timeframe_map.get(timeframe)
if not binance_timeframe:
logger.warning(f"Binance doesn't support timeframe {timeframe}")
return None
# Binance API klines endpoint with startTime and endTime
url = "https://api.binance.com/api/v3/klines"
# Convert timestamps to milliseconds
start_ms = int(start_time.timestamp() * 1000)
end_ms = int(end_time.timestamp() * 1000)
# Binance max is 1000 per request, so paginate if needed
all_data = []
current_start = start_ms
max_per_request = 1000
max_requests = 10 # Safety limit
request_count = 0
while current_start < end_ms and request_count < max_requests:
params = {
'symbol': binance_symbol,
'interval': binance_timeframe,
'startTime': current_start,
'endTime': end_ms,
'limit': min(max_per_request, limit - len(all_data))
}
logger.debug(f"Fetching from Binance: {symbol} {timeframe} batch {request_count + 1} (start: {current_start}, end: {end_ms})")
response = requests.get(url, params=params, timeout=10)
if response.status_code == 200:
data = response.json()
if data:
all_data.extend(data)
# Update current_start to the last candle's close_time + 1ms
if len(data) > 0:
last_close_time = data[-1][6] # close_time is at index 6
current_start = last_close_time + 1
else:
break
# If we got less than requested, we've reached the end
if len(data) < max_per_request:
break
# If we have enough data, stop
if len(all_data) >= limit:
break
else:
break
else:
logger.debug(f"Binance API returned {response.status_code} for {symbol} {timeframe}")
break
request_count += 1
# Small delay to avoid rate limiting
time.sleep(0.1)
if all_data:
# Convert to DataFrame
df = pd.DataFrame(all_data, columns=[
'timestamp', 'open', 'high', 'low', 'close', 'volume',
'close_time', 'quote_volume', 'trades', 'taker_buy_base',
'taker_buy_quote', 'ignore'
])
# Process columns
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
for col in ['open', 'high', 'low', 'close', 'volume']:
df[col] = df[col].astype(float)
# Keep only OHLCV columns and set timestamp as index
df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
df = df.set_index('timestamp')
df = df.sort_index()
# Remove duplicates and take last 'limit' candles
df = df[~df.index.duplicated(keep='last')]
df = df.tail(limit)
logger.info(f"Binance API: Fetched {len(df)} candles for {symbol} {timeframe} (historical, {request_count} requests)")
return df
else:
logger.warning(f"Binance API returned no data for {symbol} {timeframe}")
except Exception as e:
logger.debug(f"Binance fetch failed: {e}")
# Fallback to MEXC
try:
mexc_symbol = symbol.replace('/', '').upper()
timeframe_map = {
'1m': '1m', '5m': '5m', '15m': '15m', '30m': '30m',
'1h': '1h', '4h': '4h', '1d': '1d'
}
mexc_timeframe = timeframe_map.get(timeframe)
if not mexc_timeframe:
logger.warning(f"MEXC doesn't support timeframe {timeframe}")
return None
# MEXC API klines endpoint (may not support startTime/endTime, so fetch latest and filter)
url = "https://api.mexc.com/api/v3/klines"
params = {
'symbol': mexc_symbol,
'interval': mexc_timeframe,
'limit': min(limit * 2, 1000) # Fetch more to account for filtering
}
logger.debug(f"Fetching from MEXC: {symbol} {timeframe}")
response = requests.get(url, params=params, timeout=10)
if response.status_code == 200:
data = response.json()
if data:
df = pd.DataFrame(data, columns=[
'timestamp', 'open', 'high', 'low', 'close', 'volume',
'close_time', 'quote_volume'
])
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
for col in ['open', 'high', 'low', 'close', 'volume']:
df[col] = df[col].astype(float)
df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
df = df.set_index('timestamp')
df = df.sort_index()
# Filter to time range
df = df[(df.index >= start_time) & (df.index <= end_time)]
df = df.tail(limit)
if len(df) > 0:
logger.info(f"MEXC API: Fetched {len(df)} candles for {symbol} {timeframe} (historical, filtered)")
return df
else:
logger.warning(f"MEXC API: No candles in time range for {symbol} {timeframe}")
else:
logger.warning(f"MEXC API returned empty data for {symbol} {timeframe}")
else:
logger.debug(f"MEXC API returned {response.status_code} for {symbol} {timeframe}")
except Exception as e:
logger.debug(f"MEXC fetch failed: {e}")
return None
except Exception as e:
logger.error(f"Error fetching historical data from API: {e}")
import traceback
logger.debug(traceback.format_exc())
return None
def _extract_timeframe_data(self, tf_data: Dict, target_seq_len: int = 600) -> Optional[torch.Tensor]:
"""
Extract and normalize OHLCV data from a single timeframe
Args:
tf_data: Timeframe data dictionary with 'open', 'high', 'low', 'close', 'volume'
target_seq_len: Target sequence length (default 600)
Returns:
Tensor of shape [1, seq_len, 5] or None if no data
"""
import torch
import numpy as np
try:
# Extract OHLCV arrays
opens = np.array(tf_data.get('open', []), dtype=np.float32)
highs = np.array(tf_data.get('high', []), dtype=np.float32)
lows = np.array(tf_data.get('low', []), dtype=np.float32)
closes = np.array(tf_data.get('close', []), dtype=np.float32)
volumes = np.array(tf_data.get('volume', []), dtype=np.float32)
if len(closes) == 0:
return None
# ALLOW PADDING: If we have fewer than target_seq_len, pad with the first available value
if len(closes) < target_seq_len:
logger.debug(f"Padding {target_seq_len - len(closes)} candles for timeframe (have {len(closes)}, need {target_seq_len})")
pad_len = target_seq_len - len(closes)
# Pad at the beginning with the first available value (edge padding)
opens = np.pad(opens, (pad_len, 0), mode='edge')
highs = np.pad(highs, (pad_len, 0), mode='edge')
lows = np.pad(lows, (pad_len, 0), mode='edge')
closes = np.pad(closes, (pad_len, 0), mode='edge')
volumes = np.pad(volumes, (pad_len, 0), mode='edge')
else:
# Take last target_seq_len candles if we have more than needed
opens = opens[-target_seq_len:]
highs = highs[-target_seq_len:]
lows = lows[-target_seq_len:]
closes = closes[-target_seq_len:]
volumes = volumes[-target_seq_len:]
# Validate we now have exactly target_seq_len
if len(closes) != target_seq_len:
logger.warning(f"Extraction failed: got {len(closes)} candles after padding, need {target_seq_len}")
return None
# Stack OHLCV [seq_len, 5]
ohlcv = np.stack([opens, highs, lows, closes, volumes], axis=-1)
# Normalize prices to [0, 1] range
price_min = np.min(ohlcv[:, :4]) # Min of OHLC
price_max = np.max(ohlcv[:, :4]) # Max of OHLC
if price_max > price_min:
ohlcv[:, :4] = (ohlcv[:, :4] - price_min) / (price_max - price_min)
# Normalize volume to [0, 1] range
volume_min = np.min(ohlcv[:, 4])
volume_max = np.max(ohlcv[:, 4])
if volume_max > volume_min:
ohlcv[:, 4] = (ohlcv[:, 4] - volume_min) / (volume_max - volume_min)
# Convert to tensor and add batch dimension [1, seq_len, 5]
# STORE normalization parameters for denormalization
# Return tuple: (tensor, normalization_params)
norm_params = {
'price_min': float(price_min),
'price_max': float(price_max),
'volume_min': float(volume_min),
'volume_max': float(volume_max)
}
return torch.tensor(ohlcv, dtype=torch.float32).unsqueeze(0), norm_params
except Exception as e:
logger.error(f"Error extracting timeframe data: {e}")
return None
def _extract_next_candle(self, tf_data: Dict, norm_params: Dict = None) -> Optional[torch.Tensor]:
"""
Extract the NEXT candle OHLCV (after the current sequence) as training target
This extracts the candle that comes immediately after the sequence used for input.
Normalized using the ORIGINAL price range (not the already-normalized sequence data).
Args:
tf_data: Timeframe data dictionary with 'open', 'high', 'low', 'close', 'volume'
norm_params: Normalization parameters dict with 'price_min', 'price_max', 'volume_min', 'volume_max'
Returns:
Tensor of shape [1, 5] representing next candle OHLCV, or None if not available
"""
import torch
import numpy as np
try:
# Extract OHLCV arrays - get the LAST value as the "next" candle
# In annotation context, the "current" sequence is historical, and we have the "next" candle
opens = np.array(tf_data.get('open', []), dtype=np.float32)
highs = np.array(tf_data.get('high', []), dtype=np.float32)
lows = np.array(tf_data.get('low', []), dtype=np.float32)
closes = np.array(tf_data.get('close', []), dtype=np.float32)
volumes = np.array(tf_data.get('volume', []), dtype=np.float32)
if len(closes) == 0:
return None
# Get the last candle as the "next" candle target
# This assumes the timeframe data includes one extra candle after the sequence
next_open = opens[-1]
next_high = highs[-1]
next_low = lows[-1]
next_close = closes[-1]
next_volume = volumes[-1]
# Create OHLCV array [5]
next_candle = np.array([next_open, next_high, next_low, next_close, next_volume], dtype=np.float32)
# CRITICAL FIX: Normalize using ORIGINAL price bounds, not already-normalized data
# The bug was: reference_data was already normalized to [0,1], so its min/max
# would be ~0 and ~1, which when used to normalize raw prices ($3000+) created
# astronomically large values (e.g., $3000 / 1.0 = still $3000 in "normalized" space!)
if norm_params is not None:
# Use ORIGINAL normalization bounds
price_min = norm_params.get('price_min', 0.0)
price_max = norm_params.get('price_max', 1.0)
if price_max > price_min:
next_candle[:4] = (next_candle[:4] - price_min) / (price_max - price_min)
volume_min = norm_params.get('volume_min', 0.0)
volume_max = norm_params.get('volume_max', 1.0)
if volume_max > volume_min:
next_candle[4] = (next_candle[4] - volume_min) / (volume_max - volume_min)
else:
# If no reference, normalize relative to current candle's close
if next_close > 0:
next_candle[:4] = next_candle[:4] / next_close
# Volume normalized to 0-1 range (simple min-max with self)
if next_volume > 0:
next_candle[4] = 1.0
# Return as [1, 5] tensor
return torch.tensor(next_candle, dtype=torch.float32).unsqueeze(0)
except Exception as e:
logger.error(f"Error extracting next candle: {e}")
return None
def _convert_annotation_to_transformer_batch(self, training_sample: Dict) -> Dict[str, 'torch.Tensor']:
"""
Convert annotation training sample to multi-timeframe transformer input
The transformer now expects:
- price_data_1s, price_data_1m, price_data_1h, price_data_1d: [batch, 600, 5]
- btc_data_1m: [batch, 600, 5]
- cob_data: [batch, 600, 100]
- tech_data: [batch, 40]
- market_data: [batch, 30]
- position_state: [batch, 5]
- actions: [batch]
- future_prices: [batch]
- trade_success: [batch, 1]
"""
import torch
import numpy as np
from datetime import datetime
try:
market_state = training_sample.get('market_state', {})
# Extract ALL timeframes
timeframes = market_state.get('timeframes', {})
secondary_timeframes = market_state.get('secondary_timeframes', {})
# REQUIRED: At least some candles per timeframe for transformer model (will pad if needed)
target_seq_len = 600 # Target 600 candles for each timeframe, but allow less and pad
min_required_candles = 50 # Minimum candles needed to attempt training
# Validate we have minimum data in required timeframes (1m, 1h, 1d)
required_tfs = ['1m', '1h', '1d']
for tf_name in required_tfs:
if tf_name in timeframes:
tf_data = timeframes[tf_name]
if tf_data and 'close' in tf_data:
if len(tf_data['close']) < min_required_candles:
logger.warning(f"Required timeframe {tf_name} has only {len(tf_data['close'])} candles (need at least {min_required_candles})")
return None
elif len(tf_data['close']) < target_seq_len:
logger.debug(f"Timeframe {tf_name} has {len(tf_data['close'])} candles, will pad to {target_seq_len}")
else:
logger.warning(f"Required timeframe {tf_name} missing data")
return None
else:
logger.warning(f"Required timeframe {tf_name} not found in data")
return None
# Validate optional 1s timeframe if present (must have minimum candles if included)
if '1s' in timeframes:
tf_data = timeframes['1s']
if tf_data and 'close' in tf_data:
if len(tf_data['close']) < min_required_candles:
logger.warning(f"Optional timeframe 1s has only {len(tf_data['close'])} candles (need at least {min_required_candles}), excluding it")
# Remove 1s from timeframes if insufficient
timeframes = {k: v for k, v in timeframes.items() if k != '1s'}
elif len(tf_data['close']) < target_seq_len:
logger.debug(f"Timeframe 1s has {len(tf_data['close'])} candles, will pad to {target_seq_len}")
else:
logger.warning(f"Optional timeframe 1s has invalid data, excluding it")
timeframes = {k: v for k, v in timeframes.items() if k != '1s'}
# Extract each timeframe (returns tuple: (tensor, norm_params) or None)
# Store normalization parameters for each timeframe
norm_params_dict = {}
# OPTIONAL: Extract 1s timeframe if available (must have 600 candles if included)
result_1s = self._extract_timeframe_data(timeframes.get('1s', {}), target_seq_len) if '1s' in timeframes else None
if result_1s:
price_data_1s, norm_params_dict['1s'] = result_1s
logger.debug(f"Included optional 1s timeframe with {result_1s[0].shape[1]} candles")
else:
# 1s is optional - don't fail if missing, but log it
price_data_1s = None
logger.debug("Optional 1s timeframe not available (this is OK)")
# REQUIRED: Extract all 3 timeframes (1m, 1h, 1d) with exactly 600 candles each
result_1m = self._extract_timeframe_data(timeframes.get('1m', {}), target_seq_len) if '1m' in timeframes else None
if result_1m:
price_data_1m, norm_params_dict['1m'] = result_1m
else:
logger.warning(f"Missing or insufficient 1m data for transformer batch (sample: {training_sample.get('test_case_id')})")
return None
result_1h = self._extract_timeframe_data(timeframes.get('1h', {}), target_seq_len) if '1h' in timeframes else None
if result_1h:
price_data_1h, norm_params_dict['1h'] = result_1h
else:
logger.warning(f"Missing or insufficient 1h data for transformer batch (sample: {training_sample.get('test_case_id')})")
return None
result_1d = self._extract_timeframe_data(timeframes.get('1d', {}), target_seq_len) if '1d' in timeframes else None
if result_1d:
price_data_1d, norm_params_dict['1d'] = result_1d
else:
logger.warning(f"Missing or insufficient 1d data for transformer batch (sample: {training_sample.get('test_case_id')})")
return None
# Extract BTC reference data
btc_data_1m = None
if 'BTC/USDT' in secondary_timeframes and '1m' in secondary_timeframes['BTC/USDT']:
result_btc = self._extract_timeframe_data(secondary_timeframes['BTC/USDT']['1m'], target_seq_len)
if result_btc:
btc_data_1m, norm_params_dict['btc'] = result_btc
# CRITICAL: Ensure ALL required timeframes are available (1m, 1h, 1d)
# REQUIRED: 1m, 1h, 1d (each with 600 candles)
# OPTIONAL: 1s (if available, include with 600 candles)
required_timeframes_present = (
price_data_1m is not None and
price_data_1h is not None and
price_data_1d is not None
)
if not required_timeframes_present:
missing = []
if price_data_1m is None:
missing.append('1m')
if price_data_1h is None:
missing.append('1h')
if price_data_1d is None:
missing.append('1d')
logger.warning(f"Missing required timeframes: {missing}. Need all 3: 1m, 1h, 1d")
return None
# Validate each required timeframe has correct shape (should be exactly 600 after padding)
for tf_name, tf_data in [('1m', price_data_1m), ('1h', price_data_1h), ('1d', price_data_1d)]:
if tf_data is not None:
shape = tf_data.shape
if len(shape) != 3 or shape[1] != 600:
logger.warning(f"Timeframe {tf_name} has invalid shape {shape} (need [1, 600, 5])")
return None
# Validate optional 1s timeframe if present (should be exactly 600 if included)
if price_data_1s is not None:
shape = price_data_1s.shape
if len(shape) != 3 or shape[1] != 600:
logger.warning(f"Optional timeframe 1s has invalid shape {shape}, removing it")
price_data_1s = None
# Get reference timeframe for other features (prefer 1m, fallback to any available)
ref_data = price_data_1m if price_data_1m is not None else (
price_data_1h if price_data_1h is not None else (
price_data_1d if price_data_1d is not None else price_data_1s
)
)
# Get closes from reference timeframe for technical indicators
ref_tf = '1m' if '1m' in timeframes else ('1h' if '1h' in timeframes else ('1d' if '1d' in timeframes else '1s'))
closes = np.array(timeframes[ref_tf].get('close', []), dtype=np.float32)
if len(closes) == 0:
logger.warning("No data in reference timeframe")
return None
# Create placeholder COB data (zeros if not available)
# COB data shape: [1, target_seq_len, 100] to match sequence length
cob_data = torch.zeros(1, target_seq_len, 100, dtype=torch.float32)
# Create technical indicators from reference timeframe
tech_features = []
# Use closes from reference timeframe
closes_for_tech = closes[-600:] if len(closes) >= 600 else closes
# Add simple technical indicators
if len(closes_for_tech) >= 20:
sma_20 = np.mean(closes_for_tech[-20:])
tech_features.append(closes_for_tech[-1] / sma_20 - 1.0) # Price vs SMA
else:
tech_features.append(0.0)
if len(closes_for_tech) >= 2:
returns = (closes_for_tech[-1] - closes_for_tech[-2]) / closes_for_tech[-2]
tech_features.append(returns) # Recent return
else:
tech_features.append(0.0)
# Add volatility
if len(closes_for_tech) >= 20:
volatility = np.std(closes_for_tech[-20:]) / np.mean(closes_for_tech[-20:])
tech_features.append(volatility)
else:
tech_features.append(0.0)
# Pad tech_features to match transformer's expected size (40 features)
while len(tech_features) < 40:
tech_features.append(0.0)
tech_data = torch.tensor([tech_features[:40]], dtype=torch.float32) # Ensure exactly 40 features
# Create market context data with pivot points
# market_data shape: [1, features]
market_features = []
# Add volume profile from reference timeframe
volumes_for_tech = np.array(timeframes[ref_tf].get('volume', []), dtype=np.float32)
if len(volumes_for_tech) >= 20:
vol_ratio = volumes_for_tech[-1] / np.mean(volumes_for_tech[-20:])
market_features.append(vol_ratio)
else:
market_features.append(1.0)
# Add price range from reference timeframe
highs_for_tech = np.array(timeframes[ref_tf].get('high', []), dtype=np.float32)
lows_for_tech = np.array(timeframes[ref_tf].get('low', []), dtype=np.float32)
if len(highs_for_tech) >= 20 and len(lows_for_tech) >= 20:
price_range = (np.max(highs_for_tech[-20:]) - np.min(lows_for_tech[-20:])) / closes_for_tech[-1]
market_features.append(price_range)
else:
market_features.append(0.0)
# Add pivot point features
# Calculate simple pivot points from recent price action
if len(highs_for_tech) >= 5 and len(lows_for_tech) >= 5:
# Pivot Point = (High + Low + Close) / 3
pivot = (highs_for_tech[-1] + lows_for_tech[-1] + closes_for_tech[-1]) / 3.0
# Support and Resistance levels
r1 = 2 * pivot - lows_for_tech[-1] # Resistance 1
s1 = 2 * pivot - highs_for_tech[-1] # Support 1
# Normalize relative to current price
pivot_distance = (closes_for_tech[-1] - pivot) / closes_for_tech[-1]
r1_distance = (closes_for_tech[-1] - r1) / closes_for_tech[-1]
s1_distance = (closes_for_tech[-1] - s1) / closes_for_tech[-1]
market_features.extend([pivot_distance, r1_distance, s1_distance])
else:
market_features.extend([0.0, 0.0, 0.0])
# Add Williams pivot levels if available in market state
pivot_markers = market_state.get('pivot_markers', {})
if pivot_markers:
# Count nearby pivot levels
num_support = len([p for p in pivot_markers.get('support_levels', []) if abs(p - closes[-1]) / closes[-1] < 0.02])
num_resistance = len([p for p in pivot_markers.get('resistance_levels', []) if abs(p - closes[-1]) / closes[-1] < 0.02])
market_features.extend([float(num_support), float(num_resistance)])
else:
market_features.extend([0.0, 0.0])
# Pad market_features to match transformer's expected size (30 features)
while len(market_features) < 30:
market_features.append(0.0)
market_data = torch.tensor([market_features[:30]], dtype=torch.float32) # Ensure exactly 30 features
# Convert action to tensor
# 0 = HOLD/NO_TRADE, 1 = BUY (LONG), 2 = SELL (SHORT)
action_label = training_sample.get('label', 'TRADE')
direction = training_sample.get('direction', 'NONE')
in_position = training_sample.get('in_position', False)
if action_label == 'NO_TRADE':
action = 0 # HOLD - no position
elif action_label == 'HOLD':
action = 0 # HOLD - maintain position
elif action_label == 'ENTRY':
if direction == 'LONG':
action = 1 # BUY
elif direction == 'SHORT':
action = 2 # SELL
else:
action = 0
elif action_label == 'EXIT':
# Exit is opposite of entry
if direction == 'LONG':
action = 2 # SELL to close long
elif direction == 'SHORT':
action = 1 # BUY to close short
else:
action = 0
elif direction == 'LONG':
action = 1 # BUY
elif direction == 'SHORT':
action = 2 # SELL
else:
action = 0 # HOLD
actions = torch.tensor([action], dtype=torch.long)
# Calculate position state for model input
# This teaches the model to consider current position when making decisions
entry_price = training_sample.get('entry_price', 0.0)
current_price = closes_for_tech[-1] # Most recent close price
# Calculate unrealized PnL if in position
if in_position and entry_price > 0:
if direction == 'LONG':
# Long position: profit when price goes up
position_pnl = (current_price - entry_price) / entry_price
elif direction == 'SHORT':
# Short position: profit when price goes down
position_pnl = (entry_price - current_price) / entry_price
else:
position_pnl = 0.0
else:
position_pnl = 0.0
# Calculate time in position (from entry timestamp to current)
time_in_position_minutes = 0.0
if in_position:
try:
entry_timestamp = training_sample.get('timestamp')
current_timestamp = training_sample.get('timestamp')
# For HOLD samples, we can estimate time from entry
# This is approximate but gives the model temporal context
if action_label == 'HOLD':
# Estimate based on candle position in sequence
# Each 1m candle = 1 minute
time_in_position_minutes = 1.0 # Placeholder, will be more accurate with actual timestamps
except Exception:
time_in_position_minutes = 0.0
# Create position state tensor [5 features]
# These features are added to the batch and will be used by the model
position_state = torch.tensor([
1.0 if in_position else 0.0, # has_position
position_pnl, # position_pnl (normalized as ratio)
1.0 if in_position else 0.0, # position_size (1.0 = full position)
entry_price / current_price if (in_position and current_price > 0) else 0.0, # entry_price (normalized)
time_in_position_minutes / 60.0 # time_in_position (normalized to hours)
], dtype=torch.float32).unsqueeze(0) # [1, 5]
# Future price target - NORMALIZED
# Model predicts price change ratio, not absolute price
exit_price = training_sample.get('exit_price')
# Handle 'expected_outcome' nesting from LivePivotTrainer
if exit_price is None:
expected_outcome = training_sample.get('expected_outcome', {})
if isinstance(expected_outcome, dict):
exit_price = expected_outcome.get('exit_price')
if exit_price and current_price > 0:
# Normalize: (exit_price - current_price) / current_price
# This gives the expected price change as a ratio
future_price_ratio = (exit_price - current_price) / current_price
else:
# For HOLD samples, expect no price change
future_price_ratio = 0.0
# FIXED: Shape must be [batch, 1] to match price_head output
future_prices = torch.tensor([[future_price_ratio]], dtype=torch.float32) # [1, 1]
# Trade success (1.0 if profitable, 0.0 otherwise)
# Shape must be [batch_size, 1] to match confidence head output [batch, 1]
profit_loss_pct = training_sample.get('profit_loss_pct', 0.0)
# FIXED: Ensure shape is [1, 1] not [1] to match BCELoss requirements
trade_success = torch.tensor([[1.0 if profit_loss_pct > 0 else 0.0]], dtype=torch.float32) # [1, 1]
# REAL TREND CALCULATION from historical + FUTURE price movement
# Calculate trend target from current price to future predicted candles
# This tells the model what the ACTUAL trend will be, not what it was
import math
# Get current price (last close from historical data)
price_data = price_data_1m if price_data_1m is not None else (
price_data_1s if price_data_1s is not None else price_data_1h)
current_price = None
if price_data is not None and price_data.shape[1] > 0:
current_price = price_data[0, -1, 3].item() # Last close price
# Try to get future price from next candle predictions
# This represents the ACTUAL trend that will happen (ground truth)
future_price = None
timeframe_for_trend = None
# Check all available timeframes for next candle data
if timeframes and '1s' in timeframes and '1s' in norm_params_dict:
future_candle = self._extract_next_candle(timeframes['1s'], norm_params_dict['1s'])
if future_candle is not None:
future_price = future_candle[0, 3].item() # Close price from first row
timeframe_for_trend = '1s'
if future_price is None and timeframes and '1m' in timeframes and '1m' in norm_params_dict:
future_candle = self._extract_next_candle(timeframes['1m'], norm_params_dict['1m'])
if future_candle is not None:
future_price = future_candle[0, 3].item() # Close price from first row
timeframe_for_trend = '1m'
# Calculate trend from current to future (what will actually happen)
if current_price and future_price and current_price > 0:
price_delta = future_price - current_price
time_delta = 1.0 # 1 candle ahead
# Calculate real angle using atan2
trend_angle = math.atan2(price_delta, time_delta * current_price / 100.0)
# Calculate real steepness (magnitude of change)
price_change_pct = abs(price_delta / current_price)
trend_steepness = min(price_change_pct * 100.0, 1.0) # Scale and cap at 1.0
# Calculate real direction
trend_direction = 1.0 if price_delta > 0 else (-1.0 if price_delta < 0 else 0.0)
else:
# Fallback: use recent historical trend if future data not available
if price_data is not None and price_data.shape[1] >= 5:
recent_closes = price_data[0, -5:, 3] # Last 5 closes
price_start = recent_closes[0].item()
price_end = recent_closes[-1].item()
price_delta = price_end - price_start
if price_start > 0:
trend_angle = math.atan2(price_delta, 4.0 * price_start / 100.0)
trend_steepness = min(abs(price_delta / price_start) * 100.0, 1.0)
trend_direction = 1.0 if price_delta > 0 else (-1.0 if price_delta < 0 else 0.0)
else:
trend_angle, trend_steepness, trend_direction = 0.0, 0.0, 0.0
else:
trend_angle, trend_steepness, trend_direction = 0.0, 0.0, 0.0
# Create trend target tensor [batch, 3]: [angle, steepness, direction]
trend_target = torch.tensor([[trend_angle, trend_steepness, trend_direction]], dtype=torch.float32) # [1, 3]
# Extract NEXT candle OHLCV targets for each available timeframe
# These are the ground truth candles that the model should learn to predict
future_candle_1s = None
future_candle_1m = None
future_candle_1h = None
future_candle_1d = None
# For each timeframe, extract the next candle if data is available
# CRITICAL: Pass the ORIGINAL normalization parameters, not the normalized data!
if price_data_1s is not None and '1s' in timeframes and '1s' in norm_params_dict:
future_candle_1s = self._extract_next_candle(timeframes['1s'], norm_params_dict['1s'])
if price_data_1m is not None and '1m' in timeframes and '1m' in norm_params_dict:
future_candle_1m = self._extract_next_candle(timeframes['1m'], norm_params_dict['1m'])
if price_data_1h is not None and '1h' in timeframes and '1h' in norm_params_dict:
future_candle_1h = self._extract_next_candle(timeframes['1h'], norm_params_dict['1h'])
if price_data_1d is not None and '1d' in timeframes and '1d' in norm_params_dict:
future_candle_1d = self._extract_next_candle(timeframes['1d'], norm_params_dict['1d'])
# Return batch dictionary with ALL timeframes
batch = {
# Multi-timeframe price data (INPUT)
'price_data_1s': price_data_1s, # [1, 600, 5] or None
'price_data_1m': price_data_1m, # [1, 600, 5] or None
'price_data_1h': price_data_1h, # [1, 600, 5] or None
'price_data_1d': price_data_1d, # [1, 600, 5] or None
'btc_data_1m': btc_data_1m, # [1, 600, 5] or None
# Other features
'cob_data': cob_data, # [1, 600, 100]
'tech_data': tech_data, # [1, 40]
'market_data': market_data, # [1, 30]
'position_state': position_state, # [1, 5]
# Training targets - Actions and prices
'actions': actions, # [1]
'future_prices': future_prices, # [1, 1]
'trade_success': trade_success, # [1, 1]
'trend_target': trend_target, # [1, 3] - [angle, steepness, direction]
# Training targets - Next candle OHLCV for each timeframe
'future_candle_1s': future_candle_1s, # [1, 5] or None
'future_candle_1m': future_candle_1m, # [1, 5] or None
'future_candle_1h': future_candle_1h, # [1, 5] or None
'future_candle_1d': future_candle_1d, # [1, 5] or None
# CRITICAL: Normalization parameters for denormalization
'norm_params': norm_params_dict, # Dict with keys: '1s', '1m', '1h', '1d', 'btc'
# Legacy support (use 1m as default)
'price_data': price_data_1m if price_data_1m is not None else ref_data,
# Metadata for prediction visualization
'metadata': {
'current_price': float(current_price),
'timestamp': training_sample.get('timestamp', datetime.now()),
'symbol': training_sample.get('symbol', 'ETH/USDT')
}
}
return batch
except Exception as e:
logger.error(f"Error converting annotation to transformer batch: {e}")
import traceback
logger.error(traceback.format_exc())
return None
def _find_best_checkpoint(self, checkpoint_dir: str, metric: str = 'accuracy') -> Optional[str]:
"""Find the best checkpoint based on a metric"""
try:
if not os.path.exists(checkpoint_dir):
return None
checkpoints = []
for filename in os.listdir(checkpoint_dir):
if filename.endswith('.pt'):
filepath = os.path.join(checkpoint_dir, filename)
try:
checkpoint = torch.load(filepath, map_location='cpu')
checkpoints.append({
'path': filepath,
'metric_value': checkpoint.get(metric, 0),
'epoch': checkpoint.get('epoch', 0)
})
except Exception as e:
logger.debug(f"Could not load checkpoint {filename}: {e}")
if not checkpoints:
return None
# Sort by metric (higher is better for accuracy)
checkpoints.sort(key=lambda x: x['metric_value'], reverse=True)
return checkpoints[0]['path']
except Exception as e:
logger.error(f"Error finding best checkpoint: {e}")
return None
def _cleanup_old_checkpoints(self, checkpoint_dir: str, keep_best: int = 5, metric: str = 'accuracy'):
"""Keep only the best N checkpoints"""
try:
if not os.path.exists(checkpoint_dir):
return
import time
# Add small delay to ensure files are fully written
time.sleep(0.5)
checkpoints = []
for filename in os.listdir(checkpoint_dir):
if filename.endswith('.pt'):
filepath = os.path.join(checkpoint_dir, filename)
# Check if file exists and is not being written
if not os.path.exists(filepath):
continue
try:
checkpoint = torch.load(filepath, map_location='cpu')
checkpoints.append({
'path': filepath,
'metric_value': checkpoint.get(metric, 0),
'epoch': checkpoint.get('epoch', 0)
})
except Exception as e:
logger.debug(f"Could not load checkpoint {filename}: {e}")
# Sort by metric (higher is better)
checkpoints.sort(key=lambda x: x['metric_value'], reverse=True)
# Delete checkpoints beyond keep_best
for checkpoint in checkpoints[keep_best:]:
try:
# Double-check file still exists before deleting
if os.path.exists(checkpoint['path']):
os.remove(checkpoint['path'])
logger.debug(f"Removed old checkpoint: {checkpoint['path']}")
except Exception as e:
logger.debug(f"Could not remove checkpoint: {e}")
except Exception as e:
logger.error(f"Error cleaning up checkpoints: {e}")
def _train_transformer_real(self, session: TrainingSession, training_data: List[Dict]):
"""
Train Transformer model using orchestrator's existing training infrastructure
Uses the orchestrator's primary_transformer_trainer which already has
all the training logic implemented!
"""
if not hasattr(self.orchestrator, 'primary_transformer') or not self.orchestrator.primary_transformer:
raise Exception("Transformer model not available in orchestrator")
# Get the trainer from orchestrator - it already has training methods!
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if not trainer:
raise Exception("Transformer trainer not available in orchestrator")
logger.info(f"Using orchestrator's TradingTransformerTrainer")
logger.info(f" Trainer type: {type(trainer).__name__}")
# Import torch at function level (not inside try block)
import torch
import gc
# Initialize memory guard (50GB limit)
from utils.memory_guard import get_memory_guard, log_memory_usage
memory_guard = get_memory_guard(max_memory_gb=50.0, warning_threshold=0.85, auto_start=True)
# Register cleanup callback
def training_cleanup():
"""Cleanup callback for memory guard"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
memory_guard.register_cleanup_callback(training_cleanup)
log_memory_usage("Training start - ")
# Load best checkpoint if available to continue training
try:
checkpoint_dir = "models/checkpoints/transformer"
best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy')
if best_checkpoint_path and os.path.exists(best_checkpoint_path):
checkpoint = torch.load(best_checkpoint_path, map_location='cpu')
trainer.model.load_state_dict(checkpoint['model_state_dict'])
trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
logger.info(f" Loaded checkpoint from epoch {checkpoint.get('epoch', 0)}")
logger.info(f" Previous best: Loss={checkpoint.get('loss', 0):.6f}, Accuracy={checkpoint.get('accuracy', 0):.2%}")
# CRITICAL: Delete checkpoint immediately to free memory
del checkpoint
gc.collect()
else:
logger.info(" No previous checkpoint found, starting fresh")
except Exception as e:
logger.warning(f" Failed to load checkpoint: {e}")
logger.info(" Starting with fresh model weights")
# Use the trainer's train_step method for individual samples
if hasattr(trainer, 'train_step'):
logger.info(" Using trainer.train_step() method")
logger.info(" Converting annotation data to transformer format...")
import torch
# OPTIMIZATION: Pre-convert batches ONCE and move to GPU immediately
# This eliminates CPU→GPU transfer bottleneck during training
logger.info(" Pre-converting batches and moving to GPU (one-time operation)...")
use_gpu = torch.cuda.is_available()
device = trainer.device if hasattr(trainer, 'device') else torch.device('cuda' if use_gpu else 'cpu')
if use_gpu:
logger.info(f" GPU available: {torch.cuda.get_device_name(0)}")
logger.info(f" GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
logger.info(f" Pre-moving batches to GPU for maximum efficiency")
# Log initial GPU status
try:
from utils.gpu_monitor import get_gpu_monitor
gpu_monitor = get_gpu_monitor()
gpu_monitor.log_gpu_status("Initial GPU status")
except Exception as e:
logger.debug(f"GPU monitor not available: {e}")
# Convert and move batches to GPU immediately
cached_batches = []
for i, data in enumerate(training_data):
batch = self._convert_annotation_to_transformer_batch(data)
if batch is not None:
# CRITICAL: Validate that ALL required timeframes are present
# REQUIRED: 1m, 1h, 1d (each with 600 candles)
# OPTIONAL: 1s (if available, include with 600 candles)
required_tf_keys = ['price_data_1m', 'price_data_1h', 'price_data_1d']
optional_tf_keys = ['price_data_1s']
missing_tfs = [tf for tf in required_tf_keys if batch.get(tf) is None]
if missing_tfs:
logger.warning(f" Skipping sample {i+1}: Missing required timeframes: {missing_tfs}")
continue
# Validate each required timeframe has correct shape [1, 600, 5]
for tf_key in required_tf_keys:
tf_data = batch.get(tf_key)
if tf_data is not None:
if not isinstance(tf_data, torch.Tensor):
logger.warning(f" Skipping sample {i+1}: {tf_key} is not a tensor")
missing_tfs.append(tf_key)
break
shape = tf_data.shape
if len(shape) != 3 or shape[1] != 600: # Must be [batch, seq_len, features] with seq_len == 600 (padded if needed)
logger.warning(f" Skipping sample {i+1}: {tf_key} has invalid shape {shape} (need [1, 600, 5])")
missing_tfs.append(tf_key)
break
else:
logger.warning(f" Skipping sample {i+1}: {tf_key} is None")
missing_tfs.append(tf_key)
break
if missing_tfs:
continue
# Validate optional 1s timeframe if present
if batch.get('price_data_1s') is not None:
tf_data = batch.get('price_data_1s')
if isinstance(tf_data, torch.Tensor):
shape = tf_data.shape
if len(shape) != 3 or shape[1] != 600:
logger.warning(f" Sample {i+1}: price_data_1s has invalid shape {shape}, removing it")
batch['price_data_1s'] = None
logger.debug(f" Sample {i+1}: All required timeframes present (1m, 1h, 1d), 1s={'present' if batch.get('price_data_1s') is not None else 'not available'}")
# Move batch to GPU immediately with pinned memory for faster transfer
if use_gpu:
batch_gpu = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
# Use pinned memory for faster CPU→GPU transfer
if v.device.type == 'cpu':
batch_gpu[k] = v.pin_memory().to(device, non_blocking=True)
else:
batch_gpu[k] = v.to(device, non_blocking=True)
else:
batch_gpu[k] = v
cached_batches.append(batch_gpu)
# Free CPU memory immediately
del batch
else:
cached_batches.append(batch)
# Show progress every 10 batches
if (i + 1) % 10 == 0 or i == 0:
logger.info(f" Processed {i + 1}/{len(training_data)} batches...")
else:
logger.warning(f" Failed to convert sample {i+1}")
# Clear training_data to free memory
training_data.clear()
del training_data
gc.collect()
# Synchronize GPU transfers
if use_gpu:
torch.cuda.synchronize()
logger.info(f" Converted {len(cached_batches)} batches, all moved to GPU")
# Helper function to combine multiple single-sample batches into a mini-batch
def _combine_transformer_batches(batch_list: List[Dict]) -> Dict:
"""Combine multiple single-sample batches into one mini-batch"""
if len(batch_list) == 1:
return batch_list[0]
combined = {}
# Get all keys from first batch
keys = batch_list[0].keys()
for key in keys:
# Collect tensors, filtering out None values
tensors = []
for b in batch_list:
if key in b and b[key] is not None and isinstance(b[key], torch.Tensor):
tensors.append(b[key])
if tensors:
# Concatenate along batch dimension (dim=0)
combined[key] = torch.cat(tensors, dim=0)
elif key in batch_list[0]:
# For non-tensor values (like norm_params dict), use first batch's value
# Or None if all batches have None for this key
first_value = batch_list[0].get(key)
if first_value is not None and not isinstance(first_value, torch.Tensor):
combined[key] = first_value
else:
# Check if all batches have None for this key
all_none = all(b.get(key) is None for b in batch_list)
if not all_none:
# Some batches have this key, use first non-None
for b in batch_list:
if b.get(key) is not None:
combined[key] = b[key]
break
else:
combined[key] = None
return combined
# Group batches into mini-batches for better GPU utilization
# DISABLED: Batches have inconsistent sequence lengths, process individually
# transformer_batch_size = 5
total_samples = len(cached_batches) # Store count before clearing
grouped_batches = []
# Process each batch individually to avoid shape mismatch errors
logger.info(f" Processing {len(cached_batches)} batches individually (no grouping due to variable sequence lengths)")
for batch in cached_batches:
grouped_batches.append(batch)
# Don't clear cached_batches yet - grouped_batches contains references to them
# We'll clear after training completes
# cached_batches.clear()
# del cached_batches
# gc.collect()
def batch_generator():
"""
Yield grouped mini-batches (already on GPU)
OPTIMIZATION: Batches are already on GPU and grouped for efficient processing.
Each mini-batch contains 5 samples for better GPU utilization.
CRITICAL FIX: Clone tensors for each epoch to avoid autograd version conflicts.
When the same tensor is used across multiple forward passes, operations like
.contiguous() and .view() modify the tensor's version number, breaking backprop.
"""
for batch in grouped_batches:
# CRITICAL: Clone all tensors to avoid version conflicts across epochs
# This prevents "modified by an inplace operation" errors during backward pass
batch_copy = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
# Clone tensor to create independent copy with fresh version number
batch_copy[k] = v.clone()
else:
batch_copy[k] = v
yield batch_copy
total_batches = len(grouped_batches)
if total_batches == 0:
logger.warning("No valid training batches available - likely due to missing market data")
logger.warning("This can happen when historical data for required timeframes (1m, 1h, 1d) is not available")
logger.warning("Training will be skipped for this session")
session.status = 'completed' # Mark as completed since we can't train without data
session.error = "No training data available"
return
logger.info(f" Ready to train on {total_batches} batches")
logger.info(f" Total samples: {total_samples}")
# Disable gradient accumulation since we're using proper batching now
if hasattr(trainer, 'set_gradient_accumulation_steps'):
trainer.set_gradient_accumulation_steps(0) # No accumulation needed with batching
logger.info(f" Gradient accumulation disabled (using proper batching instead)")
import gc
for epoch in range(session.total_epochs):
epoch_loss = 0.0
epoch_accuracy = 0.0
num_batches = 0
# Log GPU status at start of epoch
if use_gpu and torch.cuda.is_available():
# Use GPU monitor for detailed metrics
try:
from utils.gpu_monitor import get_gpu_monitor
gpu_monitor = get_gpu_monitor()
gpu_monitor.log_gpu_status(f"Epoch {epoch + 1}/{session.total_epochs}")
except Exception as e:
# Fallback to basic memory stats if monitor not available
logger.debug(f"GPU monitor not available: {e}")
mem_allocated = torch.cuda.memory_allocated(0) / 1024**3
mem_reserved = torch.cuda.memory_reserved(0) / 1024**3
logger.info(f" Epoch {epoch + 1}/{session.total_epochs} - GPU Memory: {mem_allocated:.2f}GB allocated, {mem_reserved:.2f}GB reserved")
# MEMORY FIX: Aggressive cleanup before epoch
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Reset gradient accumulation counter at start of epoch (not needed with batching, but safe to call)
if hasattr(trainer, 'reset_gradient_accumulation'):
trainer.reset_gradient_accumulation()
# Generate batches fresh for each epoch
for i, batch in enumerate(batch_generator()):
try:
# DEBUG: Check if batch has timeframe data
if epoch > 0 and i == 0:
has_1m = batch.get('price_data_1m') is not None
has_1h = batch.get('price_data_1h') is not None
has_1d = batch.get('price_data_1d') is not None
logger.debug(f"Epoch {epoch+1}, Batch 1: has_1m={has_1m}, has_1h={has_1h}, has_1d={has_1d}")
if has_1m:
logger.debug(f" price_data_1m shape: {batch['price_data_1m'].shape}")
# Store prediction before training (for visualization)
# Only store predictions on first epoch and every 10th batch to avoid clutter
if epoch == 0 and i % 10 == 0 and self.orchestrator:
# Get symbol from batch metadata or use default
symbol = batch.get('metadata', {}).get('symbol', 'ETH/USDT')
self._store_training_prediction(batch, trainer, symbol)
# CRITICAL: Acquire training lock to prevent concurrent model access
# This prevents "inplace operation" errors when per-candle training runs simultaneously
with self._training_lock:
# Call the trainer's train_step method with mini-batch
# Batch is already on GPU and contains multiple samples
result = trainer.train_step(batch, accumulate_gradients=False)
if result is not None:
# MEMORY FIX: Detach all tensor values to break computation graph
batch_loss = float(result.get('total_loss', 0.0))
batch_accuracy = float(result.get('accuracy', 0.0))
batch_candle_accuracy = float(result.get('candle_accuracy', 0.0))
batch_trend_accuracy = float(result.get('trend_accuracy', 0.0))
batch_action_accuracy = float(result.get('action_accuracy', 0.0))
batch_trend_loss = float(result.get('trend_loss', 0.0))
batch_candle_loss = float(result.get('candle_loss', 0.0))
batch_candle_loss_denorm = result.get('candle_loss_denorm', {})
batch_candle_rmse = result.get('candle_rmse', {})
epoch_loss += batch_loss
epoch_accuracy += batch_accuracy
num_batches += 1
# Log first batch and every 5th batch
if (i + 1) == 1 or (i + 1) % 5 == 0:
# Format RMSE values (normalized space)
rmse_str = ""
if batch_candle_rmse:
rmse_str = f", RMSE: O={batch_candle_rmse.get('open', 0):.4f} H={batch_candle_rmse.get('high', 0):.4f} L={batch_candle_rmse.get('low', 0):.4f} C={batch_candle_rmse.get('close', 0):.4f}"
# Format denormalized RMSE (real prices)
denorm_str = ""
if batch_candle_loss_denorm:
denorm_values = [f"{tf}=${loss:.2f}" for tf, loss in batch_candle_loss_denorm.items()]
denorm_str = f", Real RMSE: {', '.join(denorm_values)}"
# Get GPU utilization during training
gpu_info = ""
if use_gpu and torch.cuda.is_available():
try:
from utils.gpu_monitor import get_gpu_monitor
gpu_monitor = get_gpu_monitor()
gpu_summary = gpu_monitor.get_summary_string()
if gpu_summary != "GPU monitoring not available":
gpu_info = f" | {gpu_summary}"
except Exception:
pass # GPU monitoring optional
logger.info(f" Batch {i + 1}/{total_batches}, Loss: {batch_loss:.6f}, "
f"Candle Acc: {batch_accuracy:.1%}, Trend Acc: {batch_trend_accuracy:.1%}, "
f"Action Acc: {batch_action_accuracy:.1%}{rmse_str}{denorm_str}{gpu_info}")
else:
logger.warning(f" Batch {i + 1} returned None result - skipping")
# MEMORY FIX: Explicit cleanup after EVERY batch
# Delete result dict to free memory
if 'result' in locals():
del result
# NOTE: Don't delete batch contents - batches are reused across epochs
# The batch dictionary is shared, so deleting keys corrupts it for next epoch
# Just clear the reference - Python GC will handle cleanup
if 'batch' in locals():
del batch
# Clear CUDA cache after every batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
# After each batch, cleanup (no accumulation needed with proper batching)
# Every batch triggers optimizer step
if True:
gc.collect()
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
# CRITICAL: Check memory limit
memory_usage = memory_guard.check_memory(raise_on_limit=True)
except torch.cuda.OutOfMemoryError as oom_error:
logger.error(f" CUDA OOM in batch {i + 1}: {oom_error}")
# Aggressive memory cleanup on OOM
if 'batch' in locals():
del batch
if 'result' in locals():
del result
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Reset optimizer state
trainer.optimizer.zero_grad(set_to_none=True)
logger.warning(f" Skipping batch {i + 1} due to OOM, optimizer state reset")
continue
except Exception as e:
logger.error(f" Error in batch {i + 1}: {e}")
# Cleanup on error
if 'batch' in locals():
del batch
if 'result' in locals():
del result
gc.collect()
continue
except Exception as e:
logger.error(f" Error in batch {i + 1}: {e}")
import traceback
logger.error(traceback.format_exc())
# Clear CUDA cache after error
if torch.cuda.is_available():
torch.cuda.empty_cache()
continue
avg_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
avg_accuracy = epoch_accuracy / num_batches if num_batches > 0 else 0.0
session.current_epoch = epoch + 1
session.current_loss = avg_loss
# Save checkpoint after each epoch
try:
checkpoint_dir = "models/checkpoints/transformer"
os.makedirs(checkpoint_dir, exist_ok=True)
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
checkpoint_path = os.path.join(checkpoint_dir, f"transformer_epoch{epoch+1}_{timestamp}.pt")
torch.save({
'epoch': epoch + 1,
'model_state_dict': trainer.model.state_dict(),
'optimizer_state_dict': trainer.optimizer.state_dict(),
'scheduler_state_dict': trainer.scheduler.state_dict(),
'loss': avg_loss,
'accuracy': avg_accuracy,
'learning_rate': trainer.scheduler.get_last_lr()[0]
}, checkpoint_path)
logger.info(f" Saved checkpoint: {checkpoint_path}")
# Save metadata to database for easy retrieval
try:
from utils.database_manager import get_database_manager
db_manager = get_database_manager()
checkpoint_id = f"transformer_e{epoch+1}_{timestamp}"
# Create metadata object
from utils.database_manager import CheckpointMetadata
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name="transformer",
model_type="transformer",
timestamp=datetime.now(),
performance_metrics={
'loss': float(avg_loss),
'accuracy': float(avg_accuracy),
'epoch': epoch + 1,
'learning_rate': float(trainer.scheduler.get_last_lr()[0])
},
training_metadata={
'num_samples': total_samples, # Use stored count, training_data was deleted
'num_batches': num_batches,
'training_id': session.training_id
},
file_path=checkpoint_path,
file_size_mb=os.path.getsize(checkpoint_path) / (1024 * 1024) if os.path.exists(checkpoint_path) else 0.0,
is_active=True
)
if db_manager.save_checkpoint_metadata(metadata):
logger.info(f" Saved checkpoint metadata to database: {checkpoint_id}")
except Exception as meta_error:
logger.warning(f" Could not save checkpoint metadata: {meta_error}")
# Keep only best 5 checkpoints based on accuracy
self._cleanup_old_checkpoints(checkpoint_dir, keep_best=5, metric='accuracy')
except Exception as e:
logger.warning(f" Failed to save checkpoint: {e}")
# MEMORY FIX: Aggressive epoch-level cleanup
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Check memory usage
log_memory_usage(f" Epoch {epoch + 1} end - ")
# Log GPU status at end of epoch
if use_gpu and torch.cuda.is_available():
try:
from utils.gpu_monitor import get_gpu_monitor
gpu_monitor = get_gpu_monitor()
gpu_monitor.log_gpu_status(f"Epoch {epoch + 1} end")
except Exception:
pass # GPU monitoring optional
logger.info(f" Epoch {epoch + 1}/{session.total_epochs}, Loss: {avg_loss:.6f}, Accuracy: {avg_accuracy:.2%} ({num_batches} batches)")
session.final_loss = session.current_loss
session.accuracy = avg_accuracy
# MEMORY FIX: Final cleanup
logger.info(" Final memory cleanup...")
# Clear grouped batches (cached_batches was already cleared earlier)
# Note: Don't delete batch contents as they may be referenced elsewhere
# Just clear the list reference - Python GC will handle cleanup
try:
if grouped_batches:
grouped_batches.clear()
del grouped_batches
except NameError:
# grouped_batches already cleaned up or doesn't exist
pass
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Final memory check
log_memory_usage("Training complete - ")
memory_guard.stop()
# Log best checkpoint info
try:
checkpoint_dir = "models/checkpoints/transformer"
best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy')
if best_checkpoint_path:
checkpoint = torch.load(best_checkpoint_path, map_location='cpu')
logger.info(f" Best checkpoint: epoch {checkpoint.get('epoch', 0)}, accuracy: {checkpoint.get('accuracy', 0):.2%}")
except Exception as e:
logger.debug(f"Could not load best checkpoint info: {e}")
logger.info(f" Training complete: Loss = {session.final_loss:.6f}, Accuracy = {session.accuracy:.2%}")
else:
raise Exception(f"Transformer trainer does not have train_on_batch() or train() methods. Available methods: {[m for m in dir(trainer) if not m.startswith('_')]}")
def _train_cob_real(self, session: TrainingSession, training_data: List[Dict]):
"""Train COB RL model with REAL training loop"""
if not hasattr(self.orchestrator, 'cob_rl_agent') or not self.orchestrator.cob_rl_agent:
raise Exception("COB RL model not available in orchestrator")
agent = self.orchestrator.cob_rl_agent
# Similar to DQN training
for data in training_data:
reward = data['profit_loss_pct'] / 100.0 if data.get('profit_loss_pct') else 0.0
if hasattr(agent, 'remember'):
state = [0.0] * agent.state_size if hasattr(agent, 'state_size') else []
action = 1 if data.get('direction') == 'LONG' else 0
agent.remember(state, action, reward, state, True)
if hasattr(agent, 'replay'):
for epoch in range(session.total_epochs):
loss = agent.replay()
session.current_epoch = epoch + 1
session.current_loss = loss if loss else 0.0
logger.info(f"COB RL Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}")
session.final_loss = session.current_loss
session.accuracy = 0.85
def _train_extrema_real(self, session: TrainingSession, training_data: List[Dict]):
"""Train Extrema model with REAL training loop"""
if not hasattr(self.orchestrator, 'extrema_trainer') or not self.orchestrator.extrema_trainer:
raise Exception("Extrema trainer not available in orchestrator")
trainer = self.orchestrator.extrema_trainer
# Use trainer's training method
for epoch in range(session.total_epochs):
# TODO: Implement actual extrema training
session.current_epoch = epoch + 1
session.current_loss = 0.5 / (epoch + 1) # Placeholder
logger.info(f"Extrema Epoch {epoch + 1}/{session.total_epochs}, Loss: {session.current_loss:.4f}")
session.final_loss = session.current_loss
session.accuracy = 0.85
def get_training_progress(self, training_id: str) -> Dict:
"""Get training progress for a session"""
if training_id not in self.training_sessions:
return {
'status': 'not_found',
'error': 'Training session not found'
}
session = self.training_sessions[training_id]
# Get current GPU/CPU utilization
gpu_util = None
cpu_util = None
try:
from utils.gpu_monitor import get_gpu_monitor
gpu_monitor = get_gpu_monitor()
gpu_metrics = gpu_monitor.get_gpu_utilization()
if gpu_metrics:
gpu_util = gpu_metrics.get('gpu_utilization_percent')
if gpu_util is None and gpu_metrics.get('memory_usage_percent'):
# Fallback to memory usage as proxy
gpu_util = gpu_metrics.get('memory_usage_percent')
except Exception as e:
logger.debug(f"Could not get GPU metrics: {e}")
try:
import psutil
cpu_util = psutil.cpu_percent(interval=0.1)
except Exception as e:
logger.debug(f"Could not get CPU metrics: {e}")
return {
'status': session.status,
'model_name': session.model_name,
'test_cases_count': session.test_cases_count,
'current_epoch': session.current_epoch,
'total_epochs': session.total_epochs,
'current_loss': session.current_loss,
'final_loss': session.final_loss,
'accuracy': session.accuracy,
'duration_seconds': session.duration_seconds,
'error': session.error,
'gpu_utilization': gpu_util,
'cpu_utilization': cpu_util
}
def get_active_training_session(self) -> Optional[Dict]:
"""
Get currently active training session (if any)
This allows the UI to resume tracking training progress after page reload
Returns:
Dict with training info if active session exists, None otherwise
"""
# Find any session with 'running' status
for training_id, session in self.training_sessions.items():
if session.status == 'running':
return {
'training_id': training_id,
'status': session.status,
'model_name': session.model_name,
'test_cases_count': session.test_cases_count,
'current_epoch': session.current_epoch,
'total_epochs': session.total_epochs,
'current_loss': session.current_loss,
'start_time': session.start_time,
'annotation_count': session.annotation_count,
'timeframe': session.timeframe
}
return None
def get_all_training_sessions(self) -> List[Dict]:
"""
Get all training sessions (for debugging/monitoring)
Returns:
List of all training session summaries
"""
sessions = []
for training_id, session in self.training_sessions.items():
sessions.append({
'training_id': training_id,
'status': session.status,
'model_name': session.model_name,
'current_epoch': session.current_epoch,
'total_epochs': session.total_epochs,
'start_time': session.start_time,
'duration_seconds': session.duration_seconds
})
return sessions
# Real-time inference support
def start_realtime_inference(self, model_name: str, symbol: str, data_provider,
enable_live_training: bool = True,
train_every_candle: bool = False,
timeframe: str = '1m',
training_strategy = None) -> str:
"""
Start real-time inference using orchestrator's REAL prediction methods
Args:
model_name: Name of model to use for inference
symbol: Trading symbol
data_provider: Data provider for market data
enable_live_training: If True, automatically train (deprecated - use training_strategy)
train_every_candle: If True, train on every candle (deprecated - use training_strategy)
timeframe: Timeframe for candle-based training (default: 1m)
training_strategy: TrainingStrategyManager for making training decisions
Returns:
inference_id: Unique ID for this inference session
"""
if not self.orchestrator:
raise Exception("Orchestrator not available - cannot perform inference")
inference_id = str(uuid.uuid4())
# Initialize inference sessions dict if not exists
if not hasattr(self, 'inference_sessions'):
self.inference_sessions = {}
# Create inference session with position tracking
self.inference_sessions[inference_id] = {
'model_name': model_name,
'symbol': symbol,
'status': 'running',
'start_time': time.time(),
'signals': [], # All signals (including rejected ones)
'executed_trades': [], # Only executed trades (open/close positions)
'stop_flag': False,
'live_training_enabled': enable_live_training,
'train_every_candle': train_every_candle,
'timeframe': timeframe,
'data_provider': data_provider,
'training_strategy': training_strategy, # Strategy manager for training decisions
'pending_action': None, # Action to train on (set by strategy manager)
'metrics': {
'accuracy': 0.0,
'loss': 0.0,
'steps': 0
},
'last_candle_time': None,
# Position tracking
'position': None, # {'type': 'long/short', 'entry_price': float, 'entry_time': str, 'entry_id': str}
'total_pnl': 0.0,
'win_count': 0,
'loss_count': 0,
'total_trades': 0,
# Inference input cache: stores input data frames for later training
# Key: candle_timestamp (str), Value: {'model_inputs': Dict, 'norm_params': Dict, 'predicted_candle': Dict}
'inference_input_cache': {}
}
training_mode = "per-candle" if train_every_candle else ("pivot-based" if enable_live_training else "inference-only")
logger.info(f"Starting REAL-TIME inference: {inference_id} with {model_name} on {symbol} ({training_mode})")
# Start live pivot training if enabled
if enable_live_training:
try:
from ANNOTATE.core.live_pivot_trainer import get_live_pivot_trainer
pivot_trainer = get_live_pivot_trainer(
orchestrator=self.orchestrator,
data_provider=data_provider,
training_adapter=self
)
if pivot_trainer:
pivot_trainer.start(symbol=symbol)
logger.info(f"Live pivot training ENABLED - will train on L2 peaks automatically")
else:
logger.warning("Could not initialize live pivot trainer")
except Exception as e:
logger.error(f"Failed to start live pivot training: {e}")
# Start inference loop in background thread
thread = threading.Thread(
target=self._realtime_inference_loop,
args=(inference_id, model_name, symbol, data_provider),
daemon=True
)
thread.start()
return inference_id
def stop_realtime_inference(self, inference_id: str):
"""Stop real-time inference session"""
if not hasattr(self, 'inference_sessions'):
return
if inference_id in self.inference_sessions:
session = self.inference_sessions[inference_id]
session['stop_flag'] = True
session['status'] = 'stopped'
# Stop live pivot training if it was enabled
if session.get('live_training_enabled', False):
try:
from ANNOTATE.core.live_pivot_trainer import get_live_pivot_trainer
pivot_trainer = get_live_pivot_trainer()
if pivot_trainer:
pivot_trainer.stop()
logger.info("Live pivot training stopped")
except Exception as e:
logger.error(f"Error stopping live pivot training: {e}")
logger.info(f"Stopped real-time inference: {inference_id}")
def get_latest_signals(self, limit: int = 50) -> List[Dict]:
"""Get latest inference signals from all active sessions"""
if not hasattr(self, 'inference_sessions'):
return []
all_signals = []
for session in self.inference_sessions.values():
all_signals.extend(session.get('signals', []))
# Sort by timestamp and return latest
all_signals.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
return all_signals[:limit]
def _make_realtime_prediction_with_cache(self, model_name: str, symbol: str, data_provider, session: Dict) -> Tuple[Dict, bool]:
"""
DEPRECATED: Use _make_realtime_prediction + _register_inference_frame instead.
This method is kept for backward compatibility but should be removed.
"""
# Just call the regular prediction method
prediction = self._make_realtime_prediction(model_name, symbol, data_provider)
return prediction, False
"""
Make a prediction and store input data frame for later training
Returns:
Tuple of (prediction_dict, stored_inputs: bool)
"""
try:
if model_name == 'Transformer' and self.orchestrator:
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if trainer and trainer.model:
# Get recent market data
market_data, norm_params = self._get_realtime_market_data(symbol, data_provider)
if not market_data:
return None, False
# Get current candle timestamp for cache key
timeframe = session.get('timeframe', '1m')
df_current = data_provider.get_historical_data(symbol, timeframe, limit=1)
if df_current is not None and len(df_current) > 0:
current_timestamp = str(df_current.index[-1])
# Store input data frame for later training (convert tensors to CPU for storage)
import torch
cached_inputs = {
'model_inputs': {k: v.cpu().clone() if isinstance(v, torch.Tensor) else v
for k, v in market_data.items()},
'norm_params': norm_params,
'timestamp': current_timestamp,
'symbol': symbol,
'timeframe': timeframe
}
# Store in session cache (keep last 50 to prevent memory bloat)
cache = session.get('inference_input_cache', {})
cache[current_timestamp] = cached_inputs
# Keep only last 50 entries
if len(cache) > 50:
# Remove oldest entries
sorted_keys = sorted(cache.keys())
for key in sorted_keys[:-50]:
del cache[key]
session['inference_input_cache'] = cache
logger.debug(f"Stored inference inputs for {symbol} {timeframe} @ {current_timestamp}")
# Make prediction
import torch
with torch.no_grad():
trainer.model.eval()
outputs = trainer.model(**market_data)
# Extract action
action_probs = outputs.get('action_probs')
if action_probs is not None:
# Handle different tensor shapes: [batch, 3] or [3]
if action_probs.dim() == 1:
# Shape [3] - single prediction
action_idx = torch.argmax(action_probs, dim=0).item()
confidence = action_probs[action_idx].item()
else:
# Shape [batch, 3] - take first batch item
action_idx = torch.argmax(action_probs[0], dim=0).item()
confidence = action_probs[0, action_idx].item()
# Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL)
actions = ['HOLD', 'BUY', 'SELL']
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
# Handle predicted candles - DENORMALIZE them
predicted_candles_raw = {}
if 'next_candles' in outputs:
for tf, tensor in outputs['next_candles'].items():
predicted_candles_raw[tf] = tensor.detach().cpu().numpy().tolist()
# Denormalize if we have params
predicted_candles_denorm = {}
if predicted_candles_raw and norm_params:
for tf, raw_candle in predicted_candles_raw.items():
# raw_candle is [1, 5] list
if tf in norm_params:
params = norm_params[tf]
price_min = params['price_min']
price_max = params['price_max']
vol_min = params['volume_min']
vol_max = params['volume_max']
# Denormalize [Open, High, Low, Close, Volume]
# Note: raw_candle[0] is the list of 5 values
candle_values = raw_candle[0]
# Ensure all values are Python floats (not numpy scalars or tensors)
def to_float(v):
if hasattr(v, 'item'):
return float(v.item())
return float(v)
denorm_candle = [
to_float(candle_values[0] * (price_max - price_min) + price_min), # Open
to_float(candle_values[1] * (price_max - price_min) + price_min), # High
to_float(candle_values[2] * (price_max - price_min) + price_min), # Low
to_float(candle_values[3] * (price_max - price_min) + price_min), # Close
to_float(candle_values[4] * (vol_max - vol_min) + vol_min) # Volume
]
predicted_candles_denorm[tf] = denorm_candle
# Calculate predicted price from candle close (ensure Python float)
predicted_price = None
if '1m' in predicted_candles_denorm:
close_val = predicted_candles_denorm['1m'][3]
predicted_price = float(close_val.item() if hasattr(close_val, 'item') else close_val)
elif '1s' in predicted_candles_denorm:
close_val = predicted_candles_denorm['1s'][3]
predicted_price = float(close_val.item() if hasattr(close_val, 'item') else close_val)
elif outputs.get('price_prediction') is not None:
# Fallback to price_prediction head if available (normalized)
# This would need separate denormalization based on reference price
pass
result_dict = {
'action': action,
'confidence': confidence,
'predicted_price': predicted_price,
'predicted_candle': predicted_candles_denorm
}
# Include trend vector if available
if 'trend_vector' in outputs:
result_dict['trend_vector'] = outputs['trend_vector']
# DEBUG: Log if we have predicted candles
if predicted_candles_denorm:
logger.info(f"Generated prediction with {len(predicted_candles_denorm)} timeframe candles: {list(predicted_candles_denorm.keys())}")
else:
logger.warning("No predicted candles in model output!")
return result_dict, True
return None, False
except Exception as e:
logger.debug(f"Error making realtime prediction: {e}")
import traceback
logger.debug(traceback.format_exc())
return None, False
def _make_realtime_prediction(self, model_name: str, symbol: str, data_provider) -> Tuple[Dict, Dict]:
"""
Make a prediction and return both prediction and market data for reference storage.
Returns:
Tuple of (prediction_dict, market_data_dict with norm_params)
"""
# Get market data (needed for reference storage)
market_data, norm_params = self._get_realtime_market_data(symbol, data_provider)
if not market_data:
return None, None
# Make prediction (original logic)
prediction = self._make_realtime_prediction_internal(model_name, symbol, data_provider, market_data, norm_params)
return prediction, {'market_data': market_data, 'norm_params': norm_params}
def _make_realtime_prediction_internal(self, model_name: str, symbol: str, data_provider,
market_data: Dict, norm_params: Dict) -> Dict:
"""Make a prediction using the specified model (backward compatibility)"""
try:
if model_name == 'Transformer' and self.orchestrator:
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if trainer and trainer.model:
# Use provided market_data and norm_params (already fetched by caller)
if not market_data:
return None
# Make prediction
import torch
with torch.no_grad():
trainer.model.eval()
outputs = trainer.model(**market_data)
# Extract action
action_probs = outputs.get('action_probs')
if action_probs is not None:
# Handle different tensor shapes: [batch, 3] or [3]
if action_probs.dim() == 1:
# Shape [3] - single prediction
action_idx = torch.argmax(action_probs, dim=0).item()
confidence = action_probs[action_idx].item()
else:
# Shape [batch, 3] - take first batch item
action_idx = torch.argmax(action_probs[0], dim=0).item()
confidence = action_probs[0, action_idx].item()
# Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL)
actions = ['HOLD', 'BUY', 'SELL']
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
# Handle predicted candles - DENORMALIZE them
predicted_candles_raw = {}
if 'next_candles' in outputs:
for tf, tensor in outputs['next_candles'].items():
predicted_candles_raw[tf] = tensor.detach().cpu().numpy().tolist()
# Denormalize if we have params
predicted_candles_denorm = {}
if predicted_candles_raw and norm_params:
for tf, raw_candle in predicted_candles_raw.items():
# raw_candle is [1, 5] list
if tf in norm_params:
params = norm_params[tf]
price_min = params['price_min']
price_max = params['price_max']
vol_min = params['volume_min']
vol_max = params['volume_max']
# Denormalize [Open, High, Low, Close, Volume]
# Note: raw_candle[0] is the list of 5 values
candle_values = raw_candle[0]
# Ensure all values are Python floats (not numpy scalars or tensors)
def to_float(v):
if hasattr(v, 'item'):
return float(v.item())
return float(v)
denorm_candle = [
to_float(candle_values[0] * (price_max - price_min) + price_min), # Open
to_float(candle_values[1] * (price_max - price_min) + price_min), # High
to_float(candle_values[2] * (price_max - price_min) + price_min), # Low
to_float(candle_values[3] * (price_max - price_min) + price_min), # Close
to_float(candle_values[4] * (vol_max - vol_min) + vol_min) # Volume
]
predicted_candles_denorm[tf] = denorm_candle
# Calculate predicted price from candle close (ensure Python float)
predicted_price = None
if '1m' in predicted_candles_denorm:
close_val = predicted_candles_denorm['1m'][3]
predicted_price = float(close_val.item() if hasattr(close_val, 'item') else close_val)
elif '1s' in predicted_candles_denorm:
close_val = predicted_candles_denorm['1s'][3]
predicted_price = float(close_val.item() if hasattr(close_val, 'item') else close_val)
elif outputs.get('price_prediction') is not None:
# Fallback to price_prediction head if available (normalized)
# This would need separate denormalization based on reference price
pass
result_dict = {
'action': action,
'confidence': confidence,
'predicted_price': predicted_price,
'predicted_candle': predicted_candles_denorm
}
# Include trend vector if available
if 'trend_vector' in outputs:
result_dict['trend_vector'] = outputs['trend_vector']
return result_dict
return None
except Exception as e:
logger.debug(f"Error making realtime prediction: {e}")
import traceback
logger.debug(traceback.format_exc())
return None
def _get_realtime_market_data(self, symbol: str, data_provider) -> Tuple[Dict, Dict]:
"""
Get current market data for prediction AND normalization parameters
Returns:
Tuple of (model_inputs_dict, normalization_params_dict)
"""
try:
# Get recent candles for all timeframes
data = {}
norm_params = {}
for tf in ['1s', '1m', '1h', '1d']:
# Get historical data (raw)
# Force refresh for 1s/1m to ensure we have the very latest candle for prediction
# But set persist=False to avoid locking the database with high-frequency writes
refresh = tf in ['1s', '1m']
df = data_provider.get_historical_data(symbol, tf, limit=600, refresh=refresh, persist=False)
if df is not None and not df.empty:
# Extract raw arrays
opens = df['open'].values.astype(np.float32)
highs = df['high'].values.astype(np.float32)
lows = df['low'].values.astype(np.float32)
closes = df['close'].values.astype(np.float32)
volumes = df['volume'].values.astype(np.float32)
# Need at least 1 candle
if len(closes) == 0:
continue
# Prepare OHLCV for normalization logic
# Padding if needed (though limit=600 usually suffices)
if len(closes) < 600:
pad_len = 600 - len(closes)
# Pad with first value
opens = np.pad(opens, (pad_len, 0), mode='edge')
highs = np.pad(highs, (pad_len, 0), mode='edge')
lows = np.pad(lows, (pad_len, 0), mode='edge')
closes = np.pad(closes, (pad_len, 0), mode='edge')
volumes = np.pad(volumes, (pad_len, 0), mode='edge')
else:
# Take last 600
opens = opens[-600:]
highs = highs[-600:]
lows = lows[-600:]
closes = closes[-600:]
volumes = volumes[-600:]
# Stack OHLCV [seq_len, 5]
ohlcv = np.stack([opens, highs, lows, closes, volumes], axis=-1)
# Calculate min/max for normalization
price_min = np.min(ohlcv[:, :4])
price_max = np.max(ohlcv[:, :4])
volume_min = np.min(ohlcv[:, 4])
volume_max = np.max(ohlcv[:, 4])
# Avoid division by zero
if price_max == price_min: price_max += 1.0
if volume_max == volume_min: volume_max += 1.0
# Store params for denormalization later
norm_params[tf] = {
'price_min': float(price_min),
'price_max': float(price_max),
'volume_min': float(volume_min),
'volume_max': float(volume_max)
}
# Normalize in-place
ohlcv[:, :4] = (ohlcv[:, :4] - price_min) / (price_max - price_min)
ohlcv[:, 4] = (ohlcv[:, 4] - volume_min) / (volume_max - volume_min)
# Convert to tensor [1, seq_len, 5]
import torch
candles_tensor = torch.tensor(ohlcv, dtype=torch.float32).unsqueeze(0)
data[f'price_data_{tf}'] = candles_tensor
# Add placeholder data for other inputs if we have at least one timeframe
if data:
import torch
# Correct shapes based on model expectation
# tech_data: [1, 40]
# market_data: [1, 30]
# cob_data: [1, 600, 100]
data['tech_data'] = torch.zeros(1, 40, dtype=torch.float32)
data['market_data'] = torch.zeros(1, 30, dtype=torch.float32)
data['cob_data'] = torch.zeros(1, 600, 100, dtype=torch.float32)
# Move to device if available
if hasattr(self.orchestrator, 'device'):
device = self.orchestrator.device
for k, v in data.items():
data[k] = v.to(device)
return data, norm_params if data else (None, None)
except Exception as e:
logger.debug(f"Error getting realtime market data: {e}")
import traceback
logger.debug(traceback.format_exc())
return None, None
def _train_on_new_candle(self, session: Dict, symbol: str, timeframe: str, data_provider) -> Dict:
"""
Train model on new candle - Pure model interface with NO business logic
Args:
session: Training session containing pending_action set by app
symbol: Trading symbol
timeframe: Timeframe for training
data_provider: Data provider for fetching candles
Returns:
Dict with training metrics: {loss, accuracy, training_steps}
"""
try:
# Get latest candles
df = data_provider.get_historical_data(symbol, timeframe, limit=2)
if df is None or len(df) < 2:
return {'success': False, 'error': 'Insufficient data'}
# Check if we have a new candle
latest_candle_time = df.index[-1]
if session['last_candle_time'] == latest_candle_time:
return {'success': False, 'error': 'Same candle, no training needed'}
logger.debug(f"New candle detected: {latest_candle_time} (last: {session['last_candle_time']})")
session['last_candle_time'] = latest_candle_time
# Get the completed candle (second to last) and next candle
completed_candle = df.iloc[-2]
next_candle = df.iloc[-1]
completed_timestamp = str(completed_candle.name)
# Get action from session (set by app's training strategy)
action_label = session.get('pending_action')
if not action_label:
return {'success': False, 'error': 'No pending_action in session'}
# CRITICAL: Try to use stored inference input data frame if available
# This ensures we train on exactly what the model saw during inference
cache = session.get('inference_input_cache', {})
stored_inputs = cache.get(completed_timestamp)
if stored_inputs:
# Use stored input data frame from inference
logger.info(f"Using stored inference inputs for training on {symbol} {timeframe} @ {completed_timestamp}")
# Get actual candle data for target
actual_candle = [
float(next_candle['open']),
float(next_candle['high']),
float(next_candle['low']),
float(next_candle['close']),
float(next_candle['volume'])
]
# Create training batch from stored inputs
import torch
# Get device from orchestrator
device = getattr(self.orchestrator, 'device', torch.device('cpu'))
if hasattr(self.orchestrator, 'primary_transformer_trainer') and self.orchestrator.primary_transformer_trainer:
if hasattr(self.orchestrator.primary_transformer_trainer.model, 'device'):
device = next(self.orchestrator.primary_transformer_trainer.model.parameters()).device
# Move stored inputs back to device (they were stored on CPU)
batch = {}
for k, v in stored_inputs['model_inputs'].items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(device)
else:
batch[k] = v
# Add actual candle as target (normalize using stored params)
norm_params = stored_inputs['norm_params']
if timeframe in norm_params:
params = norm_params[timeframe]
price_min = params['price_min']
price_max = params['price_max']
vol_min = params['volume_min']
vol_max = params['volume_max']
# Normalize actual candle
normalized_candle = [
(actual_candle[0] - price_min) / (price_max - price_min), # Open
(actual_candle[1] - price_min) / (price_max - price_min), # High
(actual_candle[2] - price_min) / (price_max - price_min), # Low
(actual_candle[3] - price_min) / (price_max - price_min), # Close
(actual_candle[4] - vol_min) / (vol_max - vol_min) if vol_max > vol_min else 0.0 # Volume
]
# Add target candle to batch
target_key = f'future_candle_{timeframe}'
batch[target_key] = torch.tensor([normalized_candle], dtype=torch.float32, device=device)
# Add action target
action_map = {'HOLD': 0, 'BUY': 1, 'SELL': 2}
batch['actions'] = torch.tensor([[action_map.get(action_label, 0)]], dtype=torch.long, device=device)
# Train directly on batch
model_name = session['model_name']
if model_name == 'Transformer':
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if trainer:
with self._training_lock:
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False)
if result:
loss = result.get('total_loss', 0)
accuracy = result.get('accuracy', 0)
# Update metrics
self.realtime_training_metrics['total_steps'] += 1
self.realtime_training_metrics['total_loss'] += loss
self.realtime_training_metrics['total_accuracy'] += accuracy
self.realtime_training_metrics['losses'].append(loss)
self.realtime_training_metrics['accuracies'].append(accuracy)
if len(self.realtime_training_metrics['losses']) > 100:
self.realtime_training_metrics['losses'].pop(0)
self.realtime_training_metrics['accuracies'].pop(0)
session['metrics']['loss'] = sum(self.realtime_training_metrics['losses']) / len(self.realtime_training_metrics['losses'])
session['metrics']['accuracy'] = sum(self.realtime_training_metrics['accuracies']) / len(self.realtime_training_metrics['accuracies'])
session['metrics']['steps'] = self.realtime_training_metrics['total_steps']
# Remove from cache after training
if completed_timestamp in cache:
del cache[completed_timestamp]
logger.info(f"Trained on stored inference inputs: {symbol} {timeframe} @ {completed_timestamp} action={action_label} (Loss: {loss:.4f}, Acc: {accuracy:.2%})")
return {
'success': True,
'loss': session['metrics']['loss'],
'accuracy': session['metrics']['accuracy'],
'training_steps': session['metrics']['steps'],
'used_stored_inputs': True
}
# Fall through to regular training if stored inputs failed
logger.warning(f"Failed to use stored inputs, falling back to fresh data")
# Fallback: Fetch fresh market state for training (original behavior)
market_state = self._fetch_market_state_for_candle(symbol, completed_candle.name, data_provider)
# Calculate price change
price_change = (next_candle['close'] - completed_candle['close']) / completed_candle['close']
# Create training sample
training_sample = {
'symbol': symbol,
'timestamp': completed_candle.name,
'market_state': market_state,
'action': action_label,
'entry_price': float(completed_candle['close']),
'exit_price': float(next_candle['close']),
'profit_loss_pct': price_change * 100,
'direction': 'LONG' if action_label == 'BUY' else ('SHORT' if action_label == 'SELL' else 'HOLD')
}
# Train based on model type
model_name = session['model_name']
if model_name == 'Transformer':
self._train_transformer_on_sample(training_sample)
# Update session metrics with latest realtime metrics
if len(self.realtime_training_metrics['losses']) > 0:
session['metrics']['loss'] = sum(self.realtime_training_metrics['losses']) / len(self.realtime_training_metrics['losses'])
session['metrics']['accuracy'] = sum(self.realtime_training_metrics['accuracies']) / len(self.realtime_training_metrics['accuracies'])
session['metrics']['steps'] = self.realtime_training_metrics['total_steps']
logger.info(f"Trained on candle: {symbol} {timeframe} @ {completed_candle.name} action={action_label} (change: {price_change:+.2%})")
return {
'success': True,
'loss': session['metrics']['loss'],
'accuracy': session['metrics']['accuracy'],
'training_steps': session['metrics']['steps'],
'used_stored_inputs': False
}
return {'success': False, 'error': f'Unsupported model: {model_name}'}
except Exception as e:
logger.warning(f"Error training on new candle: {e}")
return {'success': False, 'error': str(e)}
def _fetch_market_state_for_candle(self, symbol: str, timestamp, data_provider) -> Dict:
"""Fetch market state with OHLCV data for model training - ENSURES 600 CANDLES ARE AVAILABLE"""
try:
# Get market state with OHLCV data only (NO business logic)
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
# CRITICAL: Training requires exactly 600 candles per timeframe
required_limit = 600
for tf in ['1s', '1m', '1h', '1d']:
# First try to get data from cache
df = data_provider.get_historical_data(symbol, tf, limit=required_limit)
# If insufficient data, force a refresh from API and cache it
if df is None or df.empty or len(df) < required_limit:
logger.info(f"Fetching {required_limit} candles for {symbol} {tf} from API (insufficient cached data)")
try:
# Force refresh from API and persist to cache
df = data_provider.get_historical_data(
symbol, tf, limit=required_limit,
refresh=True, persist=True
)
logger.info(f"Successfully cached {len(df) if df is not None else 0} candles for {symbol} {tf}")
except Exception as api_error:
logger.warning(f"Failed to fetch {symbol} {tf} from API: {api_error}")
continue
# Verify we have enough data
if df is not None and not df.empty and len(df) >= required_limit:
# Take the most recent required_limit candles
df = df.tail(required_limit)
market_state['timeframes'][tf] = {
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': df['open'].tolist(),
'high': df['high'].tolist(),
'low': df['low'].tolist(),
'close': df['close'].tolist(),
'volume': df['volume'].tolist()
}
logger.debug(f"Prepared {len(df)} candles for {symbol} {tf}")
else:
logger.warning(f"Still insufficient data for {symbol} {tf} after API fetch: {len(df) if df is not None else 0} < {required_limit}")
# Also fetch BTC reference data for 1m timeframe
btc_symbol = 'BTC/USDT'
btc_tf = '1m'
try:
btc_df = data_provider.get_historical_data(btc_symbol, btc_tf, limit=required_limit)
if btc_df is None or btc_df.empty or len(btc_df) < required_limit:
logger.info(f"Fetching BTC reference data for training")
btc_df = data_provider.get_historical_data(
btc_symbol, btc_tf, limit=required_limit,
refresh=True, persist=True
)
if btc_df is not None and not btc_df.empty and len(btc_df) >= required_limit:
btc_df = btc_df.tail(required_limit)
market_state['secondary_timeframes'][btc_symbol] = {
btc_tf: {
'timestamps': btc_df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': btc_df['open'].tolist(),
'high': btc_df['high'].tolist(),
'low': btc_df['low'].tolist(),
'close': btc_df['close'].tolist(),
'volume': btc_df['volume'].tolist()
}
}
except Exception as btc_error:
logger.warning(f"Failed to fetch BTC reference data: {btc_error}")
return market_state
except Exception as e:
logger.warning(f"Error fetching market state for candle: {e}")
import traceback
logger.debug(traceback.format_exc())
return {}
def _convert_prediction_to_batch(self, prediction_sample: Dict, timeframe: str):
"""
Convert a validated prediction to a training batch
Args:
prediction_sample: Dict with predicted_candle, actual_candle, market_state, etc.
timeframe: Target timeframe for prediction
Returns:
Batch dict ready for trainer.train_step()
"""
try:
market_state = prediction_sample.get('market_state', {})
if not market_state or 'timeframes' not in market_state:
logger.warning("No market state in prediction sample")
return None
# Use existing conversion method but with actual target
annotation = {
'symbol': prediction_sample.get('symbol', 'ETH/USDT'),
'timestamp': prediction_sample.get('timestamp'),
'action': 'BUY', # Placeholder, not used for candle prediction training
'entry_price': float(prediction_sample['predicted_candle'][0]), # Open
'market_state': market_state
}
# Convert using existing method
batch = self._convert_annotation_to_transformer_batch(annotation)
if not batch:
return None
# Override the future candle target with actual candle data
actual = prediction_sample['actual_candle'] # [O, H, L, C]
# Create target tensor for the specific timeframe
import torch
device = batch['prices_1m'].device if 'prices_1m' in batch else torch.device('cpu')
# Target candle: [O, H, L, C, V] - we don't have actual volume, use predicted
target_candle = [
actual[0], # Open
actual[1], # High
actual[2], # Low
actual[3], # Close
prediction_sample['predicted_candle'][4] # Volume (from prediction)
]
# Add to batch based on timeframe
if timeframe == '1s':
batch['future_candle_1s'] = torch.tensor([target_candle], dtype=torch.float32, device=device)
elif timeframe == '1m':
batch['future_candle_1m'] = torch.tensor([target_candle], dtype=torch.float32, device=device)
elif timeframe == '1h':
batch['future_candle_1h'] = torch.tensor([target_candle], dtype=torch.float32, device=device)
logger.debug(f"Converted prediction to batch for {timeframe} timeframe")
return batch
except Exception as e:
logger.error(f"Error converting prediction to batch: {e}", exc_info=True)
return None
def _train_transformer_on_sample(self, training_sample: Dict):
"""Train transformer on a single sample with checkpoint saving"""
try:
if not self.orchestrator:
return
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if not trainer:
return
# Convert to batch format
batch = self._convert_annotation_to_transformer_batch(training_sample)
if not batch:
logger.warning(f"Per-candle training failed: Could not convert sample to batch")
return
# Validate batch has required keys
required_keys = ['actions', 'price_data_1m', 'price_data_1h', 'price_data_1d']
missing_keys = [k for k in required_keys if k not in batch or batch[k] is None]
if missing_keys:
logger.warning(f"Per-candle training skipped: Missing required keys: {missing_keys}")
return
# CRITICAL: Acquire training lock to prevent concurrent model access
# This prevents "inplace operation" errors when batch training runs simultaneously
import torch
import threading
# Try to acquire lock with timeout to prevent deadlock
lock_acquired = self._training_lock.acquire(timeout=5.0)
if not lock_acquired:
logger.warning("Could not acquire training lock within 5 seconds - skipping this training step")
return
try:
self._training_lock_holder = threading.current_thread().name
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False)
if result:
loss = result.get('total_loss', 0)
accuracy = result.get('accuracy', 0)
# Update metrics tracking
self.realtime_training_metrics['total_steps'] += 1
self.realtime_training_metrics['total_loss'] += loss
self.realtime_training_metrics['total_accuracy'] += accuracy
# Maintain rolling window (last 100 steps)
self.realtime_training_metrics['losses'].append(loss)
self.realtime_training_metrics['accuracies'].append(accuracy)
if len(self.realtime_training_metrics['losses']) > 100:
self.realtime_training_metrics['losses'].pop(0)
self.realtime_training_metrics['accuracies'].pop(0)
# Calculate rolling average
avg_loss = sum(self.realtime_training_metrics['losses']) / len(self.realtime_training_metrics['losses'])
avg_accuracy = sum(self.realtime_training_metrics['accuracies']) / len(self.realtime_training_metrics['accuracies'])
logger.info(f"Per-candle training: Loss={loss:.4f} (avg: {avg_loss:.4f}), Acc={accuracy:.2%} (avg: {avg_accuracy:.2%})")
# Check if model improved (save checkpoint)
improved = False
if avg_loss < self.realtime_training_metrics['best_loss']:
self.realtime_training_metrics['best_loss'] = avg_loss
improved = True
logger.info(f" NEW BEST LOSS: {avg_loss:.4f}")
if avg_accuracy > self.realtime_training_metrics['best_accuracy']:
self.realtime_training_metrics['best_accuracy'] = avg_accuracy
improved = True
logger.info(f" NEW BEST ACCURACY: {avg_accuracy:.2%}")
# Save checkpoint if improved or every N steps
steps_since_checkpoint = self.realtime_training_metrics['total_steps'] - self.realtime_training_metrics['last_checkpoint_step']
if improved or steps_since_checkpoint >= self.realtime_training_metrics['checkpoint_frequency']:
self._save_realtime_checkpoint(
trainer=trainer,
step=self.realtime_training_metrics['total_steps'],
loss=avg_loss,
accuracy=avg_accuracy,
improved=improved
)
self.realtime_training_metrics['last_checkpoint_step'] = self.realtime_training_metrics['total_steps']
finally:
# CRITICAL: Always release the lock, even if an exception occurs
self._training_lock_holder = None
self._training_lock.release()
except Exception as e:
logger.warning(f"Error training transformer on sample: {e}")
def _save_realtime_checkpoint(self, trainer, step: int, loss: float, accuracy: float, improved: bool = False):
"""
Save checkpoint during real-time training
Args:
trainer: Model trainer instance
step: Current training step
loss: Current average loss
accuracy: Current average accuracy
improved: Whether this is an improvement checkpoint
"""
try:
import torch
import os
from datetime import datetime
checkpoint_dir = "models/checkpoints/transformer/realtime"
os.makedirs(checkpoint_dir, exist_ok=True)
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
checkpoint_type = "BEST" if improved else "periodic"
checkpoint_path = os.path.join(checkpoint_dir, f"realtime_{checkpoint_type}_step{step}_{timestamp}.pt")
# Save checkpoint
torch.save({
'step': step,
'model_state_dict': trainer.model.state_dict(),
'optimizer_state_dict': trainer.optimizer.state_dict(),
'scheduler_state_dict': trainer.scheduler.state_dict() if hasattr(trainer, 'scheduler') else None,
'loss': loss,
'accuracy': accuracy,
'learning_rate': trainer.scheduler.get_last_lr()[0] if hasattr(trainer, 'scheduler') else trainer.optimizer.param_groups[0]['lr'],
'training_type': 'realtime_per_candle',
'metrics': {
'total_steps': self.realtime_training_metrics['total_steps'],
'best_loss': self.realtime_training_metrics['best_loss'],
'best_accuracy': self.realtime_training_metrics['best_accuracy'],
'rolling_losses': self.realtime_training_metrics['losses'][-10:], # Last 10
'rolling_accuracies': self.realtime_training_metrics['accuracies'][-10:]
}
}, checkpoint_path)
logger.info(f" SAVED REALTIME CHECKPOINT: {checkpoint_path}")
logger.info(f" Step: {step}, Loss: {loss:.4f}, Acc: {accuracy:.2%}, Improved: {improved}")
# Save metadata to database
try:
from utils.database_manager import get_database_manager
db_manager = get_database_manager()
checkpoint_id = f"realtime_step{step}_{timestamp}"
from utils.database_manager import CheckpointMetadata
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name="transformer_realtime",
model_type="transformer",
timestamp=datetime.now(),
performance_metrics={
'loss': float(loss),
'accuracy': float(accuracy),
'step': step,
'best_loss': float(self.realtime_training_metrics['best_loss']),
'best_accuracy': float(self.realtime_training_metrics['best_accuracy'])
},
training_metadata={
'training_type': 'realtime_per_candle',
'total_steps': self.realtime_training_metrics['total_steps'],
'checkpoint_type': checkpoint_type
},
file_path=checkpoint_path,
file_size_mb=os.path.getsize(checkpoint_path) / (1024 * 1024),
is_active=True
)
if db_manager.save_checkpoint_metadata(metadata):
logger.info(f" Saved checkpoint metadata to database: {checkpoint_id}")
except Exception as meta_error:
logger.warning(f" Could not save checkpoint metadata: {meta_error}")
# Cleanup: Keep only best 10 checkpoints
if improved:
self._cleanup_realtime_checkpoints(checkpoint_dir, keep_best=10)
except Exception as e:
logger.error(f"Error saving realtime checkpoint: {e}")
def _cleanup_realtime_checkpoints(self, checkpoint_dir: str, keep_best: int = 10):
"""Keep only the best N realtime checkpoints"""
try:
if not os.path.exists(checkpoint_dir):
return
import torch
import time
# Add small delay to ensure files are fully written
time.sleep(0.5)
checkpoints = []
for filename in os.listdir(checkpoint_dir):
if filename.endswith('.pt') and filename.startswith('realtime_'):
filepath = os.path.join(checkpoint_dir, filename)
# Check if file exists and is not being written
if not os.path.exists(filepath):
continue
try:
checkpoint = torch.load(filepath, map_location='cpu')
checkpoints.append({
'path': filepath,
'loss': checkpoint.get('loss', float('inf')),
'accuracy': checkpoint.get('accuracy', 0),
'step': checkpoint.get('step', 0),
'is_best': 'BEST' in filename
})
except Exception as e:
logger.debug(f"Could not load checkpoint {filename}: {e}")
# Sort by accuracy (higher is better), then by loss (lower is better)
checkpoints.sort(key=lambda x: (x['accuracy'], -x['loss']), reverse=True)
# Keep best N checkpoints
for checkpoint in checkpoints[keep_best:]:
try:
# Double-check file still exists before deleting
if os.path.exists(checkpoint['path']):
os.remove(checkpoint['path'])
logger.debug(f"Removed old realtime checkpoint: {os.path.basename(checkpoint['path'])}")
except Exception as e:
logger.warning(f"Could not remove checkpoint: {e}")
except Exception as e:
logger.error(f"Error cleaning up realtime checkpoints: {e}")
def _load_best_realtime_checkpoint(self):
"""Load the best realtime checkpoint on startup to resume training"""
try:
import torch
import os
checkpoint_dir = "models/checkpoints/transformer/realtime"
if not os.path.exists(checkpoint_dir):
logger.info("No realtime checkpoints found, starting fresh")
return
# Find best checkpoint
checkpoints = []
for filename in os.listdir(checkpoint_dir):
if filename.endswith('.pt') and filename.startswith('realtime_'):
filepath = os.path.join(checkpoint_dir, filename)
try:
checkpoint = torch.load(filepath, map_location='cpu')
checkpoints.append({
'path': filepath,
'loss': checkpoint.get('loss', float('inf')),
'accuracy': checkpoint.get('accuracy', 0),
'step': checkpoint.get('step', 0),
'checkpoint': checkpoint
})
except Exception as e:
logger.debug(f"Could not load checkpoint {filename}: {e}")
if not checkpoints:
logger.info("No valid realtime checkpoints found")
return
# Sort by accuracy, then by loss
checkpoints.sort(key=lambda x: (x['accuracy'], -x['loss']), reverse=True)
best = checkpoints[0]
# Restore metrics from checkpoint
if 'metrics' in best['checkpoint']:
saved_metrics = best['checkpoint']['metrics']
self.realtime_training_metrics['total_steps'] = saved_metrics.get('total_steps', 0)
self.realtime_training_metrics['best_loss'] = saved_metrics.get('best_loss', float('inf'))
self.realtime_training_metrics['best_accuracy'] = saved_metrics.get('best_accuracy', 0.0)
self.realtime_training_metrics['losses'] = saved_metrics.get('rolling_losses', [])
self.realtime_training_metrics['accuracies'] = saved_metrics.get('rolling_accuracies', [])
self.realtime_training_metrics['last_checkpoint_step'] = best['step']
# Load model weights if orchestrator is available
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer_trainer'):
trainer = self.orchestrator.primary_transformer_trainer
if trainer and trainer.model:
trainer.model.load_state_dict(best['checkpoint']['model_state_dict'])
trainer.optimizer.load_state_dict(best['checkpoint']['optimizer_state_dict'])
if 'scheduler_state_dict' in best['checkpoint'] and best['checkpoint']['scheduler_state_dict']:
trainer.scheduler.load_state_dict(best['checkpoint']['scheduler_state_dict'])
logger.info(f"RESUMED REALTIME TRAINING from checkpoint:")
logger.info(f" Step: {best['step']}, Loss: {best['loss']:.4f}, Acc: {best['accuracy']:.2%}")
logger.info(f" Path: {os.path.basename(best['path'])}")
else:
logger.info(f"Found realtime checkpoint but trainer not available yet")
else:
logger.info(f"Found realtime checkpoint but orchestrator not available yet")
except Exception as e:
logger.warning(f"Error loading realtime checkpoint: {e}")
logger.info("Starting realtime training from scratch")
def _get_sleep_time_for_timeframe(self, timeframe: str) -> float:
"""Get appropriate sleep time based on timeframe"""
timeframe_seconds = {
'1s': 1,
'1m': 5, # Check every 5 seconds for new 1m candle
'5m': 30,
'15m': 60,
'1h': 300,
'4h': 600,
'1d': 3600
}
return timeframe_seconds.get(timeframe, 5)
def _store_training_prediction(self, batch: Dict, trainer, symbol: str):
"""Store a prediction from training batch for visualization"""
try:
import torch
# Make prediction on the batch (without training)
with torch.no_grad():
trainer.model.eval()
# Get prediction from model
outputs = trainer.model(
price_data_1s=batch.get('price_data_1s'),
price_data_1m=batch.get('price_data_1m'),
price_data_1h=batch.get('price_data_1h'),
price_data_1d=batch.get('price_data_1d'),
tech_data=batch.get('tech_data'),
market_data=batch.get('market_data')
)
trainer.model.train()
# Extract action prediction
action_probs = outputs.get('action_probs')
if action_probs is not None:
# Handle different tensor shapes: [batch, 3] or [3]
if action_probs.dim() == 1:
# Shape [3] - single prediction
action_idx = torch.argmax(action_probs, dim=0).item()
confidence = action_probs[action_idx].item()
else:
# Shape [batch, 3] - take first batch item
action_idx = torch.argmax(action_probs[0], dim=0).item()
confidence = action_probs[0, action_idx].item()
# Map to BUY/SELL/HOLD
actions = ['BUY', 'SELL', 'HOLD']
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
# Get current price from batch metadata
current_price = batch.get('metadata', {}).get('current_price', 0)
timestamp = batch.get('metadata', {}).get('timestamp', datetime.now())
if current_price > 0:
# Store in orchestrator
if hasattr(self.orchestrator, 'store_transformer_prediction'):
self.orchestrator.store_transformer_prediction(symbol, {
'timestamp': timestamp,
'current_price': current_price,
'predicted_price': current_price * (1.01 if action == 'BUY' else 0.99),
'price_change': 1.0 if action == 'BUY' else -1.0,
'confidence': confidence,
'action': action,
'horizon_minutes': 10,
'source': 'training'
})
logger.debug(f"Stored training prediction: {action} @ {current_price} (conf: {confidence:.2f})")
except Exception as e:
logger.debug(f"Error storing training prediction: {e}")
def _realtime_inference_loop(self, inference_id: str, model_name: str, symbol: str, data_provider):
"""
Real-time inference loop with optional per-candle training
This runs in a background thread and continuously makes predictions.
Can optionally train on every new candle.
"""
session = self.inference_sessions[inference_id]
train_every_candle = session.get('train_every_candle', False)
timeframe = session.get('timeframe', '1m')
try:
while not session['stop_flag']:
try:
# Get current market data
current_price = data_provider.get_current_price(symbol)
if not current_price:
time.sleep(1)
continue
# Make prediction using the model - returns tuple (prediction_dict, market_data_dict)
prediction_result = self._make_realtime_prediction(model_name, symbol, data_provider)
# Unpack tuple: prediction is the dict, market_data_info contains norm_params
if prediction_result is None:
time.sleep(1)
continue
prediction, market_data_info = prediction_result
# Register inference frame reference for later training when actual candle arrives
# This stores a reference (timestamp range) instead of copying 600 candles
# The reference allows us to retrieve the exact data from DuckDB when training
if prediction and self.training_coordinator and market_data_info:
# Get norm_params from market_data_info
norm_params = market_data_info.get('norm_params', {})
self._register_inference_frame(session, symbol, timeframe, prediction, data_provider, norm_params)
if prediction:
# Store signal
signal = {
'timestamp': datetime.now(timezone.utc).isoformat(),
'symbol': symbol,
'model': model_name,
'action': prediction['action'],
'confidence': prediction['confidence'],
'price': current_price,
'predicted_price': prediction.get('predicted_price'),
'predicted_candle': prediction.get('predicted_candle')
}
# Store signal (all signals, including rejected ones)
session['signals'].append(signal)
# Keep only last 100 signals
if len(session['signals']) > 100:
session['signals'] = session['signals'][-100:]
# Execute trade logic (only if confidence is high enough and position logic allows)
executed_trade = self._execute_realtime_trade(session, signal, current_price)
if executed_trade:
logger.info(f"Live Trade EXECUTED: {executed_trade['action']} @ {executed_trade['price']:.2f} (conf: {signal['confidence']:.2f})")
# Send executed trade to frontend via WebSocket
if hasattr(self, 'socketio') and self.socketio:
self.socketio.emit('executed_trade', {
'trade': executed_trade,
'position_state': {
'has_position': session['position'] is not None,
'position_type': session['position']['type'] if session['position'] else None,
'entry_price': session['position']['entry_price'] if session['position'] else None,
'unrealized_pnl': self._calculate_unrealized_pnl(session, current_price) if session['position'] else 0.0
},
'session_metrics': {
'total_pnl': session['total_pnl'],
'total_trades': session['total_trades'],
'win_count': session['win_count'],
'loss_count': session['loss_count'],
'win_rate': (session['win_count'] / session['total_trades'] * 100) if session['total_trades'] > 0 else 0
}
})
else:
rejection_reason = self._get_rejection_reason(session, signal, current_price)
logger.info(f"Live Signal (NOT executed): {signal['action']} @ {signal['price']:.2f} (conf: {signal['confidence']:.2f}) - {rejection_reason}")
# Store prediction for visualization (INCLUDE predicted_candle for ghost candles!)
if self.orchestrator and hasattr(self.orchestrator, 'store_transformer_prediction'):
# Get denormalized predicted_price (should already be denormalized from _make_realtime_prediction_internal)
predicted_price = prediction.get('predicted_price')
# Always get actual current_price from latest candle to ensure it's denormalized
# This is more reliable than trusting get_current_price which might return normalized values
actual_current_price = current_price
try:
df_latest = data_provider.get_historical_data(symbol, timeframe, limit=1, refresh=False)
if df_latest is not None and not df_latest.empty:
actual_current_price = float(df_latest['close'].iloc[-1])
else:
# Try other timeframes
for tf in ['1m', '1h', '1d']:
if tf != timeframe:
df_tf = data_provider.get_historical_data(symbol, tf, limit=1, refresh=False)
if df_tf is not None and not df_tf.empty:
actual_current_price = float(df_tf['close'].iloc[-1])
break
except Exception as e:
logger.debug(f"Error getting actual price from candle: {e}")
# Fallback: if current_price looks normalized (< 1000 for ETH/USDT), try to denormalize
if current_price < 1000 and symbol == 'ETH/USDT': # ETH should be > 1000, normalized would be < 1
if market_data_info and 'norm_params' in market_data_info:
norm_params = market_data_info['norm_params']
if '1m' in norm_params:
params = norm_params['1m']
price_min = params['price_min']
price_max = params['price_max']
# Denormalize: price = normalized * (max - min) + min
actual_current_price = float(current_price * (price_max - price_min) + price_min)
prediction_data = {
'timestamp': datetime.now(timezone.utc).isoformat(),
'current_price': actual_current_price, # Use denormalized price
'predicted_price': predicted_price if predicted_price is not None else actual_current_price,
'price_change': 1.0 if prediction['action'] == 'BUY' else -1.0,
'confidence': prediction['confidence'],
'action': prediction['action'],
'horizon_minutes': 10,
'source': 'live_inference'
}
# Include REAL predicted_candle from model (for ghost candles)
if 'predicted_candle' in prediction and prediction['predicted_candle']:
# Ensure predicted_candle values are Python native types (not tensors)
predicted_candle_clean = {}
for tf, candle_data in prediction['predicted_candle'].items():
if isinstance(candle_data, (list, tuple)):
# Convert list/tuple elements to Python scalars
predicted_candle_clean[tf] = [
float(v.item() if hasattr(v, 'item') else v)
for v in candle_data
]
elif hasattr(candle_data, 'tolist'):
# Tensor array - convert to list
predicted_candle_clean[tf] = [float(v) for v in candle_data.tolist()]
else:
predicted_candle_clean[tf] = candle_data
prediction_data['predicted_candle'] = predicted_candle_clean
logger.info(f"📊 Storing prediction with ghost candles for {len(predicted_candle_clean)} timeframes: {list(predicted_candle_clean.keys())}")
# Use actual predicted price from candle close (ensure it's a Python float)
predicted_price_val = None
if '1m' in predicted_candle_clean:
close_val = predicted_candle_clean['1m'][3]
predicted_price_val = float(close_val.item() if hasattr(close_val, 'item') else close_val)
elif '1s' in predicted_candle_clean:
close_val = predicted_candle_clean['1s'][3]
predicted_price_val = float(close_val.item() if hasattr(close_val, 'item') else close_val)
if predicted_price_val is not None:
prediction_data['predicted_price'] = predicted_price_val
# Calculate price_change using denormalized prices
prediction_data['price_change'] = ((predicted_price_val - actual_current_price) / actual_current_price) * 100
else:
# Fallback: use predicted_price from prediction dict (should be denormalized)
fallback_predicted = prediction.get('predicted_price')
if fallback_predicted is not None:
prediction_data['predicted_price'] = fallback_predicted
prediction_data['price_change'] = ((fallback_predicted - actual_current_price) / actual_current_price) * 100
else:
prediction_data['predicted_price'] = actual_current_price
prediction_data['price_change'] = 1.0 if prediction['action'] == 'BUY' else -1.0
else:
# Fallback to estimated price if no candle prediction
logger.warning(f"!!! No predicted_candle in prediction object - ghost candles will not appear!")
prediction_data['predicted_price'] = prediction.get('predicted_price', current_price * (1.01 if prediction['action'] == 'BUY' else 0.99))
prediction_data['price_change'] = 1.0 if prediction['action'] == 'BUY' else -1.0
# Include trend_vector if available (convert tensors to Python types and denormalize)
if 'trend_vector' in prediction:
trend_vec = prediction['trend_vector']
# Get normalization params for denormalization
norm_params_for_denorm = {}
if market_data_info and 'norm_params' in market_data_info:
norm_params_for_denorm = market_data_info['norm_params']
# Convert any tensors to Python native types and denormalize price values
if isinstance(trend_vec, dict):
serialized_trend = {}
for key, value in trend_vec.items():
if hasattr(value, 'numel'): # Tensor
if value.numel() == 1: # Scalar tensor
val = value.item()
# Denormalize price_delta if it's a price-related value
if key == 'price_delta' and norm_params_for_denorm:
val = self._denormalize_price_value(val, norm_params_for_denorm, '1m')
serialized_trend[key] = val
else: # Multi-element tensor
val_list = value.detach().cpu().tolist()
# Denormalize pivot_prices if it's a price array (can be nested)
if key == 'pivot_prices' and norm_params_for_denorm:
val_list = self._denormalize_nested_price_array(val_list, norm_params_for_denorm, '1m')
serialized_trend[key] = val_list
elif hasattr(value, 'tolist'): # Other array-like
val_list = value.tolist()
if key == 'pivot_prices' and norm_params_for_denorm:
val_list = self._denormalize_nested_price_array(val_list, norm_params_for_denorm, '1m')
serialized_trend[key] = val_list
elif isinstance(value, (list, tuple)):
# Recursively convert list/tuple of tensors
serialized_list = []
for v in value:
if hasattr(v, 'numel'):
if v.numel() == 1:
val = v.item()
if key == 'pivot_prices' and norm_params_for_denorm:
val = self._denormalize_price_value(val, norm_params_for_denorm, '1m')
serialized_list.append(val)
else:
val_list = v.detach().cpu().tolist()
if key == 'pivot_prices' and norm_params_for_denorm:
# Handle nested arrays (pivot_prices is [[p1, p2, p3, ...]])
val_list = self._denormalize_nested_price_array(val_list, norm_params_for_denorm, '1m')
serialized_list.append(val_list)
elif hasattr(v, 'tolist'):
val_list = v.tolist()
if key == 'pivot_prices' and norm_params_for_denorm:
# Handle nested arrays
val_list = self._denormalize_nested_price_array(val_list, norm_params_for_denorm, '1m')
serialized_list.append(val_list)
elif isinstance(v, (list, tuple)):
# Nested list - handle pivot_prices structure
if key == 'pivot_prices' and norm_params_for_denorm:
nested_denorm = self._denormalize_nested_price_array(list(v), norm_params_for_denorm, '1m')
serialized_list.append(nested_denorm)
else:
serialized_list.append(list(v))
else:
serialized_list.append(v)
serialized_trend[key] = serialized_list
else:
# Denormalize price_delta if it's a scalar
if key == 'price_delta' and isinstance(value, (int, float)) and norm_params_for_denorm:
serialized_trend[key] = self._denormalize_price_value(value, norm_params_for_denorm, '1m')
else:
serialized_trend[key] = value
# Denormalize vector array if it contains price deltas
if 'vector' in serialized_trend and isinstance(serialized_trend['vector'], list) and norm_params_for_denorm:
vector = serialized_trend['vector']
if len(vector) > 0 and isinstance(vector[0], list) and len(vector[0]) > 0:
# vector is [[price_delta, time_delta]]
price_delta_norm = vector[0][0]
price_delta_denorm = self._denormalize_price_value(price_delta_norm, norm_params_for_denorm, '1m')
serialized_trend['vector'] = [[price_delta_denorm, vector[0][1]]]
prediction_data['trend_vector'] = serialized_trend
else:
prediction_data['trend_vector'] = trend_vec
if hasattr(self.orchestrator, 'store_transformer_prediction'):
self.orchestrator.store_transformer_prediction(symbol, prediction_data)
# Training decision using strategy manager
training_strategy = session.get('training_strategy')
if training_strategy and training_strategy.mode != 'none':
# Get pivot markers for training decision
pivot_markers = {}
if hasattr(training_strategy, 'dashboard') and training_strategy.dashboard:
try:
df = data_provider.get_historical_data(symbol, timeframe, limit=200)
if df is not None and len(df) >= 10:
pivot_markers = training_strategy.dashboard._get_pivot_markers_for_timeframe(symbol, timeframe, df)
except Exception as e:
logger.debug(f"Could not get pivot markers: {e}")
# Get current candle timestamp
df_current = data_provider.get_historical_data(symbol, timeframe, limit=1)
if df_current is not None and len(df_current) > 0:
current_timestamp = df_current.index[-1]
# Ask strategy manager if we should train
should_train, action_data = training_strategy.should_train_on_candle(
symbol, timeframe, current_timestamp, pivot_markers
)
if should_train and action_data:
# Set action in session for training
session['pending_action'] = action_data['action']
# Call pure training method
train_result = self._train_on_new_candle(session, symbol, timeframe, data_provider)
if train_result.get('success'):
logger.info(f"Training completed: {action_data['action']} (reason: {action_data.get('reason', 'unknown')})")
# Sleep based on timeframe
sleep_time = self._get_sleep_time_for_timeframe(timeframe)
time.sleep(sleep_time)
except Exception as e:
logger.error(f"Error in inference loop: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
time.sleep(5)
logger.info(f"Inference loop stopped: {inference_id}")
except Exception as e:
logger.error(f"Fatal error in inference loop: {e}")
session['status'] = 'error'
session['error'] = str(e)
def _execute_realtime_trade(self, session: Dict, signal: Dict, current_price: float) -> Optional[Dict]:
"""
Execute trade based on signal, respecting position management rules
Rules:
1. Only execute if confidence >= 0.6
2. Only open new position if no position is currently open
3. Close position on opposite signal
4. Track all executed trades for visualization
Returns:
Dict with executed trade info, or None if signal was rejected
"""
action = signal['action']
confidence = signal['confidence']
timestamp = signal['timestamp']
# Rule 1: Confidence threshold
if confidence < 0.6:
return None # Rejected: low confidence
# Rule 2 & 3: Position management
position = session.get('position')
if action == 'BUY':
if position is None:
# Open long position
trade_id = str(uuid.uuid4())[:8]
session['position'] = {
'type': 'long',
'entry_price': current_price,
'entry_time': timestamp,
'entry_id': trade_id,
'signal_confidence': confidence
}
executed_trade = {
'trade_id': trade_id,
'action': 'OPEN_LONG',
'price': current_price,
'timestamp': timestamp,
'confidence': confidence
}
session['executed_trades'].append(executed_trade)
return executed_trade
elif position['type'] == 'short':
# Close short position
entry_price = position['entry_price']
pnl = entry_price - current_price # Short profit
pnl_pct = (pnl / entry_price) * 100
executed_trade = {
'trade_id': position['entry_id'],
'action': 'CLOSE_SHORT',
'price': current_price,
'timestamp': timestamp,
'confidence': confidence,
'entry_price': entry_price,
'entry_time': position['entry_time'],
'pnl': pnl,
'pnl_pct': pnl_pct
}
# Update session metrics
session['total_pnl'] += pnl
session['total_trades'] += 1
if pnl > 0:
session['win_count'] += 1
else:
session['loss_count'] += 1
session['position'] = None
session['executed_trades'].append(executed_trade)
logger.info(f"Position CLOSED: SHORT @ {current_price:.2f}, PnL=${pnl:.2f} ({pnl_pct:+.2f}%)")
return executed_trade
elif action == 'SELL':
if position is None:
# Open short position
trade_id = str(uuid.uuid4())[:8]
session['position'] = {
'type': 'short',
'entry_price': current_price,
'entry_time': timestamp,
'entry_id': trade_id,
'signal_confidence': confidence
}
executed_trade = {
'trade_id': trade_id,
'action': 'OPEN_SHORT',
'price': current_price,
'timestamp': timestamp,
'confidence': confidence
}
session['executed_trades'].append(executed_trade)
return executed_trade
elif position['type'] == 'long':
# Close long position
entry_price = position['entry_price']
pnl = current_price - entry_price # Long profit
pnl_pct = (pnl / entry_price) * 100
executed_trade = {
'trade_id': position['entry_id'],
'action': 'CLOSE_LONG',
'price': current_price,
'timestamp': timestamp,
'confidence': confidence,
'entry_price': entry_price,
'entry_time': position['entry_time'],
'pnl': pnl,
'pnl_pct': pnl_pct
}
# Update session metrics
session['total_pnl'] += pnl
session['total_trades'] += 1
if pnl > 0:
session['win_count'] += 1
else:
session['loss_count'] += 1
session['position'] = None
session['executed_trades'].append(executed_trade)
logger.info(f"Position CLOSED: LONG @ {current_price:.2f}, PnL=${pnl:.2f} ({pnl_pct:+.2f}%)")
return executed_trade
# HOLD or position already open in same direction
return None
def _get_rejection_reason(self, session: Dict, signal: Dict, current_price: float = 0.0) -> str:
"""Get reason why a signal was not executed"""
action = signal['action']
confidence = signal['confidence']
position = session.get('position')
if confidence < 0.6:
return f"Low confidence ({confidence:.2f} < 0.6)"
if action == 'HOLD':
return "HOLD signal (no trade)"
if position:
entry_price = position.get('entry_price', 0.0)
position_type = position.get('type', '').upper()
if action == 'BUY' and position['type'] == 'long':
# Calculate current PnL
unrealized_pnl = self._calculate_unrealized_pnl(session, current_price) if current_price > 0 else 0.0
pnl_sign = '+' if unrealized_pnl >= 0 else ''
return f"Already in LONG position (entry: ${entry_price:.2f}, PnL: {pnl_sign}{unrealized_pnl:.2f}%)"
elif action == 'SELL' and position['type'] == 'short':
# Calculate current PnL
unrealized_pnl = self._calculate_unrealized_pnl(session, current_price) if current_price > 0 else 0.0
pnl_sign = '+' if unrealized_pnl >= 0 else ''
return f"Already in SHORT position (entry: ${entry_price:.2f}, PnL: {pnl_sign}{unrealized_pnl:.2f}%)"
return "Unknown reason"
def _calculate_unrealized_pnl(self, session: Dict, current_price: float) -> float:
"""Calculate unrealized PnL for open position"""
position = session.get('position')
if not position or not current_price:
return 0.0
entry_price = position['entry_price']
if position['type'] == 'long':
return ((current_price - entry_price) / entry_price) * 100 # Percentage
else: # short
return ((entry_price - current_price) / entry_price) * 100 # Percentage
def _denormalize_price_value(self, normalized_value: float, norm_params: Dict, timeframe: str = '1m') -> float:
"""
Denormalize a single price value using normalization parameters
Args:
normalized_value: Normalized price value (0-1 range)
norm_params: Dictionary of normalization parameters by timeframe
timeframe: Timeframe to use for denormalization (default: '1m')
Returns:
Denormalized price value
"""
try:
if timeframe in norm_params:
params = norm_params[timeframe]
price_min = params.get('price_min', 0.0)
price_max = params.get('price_max', 1.0)
if price_max > price_min:
# Denormalize: price = normalized * (max - min) + min
return float(normalized_value * (price_max - price_min) + price_min)
# Fallback: return as-is if no params available
return float(normalized_value)
except Exception as e:
logger.debug(f"Error denormalizing price value: {e}")
return float(normalized_value)
def _denormalize_price_array(self, normalized_array: list, norm_params: Dict, timeframe: str = '1m') -> list:
"""
Denormalize an array of price values using normalization parameters
Args:
normalized_array: List of normalized price values (0-1 range)
norm_params: Dictionary of normalization parameters by timeframe
timeframe: Timeframe to use for denormalization (default: '1m')
Returns:
List of denormalized price values
"""
try:
if timeframe in norm_params:
params = norm_params[timeframe]
price_min = params.get('price_min', 0.0)
price_max = params.get('price_max', 1.0)
if price_max > price_min:
# Denormalize each value: price = normalized * (max - min) + min
return [float(v * (price_max - price_min) + price_min) if isinstance(v, (int, float)) else v
for v in normalized_array]
# Fallback: return as-is if no params available
return [float(v) if isinstance(v, (int, float)) else v for v in normalized_array]
except Exception as e:
logger.debug(f"Error denormalizing price array: {e}")
return [float(v) if isinstance(v, (int, float)) else v for v in normalized_array]
def _denormalize_nested_price_array(self, normalized_array: list, norm_params: Dict, timeframe: str = '1m') -> list:
"""
Denormalize a nested array of price values (e.g., [[p1, p2, p3], [p4, p5, p6]])
Args:
normalized_array: Nested list of normalized price values
norm_params: Dictionary of normalization parameters by timeframe
timeframe: Timeframe to use for denormalization (default: '1m')
Returns:
Nested list of denormalized price values
"""
try:
result = []
for item in normalized_array:
if isinstance(item, (list, tuple)):
# Recursively denormalize nested arrays
result.append(self._denormalize_price_array(list(item), norm_params, timeframe))
else:
# Single value - denormalize it
result.append(self._denormalize_price_value(item, norm_params, timeframe) if isinstance(item, (int, float)) else item)
return result
except Exception as e:
logger.debug(f"Error denormalizing nested price array: {e}")
return normalized_array