training WIP

This commit is contained in:
Dobromir Popov
2025-12-10 01:40:03 +02:00
parent 9c59b3e0c6
commit 199235962b
3 changed files with 355 additions and 33 deletions

View File

@@ -2826,6 +2826,57 @@ class AnnotationDashboard:
'error': str(e)
}), 500
@self.server.route('/api/training-metrics', methods=['GET'])
def get_training_metrics():
"""Get current training metrics for display (loss, accuracy, etc.)"""
try:
metrics = {
'loss': 0.0,
'accuracy': 0.0,
'steps': 0,
'recent_history': []
}
# Get metrics from training adapter if available
if self.training_adapter and hasattr(self.training_adapter, 'realtime_training_metrics'):
rt_metrics = self.training_adapter.realtime_training_metrics
metrics['loss'] = rt_metrics.get('last_loss', 0.0)
metrics['accuracy'] = rt_metrics.get('last_accuracy', 0.0)
metrics['steps'] = rt_metrics.get('total_steps', 0)
# Get incremental training metrics
if hasattr(self, '_incremental_training_steps'):
metrics['incremental_steps'] = self._incremental_training_steps
if hasattr(self, '_training_metrics_history') and self._training_metrics_history:
# Get last 10 metrics for display
metrics['recent_history'] = self._training_metrics_history[-10:]
# Update current metrics from most recent
latest = self._training_metrics_history[-1]
metrics['loss'] = latest.get('loss', metrics['loss'])
metrics['accuracy'] = latest.get('accuracy', metrics['accuracy'])
# Get metrics from orchestrator trainer if available
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer_trainer'):
trainer = self.orchestrator.primary_transformer_trainer
if trainer and hasattr(trainer, 'training_history'):
history = trainer.training_history
if history.get('train_loss'):
metrics['loss'] = history['train_loss'][-1] if history['train_loss'] else metrics['loss']
if history.get('train_accuracy'):
metrics['accuracy'] = history['train_accuracy'][-1] if history['train_accuracy'] else metrics['accuracy']
return jsonify({
'success': True,
'metrics': metrics
})
except Exception as e:
logger.error(f"Error getting training metrics: {e}")
return jsonify({
'success': False,
'error': str(e)
}), 500
@self.server.route('/api/realtime-inference/train-manual', methods=['POST'])
def train_manual():
"""Manually trigger training on current candle with specified action"""
@@ -3074,11 +3125,17 @@ class AnnotationDashboard:
# We need to fetch the market state at that timestamp
symbol = 'ETH/USDT' # TODO: Get from active trading pair
# Ensure actual_candle has volume (frontend sends [O, H, L, C, V])
actual_candle = list(actual) if isinstance(actual, (list, tuple)) else actual
if len(actual_candle) == 4:
# If only 4 values, add volume from predicted (fallback)
actual_candle.append(predicted[4] if len(predicted) > 4 else 0.0)
training_sample = {
'symbol': symbol,
'timestamp': timestamp,
'predicted_candle': predicted, # [O, H, L, C, V]
'actual_candle': actual, # [O, H, L, C]
'actual_candle': actual_candle, # [O, H, L, C, V] - ensure 5 values
'errors': errors,
'accuracy': accuracy,
'direction_correct': direction_correct,
@@ -3088,22 +3145,38 @@ class AnnotationDashboard:
# Get market state at that timestamp
try:
market_state = self._fetch_market_state_at_timestamp(symbol, timestamp, timeframe)
if not market_state or 'timeframes' not in market_state:
logger.warning(f"Could not fetch market state for {symbol} at {timestamp}")
return None
training_sample['market_state'] = market_state
except Exception as e:
logger.warning(f"Could not fetch market state: {e}")
return
return None
# Convert to transformer batch format
batch = self.training_adapter._convert_prediction_to_batch(training_sample, timeframe)
if not batch:
logger.warning("Could not convert validated prediction to training batch")
return
return None
# Train on this batch with sample weighting
# CRITICAL: Use training lock to prevent concurrent access
import torch
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
import threading
# Try to acquire training lock with timeout
if hasattr(self.training_adapter, '_training_lock'):
lock_acquired = self.training_adapter._training_lock.acquire(timeout=5.0)
if not lock_acquired:
logger.warning("Could not acquire training lock within 5 seconds - skipping incremental training")
return None
else:
lock_acquired = False
try:
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
if result:
loss = result.get('total_loss', 0)
@@ -3151,46 +3224,126 @@ class AnnotationDashboard:
'steps': self._incremental_training_steps,
'sample_weight': sample_weight
}
else:
logger.warning("Training step returned no result")
return None
finally:
# CRITICAL: Always release the lock
if lock_acquired and hasattr(self.training_adapter, '_training_lock'):
self.training_adapter._training_lock.release()
except Exception as e:
logger.error(f"Error in incremental training: {e}", exc_info=True)
# Ensure lock is released even on error
if 'lock_acquired' in locals() and lock_acquired and hasattr(self.training_adapter, '_training_lock'):
try:
self.training_adapter._training_lock.release()
except:
pass
return None
def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict:
"""Fetch market state at a specific timestamp for training"""
try:
from datetime import datetime
from datetime import datetime, timezone
import pandas as pd
# Parse timestamp
ts = pd.Timestamp(timestamp)
# Parse timestamp - ensure it's timezone-aware
if isinstance(timestamp, str):
ts = pd.Timestamp(timestamp)
if ts.tz is None:
ts = ts.tz_localize('UTC')
else:
ts = pd.Timestamp(timestamp)
if ts.tz is None:
ts = ts.tz_localize('UTC')
# Get historical data for multiple timeframes
# Use data provider's method to get market state at that time
# This ensures we get the proper format with all required timeframes
if self.data_provider and hasattr(self.data_provider, 'get_market_state_at_time'):
try:
# Convert to datetime if needed
if isinstance(ts, pd.Timestamp):
dt = ts.to_pydatetime()
else:
dt = ts
# Get market state with context window (need enough candles for training)
market_state = self.data_provider.get_market_state_at_time(
symbol=symbol,
timestamp=dt,
context_window_minutes=600 # Get 600 minutes of context for 1m candles
)
if market_state and 'timeframes' in market_state:
logger.debug(f"Fetched market state with {len(market_state.get('timeframes', {}))} timeframes")
return market_state
else:
logger.warning("Market state returned empty or invalid format")
except Exception as e:
logger.warning(f"Could not use data provider method: {e}")
# Fallback: manually fetch data for each timeframe
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
for tf in ['1s', '1m', '1h']:
# REQUIRED timeframes for transformer: 1m, 1h, 1d (1s is optional)
# Need at least 50 candles, preferably 600
required_tfs = ['1m', '1h', '1d']
optional_tfs = ['1s']
for tf in required_tfs + optional_tfs:
try:
df = self.data_provider.get_historical_data(symbol, tf, limit=200)
# Fetch enough candles (600 for training, but accept less)
df = self.data_loader.get_data(
symbol=symbol,
timeframe=tf,
end_time=dt,
limit=600,
direction='before'
) if self.data_loader else None
# Fallback to data provider if data_loader not available
if df is None or df.empty:
if self.data_provider:
df = self.data_provider.get_historical_data(symbol, tf, limit=600, refresh=False)
if df is not None and not df.empty:
# Find data up to (but not including) the target timestamp
# Filter to data before the target timestamp
df_before = df[df.index < ts]
if not df_before.empty:
recent = df_before.tail(200)
if df_before.empty:
# If no data before timestamp, use all available data
df_before = df
# Take last 600 candles (or all if less)
recent = df_before.tail(600)
if len(recent) >= 50: # Minimum required
market_state['timeframes'][tf] = {
'timestamps': self._format_timestamps_utc(recent.index),
'open': recent['open'].tolist(),
'high': recent['high'].tolist(),
'low': recent['low'].tolist(),
'close': recent['close'].tolist(),
'volume': recent['volume'].tolist()
}
logger.debug(f"Fetched {len(recent)} candles for {tf} timeframe")
else:
if tf in required_tfs:
logger.warning(f"Required timeframe {tf} has only {len(recent)} candles (need at least 50)")
else:
logger.debug(f"Optional timeframe {tf} has only {len(recent)} candles, skipping")
except Exception as e:
logger.warning(f"Could not fetch {tf} data: {e}")
# Validate we have required timeframes
missing_required = [tf for tf in required_tfs if tf not in market_state['timeframes']]
if missing_required:
logger.warning(f"Missing required timeframes: {missing_required}")
return {}
return market_state
except Exception as e:
logger.error(f"Error fetching market state: {e}")
logger.error(f"Error fetching market state: {e}", exc_info=True)
return {}
def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1):