495 lines
21 KiB
Python
495 lines
21 KiB
Python
"""
|
|
Timeframe-Aware Inference Coordinator
|
|
|
|
This module coordinates model inference across multiple timeframes with proper scheduling.
|
|
It ensures that models know which timeframe they are predicting on and handles the
|
|
complex scheduling requirements for multi-timeframe predictions.
|
|
|
|
Key Features:
|
|
- Timeframe-aware model inference
|
|
- Hourly multi-timeframe inference (4 predictions per hour)
|
|
- Frequent inference at 1-5 second intervals
|
|
- Prediction context management
|
|
- Integration with enhanced reward calculator
|
|
"""
|
|
|
|
import time
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any, Callable
|
|
from dataclasses import dataclass
|
|
import threading
|
|
from enum import Enum
|
|
|
|
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class InferenceContext:
|
|
"""Context information for a model inference"""
|
|
symbol: str
|
|
timeframe: TimeFrame
|
|
timestamp: datetime
|
|
target_timeframe: TimeFrame # Which timeframe we're predicting for
|
|
is_hourly_inference: bool = False
|
|
inference_type: str = "regular" # "regular", "hourly", "continuous"
|
|
|
|
|
|
@dataclass
|
|
class InferenceSchedule:
|
|
"""Schedule configuration for different inference types"""
|
|
continuous_interval_seconds: float = 5.0 # Continuous inference every 5 seconds
|
|
hourly_timeframes: List[TimeFrame] = None # Timeframes for hourly inference
|
|
|
|
def __post_init__(self):
|
|
if self.hourly_timeframes is None:
|
|
self.hourly_timeframes = [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
|
|
TimeFrame.HOURS_1, TimeFrame.DAYS_1]
|
|
|
|
|
|
class TimeframeInferenceCoordinator:
|
|
"""
|
|
Coordinates timeframe-aware model inference with proper scheduling
|
|
|
|
This coordinator:
|
|
1. Manages continuous inference every 1-5 seconds on main timeframe
|
|
2. Schedules hourly multi-timeframe inference (4 predictions per hour)
|
|
3. Ensures models know which timeframe they're predicting on
|
|
4. Integrates with enhanced reward calculator for training
|
|
5. Handles prediction context and metadata
|
|
"""
|
|
|
|
def __init__(self,
|
|
reward_calculator: EnhancedRewardCalculator,
|
|
data_provider: Any = None,
|
|
symbols: List[str] = None):
|
|
"""
|
|
Initialize the timeframe inference coordinator
|
|
|
|
Args:
|
|
reward_calculator: Enhanced reward calculator instance
|
|
data_provider: Data provider for market data
|
|
symbols: List of symbols to coordinate inference for
|
|
"""
|
|
self.reward_calculator = reward_calculator
|
|
self.data_provider = data_provider
|
|
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
|
|
|
# Inference schedule configuration
|
|
self.schedule = InferenceSchedule()
|
|
|
|
# Model registry - stores inference functions for different models
|
|
self.model_inference_functions: Dict[str, Callable] = {}
|
|
|
|
# Tracking inference state
|
|
self.last_continuous_inference: Dict[str, datetime] = {}
|
|
self.last_hourly_inference: Dict[str, datetime] = {}
|
|
self.next_hourly_inference: Dict[str, datetime] = {}
|
|
|
|
# Active inference tasks
|
|
self.inference_tasks: List[asyncio.Task] = []
|
|
self.running = False
|
|
|
|
# Thread safety
|
|
self.lock = threading.RLock()
|
|
|
|
# Performance metrics
|
|
self.inference_stats = {
|
|
'continuous_inferences': 0,
|
|
'hourly_inferences': 0,
|
|
'failed_inferences': 0,
|
|
'average_inference_time_ms': 0.0
|
|
}
|
|
|
|
self._initialize_schedules()
|
|
|
|
logger.info(f"TimeframeInferenceCoordinator initialized for symbols: {self.symbols}")
|
|
logger.info(f"Continuous inference interval: {self.schedule.continuous_interval_seconds}s")
|
|
logger.info(f"Hourly timeframes: {[tf.value for tf in self.schedule.hourly_timeframes]}")
|
|
|
|
def _initialize_schedules(self):
|
|
"""Initialize inference schedules for all symbols"""
|
|
current_time = datetime.now()
|
|
|
|
for symbol in self.symbols:
|
|
self.last_continuous_inference[symbol] = current_time
|
|
self.last_hourly_inference[symbol] = current_time
|
|
|
|
# Schedule next hourly inference at the top of the next hour
|
|
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
|
|
self.next_hourly_inference[symbol] = next_hour
|
|
|
|
def register_model_inference_function(self, model_name: str, inference_func: Callable):
|
|
"""
|
|
Register a model's inference function
|
|
|
|
Args:
|
|
model_name: Name of the model
|
|
inference_func: Async function that takes InferenceContext and returns prediction
|
|
"""
|
|
self.model_inference_functions[model_name] = inference_func
|
|
logger.info(f"Registered inference function for model: {model_name}")
|
|
|
|
async def start_coordination(self):
|
|
"""Start the inference coordination system"""
|
|
if self.running:
|
|
logger.warning("Inference coordination already running")
|
|
return
|
|
|
|
self.running = True
|
|
logger.info("Starting timeframe inference coordination")
|
|
|
|
# Start continuous inference tasks for each symbol
|
|
for symbol in self.symbols:
|
|
task = asyncio.create_task(self._continuous_inference_loop(symbol))
|
|
self.inference_tasks.append(task)
|
|
|
|
# Start hourly inference scheduler
|
|
task = asyncio.create_task(self._hourly_inference_scheduler())
|
|
self.inference_tasks.append(task)
|
|
|
|
# Start reward evaluation loop
|
|
task = asyncio.create_task(self._reward_evaluation_loop())
|
|
self.inference_tasks.append(task)
|
|
|
|
logger.info(f"Started {len(self.inference_tasks)} inference coordination tasks")
|
|
|
|
async def stop_coordination(self):
|
|
"""Stop the inference coordination system"""
|
|
if not self.running:
|
|
return
|
|
|
|
self.running = False
|
|
logger.info("Stopping timeframe inference coordination")
|
|
|
|
# Cancel all tasks
|
|
for task in self.inference_tasks:
|
|
task.cancel()
|
|
|
|
# Wait for tasks to complete
|
|
await asyncio.gather(*self.inference_tasks, return_exceptions=True)
|
|
self.inference_tasks.clear()
|
|
|
|
logger.info("Inference coordination stopped")
|
|
|
|
async def _continuous_inference_loop(self, symbol: str):
|
|
"""
|
|
Continuous inference loop for a specific symbol
|
|
|
|
Args:
|
|
symbol: Trading symbol to run inference for
|
|
"""
|
|
logger.info(f"Starting continuous inference loop for {symbol}")
|
|
|
|
while self.running:
|
|
try:
|
|
current_time = datetime.now()
|
|
|
|
# Check if it's time for continuous inference
|
|
last_inference = self.last_continuous_inference[symbol]
|
|
time_since_last = (current_time - last_inference).total_seconds()
|
|
|
|
if time_since_last >= self.schedule.continuous_interval_seconds:
|
|
# Run continuous inference on primary timeframe (1s)
|
|
context = InferenceContext(
|
|
symbol=symbol,
|
|
timeframe=TimeFrame.SECONDS_1,
|
|
timestamp=current_time,
|
|
target_timeframe=TimeFrame.SECONDS_1,
|
|
is_hourly_inference=False,
|
|
inference_type="continuous"
|
|
)
|
|
|
|
await self._execute_inference(context)
|
|
self.last_continuous_inference[symbol] = current_time
|
|
self.inference_stats['continuous_inferences'] += 1
|
|
|
|
# Sleep for a short interval to avoid busy waiting
|
|
await asyncio.sleep(0.1)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in continuous inference loop for {symbol}: {e}")
|
|
await asyncio.sleep(1.0) # Wait longer on error
|
|
|
|
async def _hourly_inference_scheduler(self):
|
|
"""Scheduler for hourly multi-timeframe inference and timeframe-boundary triggers"""
|
|
logger.info("Starting hourly inference scheduler")
|
|
|
|
while self.running:
|
|
try:
|
|
current_time = datetime.now()
|
|
|
|
# Check if any symbol needs hourly inference
|
|
for symbol in self.symbols:
|
|
if current_time >= self.next_hourly_inference[symbol]:
|
|
await self._execute_hourly_inference(symbol, current_time)
|
|
|
|
# Schedule next hourly inference
|
|
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
|
|
self.next_hourly_inference[symbol] = next_hour
|
|
self.last_hourly_inference[symbol] = current_time
|
|
|
|
# Trigger at each new timeframe boundary: 1m, 1h, 1d
|
|
if current_time.second == 0:
|
|
# New minute
|
|
await self._execute_boundary_inference(symbol, current_time, TimeFrame.MINUTES_1)
|
|
if current_time.minute == 0 and current_time.second == 0:
|
|
# New hour
|
|
await self._execute_boundary_inference(symbol, current_time, TimeFrame.HOURS_1)
|
|
if current_time.hour == 0 and current_time.minute == 0 and current_time.second == 0:
|
|
# New day
|
|
await self._execute_boundary_inference(symbol, current_time, TimeFrame.DAYS_1)
|
|
|
|
# Sleep for 30 seconds between checks
|
|
await asyncio.sleep(30)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in hourly inference scheduler: {e}")
|
|
await asyncio.sleep(60) # Wait longer on error
|
|
|
|
async def _execute_boundary_inference(self, symbol: str, timestamp: datetime, timeframe: TimeFrame):
|
|
"""Execute an inference exactly at timeframe boundary"""
|
|
try:
|
|
context = InferenceContext(
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
timestamp=timestamp,
|
|
target_timeframe=timeframe,
|
|
is_hourly_inference=False,
|
|
inference_type="boundary"
|
|
)
|
|
await self._execute_inference(context)
|
|
except Exception as e:
|
|
logger.debug(f"Boundary inference error for {symbol} {timeframe.value}: {e}")
|
|
|
|
async def _execute_hourly_inference(self, symbol: str, timestamp: datetime):
|
|
"""
|
|
Execute hourly multi-timeframe inference for a symbol
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
timestamp: Current timestamp
|
|
"""
|
|
logger.info(f"Executing hourly multi-timeframe inference for {symbol}")
|
|
|
|
# Run inference for each timeframe
|
|
for timeframe in self.schedule.hourly_timeframes:
|
|
context = InferenceContext(
|
|
symbol=symbol,
|
|
timeframe=timeframe,
|
|
timestamp=timestamp,
|
|
target_timeframe=timeframe,
|
|
is_hourly_inference=True,
|
|
inference_type="hourly"
|
|
)
|
|
|
|
await self._execute_inference(context)
|
|
self.inference_stats['hourly_inferences'] += 1
|
|
|
|
# Small delay between timeframe inferences
|
|
await asyncio.sleep(0.5)
|
|
|
|
async def _execute_inference(self, context: InferenceContext):
|
|
"""
|
|
Execute inference for a specific context
|
|
|
|
Args:
|
|
context: Inference context containing all necessary information
|
|
"""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Run inference for all registered models
|
|
for model_name, inference_func in self.model_inference_functions.items():
|
|
try:
|
|
# Execute model inference
|
|
prediction = await inference_func(context)
|
|
|
|
if prediction is not None:
|
|
# Add prediction to reward calculator
|
|
prediction_id = self.reward_calculator.add_prediction(
|
|
symbol=context.symbol,
|
|
timeframe=context.target_timeframe,
|
|
predicted_price=prediction.get('predicted_price', 0.0),
|
|
predicted_direction=prediction.get('direction', 0),
|
|
confidence=prediction.get('confidence', 0.0),
|
|
current_price=prediction.get('current_price', 0.0),
|
|
model_name=model_name
|
|
)
|
|
|
|
logger.debug(f"Added prediction {prediction_id} from {model_name} "
|
|
f"for {context.symbol} {context.target_timeframe.value}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error running inference for model {model_name}: {e}")
|
|
self.inference_stats['failed_inferences'] += 1
|
|
|
|
# Update inference timing stats
|
|
inference_time_ms = (time.time() - start_time) * 1000
|
|
self._update_inference_timing(inference_time_ms)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error executing inference for context {context}: {e}")
|
|
self.inference_stats['failed_inferences'] += 1
|
|
|
|
def _update_inference_timing(self, inference_time_ms: float):
|
|
"""Update inference timing statistics"""
|
|
total_inferences = (self.inference_stats['continuous_inferences'] +
|
|
self.inference_stats['hourly_inferences'])
|
|
|
|
if total_inferences > 0:
|
|
current_avg = self.inference_stats['average_inference_time_ms']
|
|
new_avg = ((current_avg * (total_inferences - 1)) + inference_time_ms) / total_inferences
|
|
self.inference_stats['average_inference_time_ms'] = new_avg
|
|
|
|
async def _reward_evaluation_loop(self):
|
|
"""Continuous loop for evaluating prediction rewards"""
|
|
logger.info("Starting reward evaluation loop")
|
|
|
|
while self.running:
|
|
try:
|
|
# Update price cache if data provider available
|
|
if self.data_provider:
|
|
# DataProvider.get_current_price is synchronous; do not await
|
|
await self._update_price_cache()
|
|
|
|
# Evaluate predictions and get training data
|
|
for symbol in self.symbols:
|
|
evaluation_results = self.reward_calculator.evaluate_predictions(symbol)
|
|
|
|
if symbol in evaluation_results and evaluation_results[symbol]:
|
|
logger.debug(f"Evaluated {len(evaluation_results[symbol])} predictions for {symbol}")
|
|
|
|
# Here you could trigger training for models that have new evaluated predictions
|
|
await self._trigger_model_training(symbol, evaluation_results[symbol])
|
|
|
|
# Sleep for evaluation interval
|
|
await asyncio.sleep(10) # Evaluate every 10 seconds
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in reward evaluation loop: {e}")
|
|
await asyncio.sleep(30) # Wait longer on error
|
|
|
|
async def _update_price_cache(self):
|
|
"""Update price cache with current market prices"""
|
|
try:
|
|
for symbol in self.symbols:
|
|
# Get current price from data provider
|
|
if hasattr(self.data_provider, 'get_current_price'):
|
|
current_price = self.data_provider.get_current_price(symbol)
|
|
if current_price:
|
|
self.reward_calculator.update_price(symbol, current_price)
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Error updating price cache: {e}")
|
|
|
|
async def _trigger_model_training(self, symbol: str, evaluation_results: List[Any]):
|
|
"""
|
|
Trigger model training based on evaluation results
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
evaluation_results: List of (prediction, reward) tuples
|
|
"""
|
|
try:
|
|
# Group by model and timeframe for targeted training
|
|
training_groups = {}
|
|
|
|
for prediction_record, reward in evaluation_results:
|
|
model_name = prediction_record.model_name
|
|
timeframe = prediction_record.timeframe
|
|
|
|
key = f"{model_name}_{timeframe.value}"
|
|
if key not in training_groups:
|
|
training_groups[key] = []
|
|
|
|
training_groups[key].append((prediction_record, reward))
|
|
|
|
# Trigger training for each group
|
|
for group_key, training_data in training_groups.items():
|
|
model_name, timeframe_str = group_key.split('_', 1)
|
|
timeframe = TimeFrame(timeframe_str)
|
|
|
|
logger.info(f"Triggering training for {model_name} on {symbol} {timeframe.value} "
|
|
f"with {len(training_data)} samples")
|
|
|
|
# Here you would call the specific model's training function
|
|
# This is a placeholder - you'll need to implement the actual training calls
|
|
await self._call_model_training(model_name, symbol, timeframe, training_data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error triggering model training: {e}")
|
|
|
|
async def _call_model_training(self, model_name: str, symbol: str,
|
|
timeframe: TimeFrame, training_data: List[Any]):
|
|
"""
|
|
Call model-specific training function
|
|
|
|
Args:
|
|
model_name: Name of the model to train
|
|
symbol: Trading symbol
|
|
timeframe: Timeframe for training
|
|
training_data: List of (prediction, reward) tuples
|
|
"""
|
|
# This is a placeholder for model-specific training calls
|
|
# You'll need to implement this based on your specific model interfaces
|
|
logger.debug(f"Training call for {model_name}: {len(training_data)} samples")
|
|
|
|
def get_inference_statistics(self) -> Dict[str, Any]:
|
|
"""Get inference coordination statistics"""
|
|
with self.lock:
|
|
stats = self.inference_stats.copy()
|
|
|
|
# Add scheduling information
|
|
stats['symbols'] = self.symbols
|
|
stats['continuous_interval_seconds'] = self.schedule.continuous_interval_seconds
|
|
stats['hourly_timeframes'] = [tf.value for tf in self.schedule.hourly_timeframes]
|
|
stats['next_hourly_inferences'] = {
|
|
symbol: timestamp.isoformat()
|
|
for symbol, timestamp in self.next_hourly_inference.items()
|
|
}
|
|
|
|
# Add accuracy summary from reward calculator
|
|
stats['accuracy_summary'] = self.reward_calculator.get_accuracy_summary()
|
|
|
|
return stats
|
|
|
|
def force_hourly_inference(self, symbol: str = None):
|
|
"""
|
|
Force immediate hourly inference for symbol(s)
|
|
|
|
Args:
|
|
symbol: Specific symbol (None for all symbols)
|
|
"""
|
|
symbols_to_process = [symbol] if symbol else self.symbols
|
|
|
|
async def _force_inference():
|
|
current_time = datetime.now()
|
|
for sym in symbols_to_process:
|
|
await self._execute_hourly_inference(sym, current_time)
|
|
|
|
# Schedule the inference
|
|
if self.running:
|
|
asyncio.create_task(_force_inference())
|
|
else:
|
|
logger.warning("Cannot force inference - coordinator not running")
|
|
|
|
def get_prediction_history(self, symbol: str, timeframe: TimeFrame,
|
|
max_samples: int = 50) -> List[Any]:
|
|
"""
|
|
Get prediction history for training
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
timeframe: Specific timeframe
|
|
max_samples: Maximum samples to return
|
|
|
|
Returns:
|
|
List of training samples
|
|
"""
|
|
return self.reward_calculator.get_training_data(symbol, timeframe, max_samples)
|
|
|