training WIP
This commit is contained in:
@@ -2826,6 +2826,57 @@ class AnnotationDashboard:
|
||||
'error': str(e)
|
||||
}), 500
|
||||
|
||||
@self.server.route('/api/training-metrics', methods=['GET'])
|
||||
def get_training_metrics():
|
||||
"""Get current training metrics for display (loss, accuracy, etc.)"""
|
||||
try:
|
||||
metrics = {
|
||||
'loss': 0.0,
|
||||
'accuracy': 0.0,
|
||||
'steps': 0,
|
||||
'recent_history': []
|
||||
}
|
||||
|
||||
# Get metrics from training adapter if available
|
||||
if self.training_adapter and hasattr(self.training_adapter, 'realtime_training_metrics'):
|
||||
rt_metrics = self.training_adapter.realtime_training_metrics
|
||||
metrics['loss'] = rt_metrics.get('last_loss', 0.0)
|
||||
metrics['accuracy'] = rt_metrics.get('last_accuracy', 0.0)
|
||||
metrics['steps'] = rt_metrics.get('total_steps', 0)
|
||||
|
||||
# Get incremental training metrics
|
||||
if hasattr(self, '_incremental_training_steps'):
|
||||
metrics['incremental_steps'] = self._incremental_training_steps
|
||||
if hasattr(self, '_training_metrics_history') and self._training_metrics_history:
|
||||
# Get last 10 metrics for display
|
||||
metrics['recent_history'] = self._training_metrics_history[-10:]
|
||||
# Update current metrics from most recent
|
||||
latest = self._training_metrics_history[-1]
|
||||
metrics['loss'] = latest.get('loss', metrics['loss'])
|
||||
metrics['accuracy'] = latest.get('accuracy', metrics['accuracy'])
|
||||
|
||||
# Get metrics from orchestrator trainer if available
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer_trainer'):
|
||||
trainer = self.orchestrator.primary_transformer_trainer
|
||||
if trainer and hasattr(trainer, 'training_history'):
|
||||
history = trainer.training_history
|
||||
if history.get('train_loss'):
|
||||
metrics['loss'] = history['train_loss'][-1] if history['train_loss'] else metrics['loss']
|
||||
if history.get('train_accuracy'):
|
||||
metrics['accuracy'] = history['train_accuracy'][-1] if history['train_accuracy'] else metrics['accuracy']
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'metrics': metrics
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting training metrics: {e}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}), 500
|
||||
|
||||
@self.server.route('/api/realtime-inference/train-manual', methods=['POST'])
|
||||
def train_manual():
|
||||
"""Manually trigger training on current candle with specified action"""
|
||||
@@ -3074,11 +3125,17 @@ class AnnotationDashboard:
|
||||
# We need to fetch the market state at that timestamp
|
||||
symbol = 'ETH/USDT' # TODO: Get from active trading pair
|
||||
|
||||
# Ensure actual_candle has volume (frontend sends [O, H, L, C, V])
|
||||
actual_candle = list(actual) if isinstance(actual, (list, tuple)) else actual
|
||||
if len(actual_candle) == 4:
|
||||
# If only 4 values, add volume from predicted (fallback)
|
||||
actual_candle.append(predicted[4] if len(predicted) > 4 else 0.0)
|
||||
|
||||
training_sample = {
|
||||
'symbol': symbol,
|
||||
'timestamp': timestamp,
|
||||
'predicted_candle': predicted, # [O, H, L, C, V]
|
||||
'actual_candle': actual, # [O, H, L, C]
|
||||
'actual_candle': actual_candle, # [O, H, L, C, V] - ensure 5 values
|
||||
'errors': errors,
|
||||
'accuracy': accuracy,
|
||||
'direction_correct': direction_correct,
|
||||
@@ -3088,22 +3145,38 @@ class AnnotationDashboard:
|
||||
# Get market state at that timestamp
|
||||
try:
|
||||
market_state = self._fetch_market_state_at_timestamp(symbol, timestamp, timeframe)
|
||||
if not market_state or 'timeframes' not in market_state:
|
||||
logger.warning(f"Could not fetch market state for {symbol} at {timestamp}")
|
||||
return None
|
||||
training_sample['market_state'] = market_state
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fetch market state: {e}")
|
||||
return
|
||||
return None
|
||||
|
||||
# Convert to transformer batch format
|
||||
batch = self.training_adapter._convert_prediction_to_batch(training_sample, timeframe)
|
||||
if not batch:
|
||||
logger.warning("Could not convert validated prediction to training batch")
|
||||
return
|
||||
return None
|
||||
|
||||
# Train on this batch with sample weighting
|
||||
# CRITICAL: Use training lock to prevent concurrent access
|
||||
import torch
|
||||
with torch.enable_grad():
|
||||
trainer.model.train()
|
||||
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
|
||||
import threading
|
||||
|
||||
# Try to acquire training lock with timeout
|
||||
if hasattr(self.training_adapter, '_training_lock'):
|
||||
lock_acquired = self.training_adapter._training_lock.acquire(timeout=5.0)
|
||||
if not lock_acquired:
|
||||
logger.warning("Could not acquire training lock within 5 seconds - skipping incremental training")
|
||||
return None
|
||||
else:
|
||||
lock_acquired = False
|
||||
|
||||
try:
|
||||
with torch.enable_grad():
|
||||
trainer.model.train()
|
||||
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
|
||||
|
||||
if result:
|
||||
loss = result.get('total_loss', 0)
|
||||
@@ -3151,46 +3224,126 @@ class AnnotationDashboard:
|
||||
'steps': self._incremental_training_steps,
|
||||
'sample_weight': sample_weight
|
||||
}
|
||||
else:
|
||||
logger.warning("Training step returned no result")
|
||||
return None
|
||||
finally:
|
||||
# CRITICAL: Always release the lock
|
||||
if lock_acquired and hasattr(self.training_adapter, '_training_lock'):
|
||||
self.training_adapter._training_lock.release()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in incremental training: {e}", exc_info=True)
|
||||
# Ensure lock is released even on error
|
||||
if 'lock_acquired' in locals() and lock_acquired and hasattr(self.training_adapter, '_training_lock'):
|
||||
try:
|
||||
self.training_adapter._training_lock.release()
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict:
|
||||
"""Fetch market state at a specific timestamp for training"""
|
||||
try:
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
import pandas as pd
|
||||
|
||||
# Parse timestamp
|
||||
ts = pd.Timestamp(timestamp)
|
||||
# Parse timestamp - ensure it's timezone-aware
|
||||
if isinstance(timestamp, str):
|
||||
ts = pd.Timestamp(timestamp)
|
||||
if ts.tz is None:
|
||||
ts = ts.tz_localize('UTC')
|
||||
else:
|
||||
ts = pd.Timestamp(timestamp)
|
||||
if ts.tz is None:
|
||||
ts = ts.tz_localize('UTC')
|
||||
|
||||
# Get historical data for multiple timeframes
|
||||
# Use data provider's method to get market state at that time
|
||||
# This ensures we get the proper format with all required timeframes
|
||||
if self.data_provider and hasattr(self.data_provider, 'get_market_state_at_time'):
|
||||
try:
|
||||
# Convert to datetime if needed
|
||||
if isinstance(ts, pd.Timestamp):
|
||||
dt = ts.to_pydatetime()
|
||||
else:
|
||||
dt = ts
|
||||
|
||||
# Get market state with context window (need enough candles for training)
|
||||
market_state = self.data_provider.get_market_state_at_time(
|
||||
symbol=symbol,
|
||||
timestamp=dt,
|
||||
context_window_minutes=600 # Get 600 minutes of context for 1m candles
|
||||
)
|
||||
|
||||
if market_state and 'timeframes' in market_state:
|
||||
logger.debug(f"Fetched market state with {len(market_state.get('timeframes', {}))} timeframes")
|
||||
return market_state
|
||||
else:
|
||||
logger.warning("Market state returned empty or invalid format")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not use data provider method: {e}")
|
||||
|
||||
# Fallback: manually fetch data for each timeframe
|
||||
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
|
||||
|
||||
for tf in ['1s', '1m', '1h']:
|
||||
# REQUIRED timeframes for transformer: 1m, 1h, 1d (1s is optional)
|
||||
# Need at least 50 candles, preferably 600
|
||||
required_tfs = ['1m', '1h', '1d']
|
||||
optional_tfs = ['1s']
|
||||
|
||||
for tf in required_tfs + optional_tfs:
|
||||
try:
|
||||
df = self.data_provider.get_historical_data(symbol, tf, limit=200)
|
||||
# Fetch enough candles (600 for training, but accept less)
|
||||
df = self.data_loader.get_data(
|
||||
symbol=symbol,
|
||||
timeframe=tf,
|
||||
end_time=dt,
|
||||
limit=600,
|
||||
direction='before'
|
||||
) if self.data_loader else None
|
||||
|
||||
# Fallback to data provider if data_loader not available
|
||||
if df is None or df.empty:
|
||||
if self.data_provider:
|
||||
df = self.data_provider.get_historical_data(symbol, tf, limit=600, refresh=False)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Find data up to (but not including) the target timestamp
|
||||
# Filter to data before the target timestamp
|
||||
df_before = df[df.index < ts]
|
||||
if not df_before.empty:
|
||||
recent = df_before.tail(200)
|
||||
if df_before.empty:
|
||||
# If no data before timestamp, use all available data
|
||||
df_before = df
|
||||
|
||||
# Take last 600 candles (or all if less)
|
||||
recent = df_before.tail(600)
|
||||
|
||||
if len(recent) >= 50: # Minimum required
|
||||
market_state['timeframes'][tf] = {
|
||||
'timestamps': self._format_timestamps_utc(recent.index),
|
||||
'open': recent['open'].tolist(),
|
||||
'high': recent['high'].tolist(),
|
||||
'low': recent['low'].tolist(),
|
||||
'close': recent['close'].tolist(),
|
||||
'volume': recent['volume'].tolist()
|
||||
}
|
||||
logger.debug(f"Fetched {len(recent)} candles for {tf} timeframe")
|
||||
else:
|
||||
if tf in required_tfs:
|
||||
logger.warning(f"Required timeframe {tf} has only {len(recent)} candles (need at least 50)")
|
||||
else:
|
||||
logger.debug(f"Optional timeframe {tf} has only {len(recent)} candles, skipping")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fetch {tf} data: {e}")
|
||||
|
||||
# Validate we have required timeframes
|
||||
missing_required = [tf for tf in required_tfs if tf not in market_state['timeframes']]
|
||||
if missing_required:
|
||||
logger.warning(f"Missing required timeframes: {missing_required}")
|
||||
return {}
|
||||
|
||||
return market_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching market state: {e}")
|
||||
logger.error(f"Error fetching market state: {e}", exc_info=True)
|
||||
return {}
|
||||
|
||||
def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1):
|
||||
|
||||
Reference in New Issue
Block a user