""" Event-Driven Inference Training System This system provides: 1. Reference-based inference frame storage (no 600-candle copies) 2. Subscription system for candle completion and pivot events 3. Flexible training methods (backprop for Transformer, others for different models) 4. Integration with DuckDB for efficient data retrieval Architecture: - Inference frames stored as references (timestamp ranges) in DuckDB - Training adapter subscribes to data provider events - Time-based triggers: candle completion (known result time) - Event-based triggers: pivot points (L2L, L2H, etc. - unknown timing) """ import logging import threading from datetime import datetime, timezone, timedelta from typing import Dict, List, Optional, Callable, Tuple, Any from dataclasses import dataclass, field from enum import Enum import uuid logger = logging.getLogger(__name__) class TrainingTriggerType(Enum): """Types of training triggers""" CANDLE_COMPLETION = "candle_completion" # Time-based: next candle closes PIVOT_EVENT = "pivot_event" # Event-based: pivot detected (L2L, L2H, etc.) @dataclass class InferenceFrameReference: """ Reference to inference data stored in DuckDB. No copying - just store timestamp ranges and query when needed. """ inference_id: str # Unique ID for this inference symbol: str timeframe: str prediction_timestamp: datetime # When prediction was made target_timestamp: Optional[datetime] = None # When result will be available (for candles) # Reference to data in DuckDB (timestamp range) data_range_start: datetime # Start of 600-candle window data_range_end: datetime # End of 600-candle window # Normalization parameters (small, can be stored) norm_params: Dict[str, Dict[str, float]] = field(default_factory=dict) # Prediction metadata predicted_action: Optional[str] = None predicted_candle: Optional[Dict[str, List[float]]] = None confidence: float = 0.0 # Training status trained: bool = False training_timestamp: Optional[datetime] = None @dataclass class PivotEvent: """Pivot point event for training""" symbol: str timeframe: str timestamp: datetime pivot_type: str # 'L2L', 'L2H', 'L3L', 'L3H', etc. price: float level: int # Pivot level (2, 3, 4, etc.) strength: float @dataclass class CandleCompletionEvent: """Candle completion event for training""" symbol: str timeframe: str timestamp: datetime # When candle closed ohlcv: Dict[str, float] # {'open', 'high', 'low', 'close', 'volume'} class TrainingEventSubscriber: """ Subscriber interface for training events. Training adapters implement this to receive callbacks. """ def on_candle_completion(self, event: CandleCompletionEvent, inference_ref: Optional[InferenceFrameReference]) -> None: """ Called when a candle completes. Args: event: Candle completion event with actual OHLCV inference_ref: Reference to inference frame if available (for this candle) """ raise NotImplementedError def on_pivot_event(self, event: PivotEvent, inference_refs: List[InferenceFrameReference]) -> None: """ Called when a pivot point is detected. Args: event: Pivot event (L2L, L2H, etc.) inference_refs: List of inference frames that predicted this pivot """ raise NotImplementedError class InferenceTrainingCoordinator: """ Coordinates inference frame storage and training event distribution. NOTE: This should be integrated into TradingOrchestrator to reduce duplication. The orchestrator already manages models, training, and predictions, so it's the natural place for inference-training coordination. Responsibilities: 1. Store inference frame references (not copies) 2. Register training subscriptions (candle/pivot events) 3. Match inference frames to actual results 4. Trigger training callbacks """ def __init__(self, data_provider, duckdb_storage=None): """ Initialize coordinator Args: data_provider: DataProvider instance for event subscriptions duckdb_storage: DuckDBStorage instance for data retrieval """ self.data_provider = data_provider self.duckdb_storage = duckdb_storage # Store inference frame references (by inference_id) self.inference_frames: Dict[str, InferenceFrameReference] = {} # Index by target timestamp for candle matching self.candle_inferences: Dict[Tuple[str, str, datetime], List[str]] = {} # (symbol, timeframe, timestamp) -> [inference_ids] # Index by pivot type for pivot matching self.pivot_subscriptions: Dict[Tuple[str, str, str], List[str]] = {} # (symbol, timeframe, pivot_type) -> [inference_ids] # Training subscribers self.training_subscribers: List[TrainingEventSubscriber] = [] # Thread safety self.lock = threading.RLock() logger.info("InferenceTrainingCoordinator initialized") def register_inference_frame(self, inference_ref: InferenceFrameReference) -> None: """ Register an inference frame reference (stored in DuckDB, not copied). Args: inference_ref: Reference to inference data """ with self.lock: self.inference_frames[inference_ref.inference_id] = inference_ref # Index by target timestamp for candle matching if inference_ref.target_timestamp: key = (inference_ref.symbol, inference_ref.timeframe, inference_ref.target_timestamp) if key not in self.candle_inferences: self.candle_inferences[key] = [] self.candle_inferences[key].append(inference_ref.inference_id) logger.debug(f"Registered inference frame: {inference_ref.inference_id} for {inference_ref.symbol} {inference_ref.timeframe}") def subscribe_to_candle_completion(self, subscriber: TrainingEventSubscriber, symbol: str, timeframe: str) -> None: """ Subscribe to candle completion events for a symbol/timeframe. Args: subscriber: Training subscriber symbol: Trading symbol timeframe: Timeframe (1m, 5m, etc.) """ with self.lock: if subscriber not in self.training_subscribers: self.training_subscribers.append(subscriber) # Register with data provider for candle completion callbacks if hasattr(self.data_provider, 'subscribe_candle_completion'): self.data_provider.subscribe_candle_completion( callback=lambda event: self._handle_candle_completion(event), symbol=symbol, timeframe=timeframe ) logger.info(f"Subscribed to candle completion: {symbol} {timeframe}") def subscribe_to_pivot_events(self, subscriber: TrainingEventSubscriber, symbol: str, timeframe: str, pivot_types: List[str]) -> None: """ Subscribe to pivot events (L2L, L2H, etc.). Args: subscriber: Training subscriber symbol: Trading symbol timeframe: Timeframe pivot_types: List of pivot types to subscribe to (e.g., ['L2L', 'L2H', 'L3L']) """ with self.lock: if subscriber not in self.training_subscribers: self.training_subscribers.append(subscriber) # Register pivot subscriptions for pivot_type in pivot_types: key = (symbol, timeframe, pivot_type) if key not in self.pivot_subscriptions: self.pivot_subscriptions[key] = [] # Store subscriber reference (we'll match inference frames later) # Register with data provider for pivot callbacks if hasattr(self.data_provider, 'subscribe_pivot_events'): self.data_provider.subscribe_pivot_events( callback=lambda event: self._handle_pivot_event(event), symbol=symbol, timeframe=timeframe, pivot_types=pivot_types ) logger.info(f"Subscribed to pivot events: {symbol} {timeframe} {pivot_types}") def _handle_pivot_event(self, event: PivotEvent) -> None: """Handle pivot event from data provider and trigger training""" with self.lock: # Find matching inference frames (predictions made before this pivot) # Look for predictions within a reasonable window (e.g., last 5 minutes) window_start = event.timestamp - timedelta(minutes=5) matching_refs = [] for inference_ref in self.inference_frames.values(): if (inference_ref.symbol == event.symbol and inference_ref.timeframe == event.timeframe and inference_ref.prediction_timestamp >= window_start and not inference_ref.trained): matching_refs.append(inference_ref) # Notify subscribers for subscriber in self.training_subscribers: try: subscriber.on_pivot_event(event, matching_refs) # Mark as trained for ref in matching_refs: ref.trained = True ref.training_timestamp = datetime.now(timezone.utc) except Exception as e: logger.error(f"Error in pivot event callback: {e}", exc_info=True) def _handle_candle_completion(self, event: CandleCompletionEvent) -> None: """Handle candle completion event and trigger training""" with self.lock: # Find matching inference frames key = (event.symbol, event.timeframe, event.timestamp) inference_ids = self.candle_inferences.get(key, []) # Get inference references inference_refs = [self.inference_frames[iid] for iid in inference_ids if iid in self.inference_frames and not self.inference_frames[iid].trained] # Notify subscribers for subscriber in self.training_subscribers: for inference_ref in inference_refs: try: subscriber.on_candle_completion(event, inference_ref) # Mark as trained inference_ref.trained = True inference_ref.training_timestamp = datetime.now(timezone.utc) except Exception as e: logger.error(f"Error in candle completion callback: {e}", exc_info=True) def get_inference_data(self, inference_ref: InferenceFrameReference) -> Optional[Dict]: """ Retrieve inference data from DuckDB using reference. This queries DuckDB efficiently using the timestamp range stored in the reference. No copying - data is retrieved on-demand when training is triggered. Args: inference_ref: Reference to inference frame Returns: Dict with model inputs (price_data_1m, price_data_1h, etc.) or None """ if not self.data_provider: logger.warning("Data provider not available for inference data retrieval") return None try: import torch import numpy as np # Query data provider for OHLCV data (it uses DuckDB internally) # This is efficient - DuckDB handles the query model_inputs = {} # Use norm_params from reference if available, otherwise calculate norm_params = inference_ref.norm_params.copy() if inference_ref.norm_params else {} for tf in ['1s', '1m', '1h', '1d']: # Get 600 candles - data_provider queries DuckDB efficiently df = self.data_provider.get_historical_data( symbol=inference_ref.symbol, timeframe=tf, limit=600 ) if df is not None and len(df) >= 600: # Take last 600 candles df = df.tail(600) # Extract OHLCV arrays opens = df['open'].values.astype(np.float32) highs = df['high'].values.astype(np.float32) lows = df['low'].values.astype(np.float32) closes = df['close'].values.astype(np.float32) volumes = df['volume'].values.astype(np.float32) # Stack OHLCV [seq_len, 5] ohlcv = np.stack([opens, highs, lows, closes, volumes], axis=-1) # Calculate normalization params if not stored if tf not in norm_params: price_min = np.min(ohlcv[:, :4]) price_max = np.max(ohlcv[:, :4]) volume_min = np.min(ohlcv[:, 4]) volume_max = np.max(ohlcv[:, 4]) if price_max == price_min: price_max += 1.0 if volume_max == volume_min: volume_max += 1.0 norm_params[tf] = { 'price_min': float(price_min), 'price_max': float(price_max), 'volume_min': float(volume_min), 'volume_max': float(volume_max) } # Normalize using params params = norm_params[tf] price_min = params['price_min'] price_max = params['price_max'] vol_min = params['volume_min'] vol_max = params['volume_max'] ohlcv[:, :4] = (ohlcv[:, :4] - price_min) / (price_max - price_min) ohlcv[:, 4] = (ohlcv[:, 4] - vol_min) / (vol_max - vol_min) # Convert to tensor [1, seq_len, 5] candles_tensor = torch.tensor(ohlcv, dtype=torch.float32).unsqueeze(0) model_inputs[f'price_data_{tf}'] = candles_tensor # Store norm_params in reference for future use inference_ref.norm_params = norm_params # Add placeholder data for other inputs device = next(iter(model_inputs.values())).device if model_inputs else torch.device('cpu') model_inputs['tech_data'] = torch.zeros(1, 40, dtype=torch.float32, device=device) model_inputs['market_data'] = torch.zeros(1, 30, dtype=torch.float32, device=device) model_inputs['cob_data'] = torch.zeros(1, 600, 100, dtype=torch.float32, device=device) return model_inputs except Exception as e: logger.error(f"Error retrieving inference data: {e}", exc_info=True) return None