Files
gogo2/ANNOTATE/core/inference_training_system.py
2025-12-09 11:59:15 +02:00

377 lines
16 KiB
Python

"""
Event-Driven Inference Training System
This system provides:
1. Reference-based inference frame storage (no 600-candle copies)
2. Subscription system for candle completion and pivot events
3. Flexible training methods (backprop for Transformer, others for different models)
4. Integration with DuckDB for efficient data retrieval
Architecture:
- Inference frames stored as references (timestamp ranges) in DuckDB
- Training adapter subscribes to data provider events
- Time-based triggers: candle completion (known result time)
- Event-based triggers: pivot points (L2L, L2H, etc. - unknown timing)
"""
import logging
import threading
from datetime import datetime, timezone, timedelta
from typing import Dict, List, Optional, Callable, Tuple, Any
from dataclasses import dataclass, field
from enum import Enum
import uuid
logger = logging.getLogger(__name__)
class TrainingTriggerType(Enum):
"""Types of training triggers"""
CANDLE_COMPLETION = "candle_completion" # Time-based: next candle closes
PIVOT_EVENT = "pivot_event" # Event-based: pivot detected (L2L, L2H, etc.)
@dataclass
class InferenceFrameReference:
"""
Reference to inference data stored in DuckDB.
No copying - just store timestamp ranges and query when needed.
"""
inference_id: str # Unique ID for this inference
symbol: str
timeframe: str
prediction_timestamp: datetime # When prediction was made
target_timestamp: Optional[datetime] = None # When result will be available (for candles)
# Reference to data in DuckDB (timestamp range)
data_range_start: datetime # Start of 600-candle window
data_range_end: datetime # End of 600-candle window
# Normalization parameters (small, can be stored)
norm_params: Dict[str, Dict[str, float]] = field(default_factory=dict)
# Prediction metadata
predicted_action: Optional[str] = None
predicted_candle: Optional[Dict[str, List[float]]] = None
confidence: float = 0.0
# Training status
trained: bool = False
training_timestamp: Optional[datetime] = None
@dataclass
class PivotEvent:
"""Pivot point event for training"""
symbol: str
timeframe: str
timestamp: datetime
pivot_type: str # 'L2L', 'L2H', 'L3L', 'L3H', etc.
price: float
level: int # Pivot level (2, 3, 4, etc.)
strength: float
@dataclass
class CandleCompletionEvent:
"""Candle completion event for training"""
symbol: str
timeframe: str
timestamp: datetime # When candle closed
ohlcv: Dict[str, float] # {'open', 'high', 'low', 'close', 'volume'}
class TrainingEventSubscriber:
"""
Subscriber interface for training events.
Training adapters implement this to receive callbacks.
"""
def on_candle_completion(self, event: CandleCompletionEvent, inference_ref: Optional[InferenceFrameReference]) -> None:
"""
Called when a candle completes.
Args:
event: Candle completion event with actual OHLCV
inference_ref: Reference to inference frame if available (for this candle)
"""
raise NotImplementedError
def on_pivot_event(self, event: PivotEvent, inference_refs: List[InferenceFrameReference]) -> None:
"""
Called when a pivot point is detected.
Args:
event: Pivot event (L2L, L2H, etc.)
inference_refs: List of inference frames that predicted this pivot
"""
raise NotImplementedError
class InferenceTrainingCoordinator:
"""
Coordinates inference frame storage and training event distribution.
NOTE: This should be integrated into TradingOrchestrator to reduce duplication.
The orchestrator already manages models, training, and predictions, so it's the
natural place for inference-training coordination.
Responsibilities:
1. Store inference frame references (not copies)
2. Register training subscriptions (candle/pivot events)
3. Match inference frames to actual results
4. Trigger training callbacks
"""
def __init__(self, data_provider, duckdb_storage=None):
"""
Initialize coordinator
Args:
data_provider: DataProvider instance for event subscriptions
duckdb_storage: DuckDBStorage instance for data retrieval
"""
self.data_provider = data_provider
self.duckdb_storage = duckdb_storage
# Store inference frame references (by inference_id)
self.inference_frames: Dict[str, InferenceFrameReference] = {}
# Index by target timestamp for candle matching
self.candle_inferences: Dict[Tuple[str, str, datetime], List[str]] = {} # (symbol, timeframe, timestamp) -> [inference_ids]
# Index by pivot type for pivot matching
self.pivot_subscriptions: Dict[Tuple[str, str, str], List[str]] = {} # (symbol, timeframe, pivot_type) -> [inference_ids]
# Training subscribers
self.training_subscribers: List[TrainingEventSubscriber] = []
# Thread safety
self.lock = threading.RLock()
logger.info("InferenceTrainingCoordinator initialized")
def register_inference_frame(self, inference_ref: InferenceFrameReference) -> None:
"""
Register an inference frame reference (stored in DuckDB, not copied).
Args:
inference_ref: Reference to inference data
"""
with self.lock:
self.inference_frames[inference_ref.inference_id] = inference_ref
# Index by target timestamp for candle matching
if inference_ref.target_timestamp:
key = (inference_ref.symbol, inference_ref.timeframe, inference_ref.target_timestamp)
if key not in self.candle_inferences:
self.candle_inferences[key] = []
self.candle_inferences[key].append(inference_ref.inference_id)
logger.debug(f"Registered inference frame: {inference_ref.inference_id} for {inference_ref.symbol} {inference_ref.timeframe}")
def subscribe_to_candle_completion(self, subscriber: TrainingEventSubscriber,
symbol: str, timeframe: str) -> None:
"""
Subscribe to candle completion events for a symbol/timeframe.
Args:
subscriber: Training subscriber
symbol: Trading symbol
timeframe: Timeframe (1m, 5m, etc.)
"""
with self.lock:
if subscriber not in self.training_subscribers:
self.training_subscribers.append(subscriber)
# Register with data provider for candle completion callbacks
if hasattr(self.data_provider, 'subscribe_candle_completion'):
self.data_provider.subscribe_candle_completion(
callback=lambda event: self._handle_candle_completion(event),
symbol=symbol,
timeframe=timeframe
)
logger.info(f"Subscribed to candle completion: {symbol} {timeframe}")
def subscribe_to_pivot_events(self, subscriber: TrainingEventSubscriber,
symbol: str, timeframe: str,
pivot_types: List[str]) -> None:
"""
Subscribe to pivot events (L2L, L2H, etc.).
Args:
subscriber: Training subscriber
symbol: Trading symbol
timeframe: Timeframe
pivot_types: List of pivot types to subscribe to (e.g., ['L2L', 'L2H', 'L3L'])
"""
with self.lock:
if subscriber not in self.training_subscribers:
self.training_subscribers.append(subscriber)
# Register pivot subscriptions
for pivot_type in pivot_types:
key = (symbol, timeframe, pivot_type)
if key not in self.pivot_subscriptions:
self.pivot_subscriptions[key] = []
# Store subscriber reference (we'll match inference frames later)
# Register with data provider for pivot callbacks
if hasattr(self.data_provider, 'subscribe_pivot_events'):
self.data_provider.subscribe_pivot_events(
callback=lambda event: self._handle_pivot_event(event),
symbol=symbol,
timeframe=timeframe,
pivot_types=pivot_types
)
logger.info(f"Subscribed to pivot events: {symbol} {timeframe} {pivot_types}")
def _handle_pivot_event(self, event: PivotEvent) -> None:
"""Handle pivot event from data provider and trigger training"""
with self.lock:
# Find matching inference frames (predictions made before this pivot)
# Look for predictions within a reasonable window (e.g., last 5 minutes)
window_start = event.timestamp - timedelta(minutes=5)
matching_refs = []
for inference_ref in self.inference_frames.values():
if (inference_ref.symbol == event.symbol and
inference_ref.timeframe == event.timeframe and
inference_ref.prediction_timestamp >= window_start and
not inference_ref.trained):
matching_refs.append(inference_ref)
# Notify subscribers
for subscriber in self.training_subscribers:
try:
subscriber.on_pivot_event(event, matching_refs)
# Mark as trained
for ref in matching_refs:
ref.trained = True
ref.training_timestamp = datetime.now(timezone.utc)
except Exception as e:
logger.error(f"Error in pivot event callback: {e}", exc_info=True)
def _handle_candle_completion(self, event: CandleCompletionEvent) -> None:
"""Handle candle completion event and trigger training"""
with self.lock:
# Find matching inference frames
key = (event.symbol, event.timeframe, event.timestamp)
inference_ids = self.candle_inferences.get(key, [])
# Get inference references
inference_refs = [self.inference_frames[iid] for iid in inference_ids
if iid in self.inference_frames and not self.inference_frames[iid].trained]
# Notify subscribers
for subscriber in self.training_subscribers:
for inference_ref in inference_refs:
try:
subscriber.on_candle_completion(event, inference_ref)
# Mark as trained
inference_ref.trained = True
inference_ref.training_timestamp = datetime.now(timezone.utc)
except Exception as e:
logger.error(f"Error in candle completion callback: {e}", exc_info=True)
def get_inference_data(self, inference_ref: InferenceFrameReference) -> Optional[Dict]:
"""
Retrieve inference data from DuckDB using reference.
This queries DuckDB efficiently using the timestamp range stored in the reference.
No copying - data is retrieved on-demand when training is triggered.
Args:
inference_ref: Reference to inference frame
Returns:
Dict with model inputs (price_data_1m, price_data_1h, etc.) or None
"""
if not self.data_provider:
logger.warning("Data provider not available for inference data retrieval")
return None
try:
import torch
import numpy as np
# Query data provider for OHLCV data (it uses DuckDB internally)
# This is efficient - DuckDB handles the query
model_inputs = {}
# Use norm_params from reference if available, otherwise calculate
norm_params = inference_ref.norm_params.copy() if inference_ref.norm_params else {}
for tf in ['1s', '1m', '1h', '1d']:
# Get 600 candles - data_provider queries DuckDB efficiently
df = self.data_provider.get_historical_data(
symbol=inference_ref.symbol,
timeframe=tf,
limit=600
)
if df is not None and len(df) >= 600:
# Take last 600 candles
df = df.tail(600)
# Extract OHLCV arrays
opens = df['open'].values.astype(np.float32)
highs = df['high'].values.astype(np.float32)
lows = df['low'].values.astype(np.float32)
closes = df['close'].values.astype(np.float32)
volumes = df['volume'].values.astype(np.float32)
# Stack OHLCV [seq_len, 5]
ohlcv = np.stack([opens, highs, lows, closes, volumes], axis=-1)
# Calculate normalization params if not stored
if tf not in norm_params:
price_min = np.min(ohlcv[:, :4])
price_max = np.max(ohlcv[:, :4])
volume_min = np.min(ohlcv[:, 4])
volume_max = np.max(ohlcv[:, 4])
if price_max == price_min:
price_max += 1.0
if volume_max == volume_min:
volume_max += 1.0
norm_params[tf] = {
'price_min': float(price_min),
'price_max': float(price_max),
'volume_min': float(volume_min),
'volume_max': float(volume_max)
}
# Normalize using params
params = norm_params[tf]
price_min = params['price_min']
price_max = params['price_max']
vol_min = params['volume_min']
vol_max = params['volume_max']
ohlcv[:, :4] = (ohlcv[:, :4] - price_min) / (price_max - price_min)
ohlcv[:, 4] = (ohlcv[:, 4] - vol_min) / (vol_max - vol_min)
# Convert to tensor [1, seq_len, 5]
candles_tensor = torch.tensor(ohlcv, dtype=torch.float32).unsqueeze(0)
model_inputs[f'price_data_{tf}'] = candles_tensor
# Store norm_params in reference for future use
inference_ref.norm_params = norm_params
# Add placeholder data for other inputs
device = next(iter(model_inputs.values())).device if model_inputs else torch.device('cpu')
model_inputs['tech_data'] = torch.zeros(1, 40, dtype=torch.float32, device=device)
model_inputs['market_data'] = torch.zeros(1, 30, dtype=torch.float32, device=device)
model_inputs['cob_data'] = torch.zeros(1, 600, 100, dtype=torch.float32, device=device)
return model_inputs
except Exception as e:
logger.error(f"Error retrieving inference data: {e}", exc_info=True)
return None