refactoring. inference real data triggers
This commit is contained in:
376
ANNOTATE/core/inference_training_system.py
Normal file
376
ANNOTATE/core/inference_training_system.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""
|
||||
Event-Driven Inference Training System
|
||||
|
||||
This system provides:
|
||||
1. Reference-based inference frame storage (no 600-candle copies)
|
||||
2. Subscription system for candle completion and pivot events
|
||||
3. Flexible training methods (backprop for Transformer, others for different models)
|
||||
4. Integration with DuckDB for efficient data retrieval
|
||||
|
||||
Architecture:
|
||||
- Inference frames stored as references (timestamp ranges) in DuckDB
|
||||
- Training adapter subscribes to data provider events
|
||||
- Time-based triggers: candle completion (known result time)
|
||||
- Event-based triggers: pivot points (L2L, L2H, etc. - unknown timing)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Dict, List, Optional, Callable, Tuple, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingTriggerType(Enum):
|
||||
"""Types of training triggers"""
|
||||
CANDLE_COMPLETION = "candle_completion" # Time-based: next candle closes
|
||||
PIVOT_EVENT = "pivot_event" # Event-based: pivot detected (L2L, L2H, etc.)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceFrameReference:
|
||||
"""
|
||||
Reference to inference data stored in DuckDB.
|
||||
No copying - just store timestamp ranges and query when needed.
|
||||
"""
|
||||
inference_id: str # Unique ID for this inference
|
||||
symbol: str
|
||||
timeframe: str
|
||||
prediction_timestamp: datetime # When prediction was made
|
||||
target_timestamp: Optional[datetime] = None # When result will be available (for candles)
|
||||
|
||||
# Reference to data in DuckDB (timestamp range)
|
||||
data_range_start: datetime # Start of 600-candle window
|
||||
data_range_end: datetime # End of 600-candle window
|
||||
|
||||
# Normalization parameters (small, can be stored)
|
||||
norm_params: Dict[str, Dict[str, float]] = field(default_factory=dict)
|
||||
|
||||
# Prediction metadata
|
||||
predicted_action: Optional[str] = None
|
||||
predicted_candle: Optional[Dict[str, List[float]]] = None
|
||||
confidence: float = 0.0
|
||||
|
||||
# Training status
|
||||
trained: bool = False
|
||||
training_timestamp: Optional[datetime] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PivotEvent:
|
||||
"""Pivot point event for training"""
|
||||
symbol: str
|
||||
timeframe: str
|
||||
timestamp: datetime
|
||||
pivot_type: str # 'L2L', 'L2H', 'L3L', 'L3H', etc.
|
||||
price: float
|
||||
level: int # Pivot level (2, 3, 4, etc.)
|
||||
strength: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class CandleCompletionEvent:
|
||||
"""Candle completion event for training"""
|
||||
symbol: str
|
||||
timeframe: str
|
||||
timestamp: datetime # When candle closed
|
||||
ohlcv: Dict[str, float] # {'open', 'high', 'low', 'close', 'volume'}
|
||||
|
||||
|
||||
class TrainingEventSubscriber:
|
||||
"""
|
||||
Subscriber interface for training events.
|
||||
Training adapters implement this to receive callbacks.
|
||||
"""
|
||||
|
||||
def on_candle_completion(self, event: CandleCompletionEvent, inference_ref: Optional[InferenceFrameReference]) -> None:
|
||||
"""
|
||||
Called when a candle completes.
|
||||
|
||||
Args:
|
||||
event: Candle completion event with actual OHLCV
|
||||
inference_ref: Reference to inference frame if available (for this candle)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def on_pivot_event(self, event: PivotEvent, inference_refs: List[InferenceFrameReference]) -> None:
|
||||
"""
|
||||
Called when a pivot point is detected.
|
||||
|
||||
Args:
|
||||
event: Pivot event (L2L, L2H, etc.)
|
||||
inference_refs: List of inference frames that predicted this pivot
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class InferenceTrainingCoordinator:
|
||||
"""
|
||||
Coordinates inference frame storage and training event distribution.
|
||||
|
||||
NOTE: This should be integrated into TradingOrchestrator to reduce duplication.
|
||||
The orchestrator already manages models, training, and predictions, so it's the
|
||||
natural place for inference-training coordination.
|
||||
|
||||
Responsibilities:
|
||||
1. Store inference frame references (not copies)
|
||||
2. Register training subscriptions (candle/pivot events)
|
||||
3. Match inference frames to actual results
|
||||
4. Trigger training callbacks
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider, duckdb_storage=None):
|
||||
"""
|
||||
Initialize coordinator
|
||||
|
||||
Args:
|
||||
data_provider: DataProvider instance for event subscriptions
|
||||
duckdb_storage: DuckDBStorage instance for data retrieval
|
||||
"""
|
||||
self.data_provider = data_provider
|
||||
self.duckdb_storage = duckdb_storage
|
||||
|
||||
# Store inference frame references (by inference_id)
|
||||
self.inference_frames: Dict[str, InferenceFrameReference] = {}
|
||||
|
||||
# Index by target timestamp for candle matching
|
||||
self.candle_inferences: Dict[Tuple[str, str, datetime], List[str]] = {} # (symbol, timeframe, timestamp) -> [inference_ids]
|
||||
|
||||
# Index by pivot type for pivot matching
|
||||
self.pivot_subscriptions: Dict[Tuple[str, str, str], List[str]] = {} # (symbol, timeframe, pivot_type) -> [inference_ids]
|
||||
|
||||
# Training subscribers
|
||||
self.training_subscribers: List[TrainingEventSubscriber] = []
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
logger.info("InferenceTrainingCoordinator initialized")
|
||||
|
||||
def register_inference_frame(self, inference_ref: InferenceFrameReference) -> None:
|
||||
"""
|
||||
Register an inference frame reference (stored in DuckDB, not copied).
|
||||
|
||||
Args:
|
||||
inference_ref: Reference to inference data
|
||||
"""
|
||||
with self.lock:
|
||||
self.inference_frames[inference_ref.inference_id] = inference_ref
|
||||
|
||||
# Index by target timestamp for candle matching
|
||||
if inference_ref.target_timestamp:
|
||||
key = (inference_ref.symbol, inference_ref.timeframe, inference_ref.target_timestamp)
|
||||
if key not in self.candle_inferences:
|
||||
self.candle_inferences[key] = []
|
||||
self.candle_inferences[key].append(inference_ref.inference_id)
|
||||
|
||||
logger.debug(f"Registered inference frame: {inference_ref.inference_id} for {inference_ref.symbol} {inference_ref.timeframe}")
|
||||
|
||||
def subscribe_to_candle_completion(self, subscriber: TrainingEventSubscriber,
|
||||
symbol: str, timeframe: str) -> None:
|
||||
"""
|
||||
Subscribe to candle completion events for a symbol/timeframe.
|
||||
|
||||
Args:
|
||||
subscriber: Training subscriber
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe (1m, 5m, etc.)
|
||||
"""
|
||||
with self.lock:
|
||||
if subscriber not in self.training_subscribers:
|
||||
self.training_subscribers.append(subscriber)
|
||||
|
||||
# Register with data provider for candle completion callbacks
|
||||
if hasattr(self.data_provider, 'subscribe_candle_completion'):
|
||||
self.data_provider.subscribe_candle_completion(
|
||||
callback=lambda event: self._handle_candle_completion(event),
|
||||
symbol=symbol,
|
||||
timeframe=timeframe
|
||||
)
|
||||
|
||||
logger.info(f"Subscribed to candle completion: {symbol} {timeframe}")
|
||||
|
||||
def subscribe_to_pivot_events(self, subscriber: TrainingEventSubscriber,
|
||||
symbol: str, timeframe: str,
|
||||
pivot_types: List[str]) -> None:
|
||||
"""
|
||||
Subscribe to pivot events (L2L, L2H, etc.).
|
||||
|
||||
Args:
|
||||
subscriber: Training subscriber
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
pivot_types: List of pivot types to subscribe to (e.g., ['L2L', 'L2H', 'L3L'])
|
||||
"""
|
||||
with self.lock:
|
||||
if subscriber not in self.training_subscribers:
|
||||
self.training_subscribers.append(subscriber)
|
||||
|
||||
# Register pivot subscriptions
|
||||
for pivot_type in pivot_types:
|
||||
key = (symbol, timeframe, pivot_type)
|
||||
if key not in self.pivot_subscriptions:
|
||||
self.pivot_subscriptions[key] = []
|
||||
# Store subscriber reference (we'll match inference frames later)
|
||||
|
||||
# Register with data provider for pivot callbacks
|
||||
if hasattr(self.data_provider, 'subscribe_pivot_events'):
|
||||
self.data_provider.subscribe_pivot_events(
|
||||
callback=lambda event: self._handle_pivot_event(event),
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
pivot_types=pivot_types
|
||||
)
|
||||
|
||||
logger.info(f"Subscribed to pivot events: {symbol} {timeframe} {pivot_types}")
|
||||
|
||||
def _handle_pivot_event(self, event: PivotEvent) -> None:
|
||||
"""Handle pivot event from data provider and trigger training"""
|
||||
with self.lock:
|
||||
# Find matching inference frames (predictions made before this pivot)
|
||||
# Look for predictions within a reasonable window (e.g., last 5 minutes)
|
||||
window_start = event.timestamp - timedelta(minutes=5)
|
||||
|
||||
matching_refs = []
|
||||
for inference_ref in self.inference_frames.values():
|
||||
if (inference_ref.symbol == event.symbol and
|
||||
inference_ref.timeframe == event.timeframe and
|
||||
inference_ref.prediction_timestamp >= window_start and
|
||||
not inference_ref.trained):
|
||||
matching_refs.append(inference_ref)
|
||||
|
||||
# Notify subscribers
|
||||
for subscriber in self.training_subscribers:
|
||||
try:
|
||||
subscriber.on_pivot_event(event, matching_refs)
|
||||
# Mark as trained
|
||||
for ref in matching_refs:
|
||||
ref.trained = True
|
||||
ref.training_timestamp = datetime.now(timezone.utc)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pivot event callback: {e}", exc_info=True)
|
||||
|
||||
def _handle_candle_completion(self, event: CandleCompletionEvent) -> None:
|
||||
"""Handle candle completion event and trigger training"""
|
||||
with self.lock:
|
||||
# Find matching inference frames
|
||||
key = (event.symbol, event.timeframe, event.timestamp)
|
||||
inference_ids = self.candle_inferences.get(key, [])
|
||||
|
||||
# Get inference references
|
||||
inference_refs = [self.inference_frames[iid] for iid in inference_ids
|
||||
if iid in self.inference_frames and not self.inference_frames[iid].trained]
|
||||
|
||||
# Notify subscribers
|
||||
for subscriber in self.training_subscribers:
|
||||
for inference_ref in inference_refs:
|
||||
try:
|
||||
subscriber.on_candle_completion(event, inference_ref)
|
||||
# Mark as trained
|
||||
inference_ref.trained = True
|
||||
inference_ref.training_timestamp = datetime.now(timezone.utc)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in candle completion callback: {e}", exc_info=True)
|
||||
|
||||
|
||||
def get_inference_data(self, inference_ref: InferenceFrameReference) -> Optional[Dict]:
|
||||
"""
|
||||
Retrieve inference data from DuckDB using reference.
|
||||
|
||||
This queries DuckDB efficiently using the timestamp range stored in the reference.
|
||||
No copying - data is retrieved on-demand when training is triggered.
|
||||
|
||||
Args:
|
||||
inference_ref: Reference to inference frame
|
||||
|
||||
Returns:
|
||||
Dict with model inputs (price_data_1m, price_data_1h, etc.) or None
|
||||
"""
|
||||
if not self.data_provider:
|
||||
logger.warning("Data provider not available for inference data retrieval")
|
||||
return None
|
||||
|
||||
try:
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
# Query data provider for OHLCV data (it uses DuckDB internally)
|
||||
# This is efficient - DuckDB handles the query
|
||||
model_inputs = {}
|
||||
|
||||
# Use norm_params from reference if available, otherwise calculate
|
||||
norm_params = inference_ref.norm_params.copy() if inference_ref.norm_params else {}
|
||||
|
||||
for tf in ['1s', '1m', '1h', '1d']:
|
||||
# Get 600 candles - data_provider queries DuckDB efficiently
|
||||
df = self.data_provider.get_historical_data(
|
||||
symbol=inference_ref.symbol,
|
||||
timeframe=tf,
|
||||
limit=600
|
||||
)
|
||||
|
||||
if df is not None and len(df) >= 600:
|
||||
# Take last 600 candles
|
||||
df = df.tail(600)
|
||||
|
||||
# Extract OHLCV arrays
|
||||
opens = df['open'].values.astype(np.float32)
|
||||
highs = df['high'].values.astype(np.float32)
|
||||
lows = df['low'].values.astype(np.float32)
|
||||
closes = df['close'].values.astype(np.float32)
|
||||
volumes = df['volume'].values.astype(np.float32)
|
||||
|
||||
# Stack OHLCV [seq_len, 5]
|
||||
ohlcv = np.stack([opens, highs, lows, closes, volumes], axis=-1)
|
||||
|
||||
# Calculate normalization params if not stored
|
||||
if tf not in norm_params:
|
||||
price_min = np.min(ohlcv[:, :4])
|
||||
price_max = np.max(ohlcv[:, :4])
|
||||
volume_min = np.min(ohlcv[:, 4])
|
||||
volume_max = np.max(ohlcv[:, 4])
|
||||
|
||||
if price_max == price_min:
|
||||
price_max += 1.0
|
||||
if volume_max == volume_min:
|
||||
volume_max += 1.0
|
||||
|
||||
norm_params[tf] = {
|
||||
'price_min': float(price_min),
|
||||
'price_max': float(price_max),
|
||||
'volume_min': float(volume_min),
|
||||
'volume_max': float(volume_max)
|
||||
}
|
||||
|
||||
# Normalize using params
|
||||
params = norm_params[tf]
|
||||
price_min = params['price_min']
|
||||
price_max = params['price_max']
|
||||
vol_min = params['volume_min']
|
||||
vol_max = params['volume_max']
|
||||
|
||||
ohlcv[:, :4] = (ohlcv[:, :4] - price_min) / (price_max - price_min)
|
||||
ohlcv[:, 4] = (ohlcv[:, 4] - vol_min) / (vol_max - vol_min)
|
||||
|
||||
# Convert to tensor [1, seq_len, 5]
|
||||
candles_tensor = torch.tensor(ohlcv, dtype=torch.float32).unsqueeze(0)
|
||||
model_inputs[f'price_data_{tf}'] = candles_tensor
|
||||
|
||||
# Store norm_params in reference for future use
|
||||
inference_ref.norm_params = norm_params
|
||||
|
||||
# Add placeholder data for other inputs
|
||||
device = next(iter(model_inputs.values())).device if model_inputs else torch.device('cpu')
|
||||
model_inputs['tech_data'] = torch.zeros(1, 40, dtype=torch.float32, device=device)
|
||||
model_inputs['market_data'] = torch.zeros(1, 30, dtype=torch.float32, device=device)
|
||||
model_inputs['cob_data'] = torch.zeros(1, 600, 100, dtype=torch.float32, device=device)
|
||||
|
||||
return model_inputs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving inference data: {e}", exc_info=True)
|
||||
return None
|
||||
Reference in New Issue
Block a user