refactoring. inference real data triggers

This commit is contained in:
Dobromir Popov
2025-12-09 11:59:15 +02:00
parent 1c1ebf6d7e
commit 992d6de25b
9 changed files with 1970 additions and 224 deletions

View File

@@ -166,6 +166,16 @@ class RealTrainingAdapter:
import threading
self._training_lock = threading.Lock()
# Use orchestrator's inference training coordinator (if available)
# This reduces duplication and centralizes coordination logic
if orchestrator and hasattr(orchestrator, 'inference_training_coordinator'):
self.training_coordinator = orchestrator.inference_training_coordinator
if self.training_coordinator:
# Subscribe to training events
self._subscribe_to_training_events()
else:
self.training_coordinator = None
# Real-time training tracking
self.realtime_training_metrics = {
'total_steps': 0,
@@ -187,6 +197,279 @@ class RealTrainingAdapter:
logger.info("RealTrainingAdapter initialized - NO SIMULATION, REAL TRAINING ONLY")
# Implement TrainingEventSubscriber interface
def on_candle_completion(self, event, inference_ref):
"""
Called when a candle completes - train on stored inference frame with actual result.
This uses the reference-based system: inference data is retrieved from DuckDB
using the reference, not copied.
"""
if not inference_ref or not self.training_coordinator:
return
try:
# Retrieve inference data from DuckDB using reference
model_inputs = self.training_coordinator.get_inference_data(inference_ref)
if not model_inputs:
logger.warning(f"Could not retrieve inference data for {inference_ref.inference_id}")
return
# Create training batch with actual candle
batch = self._create_training_batch_from_inference(
model_inputs, event.ohlcv, inference_ref
)
if not batch:
return
# Train model (backprop for Transformer)
self._train_on_inference_batch(batch, inference_ref)
except Exception as e:
logger.error(f"Error in candle completion training: {e}", exc_info=True)
def on_pivot_event(self, event, inference_refs):
"""
Called when a pivot point is detected - train on matching inference frames.
This handles event-based training where we don't know when the pivot will occur.
"""
if not inference_refs or not self.training_coordinator:
return
try:
for inference_ref in inference_refs:
# Retrieve inference data
model_inputs = self.training_coordinator.get_inference_data(inference_ref)
if not model_inputs:
continue
# Create training batch with pivot result
batch = self._create_pivot_training_batch(model_inputs, event, inference_ref)
if not batch:
continue
# Train model
self._train_on_inference_batch(batch, inference_ref)
except Exception as e:
logger.error(f"Error in pivot event training: {e}", exc_info=True)
def _create_training_batch_from_inference(self, model_inputs: Dict, actual_ohlcv: Dict,
inference_ref) -> Optional[Dict]:
"""Create training batch from inference inputs and actual candle result"""
try:
import torch
# Copy model inputs
batch = {k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in model_inputs.items()}
# Get device
device = next(iter(batch.values())).device if batch else torch.device('cpu')
# Normalize actual candle using stored params
timeframe = inference_ref.timeframe
if timeframe in inference_ref.norm_params:
params = inference_ref.norm_params[timeframe]
price_min = params['price_min']
price_max = params['price_max']
vol_min = params['volume_min']
vol_max = params['volume_max']
# Normalize actual OHLCV
normalized_candle = [
(actual_ohlcv['open'] - price_min) / (price_max - price_min),
(actual_ohlcv['high'] - price_min) / (price_max - price_min),
(actual_ohlcv['low'] - price_min) / (price_max - price_min),
(actual_ohlcv['close'] - price_min) / (price_max - price_min),
(actual_ohlcv['volume'] - vol_min) / (vol_max - vol_min) if vol_max > vol_min else 0.0
]
# Add target candle to batch
target_key = f'future_candle_{timeframe}'
batch[target_key] = torch.tensor([normalized_candle], dtype=torch.float32, device=device)
# Add action target (determine from price movement)
price_change = (actual_ohlcv['close'] - actual_ohlcv['open']) / actual_ohlcv['open']
if price_change > 0.0005: # 0.05% up
action = 1 # BUY
elif price_change < -0.0005: # 0.05% down
action = 2 # SELL
else:
action = 0 # HOLD
batch['actions'] = torch.tensor([[action]], dtype=torch.long, device=device)
return batch
return None
except Exception as e:
logger.error(f"Error creating training batch from inference: {e}", exc_info=True)
return None
def _create_pivot_training_batch(self, model_inputs: Dict, pivot_event, inference_ref) -> Optional[Dict]:
"""Create training batch from inference inputs and pivot event"""
try:
import torch
# Copy model inputs
batch = {k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in model_inputs.items()}
# Get device
device = next(iter(batch.values())).device if batch else torch.device('cpu')
# Determine action from pivot type
# L2L, L3L, etc. -> BUY (support levels)
# L2H, L3H, etc. -> SELL (resistance levels)
if pivot_event.pivot_type.endswith('L'):
action = 1 # BUY
elif pivot_event.pivot_type.endswith('H'):
action = 2 # SELL
else:
action = 0 # HOLD
batch['actions'] = torch.tensor([[action]], dtype=torch.long, device=device)
# For pivot training, we don't have a target candle, so we use the pivot price
# as a reference point for training
# This is a simplified approach - could be enhanced with pivot-based targets
return batch
except Exception as e:
logger.error(f"Error creating pivot training batch: {e}", exc_info=True)
return None
def _train_on_inference_batch(self, batch: Dict, inference_ref) -> None:
"""Train model on inference batch (uses stored inference frame)"""
try:
if not self.orchestrator:
return
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if not trainer:
return
# Train with lock protection
import torch
with self._training_lock:
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False)
if result:
loss = result.get('total_loss', 0)
accuracy = result.get('accuracy', 0)
# Update metrics
self.realtime_training_metrics['total_steps'] += 1
self.realtime_training_metrics['total_loss'] += loss
self.realtime_training_metrics['total_accuracy'] += accuracy
self.realtime_training_metrics['losses'].append(loss)
self.realtime_training_metrics['accuracies'].append(accuracy)
if len(self.realtime_training_metrics['losses']) > 100:
self.realtime_training_metrics['losses'].pop(0)
self.realtime_training_metrics['accuracies'].pop(0)
logger.info(f"Trained on inference frame {inference_ref.inference_id}: Loss={loss:.4f}, Acc={accuracy:.2%}")
except Exception as e:
logger.error(f"Error training on inference batch: {e}", exc_info=True)
def _register_inference_frame(self, session: Dict, symbol: str, timeframe: str,
prediction: Dict, data_provider, norm_params: Dict = None) -> None:
"""
Register inference frame reference with coordinator.
Stores reference (timestamp range) instead of copying 600 candles.
This method stores norm_params in the reference for efficient retrieval.
When training is triggered, data is retrieved from DuckDB using the reference.
Args:
session: Inference session
symbol: Trading symbol
timeframe: Timeframe
prediction: Prediction dict from model
data_provider: Data provider instance
norm_params: Normalization parameters (optional, will be calculated if not provided)
"""
if not self.training_coordinator:
return
try:
from ANNOTATE.core.inference_training_system import InferenceFrameReference
from datetime import datetime, timezone, timedelta
import uuid
# Get current time and calculate data range
current_time = datetime.now(timezone.utc)
data_range_end = current_time
# Calculate start time for 600 candles (approximate)
timeframe_seconds = {'1s': 1, '1m': 60, '5m': 300, '15m': 900, '1h': 3600, '1d': 86400}.get(timeframe, 60)
data_range_start = current_time - timedelta(seconds=600 * timeframe_seconds)
# Use provided norm_params or calculate if not available
if not norm_params:
norm_params = {}
# Calculate target timestamp (next candle close time)
# For 1m timeframe, next candle closes in 1 minute
target_timestamp = current_time + timedelta(seconds=timeframe_seconds)
# Create inference frame reference
inference_ref = InferenceFrameReference(
inference_id=str(uuid.uuid4()),
symbol=symbol,
timeframe=timeframe,
prediction_timestamp=current_time,
target_timestamp=target_timestamp,
data_range_start=data_range_start,
data_range_end=data_range_end,
norm_params=norm_params, # Stored for efficient retrieval
predicted_action=prediction.get('action'),
predicted_candle=prediction.get('predicted_candle'),
confidence=prediction.get('confidence', 0.0)
)
# Register with coordinator
self.training_coordinator.register_inference_frame(inference_ref)
logger.debug(f"Registered inference frame: {inference_ref.inference_id} for {symbol} {timeframe} (target: {target_timestamp})")
except Exception as e:
logger.warning(f"Could not register inference frame: {e}", exc_info=True)
def _subscribe_to_training_events(self):
"""Subscribe to training events via orchestrator's coordinator"""
if not self.training_coordinator:
return
try:
# Subscribe to candle completion for primary symbol/timeframe
primary_symbol = getattr(self.orchestrator, 'symbol', 'ETH/USDT')
primary_timeframe = '1m' # Default timeframe
self.training_coordinator.subscribe_to_candle_completion(
self, symbol=primary_symbol, timeframe=primary_timeframe
)
# Subscribe to pivot events (L2L, L2H, L3L, L3H)
self.training_coordinator.subscribe_to_pivot_events(
self, symbol=primary_symbol, timeframe=primary_timeframe,
pivot_types=['L2L', 'L2H', 'L3L', 'L3H']
)
logger.info(f"Subscribed to training events: {primary_symbol} {primary_timeframe}")
except Exception as e:
logger.warning(f"Could not subscribe to training events: {e}")
def _import_training_systems(self):
"""Import real training system implementations"""
try:
@@ -3056,7 +3339,10 @@ class RealTrainingAdapter:
'total_pnl': 0.0,
'win_count': 0,
'loss_count': 0,
'total_trades': 0
'total_trades': 0,
# Inference input cache: stores input data frames for later training
# Key: candle_timestamp (str), Value: {'model_inputs': Dict, 'norm_params': Dict, 'predicted_candle': Dict}
'inference_input_cache': {}
}
training_mode = "per-candle" if train_every_candle else ("pivot-based" if enable_live_training else "inference-only")
@@ -3128,8 +3414,177 @@ class RealTrainingAdapter:
all_signals.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
return all_signals[:limit]
def _make_realtime_prediction(self, model_name: str, symbol: str, data_provider) -> Dict:
"""Make a prediction using the specified model"""
def _make_realtime_prediction_with_cache(self, model_name: str, symbol: str, data_provider, session: Dict) -> Tuple[Dict, bool]:
"""
DEPRECATED: Use _make_realtime_prediction + _register_inference_frame instead.
This method is kept for backward compatibility but should be removed.
"""
# Just call the regular prediction method
prediction = self._make_realtime_prediction(model_name, symbol, data_provider)
return prediction, False
"""
Make a prediction and store input data frame for later training
Returns:
Tuple of (prediction_dict, stored_inputs: bool)
"""
try:
if model_name == 'Transformer' and self.orchestrator:
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if trainer and trainer.model:
# Get recent market data
market_data, norm_params = self._get_realtime_market_data(symbol, data_provider)
if not market_data:
return None, False
# Get current candle timestamp for cache key
timeframe = session.get('timeframe', '1m')
df_current = data_provider.get_historical_data(symbol, timeframe, limit=1)
if df_current is not None and len(df_current) > 0:
current_timestamp = str(df_current.index[-1])
# Store input data frame for later training (convert tensors to CPU for storage)
import torch
cached_inputs = {
'model_inputs': {k: v.cpu().clone() if isinstance(v, torch.Tensor) else v
for k, v in market_data.items()},
'norm_params': norm_params,
'timestamp': current_timestamp,
'symbol': symbol,
'timeframe': timeframe
}
# Store in session cache (keep last 50 to prevent memory bloat)
cache = session.get('inference_input_cache', {})
cache[current_timestamp] = cached_inputs
# Keep only last 50 entries
if len(cache) > 50:
# Remove oldest entries
sorted_keys = sorted(cache.keys())
for key in sorted_keys[:-50]:
del cache[key]
session['inference_input_cache'] = cache
logger.debug(f"Stored inference inputs for {symbol} {timeframe} @ {current_timestamp}")
# Make prediction
import torch
with torch.no_grad():
trainer.model.eval()
outputs = trainer.model(**market_data)
# Extract action
action_probs = outputs.get('action_probs')
if action_probs is not None:
# Handle different tensor shapes: [batch, 3] or [3]
if action_probs.dim() == 1:
# Shape [3] - single prediction
action_idx = torch.argmax(action_probs, dim=0).item()
confidence = action_probs[action_idx].item()
else:
# Shape [batch, 3] - take first batch item
action_idx = torch.argmax(action_probs[0], dim=0).item()
confidence = action_probs[0, action_idx].item()
# Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL)
actions = ['HOLD', 'BUY', 'SELL']
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
# Handle predicted candles - DENORMALIZE them
predicted_candles_raw = {}
if 'next_candles' in outputs:
for tf, tensor in outputs['next_candles'].items():
predicted_candles_raw[tf] = tensor.detach().cpu().numpy().tolist()
# Denormalize if we have params
predicted_candles_denorm = {}
if predicted_candles_raw and norm_params:
for tf, raw_candle in predicted_candles_raw.items():
# raw_candle is [1, 5] list
if tf in norm_params:
params = norm_params[tf]
price_min = params['price_min']
price_max = params['price_max']
vol_min = params['volume_min']
vol_max = params['volume_max']
# Denormalize [Open, High, Low, Close, Volume]
# Note: raw_candle[0] is the list of 5 values
candle_values = raw_candle[0]
# Ensure all values are Python floats (not numpy scalars or tensors)
def to_float(v):
if hasattr(v, 'item'):
return float(v.item())
return float(v)
denorm_candle = [
to_float(candle_values[0] * (price_max - price_min) + price_min), # Open
to_float(candle_values[1] * (price_max - price_min) + price_min), # High
to_float(candle_values[2] * (price_max - price_min) + price_min), # Low
to_float(candle_values[3] * (price_max - price_min) + price_min), # Close
to_float(candle_values[4] * (vol_max - vol_min) + vol_min) # Volume
]
predicted_candles_denorm[tf] = denorm_candle
# Calculate predicted price from candle close (ensure Python float)
predicted_price = None
if '1m' in predicted_candles_denorm:
close_val = predicted_candles_denorm['1m'][3]
predicted_price = float(close_val.item() if hasattr(close_val, 'item') else close_val)
elif '1s' in predicted_candles_denorm:
close_val = predicted_candles_denorm['1s'][3]
predicted_price = float(close_val.item() if hasattr(close_val, 'item') else close_val)
elif outputs.get('price_prediction') is not None:
# Fallback to price_prediction head if available (normalized)
# This would need separate denormalization based on reference price
pass
result_dict = {
'action': action,
'confidence': confidence,
'predicted_price': predicted_price,
'predicted_candle': predicted_candles_denorm
}
# Include trend vector if available
if 'trend_vector' in outputs:
result_dict['trend_vector'] = outputs['trend_vector']
# DEBUG: Log if we have predicted candles
if predicted_candles_denorm:
logger.info(f"Generated prediction with {len(predicted_candles_denorm)} timeframe candles: {list(predicted_candles_denorm.keys())}")
else:
logger.warning("No predicted candles in model output!")
return result_dict, True
return None, False
except Exception as e:
logger.debug(f"Error making realtime prediction: {e}")
import traceback
logger.debug(traceback.format_exc())
return None, False
def _make_realtime_prediction(self, model_name: str, symbol: str, data_provider) -> Tuple[Dict, Dict]:
"""
Make a prediction and return both prediction and market data for reference storage.
Returns:
Tuple of (prediction_dict, market_data_dict with norm_params)
"""
# Get market data (needed for reference storage)
market_data, norm_params = self._get_realtime_market_data(symbol, data_provider)
if not market_data:
return None, None
# Make prediction (original logic)
prediction = self._make_realtime_prediction_internal(model_name, symbol, data_provider, market_data, norm_params)
return prediction, {'market_data': market_data, 'norm_params': norm_params}
def _make_realtime_prediction_internal(self, model_name: str, symbol: str, data_provider,
market_data: Dict, norm_params: Dict) -> Dict:
"""Make a prediction using the specified model (backward compatibility)"""
try:
if model_name == 'Transformer' and self.orchestrator:
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
@@ -3223,12 +3678,6 @@ class RealTrainingAdapter:
if 'trend_vector' in outputs:
result_dict['trend_vector'] = outputs['trend_vector']
# DEBUG: Log if we have predicted candles
if predicted_candles_denorm:
logger.info(f"🔮 Generated prediction with {len(predicted_candles_denorm)} timeframe candles: {list(predicted_candles_denorm.keys())}")
else:
logger.warning("⚠️ No predicted candles in model output!")
return result_dict
return None
@@ -3370,13 +3819,120 @@ class RealTrainingAdapter:
# Get the completed candle (second to last) and next candle
completed_candle = df.iloc[-2]
next_candle = df.iloc[-1]
completed_timestamp = str(completed_candle.name)
# Get action from session (set by app's training strategy)
action_label = session.get('pending_action')
if not action_label:
return {'success': False, 'error': 'No pending_action in session'}
# Fetch market state for training
# CRITICAL: Try to use stored inference input data frame if available
# This ensures we train on exactly what the model saw during inference
cache = session.get('inference_input_cache', {})
stored_inputs = cache.get(completed_timestamp)
if stored_inputs:
# Use stored input data frame from inference
logger.info(f"Using stored inference inputs for training on {symbol} {timeframe} @ {completed_timestamp}")
# Get actual candle data for target
actual_candle = [
float(next_candle['open']),
float(next_candle['high']),
float(next_candle['low']),
float(next_candle['close']),
float(next_candle['volume'])
]
# Create training batch from stored inputs
import torch
# Get device from orchestrator
device = getattr(self.orchestrator, 'device', torch.device('cpu'))
if hasattr(self.orchestrator, 'primary_transformer_trainer') and self.orchestrator.primary_transformer_trainer:
if hasattr(self.orchestrator.primary_transformer_trainer.model, 'device'):
device = next(self.orchestrator.primary_transformer_trainer.model.parameters()).device
# Move stored inputs back to device (they were stored on CPU)
batch = {}
for k, v in stored_inputs['model_inputs'].items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(device)
else:
batch[k] = v
# Add actual candle as target (normalize using stored params)
norm_params = stored_inputs['norm_params']
if timeframe in norm_params:
params = norm_params[timeframe]
price_min = params['price_min']
price_max = params['price_max']
vol_min = params['volume_min']
vol_max = params['volume_max']
# Normalize actual candle
normalized_candle = [
(actual_candle[0] - price_min) / (price_max - price_min), # Open
(actual_candle[1] - price_min) / (price_max - price_min), # High
(actual_candle[2] - price_min) / (price_max - price_min), # Low
(actual_candle[3] - price_min) / (price_max - price_min), # Close
(actual_candle[4] - vol_min) / (vol_max - vol_min) if vol_max > vol_min else 0.0 # Volume
]
# Add target candle to batch
target_key = f'future_candle_{timeframe}'
batch[target_key] = torch.tensor([normalized_candle], dtype=torch.float32, device=device)
# Add action target
action_map = {'HOLD': 0, 'BUY': 1, 'SELL': 2}
batch['actions'] = torch.tensor([[action_map.get(action_label, 0)]], dtype=torch.long, device=device)
# Train directly on batch
model_name = session['model_name']
if model_name == 'Transformer':
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if trainer:
with self._training_lock:
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False)
if result:
loss = result.get('total_loss', 0)
accuracy = result.get('accuracy', 0)
# Update metrics
self.realtime_training_metrics['total_steps'] += 1
self.realtime_training_metrics['total_loss'] += loss
self.realtime_training_metrics['total_accuracy'] += accuracy
self.realtime_training_metrics['losses'].append(loss)
self.realtime_training_metrics['accuracies'].append(accuracy)
if len(self.realtime_training_metrics['losses']) > 100:
self.realtime_training_metrics['losses'].pop(0)
self.realtime_training_metrics['accuracies'].pop(0)
session['metrics']['loss'] = sum(self.realtime_training_metrics['losses']) / len(self.realtime_training_metrics['losses'])
session['metrics']['accuracy'] = sum(self.realtime_training_metrics['accuracies']) / len(self.realtime_training_metrics['accuracies'])
session['metrics']['steps'] = self.realtime_training_metrics['total_steps']
# Remove from cache after training
if completed_timestamp in cache:
del cache[completed_timestamp]
logger.info(f"Trained on stored inference inputs: {symbol} {timeframe} @ {completed_timestamp} action={action_label} (Loss: {loss:.4f}, Acc: {accuracy:.2%})")
return {
'success': True,
'loss': session['metrics']['loss'],
'accuracy': session['metrics']['accuracy'],
'training_steps': session['metrics']['steps'],
'used_stored_inputs': True
}
# Fall through to regular training if stored inputs failed
logger.warning(f"Failed to use stored inputs, falling back to fresh data")
# Fallback: Fetch fresh market state for training (original behavior)
market_state = self._fetch_market_state_for_candle(symbol, completed_candle.name, data_provider)
# Calculate price change
@@ -3411,7 +3967,8 @@ class RealTrainingAdapter:
'success': True,
'loss': session['metrics']['loss'],
'accuracy': session['metrics']['accuracy'],
'training_steps': session['metrics']['steps']
'training_steps': session['metrics']['steps'],
'used_stored_inputs': False
}
return {'success': False, 'error': f'Unsupported model: {model_name}'}
@@ -3939,6 +4496,14 @@ class RealTrainingAdapter:
# Make prediction using the model
prediction = self._make_realtime_prediction(model_name, symbol, data_provider)
# Register inference frame reference for later training when actual candle arrives
# This stores a reference (timestamp range) instead of copying 600 candles
# The reference allows us to retrieve the exact data from DuckDB when training
if prediction and self.training_coordinator:
# Get norm_params for storage in reference
_, norm_params = self._get_realtime_market_data(symbol, data_provider)
self._register_inference_frame(session, symbol, timeframe, prediction, data_provider, norm_params)
if prediction:
# Store signal
signal = {