refactoring. inference real data triggers

This commit is contained in:
Dobromir Popov
2025-12-09 11:59:15 +02:00
parent 1c1ebf6d7e
commit 992d6de25b
9 changed files with 1970 additions and 224 deletions

View File

@@ -0,0 +1,376 @@
"""
Event-Driven Inference Training System
This system provides:
1. Reference-based inference frame storage (no 600-candle copies)
2. Subscription system for candle completion and pivot events
3. Flexible training methods (backprop for Transformer, others for different models)
4. Integration with DuckDB for efficient data retrieval
Architecture:
- Inference frames stored as references (timestamp ranges) in DuckDB
- Training adapter subscribes to data provider events
- Time-based triggers: candle completion (known result time)
- Event-based triggers: pivot points (L2L, L2H, etc. - unknown timing)
"""
import logging
import threading
from datetime import datetime, timezone, timedelta
from typing import Dict, List, Optional, Callable, Tuple, Any
from dataclasses import dataclass, field
from enum import Enum
import uuid
logger = logging.getLogger(__name__)
class TrainingTriggerType(Enum):
"""Types of training triggers"""
CANDLE_COMPLETION = "candle_completion" # Time-based: next candle closes
PIVOT_EVENT = "pivot_event" # Event-based: pivot detected (L2L, L2H, etc.)
@dataclass
class InferenceFrameReference:
"""
Reference to inference data stored in DuckDB.
No copying - just store timestamp ranges and query when needed.
"""
inference_id: str # Unique ID for this inference
symbol: str
timeframe: str
prediction_timestamp: datetime # When prediction was made
target_timestamp: Optional[datetime] = None # When result will be available (for candles)
# Reference to data in DuckDB (timestamp range)
data_range_start: datetime # Start of 600-candle window
data_range_end: datetime # End of 600-candle window
# Normalization parameters (small, can be stored)
norm_params: Dict[str, Dict[str, float]] = field(default_factory=dict)
# Prediction metadata
predicted_action: Optional[str] = None
predicted_candle: Optional[Dict[str, List[float]]] = None
confidence: float = 0.0
# Training status
trained: bool = False
training_timestamp: Optional[datetime] = None
@dataclass
class PivotEvent:
"""Pivot point event for training"""
symbol: str
timeframe: str
timestamp: datetime
pivot_type: str # 'L2L', 'L2H', 'L3L', 'L3H', etc.
price: float
level: int # Pivot level (2, 3, 4, etc.)
strength: float
@dataclass
class CandleCompletionEvent:
"""Candle completion event for training"""
symbol: str
timeframe: str
timestamp: datetime # When candle closed
ohlcv: Dict[str, float] # {'open', 'high', 'low', 'close', 'volume'}
class TrainingEventSubscriber:
"""
Subscriber interface for training events.
Training adapters implement this to receive callbacks.
"""
def on_candle_completion(self, event: CandleCompletionEvent, inference_ref: Optional[InferenceFrameReference]) -> None:
"""
Called when a candle completes.
Args:
event: Candle completion event with actual OHLCV
inference_ref: Reference to inference frame if available (for this candle)
"""
raise NotImplementedError
def on_pivot_event(self, event: PivotEvent, inference_refs: List[InferenceFrameReference]) -> None:
"""
Called when a pivot point is detected.
Args:
event: Pivot event (L2L, L2H, etc.)
inference_refs: List of inference frames that predicted this pivot
"""
raise NotImplementedError
class InferenceTrainingCoordinator:
"""
Coordinates inference frame storage and training event distribution.
NOTE: This should be integrated into TradingOrchestrator to reduce duplication.
The orchestrator already manages models, training, and predictions, so it's the
natural place for inference-training coordination.
Responsibilities:
1. Store inference frame references (not copies)
2. Register training subscriptions (candle/pivot events)
3. Match inference frames to actual results
4. Trigger training callbacks
"""
def __init__(self, data_provider, duckdb_storage=None):
"""
Initialize coordinator
Args:
data_provider: DataProvider instance for event subscriptions
duckdb_storage: DuckDBStorage instance for data retrieval
"""
self.data_provider = data_provider
self.duckdb_storage = duckdb_storage
# Store inference frame references (by inference_id)
self.inference_frames: Dict[str, InferenceFrameReference] = {}
# Index by target timestamp for candle matching
self.candle_inferences: Dict[Tuple[str, str, datetime], List[str]] = {} # (symbol, timeframe, timestamp) -> [inference_ids]
# Index by pivot type for pivot matching
self.pivot_subscriptions: Dict[Tuple[str, str, str], List[str]] = {} # (symbol, timeframe, pivot_type) -> [inference_ids]
# Training subscribers
self.training_subscribers: List[TrainingEventSubscriber] = []
# Thread safety
self.lock = threading.RLock()
logger.info("InferenceTrainingCoordinator initialized")
def register_inference_frame(self, inference_ref: InferenceFrameReference) -> None:
"""
Register an inference frame reference (stored in DuckDB, not copied).
Args:
inference_ref: Reference to inference data
"""
with self.lock:
self.inference_frames[inference_ref.inference_id] = inference_ref
# Index by target timestamp for candle matching
if inference_ref.target_timestamp:
key = (inference_ref.symbol, inference_ref.timeframe, inference_ref.target_timestamp)
if key not in self.candle_inferences:
self.candle_inferences[key] = []
self.candle_inferences[key].append(inference_ref.inference_id)
logger.debug(f"Registered inference frame: {inference_ref.inference_id} for {inference_ref.symbol} {inference_ref.timeframe}")
def subscribe_to_candle_completion(self, subscriber: TrainingEventSubscriber,
symbol: str, timeframe: str) -> None:
"""
Subscribe to candle completion events for a symbol/timeframe.
Args:
subscriber: Training subscriber
symbol: Trading symbol
timeframe: Timeframe (1m, 5m, etc.)
"""
with self.lock:
if subscriber not in self.training_subscribers:
self.training_subscribers.append(subscriber)
# Register with data provider for candle completion callbacks
if hasattr(self.data_provider, 'subscribe_candle_completion'):
self.data_provider.subscribe_candle_completion(
callback=lambda event: self._handle_candle_completion(event),
symbol=symbol,
timeframe=timeframe
)
logger.info(f"Subscribed to candle completion: {symbol} {timeframe}")
def subscribe_to_pivot_events(self, subscriber: TrainingEventSubscriber,
symbol: str, timeframe: str,
pivot_types: List[str]) -> None:
"""
Subscribe to pivot events (L2L, L2H, etc.).
Args:
subscriber: Training subscriber
symbol: Trading symbol
timeframe: Timeframe
pivot_types: List of pivot types to subscribe to (e.g., ['L2L', 'L2H', 'L3L'])
"""
with self.lock:
if subscriber not in self.training_subscribers:
self.training_subscribers.append(subscriber)
# Register pivot subscriptions
for pivot_type in pivot_types:
key = (symbol, timeframe, pivot_type)
if key not in self.pivot_subscriptions:
self.pivot_subscriptions[key] = []
# Store subscriber reference (we'll match inference frames later)
# Register with data provider for pivot callbacks
if hasattr(self.data_provider, 'subscribe_pivot_events'):
self.data_provider.subscribe_pivot_events(
callback=lambda event: self._handle_pivot_event(event),
symbol=symbol,
timeframe=timeframe,
pivot_types=pivot_types
)
logger.info(f"Subscribed to pivot events: {symbol} {timeframe} {pivot_types}")
def _handle_pivot_event(self, event: PivotEvent) -> None:
"""Handle pivot event from data provider and trigger training"""
with self.lock:
# Find matching inference frames (predictions made before this pivot)
# Look for predictions within a reasonable window (e.g., last 5 minutes)
window_start = event.timestamp - timedelta(minutes=5)
matching_refs = []
for inference_ref in self.inference_frames.values():
if (inference_ref.symbol == event.symbol and
inference_ref.timeframe == event.timeframe and
inference_ref.prediction_timestamp >= window_start and
not inference_ref.trained):
matching_refs.append(inference_ref)
# Notify subscribers
for subscriber in self.training_subscribers:
try:
subscriber.on_pivot_event(event, matching_refs)
# Mark as trained
for ref in matching_refs:
ref.trained = True
ref.training_timestamp = datetime.now(timezone.utc)
except Exception as e:
logger.error(f"Error in pivot event callback: {e}", exc_info=True)
def _handle_candle_completion(self, event: CandleCompletionEvent) -> None:
"""Handle candle completion event and trigger training"""
with self.lock:
# Find matching inference frames
key = (event.symbol, event.timeframe, event.timestamp)
inference_ids = self.candle_inferences.get(key, [])
# Get inference references
inference_refs = [self.inference_frames[iid] for iid in inference_ids
if iid in self.inference_frames and not self.inference_frames[iid].trained]
# Notify subscribers
for subscriber in self.training_subscribers:
for inference_ref in inference_refs:
try:
subscriber.on_candle_completion(event, inference_ref)
# Mark as trained
inference_ref.trained = True
inference_ref.training_timestamp = datetime.now(timezone.utc)
except Exception as e:
logger.error(f"Error in candle completion callback: {e}", exc_info=True)
def get_inference_data(self, inference_ref: InferenceFrameReference) -> Optional[Dict]:
"""
Retrieve inference data from DuckDB using reference.
This queries DuckDB efficiently using the timestamp range stored in the reference.
No copying - data is retrieved on-demand when training is triggered.
Args:
inference_ref: Reference to inference frame
Returns:
Dict with model inputs (price_data_1m, price_data_1h, etc.) or None
"""
if not self.data_provider:
logger.warning("Data provider not available for inference data retrieval")
return None
try:
import torch
import numpy as np
# Query data provider for OHLCV data (it uses DuckDB internally)
# This is efficient - DuckDB handles the query
model_inputs = {}
# Use norm_params from reference if available, otherwise calculate
norm_params = inference_ref.norm_params.copy() if inference_ref.norm_params else {}
for tf in ['1s', '1m', '1h', '1d']:
# Get 600 candles - data_provider queries DuckDB efficiently
df = self.data_provider.get_historical_data(
symbol=inference_ref.symbol,
timeframe=tf,
limit=600
)
if df is not None and len(df) >= 600:
# Take last 600 candles
df = df.tail(600)
# Extract OHLCV arrays
opens = df['open'].values.astype(np.float32)
highs = df['high'].values.astype(np.float32)
lows = df['low'].values.astype(np.float32)
closes = df['close'].values.astype(np.float32)
volumes = df['volume'].values.astype(np.float32)
# Stack OHLCV [seq_len, 5]
ohlcv = np.stack([opens, highs, lows, closes, volumes], axis=-1)
# Calculate normalization params if not stored
if tf not in norm_params:
price_min = np.min(ohlcv[:, :4])
price_max = np.max(ohlcv[:, :4])
volume_min = np.min(ohlcv[:, 4])
volume_max = np.max(ohlcv[:, 4])
if price_max == price_min:
price_max += 1.0
if volume_max == volume_min:
volume_max += 1.0
norm_params[tf] = {
'price_min': float(price_min),
'price_max': float(price_max),
'volume_min': float(volume_min),
'volume_max': float(volume_max)
}
# Normalize using params
params = norm_params[tf]
price_min = params['price_min']
price_max = params['price_max']
vol_min = params['volume_min']
vol_max = params['volume_max']
ohlcv[:, :4] = (ohlcv[:, :4] - price_min) / (price_max - price_min)
ohlcv[:, 4] = (ohlcv[:, 4] - vol_min) / (vol_max - vol_min)
# Convert to tensor [1, seq_len, 5]
candles_tensor = torch.tensor(ohlcv, dtype=torch.float32).unsqueeze(0)
model_inputs[f'price_data_{tf}'] = candles_tensor
# Store norm_params in reference for future use
inference_ref.norm_params = norm_params
# Add placeholder data for other inputs
device = next(iter(model_inputs.values())).device if model_inputs else torch.device('cpu')
model_inputs['tech_data'] = torch.zeros(1, 40, dtype=torch.float32, device=device)
model_inputs['market_data'] = torch.zeros(1, 30, dtype=torch.float32, device=device)
model_inputs['cob_data'] = torch.zeros(1, 600, 100, dtype=torch.float32, device=device)
return model_inputs
except Exception as e:
logger.error(f"Error retrieving inference data: {e}", exc_info=True)
return None

View File

@@ -166,6 +166,16 @@ class RealTrainingAdapter:
import threading
self._training_lock = threading.Lock()
# Use orchestrator's inference training coordinator (if available)
# This reduces duplication and centralizes coordination logic
if orchestrator and hasattr(orchestrator, 'inference_training_coordinator'):
self.training_coordinator = orchestrator.inference_training_coordinator
if self.training_coordinator:
# Subscribe to training events
self._subscribe_to_training_events()
else:
self.training_coordinator = None
# Real-time training tracking
self.realtime_training_metrics = {
'total_steps': 0,
@@ -187,6 +197,279 @@ class RealTrainingAdapter:
logger.info("RealTrainingAdapter initialized - NO SIMULATION, REAL TRAINING ONLY")
# Implement TrainingEventSubscriber interface
def on_candle_completion(self, event, inference_ref):
"""
Called when a candle completes - train on stored inference frame with actual result.
This uses the reference-based system: inference data is retrieved from DuckDB
using the reference, not copied.
"""
if not inference_ref or not self.training_coordinator:
return
try:
# Retrieve inference data from DuckDB using reference
model_inputs = self.training_coordinator.get_inference_data(inference_ref)
if not model_inputs:
logger.warning(f"Could not retrieve inference data for {inference_ref.inference_id}")
return
# Create training batch with actual candle
batch = self._create_training_batch_from_inference(
model_inputs, event.ohlcv, inference_ref
)
if not batch:
return
# Train model (backprop for Transformer)
self._train_on_inference_batch(batch, inference_ref)
except Exception as e:
logger.error(f"Error in candle completion training: {e}", exc_info=True)
def on_pivot_event(self, event, inference_refs):
"""
Called when a pivot point is detected - train on matching inference frames.
This handles event-based training where we don't know when the pivot will occur.
"""
if not inference_refs or not self.training_coordinator:
return
try:
for inference_ref in inference_refs:
# Retrieve inference data
model_inputs = self.training_coordinator.get_inference_data(inference_ref)
if not model_inputs:
continue
# Create training batch with pivot result
batch = self._create_pivot_training_batch(model_inputs, event, inference_ref)
if not batch:
continue
# Train model
self._train_on_inference_batch(batch, inference_ref)
except Exception as e:
logger.error(f"Error in pivot event training: {e}", exc_info=True)
def _create_training_batch_from_inference(self, model_inputs: Dict, actual_ohlcv: Dict,
inference_ref) -> Optional[Dict]:
"""Create training batch from inference inputs and actual candle result"""
try:
import torch
# Copy model inputs
batch = {k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in model_inputs.items()}
# Get device
device = next(iter(batch.values())).device if batch else torch.device('cpu')
# Normalize actual candle using stored params
timeframe = inference_ref.timeframe
if timeframe in inference_ref.norm_params:
params = inference_ref.norm_params[timeframe]
price_min = params['price_min']
price_max = params['price_max']
vol_min = params['volume_min']
vol_max = params['volume_max']
# Normalize actual OHLCV
normalized_candle = [
(actual_ohlcv['open'] - price_min) / (price_max - price_min),
(actual_ohlcv['high'] - price_min) / (price_max - price_min),
(actual_ohlcv['low'] - price_min) / (price_max - price_min),
(actual_ohlcv['close'] - price_min) / (price_max - price_min),
(actual_ohlcv['volume'] - vol_min) / (vol_max - vol_min) if vol_max > vol_min else 0.0
]
# Add target candle to batch
target_key = f'future_candle_{timeframe}'
batch[target_key] = torch.tensor([normalized_candle], dtype=torch.float32, device=device)
# Add action target (determine from price movement)
price_change = (actual_ohlcv['close'] - actual_ohlcv['open']) / actual_ohlcv['open']
if price_change > 0.0005: # 0.05% up
action = 1 # BUY
elif price_change < -0.0005: # 0.05% down
action = 2 # SELL
else:
action = 0 # HOLD
batch['actions'] = torch.tensor([[action]], dtype=torch.long, device=device)
return batch
return None
except Exception as e:
logger.error(f"Error creating training batch from inference: {e}", exc_info=True)
return None
def _create_pivot_training_batch(self, model_inputs: Dict, pivot_event, inference_ref) -> Optional[Dict]:
"""Create training batch from inference inputs and pivot event"""
try:
import torch
# Copy model inputs
batch = {k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in model_inputs.items()}
# Get device
device = next(iter(batch.values())).device if batch else torch.device('cpu')
# Determine action from pivot type
# L2L, L3L, etc. -> BUY (support levels)
# L2H, L3H, etc. -> SELL (resistance levels)
if pivot_event.pivot_type.endswith('L'):
action = 1 # BUY
elif pivot_event.pivot_type.endswith('H'):
action = 2 # SELL
else:
action = 0 # HOLD
batch['actions'] = torch.tensor([[action]], dtype=torch.long, device=device)
# For pivot training, we don't have a target candle, so we use the pivot price
# as a reference point for training
# This is a simplified approach - could be enhanced with pivot-based targets
return batch
except Exception as e:
logger.error(f"Error creating pivot training batch: {e}", exc_info=True)
return None
def _train_on_inference_batch(self, batch: Dict, inference_ref) -> None:
"""Train model on inference batch (uses stored inference frame)"""
try:
if not self.orchestrator:
return
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if not trainer:
return
# Train with lock protection
import torch
with self._training_lock:
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False)
if result:
loss = result.get('total_loss', 0)
accuracy = result.get('accuracy', 0)
# Update metrics
self.realtime_training_metrics['total_steps'] += 1
self.realtime_training_metrics['total_loss'] += loss
self.realtime_training_metrics['total_accuracy'] += accuracy
self.realtime_training_metrics['losses'].append(loss)
self.realtime_training_metrics['accuracies'].append(accuracy)
if len(self.realtime_training_metrics['losses']) > 100:
self.realtime_training_metrics['losses'].pop(0)
self.realtime_training_metrics['accuracies'].pop(0)
logger.info(f"Trained on inference frame {inference_ref.inference_id}: Loss={loss:.4f}, Acc={accuracy:.2%}")
except Exception as e:
logger.error(f"Error training on inference batch: {e}", exc_info=True)
def _register_inference_frame(self, session: Dict, symbol: str, timeframe: str,
prediction: Dict, data_provider, norm_params: Dict = None) -> None:
"""
Register inference frame reference with coordinator.
Stores reference (timestamp range) instead of copying 600 candles.
This method stores norm_params in the reference for efficient retrieval.
When training is triggered, data is retrieved from DuckDB using the reference.
Args:
session: Inference session
symbol: Trading symbol
timeframe: Timeframe
prediction: Prediction dict from model
data_provider: Data provider instance
norm_params: Normalization parameters (optional, will be calculated if not provided)
"""
if not self.training_coordinator:
return
try:
from ANNOTATE.core.inference_training_system import InferenceFrameReference
from datetime import datetime, timezone, timedelta
import uuid
# Get current time and calculate data range
current_time = datetime.now(timezone.utc)
data_range_end = current_time
# Calculate start time for 600 candles (approximate)
timeframe_seconds = {'1s': 1, '1m': 60, '5m': 300, '15m': 900, '1h': 3600, '1d': 86400}.get(timeframe, 60)
data_range_start = current_time - timedelta(seconds=600 * timeframe_seconds)
# Use provided norm_params or calculate if not available
if not norm_params:
norm_params = {}
# Calculate target timestamp (next candle close time)
# For 1m timeframe, next candle closes in 1 minute
target_timestamp = current_time + timedelta(seconds=timeframe_seconds)
# Create inference frame reference
inference_ref = InferenceFrameReference(
inference_id=str(uuid.uuid4()),
symbol=symbol,
timeframe=timeframe,
prediction_timestamp=current_time,
target_timestamp=target_timestamp,
data_range_start=data_range_start,
data_range_end=data_range_end,
norm_params=norm_params, # Stored for efficient retrieval
predicted_action=prediction.get('action'),
predicted_candle=prediction.get('predicted_candle'),
confidence=prediction.get('confidence', 0.0)
)
# Register with coordinator
self.training_coordinator.register_inference_frame(inference_ref)
logger.debug(f"Registered inference frame: {inference_ref.inference_id} for {symbol} {timeframe} (target: {target_timestamp})")
except Exception as e:
logger.warning(f"Could not register inference frame: {e}", exc_info=True)
def _subscribe_to_training_events(self):
"""Subscribe to training events via orchestrator's coordinator"""
if not self.training_coordinator:
return
try:
# Subscribe to candle completion for primary symbol/timeframe
primary_symbol = getattr(self.orchestrator, 'symbol', 'ETH/USDT')
primary_timeframe = '1m' # Default timeframe
self.training_coordinator.subscribe_to_candle_completion(
self, symbol=primary_symbol, timeframe=primary_timeframe
)
# Subscribe to pivot events (L2L, L2H, L3L, L3H)
self.training_coordinator.subscribe_to_pivot_events(
self, symbol=primary_symbol, timeframe=primary_timeframe,
pivot_types=['L2L', 'L2H', 'L3L', 'L3H']
)
logger.info(f"Subscribed to training events: {primary_symbol} {primary_timeframe}")
except Exception as e:
logger.warning(f"Could not subscribe to training events: {e}")
def _import_training_systems(self):
"""Import real training system implementations"""
try:
@@ -3056,7 +3339,10 @@ class RealTrainingAdapter:
'total_pnl': 0.0,
'win_count': 0,
'loss_count': 0,
'total_trades': 0
'total_trades': 0,
# Inference input cache: stores input data frames for later training
# Key: candle_timestamp (str), Value: {'model_inputs': Dict, 'norm_params': Dict, 'predicted_candle': Dict}
'inference_input_cache': {}
}
training_mode = "per-candle" if train_every_candle else ("pivot-based" if enable_live_training else "inference-only")
@@ -3128,8 +3414,177 @@ class RealTrainingAdapter:
all_signals.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
return all_signals[:limit]
def _make_realtime_prediction(self, model_name: str, symbol: str, data_provider) -> Dict:
"""Make a prediction using the specified model"""
def _make_realtime_prediction_with_cache(self, model_name: str, symbol: str, data_provider, session: Dict) -> Tuple[Dict, bool]:
"""
DEPRECATED: Use _make_realtime_prediction + _register_inference_frame instead.
This method is kept for backward compatibility but should be removed.
"""
# Just call the regular prediction method
prediction = self._make_realtime_prediction(model_name, symbol, data_provider)
return prediction, False
"""
Make a prediction and store input data frame for later training
Returns:
Tuple of (prediction_dict, stored_inputs: bool)
"""
try:
if model_name == 'Transformer' and self.orchestrator:
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if trainer and trainer.model:
# Get recent market data
market_data, norm_params = self._get_realtime_market_data(symbol, data_provider)
if not market_data:
return None, False
# Get current candle timestamp for cache key
timeframe = session.get('timeframe', '1m')
df_current = data_provider.get_historical_data(symbol, timeframe, limit=1)
if df_current is not None and len(df_current) > 0:
current_timestamp = str(df_current.index[-1])
# Store input data frame for later training (convert tensors to CPU for storage)
import torch
cached_inputs = {
'model_inputs': {k: v.cpu().clone() if isinstance(v, torch.Tensor) else v
for k, v in market_data.items()},
'norm_params': norm_params,
'timestamp': current_timestamp,
'symbol': symbol,
'timeframe': timeframe
}
# Store in session cache (keep last 50 to prevent memory bloat)
cache = session.get('inference_input_cache', {})
cache[current_timestamp] = cached_inputs
# Keep only last 50 entries
if len(cache) > 50:
# Remove oldest entries
sorted_keys = sorted(cache.keys())
for key in sorted_keys[:-50]:
del cache[key]
session['inference_input_cache'] = cache
logger.debug(f"Stored inference inputs for {symbol} {timeframe} @ {current_timestamp}")
# Make prediction
import torch
with torch.no_grad():
trainer.model.eval()
outputs = trainer.model(**market_data)
# Extract action
action_probs = outputs.get('action_probs')
if action_probs is not None:
# Handle different tensor shapes: [batch, 3] or [3]
if action_probs.dim() == 1:
# Shape [3] - single prediction
action_idx = torch.argmax(action_probs, dim=0).item()
confidence = action_probs[action_idx].item()
else:
# Shape [batch, 3] - take first batch item
action_idx = torch.argmax(action_probs[0], dim=0).item()
confidence = action_probs[0, action_idx].item()
# Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL)
actions = ['HOLD', 'BUY', 'SELL']
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
# Handle predicted candles - DENORMALIZE them
predicted_candles_raw = {}
if 'next_candles' in outputs:
for tf, tensor in outputs['next_candles'].items():
predicted_candles_raw[tf] = tensor.detach().cpu().numpy().tolist()
# Denormalize if we have params
predicted_candles_denorm = {}
if predicted_candles_raw and norm_params:
for tf, raw_candle in predicted_candles_raw.items():
# raw_candle is [1, 5] list
if tf in norm_params:
params = norm_params[tf]
price_min = params['price_min']
price_max = params['price_max']
vol_min = params['volume_min']
vol_max = params['volume_max']
# Denormalize [Open, High, Low, Close, Volume]
# Note: raw_candle[0] is the list of 5 values
candle_values = raw_candle[0]
# Ensure all values are Python floats (not numpy scalars or tensors)
def to_float(v):
if hasattr(v, 'item'):
return float(v.item())
return float(v)
denorm_candle = [
to_float(candle_values[0] * (price_max - price_min) + price_min), # Open
to_float(candle_values[1] * (price_max - price_min) + price_min), # High
to_float(candle_values[2] * (price_max - price_min) + price_min), # Low
to_float(candle_values[3] * (price_max - price_min) + price_min), # Close
to_float(candle_values[4] * (vol_max - vol_min) + vol_min) # Volume
]
predicted_candles_denorm[tf] = denorm_candle
# Calculate predicted price from candle close (ensure Python float)
predicted_price = None
if '1m' in predicted_candles_denorm:
close_val = predicted_candles_denorm['1m'][3]
predicted_price = float(close_val.item() if hasattr(close_val, 'item') else close_val)
elif '1s' in predicted_candles_denorm:
close_val = predicted_candles_denorm['1s'][3]
predicted_price = float(close_val.item() if hasattr(close_val, 'item') else close_val)
elif outputs.get('price_prediction') is not None:
# Fallback to price_prediction head if available (normalized)
# This would need separate denormalization based on reference price
pass
result_dict = {
'action': action,
'confidence': confidence,
'predicted_price': predicted_price,
'predicted_candle': predicted_candles_denorm
}
# Include trend vector if available
if 'trend_vector' in outputs:
result_dict['trend_vector'] = outputs['trend_vector']
# DEBUG: Log if we have predicted candles
if predicted_candles_denorm:
logger.info(f"Generated prediction with {len(predicted_candles_denorm)} timeframe candles: {list(predicted_candles_denorm.keys())}")
else:
logger.warning("No predicted candles in model output!")
return result_dict, True
return None, False
except Exception as e:
logger.debug(f"Error making realtime prediction: {e}")
import traceback
logger.debug(traceback.format_exc())
return None, False
def _make_realtime_prediction(self, model_name: str, symbol: str, data_provider) -> Tuple[Dict, Dict]:
"""
Make a prediction and return both prediction and market data for reference storage.
Returns:
Tuple of (prediction_dict, market_data_dict with norm_params)
"""
# Get market data (needed for reference storage)
market_data, norm_params = self._get_realtime_market_data(symbol, data_provider)
if not market_data:
return None, None
# Make prediction (original logic)
prediction = self._make_realtime_prediction_internal(model_name, symbol, data_provider, market_data, norm_params)
return prediction, {'market_data': market_data, 'norm_params': norm_params}
def _make_realtime_prediction_internal(self, model_name: str, symbol: str, data_provider,
market_data: Dict, norm_params: Dict) -> Dict:
"""Make a prediction using the specified model (backward compatibility)"""
try:
if model_name == 'Transformer' and self.orchestrator:
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
@@ -3223,12 +3678,6 @@ class RealTrainingAdapter:
if 'trend_vector' in outputs:
result_dict['trend_vector'] = outputs['trend_vector']
# DEBUG: Log if we have predicted candles
if predicted_candles_denorm:
logger.info(f"🔮 Generated prediction with {len(predicted_candles_denorm)} timeframe candles: {list(predicted_candles_denorm.keys())}")
else:
logger.warning("⚠️ No predicted candles in model output!")
return result_dict
return None
@@ -3370,13 +3819,120 @@ class RealTrainingAdapter:
# Get the completed candle (second to last) and next candle
completed_candle = df.iloc[-2]
next_candle = df.iloc[-1]
completed_timestamp = str(completed_candle.name)
# Get action from session (set by app's training strategy)
action_label = session.get('pending_action')
if not action_label:
return {'success': False, 'error': 'No pending_action in session'}
# Fetch market state for training
# CRITICAL: Try to use stored inference input data frame if available
# This ensures we train on exactly what the model saw during inference
cache = session.get('inference_input_cache', {})
stored_inputs = cache.get(completed_timestamp)
if stored_inputs:
# Use stored input data frame from inference
logger.info(f"Using stored inference inputs for training on {symbol} {timeframe} @ {completed_timestamp}")
# Get actual candle data for target
actual_candle = [
float(next_candle['open']),
float(next_candle['high']),
float(next_candle['low']),
float(next_candle['close']),
float(next_candle['volume'])
]
# Create training batch from stored inputs
import torch
# Get device from orchestrator
device = getattr(self.orchestrator, 'device', torch.device('cpu'))
if hasattr(self.orchestrator, 'primary_transformer_trainer') and self.orchestrator.primary_transformer_trainer:
if hasattr(self.orchestrator.primary_transformer_trainer.model, 'device'):
device = next(self.orchestrator.primary_transformer_trainer.model.parameters()).device
# Move stored inputs back to device (they were stored on CPU)
batch = {}
for k, v in stored_inputs['model_inputs'].items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(device)
else:
batch[k] = v
# Add actual candle as target (normalize using stored params)
norm_params = stored_inputs['norm_params']
if timeframe in norm_params:
params = norm_params[timeframe]
price_min = params['price_min']
price_max = params['price_max']
vol_min = params['volume_min']
vol_max = params['volume_max']
# Normalize actual candle
normalized_candle = [
(actual_candle[0] - price_min) / (price_max - price_min), # Open
(actual_candle[1] - price_min) / (price_max - price_min), # High
(actual_candle[2] - price_min) / (price_max - price_min), # Low
(actual_candle[3] - price_min) / (price_max - price_min), # Close
(actual_candle[4] - vol_min) / (vol_max - vol_min) if vol_max > vol_min else 0.0 # Volume
]
# Add target candle to batch
target_key = f'future_candle_{timeframe}'
batch[target_key] = torch.tensor([normalized_candle], dtype=torch.float32, device=device)
# Add action target
action_map = {'HOLD': 0, 'BUY': 1, 'SELL': 2}
batch['actions'] = torch.tensor([[action_map.get(action_label, 0)]], dtype=torch.long, device=device)
# Train directly on batch
model_name = session['model_name']
if model_name == 'Transformer':
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if trainer:
with self._training_lock:
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False)
if result:
loss = result.get('total_loss', 0)
accuracy = result.get('accuracy', 0)
# Update metrics
self.realtime_training_metrics['total_steps'] += 1
self.realtime_training_metrics['total_loss'] += loss
self.realtime_training_metrics['total_accuracy'] += accuracy
self.realtime_training_metrics['losses'].append(loss)
self.realtime_training_metrics['accuracies'].append(accuracy)
if len(self.realtime_training_metrics['losses']) > 100:
self.realtime_training_metrics['losses'].pop(0)
self.realtime_training_metrics['accuracies'].pop(0)
session['metrics']['loss'] = sum(self.realtime_training_metrics['losses']) / len(self.realtime_training_metrics['losses'])
session['metrics']['accuracy'] = sum(self.realtime_training_metrics['accuracies']) / len(self.realtime_training_metrics['accuracies'])
session['metrics']['steps'] = self.realtime_training_metrics['total_steps']
# Remove from cache after training
if completed_timestamp in cache:
del cache[completed_timestamp]
logger.info(f"Trained on stored inference inputs: {symbol} {timeframe} @ {completed_timestamp} action={action_label} (Loss: {loss:.4f}, Acc: {accuracy:.2%})")
return {
'success': True,
'loss': session['metrics']['loss'],
'accuracy': session['metrics']['accuracy'],
'training_steps': session['metrics']['steps'],
'used_stored_inputs': True
}
# Fall through to regular training if stored inputs failed
logger.warning(f"Failed to use stored inputs, falling back to fresh data")
# Fallback: Fetch fresh market state for training (original behavior)
market_state = self._fetch_market_state_for_candle(symbol, completed_candle.name, data_provider)
# Calculate price change
@@ -3411,7 +3967,8 @@ class RealTrainingAdapter:
'success': True,
'loss': session['metrics']['loss'],
'accuracy': session['metrics']['accuracy'],
'training_steps': session['metrics']['steps']
'training_steps': session['metrics']['steps'],
'used_stored_inputs': False
}
return {'success': False, 'error': f'Unsupported model: {model_name}'}
@@ -3939,6 +4496,14 @@ class RealTrainingAdapter:
# Make prediction using the model
prediction = self._make_realtime_prediction(model_name, symbol, data_provider)
# Register inference frame reference for later training when actual candle arrives
# This stores a reference (timestamp range) instead of copying 600 candles
# The reference allows us to retrieve the exact data from DuckDB when training
if prediction and self.training_coordinator:
# Get norm_params for storage in reference
_, norm_params = self._get_realtime_market_data(symbol, data_provider)
self._register_inference_frame(session, symbol, timeframe, prediction, data_provider, norm_params)
if prediction:
# Store signal
signal = {