Files
gogo2/core/timeframe_inference_coordinator.py
2025-08-23 16:27:05 +03:00

497 lines
21 KiB
Python

"""
Timeframe-Aware Inference Coordinator
This module coordinates model inference across multiple timeframes with proper scheduling.
It ensures that models know which timeframe they are predicting on and handles the
complex scheduling requirements for multi-timeframe predictions.
Key Features:
- Timeframe-aware model inference
- Hourly multi-timeframe inference (4 predictions per hour)
- Frequent inference at 1-5 second intervals
- Prediction context management
- Integration with enhanced reward calculator
"""
import time
import asyncio
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Callable
from dataclasses import dataclass
import threading
from enum import Enum
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame
logger = logging.getLogger(__name__)
@dataclass
class InferenceContext:
"""Context information for a model inference"""
symbol: str
timeframe: TimeFrame
timestamp: datetime
target_timeframe: TimeFrame # Which timeframe we're predicting for
is_hourly_inference: bool = False
inference_type: str = "regular" # "regular", "hourly", "continuous"
@dataclass
class InferenceSchedule:
"""Schedule configuration for different inference types"""
continuous_interval_seconds: float = 5.0 # Continuous inference every 5 seconds
hourly_timeframes: List[TimeFrame] = None # Timeframes for hourly inference
def __post_init__(self):
if self.hourly_timeframes is None:
self.hourly_timeframes = [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
TimeFrame.HOURS_1, TimeFrame.DAYS_1]
class TimeframeInferenceCoordinator:
"""
Coordinates timeframe-aware model inference with proper scheduling
This coordinator:
1. Manages continuous inference every 1-5 seconds on main timeframe
2. Schedules hourly multi-timeframe inference (4 predictions per hour)
3. Ensures models know which timeframe they're predicting on
4. Integrates with enhanced reward calculator for training
5. Handles prediction context and metadata
"""
def __init__(self,
reward_calculator: EnhancedRewardCalculator,
data_provider: Any = None,
symbols: List[str] = None):
"""
Initialize the timeframe inference coordinator
Args:
reward_calculator: Enhanced reward calculator instance
data_provider: Data provider for market data
symbols: List of symbols to coordinate inference for
"""
self.reward_calculator = reward_calculator
self.data_provider = data_provider
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
# Inference schedule configuration
self.schedule = InferenceSchedule()
# Model registry - stores inference functions for different models
self.model_inference_functions: Dict[str, Callable] = {}
# Tracking inference state
self.last_continuous_inference: Dict[str, datetime] = {}
self.last_hourly_inference: Dict[str, datetime] = {}
self.next_hourly_inference: Dict[str, datetime] = {}
# Active inference tasks
self.inference_tasks: List[asyncio.Task] = []
self.running = False
# Thread safety
self.lock = threading.RLock()
# Performance metrics
self.inference_stats = {
'continuous_inferences': 0,
'hourly_inferences': 0,
'failed_inferences': 0,
'average_inference_time_ms': 0.0
}
self._initialize_schedules()
logger.info(f"TimeframeInferenceCoordinator initialized for symbols: {self.symbols}")
logger.info(f"Continuous inference interval: {self.schedule.continuous_interval_seconds}s")
logger.info(f"Hourly timeframes: {[tf.value for tf in self.schedule.hourly_timeframes]}")
def _initialize_schedules(self):
"""Initialize inference schedules for all symbols"""
current_time = datetime.now()
for symbol in self.symbols:
self.last_continuous_inference[symbol] = current_time
self.last_hourly_inference[symbol] = current_time
# Schedule next hourly inference at the top of the next hour
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
self.next_hourly_inference[symbol] = next_hour
def register_model_inference_function(self, model_name: str, inference_func: Callable):
"""
Register a model's inference function
Args:
model_name: Name of the model
inference_func: Async function that takes InferenceContext and returns prediction
"""
self.model_inference_functions[model_name] = inference_func
logger.info(f"Registered inference function for model: {model_name}")
async def start_coordination(self):
"""Start the inference coordination system"""
if self.running:
logger.warning("Inference coordination already running")
return
self.running = True
logger.info("Starting timeframe inference coordination")
# Start continuous inference tasks for each symbol
for symbol in self.symbols:
task = asyncio.create_task(self._continuous_inference_loop(symbol))
self.inference_tasks.append(task)
# Start hourly inference scheduler
task = asyncio.create_task(self._hourly_inference_scheduler())
self.inference_tasks.append(task)
# Start reward evaluation loop
task = asyncio.create_task(self._reward_evaluation_loop())
self.inference_tasks.append(task)
logger.info(f"Started {len(self.inference_tasks)} inference coordination tasks")
async def stop_coordination(self):
"""Stop the inference coordination system"""
if not self.running:
return
self.running = False
logger.info("Stopping timeframe inference coordination")
# Cancel all tasks
for task in self.inference_tasks:
task.cancel()
# Wait for tasks to complete
await asyncio.gather(*self.inference_tasks, return_exceptions=True)
self.inference_tasks.clear()
logger.info("Inference coordination stopped")
async def _continuous_inference_loop(self, symbol: str):
"""
Continuous inference loop for a specific symbol
Args:
symbol: Trading symbol to run inference for
"""
logger.info(f"Starting continuous inference loop for {symbol}")
while self.running:
try:
current_time = datetime.now()
# Check if it's time for continuous inference
last_inference = self.last_continuous_inference[symbol]
time_since_last = (current_time - last_inference).total_seconds()
if time_since_last >= self.schedule.continuous_interval_seconds:
# Run continuous inference on primary timeframe (1s)
context = InferenceContext(
symbol=symbol,
timeframe=TimeFrame.SECONDS_1,
timestamp=current_time,
target_timeframe=TimeFrame.SECONDS_1,
is_hourly_inference=False,
inference_type="continuous"
)
await self._execute_inference(context)
self.last_continuous_inference[symbol] = current_time
self.inference_stats['continuous_inferences'] += 1
# Sleep for a short interval to avoid busy waiting
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"Error in continuous inference loop for {symbol}: {e}")
await asyncio.sleep(1.0) # Wait longer on error
async def _hourly_inference_scheduler(self):
"""Scheduler for hourly multi-timeframe inference and timeframe-boundary triggers"""
logger.info("Starting hourly inference scheduler")
while self.running:
try:
current_time = datetime.now()
# Check if any symbol needs hourly inference
for symbol in self.symbols:
if current_time >= self.next_hourly_inference[symbol]:
await self._execute_hourly_inference(symbol, current_time)
# Schedule next hourly inference
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
self.next_hourly_inference[symbol] = next_hour
self.last_hourly_inference[symbol] = current_time
# Trigger at each new timeframe boundary: 1m, 1h, 1d
if current_time.second == 0:
# New minute
await self._execute_boundary_inference(symbol, current_time, TimeFrame.MINUTES_1)
if current_time.minute == 0 and current_time.second == 0:
# New hour
await self._execute_boundary_inference(symbol, current_time, TimeFrame.HOURS_1)
if current_time.hour == 0 and current_time.minute == 0 and current_time.second == 0:
# New day
await self._execute_boundary_inference(symbol, current_time, TimeFrame.DAYS_1)
# Sleep for 30 seconds between checks
await asyncio.sleep(30)
except Exception as e:
logger.error(f"Error in hourly inference scheduler: {e}")
await asyncio.sleep(60) # Wait longer on error
async def _execute_boundary_inference(self, symbol: str, timestamp: datetime, timeframe: TimeFrame):
"""Execute an inference exactly at timeframe boundary"""
try:
context = InferenceContext(
symbol=symbol,
timeframe=timeframe,
timestamp=timestamp,
target_timeframe=timeframe,
is_hourly_inference=False,
inference_type="boundary"
)
await self._execute_inference(context)
except Exception as e:
logger.debug(f"Boundary inference error for {symbol} {timeframe.value}: {e}")
async def _execute_hourly_inference(self, symbol: str, timestamp: datetime):
"""
Execute hourly multi-timeframe inference for a symbol
Args:
symbol: Trading symbol
timestamp: Current timestamp
"""
logger.info(f"Executing hourly multi-timeframe inference for {symbol}")
# Run inference for each timeframe
for timeframe in self.schedule.hourly_timeframes:
context = InferenceContext(
symbol=symbol,
timeframe=timeframe,
timestamp=timestamp,
target_timeframe=timeframe,
is_hourly_inference=True,
inference_type="hourly"
)
await self._execute_inference(context)
self.inference_stats['hourly_inferences'] += 1
# Small delay between timeframe inferences
await asyncio.sleep(0.5)
async def _execute_inference(self, context: InferenceContext):
"""
Execute inference for a specific context
Args:
context: Inference context containing all necessary information
"""
start_time = time.time()
try:
# Run inference for all registered models
for model_name, inference_func in self.model_inference_functions.items():
try:
# Execute model inference
prediction = await inference_func(context)
if prediction is not None:
# Add prediction to reward calculator
prediction_id = self.reward_calculator.add_prediction(
symbol=context.symbol,
timeframe=context.target_timeframe,
predicted_price=prediction.get('predicted_price', 0.0),
predicted_direction=prediction.get('direction', 0),
confidence=prediction.get('confidence', 0.0),
current_price=prediction.get('current_price', 0.0),
model_name=model_name,
predicted_return=prediction.get('predicted_return'),
state_vector=prediction.get('model_state') or prediction.get('model_features')
)
logger.debug(f"Added prediction {prediction_id} from {model_name} "
f"for {context.symbol} {context.target_timeframe.value}")
except Exception as e:
logger.error(f"Error running inference for model {model_name}: {e}")
self.inference_stats['failed_inferences'] += 1
# Update inference timing stats
inference_time_ms = (time.time() - start_time) * 1000
self._update_inference_timing(inference_time_ms)
except Exception as e:
logger.error(f"Error executing inference for context {context}: {e}")
self.inference_stats['failed_inferences'] += 1
def _update_inference_timing(self, inference_time_ms: float):
"""Update inference timing statistics"""
total_inferences = (self.inference_stats['continuous_inferences'] +
self.inference_stats['hourly_inferences'])
if total_inferences > 0:
current_avg = self.inference_stats['average_inference_time_ms']
new_avg = ((current_avg * (total_inferences - 1)) + inference_time_ms) / total_inferences
self.inference_stats['average_inference_time_ms'] = new_avg
async def _reward_evaluation_loop(self):
"""Continuous loop for evaluating prediction rewards"""
logger.info("Starting reward evaluation loop")
while self.running:
try:
# Update price cache if data provider available
if self.data_provider:
# DataProvider.get_current_price is synchronous; do not await
await self._update_price_cache()
# Evaluate predictions and get training data
for symbol in self.symbols:
evaluation_results = self.reward_calculator.evaluate_predictions(symbol)
if symbol in evaluation_results and evaluation_results[symbol]:
logger.debug(f"Evaluated {len(evaluation_results[symbol])} predictions for {symbol}")
# Here you could trigger training for models that have new evaluated predictions
await self._trigger_model_training(symbol, evaluation_results[symbol])
# Sleep for evaluation interval
await asyncio.sleep(10) # Evaluate every 10 seconds
except Exception as e:
logger.error(f"Error in reward evaluation loop: {e}")
await asyncio.sleep(30) # Wait longer on error
async def _update_price_cache(self):
"""Update price cache with current market prices"""
try:
for symbol in self.symbols:
# Get current price from data provider
if hasattr(self.data_provider, 'get_current_price'):
current_price = self.data_provider.get_current_price(symbol)
if current_price:
self.reward_calculator.update_price(symbol, current_price)
except Exception as e:
logger.debug(f"Error updating price cache: {e}")
async def _trigger_model_training(self, symbol: str, evaluation_results: List[Any]):
"""
Trigger model training based on evaluation results
Args:
symbol: Trading symbol
evaluation_results: List of (prediction, reward) tuples
"""
try:
# Group by model and timeframe for targeted training
training_groups = {}
for prediction_record, reward in evaluation_results:
model_name = prediction_record.model_name
timeframe = prediction_record.timeframe
key = f"{model_name}_{timeframe.value}"
if key not in training_groups:
training_groups[key] = []
training_groups[key].append((prediction_record, reward))
# Trigger training for each group
for group_key, training_data in training_groups.items():
model_name, timeframe_str = group_key.split('_', 1)
timeframe = TimeFrame(timeframe_str)
logger.info(f"Triggering training for {model_name} on {symbol} {timeframe.value} "
f"with {len(training_data)} samples")
# Here you would call the specific model's training function
# This is a placeholder - you'll need to implement the actual training calls
await self._call_model_training(model_name, symbol, timeframe, training_data)
except Exception as e:
logger.error(f"Error triggering model training: {e}")
async def _call_model_training(self, model_name: str, symbol: str,
timeframe: TimeFrame, training_data: List[Any]):
"""
Call model-specific training function
Args:
model_name: Name of the model to train
symbol: Trading symbol
timeframe: Timeframe for training
training_data: List of (prediction, reward) tuples
"""
# This is a placeholder for model-specific training calls
# You'll need to implement this based on your specific model interfaces
logger.debug(f"Training call for {model_name}: {len(training_data)} samples")
def get_inference_statistics(self) -> Dict[str, Any]:
"""Get inference coordination statistics"""
with self.lock:
stats = self.inference_stats.copy()
# Add scheduling information
stats['symbols'] = self.symbols
stats['continuous_interval_seconds'] = self.schedule.continuous_interval_seconds
stats['hourly_timeframes'] = [tf.value for tf in self.schedule.hourly_timeframes]
stats['next_hourly_inferences'] = {
symbol: timestamp.isoformat()
for symbol, timestamp in self.next_hourly_inference.items()
}
# Add accuracy summary from reward calculator
stats['accuracy_summary'] = self.reward_calculator.get_accuracy_summary()
return stats
def force_hourly_inference(self, symbol: str = None):
"""
Force immediate hourly inference for symbol(s)
Args:
symbol: Specific symbol (None for all symbols)
"""
symbols_to_process = [symbol] if symbol else self.symbols
async def _force_inference():
current_time = datetime.now()
for sym in symbols_to_process:
await self._execute_hourly_inference(sym, current_time)
# Schedule the inference
if self.running:
asyncio.create_task(_force_inference())
else:
logger.warning("Cannot force inference - coordinator not running")
def get_prediction_history(self, symbol: str, timeframe: TimeFrame,
max_samples: int = 50) -> List[Any]:
"""
Get prediction history for training
Args:
symbol: Trading symbol
timeframe: Specific timeframe
max_samples: Maximum samples to return
Returns:
List of training samples
"""
return self.reward_calculator.get_training_data(symbol, timeframe, max_samples)