refactoring. inference real data triggers
This commit is contained in:
@@ -166,6 +166,16 @@ class RealTrainingAdapter:
|
||||
import threading
|
||||
self._training_lock = threading.Lock()
|
||||
|
||||
# Use orchestrator's inference training coordinator (if available)
|
||||
# This reduces duplication and centralizes coordination logic
|
||||
if orchestrator and hasattr(orchestrator, 'inference_training_coordinator'):
|
||||
self.training_coordinator = orchestrator.inference_training_coordinator
|
||||
if self.training_coordinator:
|
||||
# Subscribe to training events
|
||||
self._subscribe_to_training_events()
|
||||
else:
|
||||
self.training_coordinator = None
|
||||
|
||||
# Real-time training tracking
|
||||
self.realtime_training_metrics = {
|
||||
'total_steps': 0,
|
||||
@@ -187,6 +197,279 @@ class RealTrainingAdapter:
|
||||
|
||||
logger.info("RealTrainingAdapter initialized - NO SIMULATION, REAL TRAINING ONLY")
|
||||
|
||||
# Implement TrainingEventSubscriber interface
|
||||
def on_candle_completion(self, event, inference_ref):
|
||||
"""
|
||||
Called when a candle completes - train on stored inference frame with actual result.
|
||||
|
||||
This uses the reference-based system: inference data is retrieved from DuckDB
|
||||
using the reference, not copied.
|
||||
"""
|
||||
if not inference_ref or not self.training_coordinator:
|
||||
return
|
||||
|
||||
try:
|
||||
# Retrieve inference data from DuckDB using reference
|
||||
model_inputs = self.training_coordinator.get_inference_data(inference_ref)
|
||||
if not model_inputs:
|
||||
logger.warning(f"Could not retrieve inference data for {inference_ref.inference_id}")
|
||||
return
|
||||
|
||||
# Create training batch with actual candle
|
||||
batch = self._create_training_batch_from_inference(
|
||||
model_inputs, event.ohlcv, inference_ref
|
||||
)
|
||||
|
||||
if not batch:
|
||||
return
|
||||
|
||||
# Train model (backprop for Transformer)
|
||||
self._train_on_inference_batch(batch, inference_ref)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in candle completion training: {e}", exc_info=True)
|
||||
|
||||
def on_pivot_event(self, event, inference_refs):
|
||||
"""
|
||||
Called when a pivot point is detected - train on matching inference frames.
|
||||
|
||||
This handles event-based training where we don't know when the pivot will occur.
|
||||
"""
|
||||
if not inference_refs or not self.training_coordinator:
|
||||
return
|
||||
|
||||
try:
|
||||
for inference_ref in inference_refs:
|
||||
# Retrieve inference data
|
||||
model_inputs = self.training_coordinator.get_inference_data(inference_ref)
|
||||
if not model_inputs:
|
||||
continue
|
||||
|
||||
# Create training batch with pivot result
|
||||
batch = self._create_pivot_training_batch(model_inputs, event, inference_ref)
|
||||
|
||||
if not batch:
|
||||
continue
|
||||
|
||||
# Train model
|
||||
self._train_on_inference_batch(batch, inference_ref)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pivot event training: {e}", exc_info=True)
|
||||
|
||||
def _create_training_batch_from_inference(self, model_inputs: Dict, actual_ohlcv: Dict,
|
||||
inference_ref) -> Optional[Dict]:
|
||||
"""Create training batch from inference inputs and actual candle result"""
|
||||
try:
|
||||
import torch
|
||||
|
||||
# Copy model inputs
|
||||
batch = {k: v.clone() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in model_inputs.items()}
|
||||
|
||||
# Get device
|
||||
device = next(iter(batch.values())).device if batch else torch.device('cpu')
|
||||
|
||||
# Normalize actual candle using stored params
|
||||
timeframe = inference_ref.timeframe
|
||||
if timeframe in inference_ref.norm_params:
|
||||
params = inference_ref.norm_params[timeframe]
|
||||
price_min = params['price_min']
|
||||
price_max = params['price_max']
|
||||
vol_min = params['volume_min']
|
||||
vol_max = params['volume_max']
|
||||
|
||||
# Normalize actual OHLCV
|
||||
normalized_candle = [
|
||||
(actual_ohlcv['open'] - price_min) / (price_max - price_min),
|
||||
(actual_ohlcv['high'] - price_min) / (price_max - price_min),
|
||||
(actual_ohlcv['low'] - price_min) / (price_max - price_min),
|
||||
(actual_ohlcv['close'] - price_min) / (price_max - price_min),
|
||||
(actual_ohlcv['volume'] - vol_min) / (vol_max - vol_min) if vol_max > vol_min else 0.0
|
||||
]
|
||||
|
||||
# Add target candle to batch
|
||||
target_key = f'future_candle_{timeframe}'
|
||||
batch[target_key] = torch.tensor([normalized_candle], dtype=torch.float32, device=device)
|
||||
|
||||
# Add action target (determine from price movement)
|
||||
price_change = (actual_ohlcv['close'] - actual_ohlcv['open']) / actual_ohlcv['open']
|
||||
if price_change > 0.0005: # 0.05% up
|
||||
action = 1 # BUY
|
||||
elif price_change < -0.0005: # 0.05% down
|
||||
action = 2 # SELL
|
||||
else:
|
||||
action = 0 # HOLD
|
||||
|
||||
batch['actions'] = torch.tensor([[action]], dtype=torch.long, device=device)
|
||||
|
||||
return batch
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating training batch from inference: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _create_pivot_training_batch(self, model_inputs: Dict, pivot_event, inference_ref) -> Optional[Dict]:
|
||||
"""Create training batch from inference inputs and pivot event"""
|
||||
try:
|
||||
import torch
|
||||
|
||||
# Copy model inputs
|
||||
batch = {k: v.clone() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in model_inputs.items()}
|
||||
|
||||
# Get device
|
||||
device = next(iter(batch.values())).device if batch else torch.device('cpu')
|
||||
|
||||
# Determine action from pivot type
|
||||
# L2L, L3L, etc. -> BUY (support levels)
|
||||
# L2H, L3H, etc. -> SELL (resistance levels)
|
||||
if pivot_event.pivot_type.endswith('L'):
|
||||
action = 1 # BUY
|
||||
elif pivot_event.pivot_type.endswith('H'):
|
||||
action = 2 # SELL
|
||||
else:
|
||||
action = 0 # HOLD
|
||||
|
||||
batch['actions'] = torch.tensor([[action]], dtype=torch.long, device=device)
|
||||
|
||||
# For pivot training, we don't have a target candle, so we use the pivot price
|
||||
# as a reference point for training
|
||||
# This is a simplified approach - could be enhanced with pivot-based targets
|
||||
|
||||
return batch
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating pivot training batch: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _train_on_inference_batch(self, batch: Dict, inference_ref) -> None:
|
||||
"""Train model on inference batch (uses stored inference frame)"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
return
|
||||
|
||||
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
||||
if not trainer:
|
||||
return
|
||||
|
||||
# Train with lock protection
|
||||
import torch
|
||||
with self._training_lock:
|
||||
with torch.enable_grad():
|
||||
trainer.model.train()
|
||||
result = trainer.train_step(batch, accumulate_gradients=False)
|
||||
|
||||
if result:
|
||||
loss = result.get('total_loss', 0)
|
||||
accuracy = result.get('accuracy', 0)
|
||||
|
||||
# Update metrics
|
||||
self.realtime_training_metrics['total_steps'] += 1
|
||||
self.realtime_training_metrics['total_loss'] += loss
|
||||
self.realtime_training_metrics['total_accuracy'] += accuracy
|
||||
|
||||
self.realtime_training_metrics['losses'].append(loss)
|
||||
self.realtime_training_metrics['accuracies'].append(accuracy)
|
||||
if len(self.realtime_training_metrics['losses']) > 100:
|
||||
self.realtime_training_metrics['losses'].pop(0)
|
||||
self.realtime_training_metrics['accuracies'].pop(0)
|
||||
|
||||
logger.info(f"Trained on inference frame {inference_ref.inference_id}: Loss={loss:.4f}, Acc={accuracy:.2%}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training on inference batch: {e}", exc_info=True)
|
||||
|
||||
def _register_inference_frame(self, session: Dict, symbol: str, timeframe: str,
|
||||
prediction: Dict, data_provider, norm_params: Dict = None) -> None:
|
||||
"""
|
||||
Register inference frame reference with coordinator.
|
||||
Stores reference (timestamp range) instead of copying 600 candles.
|
||||
|
||||
This method stores norm_params in the reference for efficient retrieval.
|
||||
When training is triggered, data is retrieved from DuckDB using the reference.
|
||||
|
||||
Args:
|
||||
session: Inference session
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
prediction: Prediction dict from model
|
||||
data_provider: Data provider instance
|
||||
norm_params: Normalization parameters (optional, will be calculated if not provided)
|
||||
"""
|
||||
if not self.training_coordinator:
|
||||
return
|
||||
|
||||
try:
|
||||
from ANNOTATE.core.inference_training_system import InferenceFrameReference
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import uuid
|
||||
|
||||
# Get current time and calculate data range
|
||||
current_time = datetime.now(timezone.utc)
|
||||
data_range_end = current_time
|
||||
|
||||
# Calculate start time for 600 candles (approximate)
|
||||
timeframe_seconds = {'1s': 1, '1m': 60, '5m': 300, '15m': 900, '1h': 3600, '1d': 86400}.get(timeframe, 60)
|
||||
data_range_start = current_time - timedelta(seconds=600 * timeframe_seconds)
|
||||
|
||||
# Use provided norm_params or calculate if not available
|
||||
if not norm_params:
|
||||
norm_params = {}
|
||||
|
||||
# Calculate target timestamp (next candle close time)
|
||||
# For 1m timeframe, next candle closes in 1 minute
|
||||
target_timestamp = current_time + timedelta(seconds=timeframe_seconds)
|
||||
|
||||
# Create inference frame reference
|
||||
inference_ref = InferenceFrameReference(
|
||||
inference_id=str(uuid.uuid4()),
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
prediction_timestamp=current_time,
|
||||
target_timestamp=target_timestamp,
|
||||
data_range_start=data_range_start,
|
||||
data_range_end=data_range_end,
|
||||
norm_params=norm_params, # Stored for efficient retrieval
|
||||
predicted_action=prediction.get('action'),
|
||||
predicted_candle=prediction.get('predicted_candle'),
|
||||
confidence=prediction.get('confidence', 0.0)
|
||||
)
|
||||
|
||||
# Register with coordinator
|
||||
self.training_coordinator.register_inference_frame(inference_ref)
|
||||
|
||||
logger.debug(f"Registered inference frame: {inference_ref.inference_id} for {symbol} {timeframe} (target: {target_timestamp})")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not register inference frame: {e}", exc_info=True)
|
||||
|
||||
def _subscribe_to_training_events(self):
|
||||
"""Subscribe to training events via orchestrator's coordinator"""
|
||||
if not self.training_coordinator:
|
||||
return
|
||||
|
||||
try:
|
||||
# Subscribe to candle completion for primary symbol/timeframe
|
||||
primary_symbol = getattr(self.orchestrator, 'symbol', 'ETH/USDT')
|
||||
primary_timeframe = '1m' # Default timeframe
|
||||
|
||||
self.training_coordinator.subscribe_to_candle_completion(
|
||||
self, symbol=primary_symbol, timeframe=primary_timeframe
|
||||
)
|
||||
|
||||
# Subscribe to pivot events (L2L, L2H, L3L, L3H)
|
||||
self.training_coordinator.subscribe_to_pivot_events(
|
||||
self, symbol=primary_symbol, timeframe=primary_timeframe,
|
||||
pivot_types=['L2L', 'L2H', 'L3L', 'L3H']
|
||||
)
|
||||
|
||||
logger.info(f"Subscribed to training events: {primary_symbol} {primary_timeframe}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not subscribe to training events: {e}")
|
||||
|
||||
def _import_training_systems(self):
|
||||
"""Import real training system implementations"""
|
||||
try:
|
||||
@@ -3056,7 +3339,10 @@ class RealTrainingAdapter:
|
||||
'total_pnl': 0.0,
|
||||
'win_count': 0,
|
||||
'loss_count': 0,
|
||||
'total_trades': 0
|
||||
'total_trades': 0,
|
||||
# Inference input cache: stores input data frames for later training
|
||||
# Key: candle_timestamp (str), Value: {'model_inputs': Dict, 'norm_params': Dict, 'predicted_candle': Dict}
|
||||
'inference_input_cache': {}
|
||||
}
|
||||
|
||||
training_mode = "per-candle" if train_every_candle else ("pivot-based" if enable_live_training else "inference-only")
|
||||
@@ -3128,8 +3414,177 @@ class RealTrainingAdapter:
|
||||
all_signals.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
|
||||
return all_signals[:limit]
|
||||
|
||||
def _make_realtime_prediction(self, model_name: str, symbol: str, data_provider) -> Dict:
|
||||
"""Make a prediction using the specified model"""
|
||||
def _make_realtime_prediction_with_cache(self, model_name: str, symbol: str, data_provider, session: Dict) -> Tuple[Dict, bool]:
|
||||
"""
|
||||
DEPRECATED: Use _make_realtime_prediction + _register_inference_frame instead.
|
||||
This method is kept for backward compatibility but should be removed.
|
||||
"""
|
||||
# Just call the regular prediction method
|
||||
prediction = self._make_realtime_prediction(model_name, symbol, data_provider)
|
||||
return prediction, False
|
||||
"""
|
||||
Make a prediction and store input data frame for later training
|
||||
|
||||
Returns:
|
||||
Tuple of (prediction_dict, stored_inputs: bool)
|
||||
"""
|
||||
try:
|
||||
if model_name == 'Transformer' and self.orchestrator:
|
||||
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
||||
if trainer and trainer.model:
|
||||
# Get recent market data
|
||||
market_data, norm_params = self._get_realtime_market_data(symbol, data_provider)
|
||||
if not market_data:
|
||||
return None, False
|
||||
|
||||
# Get current candle timestamp for cache key
|
||||
timeframe = session.get('timeframe', '1m')
|
||||
df_current = data_provider.get_historical_data(symbol, timeframe, limit=1)
|
||||
if df_current is not None and len(df_current) > 0:
|
||||
current_timestamp = str(df_current.index[-1])
|
||||
|
||||
# Store input data frame for later training (convert tensors to CPU for storage)
|
||||
import torch
|
||||
cached_inputs = {
|
||||
'model_inputs': {k: v.cpu().clone() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in market_data.items()},
|
||||
'norm_params': norm_params,
|
||||
'timestamp': current_timestamp,
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe
|
||||
}
|
||||
|
||||
# Store in session cache (keep last 50 to prevent memory bloat)
|
||||
cache = session.get('inference_input_cache', {})
|
||||
cache[current_timestamp] = cached_inputs
|
||||
# Keep only last 50 entries
|
||||
if len(cache) > 50:
|
||||
# Remove oldest entries
|
||||
sorted_keys = sorted(cache.keys())
|
||||
for key in sorted_keys[:-50]:
|
||||
del cache[key]
|
||||
session['inference_input_cache'] = cache
|
||||
logger.debug(f"Stored inference inputs for {symbol} {timeframe} @ {current_timestamp}")
|
||||
|
||||
# Make prediction
|
||||
import torch
|
||||
with torch.no_grad():
|
||||
trainer.model.eval()
|
||||
outputs = trainer.model(**market_data)
|
||||
|
||||
# Extract action
|
||||
action_probs = outputs.get('action_probs')
|
||||
if action_probs is not None:
|
||||
# Handle different tensor shapes: [batch, 3] or [3]
|
||||
if action_probs.dim() == 1:
|
||||
# Shape [3] - single prediction
|
||||
action_idx = torch.argmax(action_probs, dim=0).item()
|
||||
confidence = action_probs[action_idx].item()
|
||||
else:
|
||||
# Shape [batch, 3] - take first batch item
|
||||
action_idx = torch.argmax(action_probs[0], dim=0).item()
|
||||
confidence = action_probs[0, action_idx].item()
|
||||
|
||||
# Map to action string (must match training: 0=HOLD, 1=BUY, 2=SELL)
|
||||
actions = ['HOLD', 'BUY', 'SELL']
|
||||
action = actions[action_idx] if action_idx < len(actions) else 'HOLD'
|
||||
|
||||
# Handle predicted candles - DENORMALIZE them
|
||||
predicted_candles_raw = {}
|
||||
if 'next_candles' in outputs:
|
||||
for tf, tensor in outputs['next_candles'].items():
|
||||
predicted_candles_raw[tf] = tensor.detach().cpu().numpy().tolist()
|
||||
|
||||
# Denormalize if we have params
|
||||
predicted_candles_denorm = {}
|
||||
if predicted_candles_raw and norm_params:
|
||||
for tf, raw_candle in predicted_candles_raw.items():
|
||||
# raw_candle is [1, 5] list
|
||||
if tf in norm_params:
|
||||
params = norm_params[tf]
|
||||
price_min = params['price_min']
|
||||
price_max = params['price_max']
|
||||
vol_min = params['volume_min']
|
||||
vol_max = params['volume_max']
|
||||
|
||||
# Denormalize [Open, High, Low, Close, Volume]
|
||||
# Note: raw_candle[0] is the list of 5 values
|
||||
candle_values = raw_candle[0]
|
||||
|
||||
# Ensure all values are Python floats (not numpy scalars or tensors)
|
||||
def to_float(v):
|
||||
if hasattr(v, 'item'):
|
||||
return float(v.item())
|
||||
return float(v)
|
||||
|
||||
denorm_candle = [
|
||||
to_float(candle_values[0] * (price_max - price_min) + price_min), # Open
|
||||
to_float(candle_values[1] * (price_max - price_min) + price_min), # High
|
||||
to_float(candle_values[2] * (price_max - price_min) + price_min), # Low
|
||||
to_float(candle_values[3] * (price_max - price_min) + price_min), # Close
|
||||
to_float(candle_values[4] * (vol_max - vol_min) + vol_min) # Volume
|
||||
]
|
||||
predicted_candles_denorm[tf] = denorm_candle
|
||||
|
||||
# Calculate predicted price from candle close (ensure Python float)
|
||||
predicted_price = None
|
||||
if '1m' in predicted_candles_denorm:
|
||||
close_val = predicted_candles_denorm['1m'][3]
|
||||
predicted_price = float(close_val.item() if hasattr(close_val, 'item') else close_val)
|
||||
elif '1s' in predicted_candles_denorm:
|
||||
close_val = predicted_candles_denorm['1s'][3]
|
||||
predicted_price = float(close_val.item() if hasattr(close_val, 'item') else close_val)
|
||||
elif outputs.get('price_prediction') is not None:
|
||||
# Fallback to price_prediction head if available (normalized)
|
||||
# This would need separate denormalization based on reference price
|
||||
pass
|
||||
|
||||
result_dict = {
|
||||
'action': action,
|
||||
'confidence': confidence,
|
||||
'predicted_price': predicted_price,
|
||||
'predicted_candle': predicted_candles_denorm
|
||||
}
|
||||
|
||||
# Include trend vector if available
|
||||
if 'trend_vector' in outputs:
|
||||
result_dict['trend_vector'] = outputs['trend_vector']
|
||||
|
||||
# DEBUG: Log if we have predicted candles
|
||||
if predicted_candles_denorm:
|
||||
logger.info(f"Generated prediction with {len(predicted_candles_denorm)} timeframe candles: {list(predicted_candles_denorm.keys())}")
|
||||
else:
|
||||
logger.warning("No predicted candles in model output!")
|
||||
|
||||
return result_dict, True
|
||||
|
||||
return None, False
|
||||
except Exception as e:
|
||||
logger.debug(f"Error making realtime prediction: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return None, False
|
||||
|
||||
def _make_realtime_prediction(self, model_name: str, symbol: str, data_provider) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Make a prediction and return both prediction and market data for reference storage.
|
||||
|
||||
Returns:
|
||||
Tuple of (prediction_dict, market_data_dict with norm_params)
|
||||
"""
|
||||
# Get market data (needed for reference storage)
|
||||
market_data, norm_params = self._get_realtime_market_data(symbol, data_provider)
|
||||
if not market_data:
|
||||
return None, None
|
||||
|
||||
# Make prediction (original logic)
|
||||
prediction = self._make_realtime_prediction_internal(model_name, symbol, data_provider, market_data, norm_params)
|
||||
|
||||
return prediction, {'market_data': market_data, 'norm_params': norm_params}
|
||||
|
||||
def _make_realtime_prediction_internal(self, model_name: str, symbol: str, data_provider,
|
||||
market_data: Dict, norm_params: Dict) -> Dict:
|
||||
"""Make a prediction using the specified model (backward compatibility)"""
|
||||
try:
|
||||
if model_name == 'Transformer' and self.orchestrator:
|
||||
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
||||
@@ -3223,12 +3678,6 @@ class RealTrainingAdapter:
|
||||
if 'trend_vector' in outputs:
|
||||
result_dict['trend_vector'] = outputs['trend_vector']
|
||||
|
||||
# DEBUG: Log if we have predicted candles
|
||||
if predicted_candles_denorm:
|
||||
logger.info(f"🔮 Generated prediction with {len(predicted_candles_denorm)} timeframe candles: {list(predicted_candles_denorm.keys())}")
|
||||
else:
|
||||
logger.warning("⚠️ No predicted candles in model output!")
|
||||
|
||||
return result_dict
|
||||
|
||||
return None
|
||||
@@ -3370,13 +3819,120 @@ class RealTrainingAdapter:
|
||||
# Get the completed candle (second to last) and next candle
|
||||
completed_candle = df.iloc[-2]
|
||||
next_candle = df.iloc[-1]
|
||||
completed_timestamp = str(completed_candle.name)
|
||||
|
||||
# Get action from session (set by app's training strategy)
|
||||
action_label = session.get('pending_action')
|
||||
if not action_label:
|
||||
return {'success': False, 'error': 'No pending_action in session'}
|
||||
|
||||
# Fetch market state for training
|
||||
# CRITICAL: Try to use stored inference input data frame if available
|
||||
# This ensures we train on exactly what the model saw during inference
|
||||
cache = session.get('inference_input_cache', {})
|
||||
stored_inputs = cache.get(completed_timestamp)
|
||||
|
||||
if stored_inputs:
|
||||
# Use stored input data frame from inference
|
||||
logger.info(f"Using stored inference inputs for training on {symbol} {timeframe} @ {completed_timestamp}")
|
||||
|
||||
# Get actual candle data for target
|
||||
actual_candle = [
|
||||
float(next_candle['open']),
|
||||
float(next_candle['high']),
|
||||
float(next_candle['low']),
|
||||
float(next_candle['close']),
|
||||
float(next_candle['volume'])
|
||||
]
|
||||
|
||||
# Create training batch from stored inputs
|
||||
import torch
|
||||
# Get device from orchestrator
|
||||
device = getattr(self.orchestrator, 'device', torch.device('cpu'))
|
||||
if hasattr(self.orchestrator, 'primary_transformer_trainer') and self.orchestrator.primary_transformer_trainer:
|
||||
if hasattr(self.orchestrator.primary_transformer_trainer.model, 'device'):
|
||||
device = next(self.orchestrator.primary_transformer_trainer.model.parameters()).device
|
||||
|
||||
# Move stored inputs back to device (they were stored on CPU)
|
||||
batch = {}
|
||||
for k, v in stored_inputs['model_inputs'].items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
batch[k] = v.to(device)
|
||||
else:
|
||||
batch[k] = v
|
||||
|
||||
# Add actual candle as target (normalize using stored params)
|
||||
norm_params = stored_inputs['norm_params']
|
||||
if timeframe in norm_params:
|
||||
params = norm_params[timeframe]
|
||||
price_min = params['price_min']
|
||||
price_max = params['price_max']
|
||||
vol_min = params['volume_min']
|
||||
vol_max = params['volume_max']
|
||||
|
||||
# Normalize actual candle
|
||||
normalized_candle = [
|
||||
(actual_candle[0] - price_min) / (price_max - price_min), # Open
|
||||
(actual_candle[1] - price_min) / (price_max - price_min), # High
|
||||
(actual_candle[2] - price_min) / (price_max - price_min), # Low
|
||||
(actual_candle[3] - price_min) / (price_max - price_min), # Close
|
||||
(actual_candle[4] - vol_min) / (vol_max - vol_min) if vol_max > vol_min else 0.0 # Volume
|
||||
]
|
||||
|
||||
# Add target candle to batch
|
||||
target_key = f'future_candle_{timeframe}'
|
||||
batch[target_key] = torch.tensor([normalized_candle], dtype=torch.float32, device=device)
|
||||
|
||||
# Add action target
|
||||
action_map = {'HOLD': 0, 'BUY': 1, 'SELL': 2}
|
||||
batch['actions'] = torch.tensor([[action_map.get(action_label, 0)]], dtype=torch.long, device=device)
|
||||
|
||||
# Train directly on batch
|
||||
model_name = session['model_name']
|
||||
if model_name == 'Transformer':
|
||||
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
||||
if trainer:
|
||||
with self._training_lock:
|
||||
with torch.enable_grad():
|
||||
trainer.model.train()
|
||||
result = trainer.train_step(batch, accumulate_gradients=False)
|
||||
|
||||
if result:
|
||||
loss = result.get('total_loss', 0)
|
||||
accuracy = result.get('accuracy', 0)
|
||||
|
||||
# Update metrics
|
||||
self.realtime_training_metrics['total_steps'] += 1
|
||||
self.realtime_training_metrics['total_loss'] += loss
|
||||
self.realtime_training_metrics['total_accuracy'] += accuracy
|
||||
|
||||
self.realtime_training_metrics['losses'].append(loss)
|
||||
self.realtime_training_metrics['accuracies'].append(accuracy)
|
||||
if len(self.realtime_training_metrics['losses']) > 100:
|
||||
self.realtime_training_metrics['losses'].pop(0)
|
||||
self.realtime_training_metrics['accuracies'].pop(0)
|
||||
|
||||
session['metrics']['loss'] = sum(self.realtime_training_metrics['losses']) / len(self.realtime_training_metrics['losses'])
|
||||
session['metrics']['accuracy'] = sum(self.realtime_training_metrics['accuracies']) / len(self.realtime_training_metrics['accuracies'])
|
||||
session['metrics']['steps'] = self.realtime_training_metrics['total_steps']
|
||||
|
||||
# Remove from cache after training
|
||||
if completed_timestamp in cache:
|
||||
del cache[completed_timestamp]
|
||||
|
||||
logger.info(f"Trained on stored inference inputs: {symbol} {timeframe} @ {completed_timestamp} action={action_label} (Loss: {loss:.4f}, Acc: {accuracy:.2%})")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'loss': session['metrics']['loss'],
|
||||
'accuracy': session['metrics']['accuracy'],
|
||||
'training_steps': session['metrics']['steps'],
|
||||
'used_stored_inputs': True
|
||||
}
|
||||
|
||||
# Fall through to regular training if stored inputs failed
|
||||
logger.warning(f"Failed to use stored inputs, falling back to fresh data")
|
||||
|
||||
# Fallback: Fetch fresh market state for training (original behavior)
|
||||
market_state = self._fetch_market_state_for_candle(symbol, completed_candle.name, data_provider)
|
||||
|
||||
# Calculate price change
|
||||
@@ -3411,7 +3967,8 @@ class RealTrainingAdapter:
|
||||
'success': True,
|
||||
'loss': session['metrics']['loss'],
|
||||
'accuracy': session['metrics']['accuracy'],
|
||||
'training_steps': session['metrics']['steps']
|
||||
'training_steps': session['metrics']['steps'],
|
||||
'used_stored_inputs': False
|
||||
}
|
||||
|
||||
return {'success': False, 'error': f'Unsupported model: {model_name}'}
|
||||
@@ -3939,6 +4496,14 @@ class RealTrainingAdapter:
|
||||
# Make prediction using the model
|
||||
prediction = self._make_realtime_prediction(model_name, symbol, data_provider)
|
||||
|
||||
# Register inference frame reference for later training when actual candle arrives
|
||||
# This stores a reference (timestamp range) instead of copying 600 candles
|
||||
# The reference allows us to retrieve the exact data from DuckDB when training
|
||||
if prediction and self.training_coordinator:
|
||||
# Get norm_params for storage in reference
|
||||
_, norm_params = self._get_realtime_market_data(symbol, data_provider)
|
||||
self._register_inference_frame(session, symbol, timeframe, prediction, data_provider, norm_params)
|
||||
|
||||
if prediction:
|
||||
# Store signal
|
||||
signal = {
|
||||
|
||||
Reference in New Issue
Block a user