377 lines
16 KiB
Python
377 lines
16 KiB
Python
"""
|
|
Event-Driven Inference Training System
|
|
|
|
This system provides:
|
|
1. Reference-based inference frame storage (no 600-candle copies)
|
|
2. Subscription system for candle completion and pivot events
|
|
3. Flexible training methods (backprop for Transformer, others for different models)
|
|
4. Integration with DuckDB for efficient data retrieval
|
|
|
|
Architecture:
|
|
- Inference frames stored as references (timestamp ranges) in DuckDB
|
|
- Training adapter subscribes to data provider events
|
|
- Time-based triggers: candle completion (known result time)
|
|
- Event-based triggers: pivot points (L2L, L2H, etc. - unknown timing)
|
|
"""
|
|
|
|
import logging
|
|
import threading
|
|
from datetime import datetime, timezone, timedelta
|
|
from typing import Dict, List, Optional, Callable, Tuple, Any
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
import uuid
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TrainingTriggerType(Enum):
|
|
"""Types of training triggers"""
|
|
CANDLE_COMPLETION = "candle_completion" # Time-based: next candle closes
|
|
PIVOT_EVENT = "pivot_event" # Event-based: pivot detected (L2L, L2H, etc.)
|
|
|
|
|
|
@dataclass
|
|
class InferenceFrameReference:
|
|
"""
|
|
Reference to inference data stored in DuckDB.
|
|
No copying - just store timestamp ranges and query when needed.
|
|
"""
|
|
inference_id: str # Unique ID for this inference
|
|
symbol: str
|
|
timeframe: str
|
|
prediction_timestamp: datetime # When prediction was made
|
|
target_timestamp: Optional[datetime] = None # When result will be available (for candles)
|
|
|
|
# Reference to data in DuckDB (timestamp range)
|
|
data_range_start: datetime # Start of 600-candle window
|
|
data_range_end: datetime # End of 600-candle window
|
|
|
|
# Normalization parameters (small, can be stored)
|
|
norm_params: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
|
|
|
# Prediction metadata
|
|
predicted_action: Optional[str] = None
|
|
predicted_candle: Optional[Dict[str, List[float]]] = None
|
|
confidence: float = 0.0
|
|
|
|
# Training status
|
|
trained: bool = False
|
|
training_timestamp: Optional[datetime] = None
|
|
|
|
|
|
@dataclass
|
|
class PivotEvent:
|
|
"""Pivot point event for training"""
|
|
symbol: str
|
|
timeframe: str
|
|
timestamp: datetime
|
|
pivot_type: str # 'L2L', 'L2H', 'L3L', 'L3H', etc.
|
|
price: float
|
|
level: int # Pivot level (2, 3, 4, etc.)
|
|
strength: float
|
|
|
|
|
|
@dataclass
|
|
class CandleCompletionEvent:
|
|
"""Candle completion event for training"""
|
|
symbol: str
|
|
timeframe: str
|
|
timestamp: datetime # When candle closed
|
|
ohlcv: Dict[str, float] # {'open', 'high', 'low', 'close', 'volume'}
|
|
|
|
|
|
class TrainingEventSubscriber:
|
|
"""
|
|
Subscriber interface for training events.
|
|
Training adapters implement this to receive callbacks.
|
|
"""
|
|
|
|
def on_candle_completion(self, event: CandleCompletionEvent, inference_ref: Optional[InferenceFrameReference]) -> None:
|
|
"""
|
|
Called when a candle completes.
|
|
|
|
Args:
|
|
event: Candle completion event with actual OHLCV
|
|
inference_ref: Reference to inference frame if available (for this candle)
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def on_pivot_event(self, event: PivotEvent, inference_refs: List[InferenceFrameReference]) -> None:
|
|
"""
|
|
Called when a pivot point is detected.
|
|
|
|
Args:
|
|
event: Pivot event (L2L, L2H, etc.)
|
|
inference_refs: List of inference frames that predicted this pivot
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class InferenceTrainingCoordinator:
|
|
"""
|
|
Coordinates inference frame storage and training event distribution.
|
|
|
|
NOTE: This should be integrated into TradingOrchestrator to reduce duplication.
|
|
The orchestrator already manages models, training, and predictions, so it's the
|
|
natural place for inference-training coordination.
|
|
|
|
Responsibilities:
|
|
1. Store inference frame references (not copies)
|
|
2. Register training subscriptions (candle/pivot events)
|
|
3. Match inference frames to actual results
|
|
4. Trigger training callbacks
|
|
"""
|
|
|
|
def __init__(self, data_provider, duckdb_storage=None):
|
|
"""
|
|
Initialize coordinator
|
|
|
|
Args:
|
|
data_provider: DataProvider instance for event subscriptions
|
|
duckdb_storage: DuckDBStorage instance for data retrieval
|
|
"""
|
|
self.data_provider = data_provider
|
|
self.duckdb_storage = duckdb_storage
|
|
|
|
# Store inference frame references (by inference_id)
|
|
self.inference_frames: Dict[str, InferenceFrameReference] = {}
|
|
|
|
# Index by target timestamp for candle matching
|
|
self.candle_inferences: Dict[Tuple[str, str, datetime], List[str]] = {} # (symbol, timeframe, timestamp) -> [inference_ids]
|
|
|
|
# Index by pivot type for pivot matching
|
|
self.pivot_subscriptions: Dict[Tuple[str, str, str], List[str]] = {} # (symbol, timeframe, pivot_type) -> [inference_ids]
|
|
|
|
# Training subscribers
|
|
self.training_subscribers: List[TrainingEventSubscriber] = []
|
|
|
|
# Thread safety
|
|
self.lock = threading.RLock()
|
|
|
|
logger.info("InferenceTrainingCoordinator initialized")
|
|
|
|
def register_inference_frame(self, inference_ref: InferenceFrameReference) -> None:
|
|
"""
|
|
Register an inference frame reference (stored in DuckDB, not copied).
|
|
|
|
Args:
|
|
inference_ref: Reference to inference data
|
|
"""
|
|
with self.lock:
|
|
self.inference_frames[inference_ref.inference_id] = inference_ref
|
|
|
|
# Index by target timestamp for candle matching
|
|
if inference_ref.target_timestamp:
|
|
key = (inference_ref.symbol, inference_ref.timeframe, inference_ref.target_timestamp)
|
|
if key not in self.candle_inferences:
|
|
self.candle_inferences[key] = []
|
|
self.candle_inferences[key].append(inference_ref.inference_id)
|
|
|
|
logger.debug(f"Registered inference frame: {inference_ref.inference_id} for {inference_ref.symbol} {inference_ref.timeframe}")
|
|
|
|
def subscribe_to_candle_completion(self, subscriber: TrainingEventSubscriber,
|
|
symbol: str, timeframe: str) -> None:
|
|
"""
|
|
Subscribe to candle completion events for a symbol/timeframe.
|
|
|
|
Args:
|
|
subscriber: Training subscriber
|
|
symbol: Trading symbol
|
|
timeframe: Timeframe (1m, 5m, etc.)
|
|
"""
|
|
with self.lock:
|
|
if subscriber not in self.training_subscribers:
|
|
self.training_subscribers.append(subscriber)
|
|
|
|
# Register with data provider for candle completion callbacks
|
|
if hasattr(self.data_provider, 'subscribe_candle_completion'):
|
|
self.data_provider.subscribe_candle_completion(
|
|
callback=lambda event: self._handle_candle_completion(event),
|
|
symbol=symbol,
|
|
timeframe=timeframe
|
|
)
|
|
|
|
logger.info(f"Subscribed to candle completion: {symbol} {timeframe}")
|
|
|
|
def subscribe_to_pivot_events(self, subscriber: TrainingEventSubscriber,
|
|
symbol: str, timeframe: str,
|
|
pivot_types: List[str]) -> None:
|
|
"""
|
|
Subscribe to pivot events (L2L, L2H, etc.).
|
|
|
|
Args:
|
|
subscriber: Training subscriber
|
|
symbol: Trading symbol
|
|
timeframe: Timeframe
|
|
pivot_types: List of pivot types to subscribe to (e.g., ['L2L', 'L2H', 'L3L'])
|
|
"""
|
|
with self.lock:
|
|
if subscriber not in self.training_subscribers:
|
|
self.training_subscribers.append(subscriber)
|
|
|
|
# Register pivot subscriptions
|
|
for pivot_type in pivot_types:
|
|
key = (symbol, timeframe, pivot_type)
|
|
if key not in self.pivot_subscriptions:
|
|
self.pivot_subscriptions[key] = []
|
|
# Store subscriber reference (we'll match inference frames later)
|
|
|
|
# Register with data provider for pivot callbacks
|
|
if hasattr(self.data_provider, 'subscribe_pivot_events'):
|
|
self.data_provider.subscribe_pivot_events(
|
|
callback=lambda event: self._handle_pivot_event(event),
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
pivot_types=pivot_types
|
|
)
|
|
|
|
logger.info(f"Subscribed to pivot events: {symbol} {timeframe} {pivot_types}")
|
|
|
|
def _handle_pivot_event(self, event: PivotEvent) -> None:
|
|
"""Handle pivot event from data provider and trigger training"""
|
|
with self.lock:
|
|
# Find matching inference frames (predictions made before this pivot)
|
|
# Look for predictions within a reasonable window (e.g., last 5 minutes)
|
|
window_start = event.timestamp - timedelta(minutes=5)
|
|
|
|
matching_refs = []
|
|
for inference_ref in self.inference_frames.values():
|
|
if (inference_ref.symbol == event.symbol and
|
|
inference_ref.timeframe == event.timeframe and
|
|
inference_ref.prediction_timestamp >= window_start and
|
|
not inference_ref.trained):
|
|
matching_refs.append(inference_ref)
|
|
|
|
# Notify subscribers
|
|
for subscriber in self.training_subscribers:
|
|
try:
|
|
subscriber.on_pivot_event(event, matching_refs)
|
|
# Mark as trained
|
|
for ref in matching_refs:
|
|
ref.trained = True
|
|
ref.training_timestamp = datetime.now(timezone.utc)
|
|
except Exception as e:
|
|
logger.error(f"Error in pivot event callback: {e}", exc_info=True)
|
|
|
|
def _handle_candle_completion(self, event: CandleCompletionEvent) -> None:
|
|
"""Handle candle completion event and trigger training"""
|
|
with self.lock:
|
|
# Find matching inference frames
|
|
key = (event.symbol, event.timeframe, event.timestamp)
|
|
inference_ids = self.candle_inferences.get(key, [])
|
|
|
|
# Get inference references
|
|
inference_refs = [self.inference_frames[iid] for iid in inference_ids
|
|
if iid in self.inference_frames and not self.inference_frames[iid].trained]
|
|
|
|
# Notify subscribers
|
|
for subscriber in self.training_subscribers:
|
|
for inference_ref in inference_refs:
|
|
try:
|
|
subscriber.on_candle_completion(event, inference_ref)
|
|
# Mark as trained
|
|
inference_ref.trained = True
|
|
inference_ref.training_timestamp = datetime.now(timezone.utc)
|
|
except Exception as e:
|
|
logger.error(f"Error in candle completion callback: {e}", exc_info=True)
|
|
|
|
|
|
def get_inference_data(self, inference_ref: InferenceFrameReference) -> Optional[Dict]:
|
|
"""
|
|
Retrieve inference data from DuckDB using reference.
|
|
|
|
This queries DuckDB efficiently using the timestamp range stored in the reference.
|
|
No copying - data is retrieved on-demand when training is triggered.
|
|
|
|
Args:
|
|
inference_ref: Reference to inference frame
|
|
|
|
Returns:
|
|
Dict with model inputs (price_data_1m, price_data_1h, etc.) or None
|
|
"""
|
|
if not self.data_provider:
|
|
logger.warning("Data provider not available for inference data retrieval")
|
|
return None
|
|
|
|
try:
|
|
import torch
|
|
import numpy as np
|
|
|
|
# Query data provider for OHLCV data (it uses DuckDB internally)
|
|
# This is efficient - DuckDB handles the query
|
|
model_inputs = {}
|
|
|
|
# Use norm_params from reference if available, otherwise calculate
|
|
norm_params = inference_ref.norm_params.copy() if inference_ref.norm_params else {}
|
|
|
|
for tf in ['1s', '1m', '1h', '1d']:
|
|
# Get 600 candles - data_provider queries DuckDB efficiently
|
|
df = self.data_provider.get_historical_data(
|
|
symbol=inference_ref.symbol,
|
|
timeframe=tf,
|
|
limit=600
|
|
)
|
|
|
|
if df is not None and len(df) >= 600:
|
|
# Take last 600 candles
|
|
df = df.tail(600)
|
|
|
|
# Extract OHLCV arrays
|
|
opens = df['open'].values.astype(np.float32)
|
|
highs = df['high'].values.astype(np.float32)
|
|
lows = df['low'].values.astype(np.float32)
|
|
closes = df['close'].values.astype(np.float32)
|
|
volumes = df['volume'].values.astype(np.float32)
|
|
|
|
# Stack OHLCV [seq_len, 5]
|
|
ohlcv = np.stack([opens, highs, lows, closes, volumes], axis=-1)
|
|
|
|
# Calculate normalization params if not stored
|
|
if tf not in norm_params:
|
|
price_min = np.min(ohlcv[:, :4])
|
|
price_max = np.max(ohlcv[:, :4])
|
|
volume_min = np.min(ohlcv[:, 4])
|
|
volume_max = np.max(ohlcv[:, 4])
|
|
|
|
if price_max == price_min:
|
|
price_max += 1.0
|
|
if volume_max == volume_min:
|
|
volume_max += 1.0
|
|
|
|
norm_params[tf] = {
|
|
'price_min': float(price_min),
|
|
'price_max': float(price_max),
|
|
'volume_min': float(volume_min),
|
|
'volume_max': float(volume_max)
|
|
}
|
|
|
|
# Normalize using params
|
|
params = norm_params[tf]
|
|
price_min = params['price_min']
|
|
price_max = params['price_max']
|
|
vol_min = params['volume_min']
|
|
vol_max = params['volume_max']
|
|
|
|
ohlcv[:, :4] = (ohlcv[:, :4] - price_min) / (price_max - price_min)
|
|
ohlcv[:, 4] = (ohlcv[:, 4] - vol_min) / (vol_max - vol_min)
|
|
|
|
# Convert to tensor [1, seq_len, 5]
|
|
candles_tensor = torch.tensor(ohlcv, dtype=torch.float32).unsqueeze(0)
|
|
model_inputs[f'price_data_{tf}'] = candles_tensor
|
|
|
|
# Store norm_params in reference for future use
|
|
inference_ref.norm_params = norm_params
|
|
|
|
# Add placeholder data for other inputs
|
|
device = next(iter(model_inputs.values())).device if model_inputs else torch.device('cpu')
|
|
model_inputs['tech_data'] = torch.zeros(1, 40, dtype=torch.float32, device=device)
|
|
model_inputs['market_data'] = torch.zeros(1, 30, dtype=torch.float32, device=device)
|
|
model_inputs['cob_data'] = torch.zeros(1, 600, 100, dtype=torch.float32, device=device)
|
|
|
|
return model_inputs
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving inference data: {e}", exc_info=True)
|
|
return None
|