training WIP

This commit is contained in:
Dobromir Popov
2025-12-10 01:40:03 +02:00
parent 9c59b3e0c6
commit 199235962b
3 changed files with 355 additions and 33 deletions

View File

@@ -4090,20 +4090,37 @@ class RealTrainingAdapter:
return None
# Override the future candle target with actual candle data
actual = prediction_sample['actual_candle'] # [O, H, L, C]
actual = prediction_sample['actual_candle'] # [O, H, L, C, V] or [O, H, L, C]
# Create target tensor for the specific timeframe
import torch
device = batch['prices_1m'].device if 'prices_1m' in batch else torch.device('cpu')
# Get device from any available tensor in batch
device = torch.device('cpu')
for key in ['price_data_1m', 'price_data_1h', 'price_data_1d', 'prices_1m']:
if key in batch and batch[key] is not None:
device = batch[key].device
break
# Target candle: [O, H, L, C, V] - we don't have actual volume, use predicted
target_candle = [
actual[0], # Open
actual[1], # High
actual[2], # Low
actual[3], # Close
prediction_sample['predicted_candle'][4] # Volume (from prediction)
]
# Target candle: [O, H, L, C, V]
# Use actual volume if available, otherwise use predicted volume
if len(actual) >= 5:
target_candle = [
float(actual[0]), # Open
float(actual[1]), # High
float(actual[2]), # Low
float(actual[3]), # Close
float(actual[4]) # Volume (from actual)
]
else:
# Fallback: use predicted volume if actual doesn't have it
predicted = prediction_sample.get('predicted_candle', [0, 0, 0, 0, 0])
target_candle = [
float(actual[0]), # Open
float(actual[1]), # High
float(actual[2]), # Low
float(actual[3]), # Close
float(predicted[4] if len(predicted) > 4 else 0.0) # Volume (from prediction)
]
# Add to batch based on timeframe
if timeframe == '1s':

View File

@@ -2826,6 +2826,57 @@ class AnnotationDashboard:
'error': str(e)
}), 500
@self.server.route('/api/training-metrics', methods=['GET'])
def get_training_metrics():
"""Get current training metrics for display (loss, accuracy, etc.)"""
try:
metrics = {
'loss': 0.0,
'accuracy': 0.0,
'steps': 0,
'recent_history': []
}
# Get metrics from training adapter if available
if self.training_adapter and hasattr(self.training_adapter, 'realtime_training_metrics'):
rt_metrics = self.training_adapter.realtime_training_metrics
metrics['loss'] = rt_metrics.get('last_loss', 0.0)
metrics['accuracy'] = rt_metrics.get('last_accuracy', 0.0)
metrics['steps'] = rt_metrics.get('total_steps', 0)
# Get incremental training metrics
if hasattr(self, '_incremental_training_steps'):
metrics['incremental_steps'] = self._incremental_training_steps
if hasattr(self, '_training_metrics_history') and self._training_metrics_history:
# Get last 10 metrics for display
metrics['recent_history'] = self._training_metrics_history[-10:]
# Update current metrics from most recent
latest = self._training_metrics_history[-1]
metrics['loss'] = latest.get('loss', metrics['loss'])
metrics['accuracy'] = latest.get('accuracy', metrics['accuracy'])
# Get metrics from orchestrator trainer if available
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer_trainer'):
trainer = self.orchestrator.primary_transformer_trainer
if trainer and hasattr(trainer, 'training_history'):
history = trainer.training_history
if history.get('train_loss'):
metrics['loss'] = history['train_loss'][-1] if history['train_loss'] else metrics['loss']
if history.get('train_accuracy'):
metrics['accuracy'] = history['train_accuracy'][-1] if history['train_accuracy'] else metrics['accuracy']
return jsonify({
'success': True,
'metrics': metrics
})
except Exception as e:
logger.error(f"Error getting training metrics: {e}")
return jsonify({
'success': False,
'error': str(e)
}), 500
@self.server.route('/api/realtime-inference/train-manual', methods=['POST'])
def train_manual():
"""Manually trigger training on current candle with specified action"""
@@ -3074,11 +3125,17 @@ class AnnotationDashboard:
# We need to fetch the market state at that timestamp
symbol = 'ETH/USDT' # TODO: Get from active trading pair
# Ensure actual_candle has volume (frontend sends [O, H, L, C, V])
actual_candle = list(actual) if isinstance(actual, (list, tuple)) else actual
if len(actual_candle) == 4:
# If only 4 values, add volume from predicted (fallback)
actual_candle.append(predicted[4] if len(predicted) > 4 else 0.0)
training_sample = {
'symbol': symbol,
'timestamp': timestamp,
'predicted_candle': predicted, # [O, H, L, C, V]
'actual_candle': actual, # [O, H, L, C]
'actual_candle': actual_candle, # [O, H, L, C, V] - ensure 5 values
'errors': errors,
'accuracy': accuracy,
'direction_correct': direction_correct,
@@ -3088,22 +3145,38 @@ class AnnotationDashboard:
# Get market state at that timestamp
try:
market_state = self._fetch_market_state_at_timestamp(symbol, timestamp, timeframe)
if not market_state or 'timeframes' not in market_state:
logger.warning(f"Could not fetch market state for {symbol} at {timestamp}")
return None
training_sample['market_state'] = market_state
except Exception as e:
logger.warning(f"Could not fetch market state: {e}")
return
return None
# Convert to transformer batch format
batch = self.training_adapter._convert_prediction_to_batch(training_sample, timeframe)
if not batch:
logger.warning("Could not convert validated prediction to training batch")
return
return None
# Train on this batch with sample weighting
# CRITICAL: Use training lock to prevent concurrent access
import torch
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
import threading
# Try to acquire training lock with timeout
if hasattr(self.training_adapter, '_training_lock'):
lock_acquired = self.training_adapter._training_lock.acquire(timeout=5.0)
if not lock_acquired:
logger.warning("Could not acquire training lock within 5 seconds - skipping incremental training")
return None
else:
lock_acquired = False
try:
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
if result:
loss = result.get('total_loss', 0)
@@ -3151,46 +3224,126 @@ class AnnotationDashboard:
'steps': self._incremental_training_steps,
'sample_weight': sample_weight
}
else:
logger.warning("Training step returned no result")
return None
finally:
# CRITICAL: Always release the lock
if lock_acquired and hasattr(self.training_adapter, '_training_lock'):
self.training_adapter._training_lock.release()
except Exception as e:
logger.error(f"Error in incremental training: {e}", exc_info=True)
# Ensure lock is released even on error
if 'lock_acquired' in locals() and lock_acquired and hasattr(self.training_adapter, '_training_lock'):
try:
self.training_adapter._training_lock.release()
except:
pass
return None
def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict:
"""Fetch market state at a specific timestamp for training"""
try:
from datetime import datetime
from datetime import datetime, timezone
import pandas as pd
# Parse timestamp
ts = pd.Timestamp(timestamp)
# Parse timestamp - ensure it's timezone-aware
if isinstance(timestamp, str):
ts = pd.Timestamp(timestamp)
if ts.tz is None:
ts = ts.tz_localize('UTC')
else:
ts = pd.Timestamp(timestamp)
if ts.tz is None:
ts = ts.tz_localize('UTC')
# Get historical data for multiple timeframes
# Use data provider's method to get market state at that time
# This ensures we get the proper format with all required timeframes
if self.data_provider and hasattr(self.data_provider, 'get_market_state_at_time'):
try:
# Convert to datetime if needed
if isinstance(ts, pd.Timestamp):
dt = ts.to_pydatetime()
else:
dt = ts
# Get market state with context window (need enough candles for training)
market_state = self.data_provider.get_market_state_at_time(
symbol=symbol,
timestamp=dt,
context_window_minutes=600 # Get 600 minutes of context for 1m candles
)
if market_state and 'timeframes' in market_state:
logger.debug(f"Fetched market state with {len(market_state.get('timeframes', {}))} timeframes")
return market_state
else:
logger.warning("Market state returned empty or invalid format")
except Exception as e:
logger.warning(f"Could not use data provider method: {e}")
# Fallback: manually fetch data for each timeframe
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
for tf in ['1s', '1m', '1h']:
# REQUIRED timeframes for transformer: 1m, 1h, 1d (1s is optional)
# Need at least 50 candles, preferably 600
required_tfs = ['1m', '1h', '1d']
optional_tfs = ['1s']
for tf in required_tfs + optional_tfs:
try:
df = self.data_provider.get_historical_data(symbol, tf, limit=200)
# Fetch enough candles (600 for training, but accept less)
df = self.data_loader.get_data(
symbol=symbol,
timeframe=tf,
end_time=dt,
limit=600,
direction='before'
) if self.data_loader else None
# Fallback to data provider if data_loader not available
if df is None or df.empty:
if self.data_provider:
df = self.data_provider.get_historical_data(symbol, tf, limit=600, refresh=False)
if df is not None and not df.empty:
# Find data up to (but not including) the target timestamp
# Filter to data before the target timestamp
df_before = df[df.index < ts]
if not df_before.empty:
recent = df_before.tail(200)
if df_before.empty:
# If no data before timestamp, use all available data
df_before = df
# Take last 600 candles (or all if less)
recent = df_before.tail(600)
if len(recent) >= 50: # Minimum required
market_state['timeframes'][tf] = {
'timestamps': self._format_timestamps_utc(recent.index),
'open': recent['open'].tolist(),
'high': recent['high'].tolist(),
'low': recent['low'].tolist(),
'close': recent['close'].tolist(),
'volume': recent['volume'].tolist()
}
logger.debug(f"Fetched {len(recent)} candles for {tf} timeframe")
else:
if tf in required_tfs:
logger.warning(f"Required timeframe {tf} has only {len(recent)} candles (need at least 50)")
else:
logger.debug(f"Optional timeframe {tf} has only {len(recent)} candles, skipping")
except Exception as e:
logger.warning(f"Could not fetch {tf} data: {e}")
# Validate we have required timeframes
missing_required = [tf for tf in required_tfs if tf not in market_state['timeframes']]
if missing_required:
logger.warning(f"Missing required timeframes: {missing_required}")
return {}
return market_state
except Exception as e:
logger.error(f"Error fetching market state: {e}")
logger.error(f"Error fetching market state: {e}", exc_info=True)
return {}
def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1):

View File

@@ -85,15 +85,15 @@
<div id="inference-buttons-container">
<button class="btn btn-success btn-sm w-100" id="start-inference-btn">
<i class="fas fa-play"></i>
Start Live Inference (No Training)
Start Paper Trading (No Training)
</button>
<button class="btn btn-info btn-sm w-100 mt-1" id="start-inference-pivot-btn">
<i class="fas fa-chart-line"></i>
Live Inference + Pivot Training
Paper Trading + Pivot Training
</button>
<button class="btn btn-primary btn-sm w-100 mt-1" id="start-inference-candle-btn">
<i class="fas fa-graduation-cap"></i>
Live Inference + Per-Candle Training
Paper Trading + Online Learning
</button>
</div>
<button class="btn btn-danger btn-sm w-100 mt-1" id="stop-inference-btn" style="display: none;">
@@ -144,6 +144,24 @@
</div>
</div>
<!-- Online Learning Metrics -->
<div id="online-learning-metrics" style="display: none;">
<div class="alert alert-info py-2 px-2 mb-2">
<strong class="small">
<i class="fas fa-sync-alt"></i>
Online Learning
</strong>
<div class="small mt-1">
<div>Incremental Steps: <span id="incremental-steps" class="fw-bold text-primary">0</span></div>
<div>Current Loss: <span id="online-loss" class="fw-bold text-warning">--</span></div>
<div>Current Accuracy: <span id="online-accuracy" class="fw-bold text-success">--</span></div>
<div class="mt-1 pt-1 border-top" style="font-size: 0.7rem; color: #666;">
<div>Last Training: <span id="last-training-time">--</span></div>
</div>
</div>
</div>
</div>
<!-- Inference Status -->
<div id="inference-status" style="display: none;">
<div class="alert alert-success py-2 px-2 mb-2">
@@ -152,7 +170,7 @@
<div class="spinner-border spinner-border-sm me-2" role="status">
<span class="visually-hidden">Running...</span>
</div>
<strong class="small">🔴 LIVE</strong>
<strong class="small">LIVE</strong>
</div>
<!-- Model Performance -->
<div class="small text-end">
@@ -161,8 +179,27 @@
</div>
</div>
<!-- Position & PnL Status -->
<div class="mb-2 p-2" style="background-color: rgba(0,0,0,0.1); border-radius: 4px;">
<!-- Trading Status Warning -->
<div class="mb-2 p-2" style="background-color: rgba(255, 193, 7, 0.2); border: 1px solid rgba(255, 193, 7, 0.5); border-radius: 4px;" id="trading-inactive-warning">
<div class="small text-center">
<i class="fas fa-exclamation-triangle text-warning"></i>
<strong>PREDICTIONS ONLY</strong>
<div style="font-size: 0.7rem; color: #666; margin-top: 4px;">
No trading session active.<br>
Click a button above to start paper trading.
</div>
</div>
</div>
<!-- Position & PnL Status (shown when trading active) -->
<div class="mb-2 p-2" style="background-color: rgba(0,0,0,0.1); border-radius: 4px; display: none;" id="trading-active-status">
<!-- DEMO MODE Badge -->
<div class="text-center mb-2 pb-2 border-bottom" style="border-color: rgba(255,255,255,0.2) !important;">
<span class="badge bg-info" style="font-size: 0.65rem; letter-spacing: 0.5px;">
📊 PAPER TRADING (DEMO)
</span>
</div>
<div class="small">
<div class="d-flex justify-content-between">
<span>Position:</span>
@@ -742,6 +779,10 @@
document.getElementById('inference-status').style.display = 'block';
document.getElementById('inference-controls').style.display = 'block';
// Show trading active status, hide warning
document.getElementById('trading-inactive-warning').style.display = 'none';
document.getElementById('trading-active-status').style.display = 'block';
// Show manual training button if in manual mode
if (trainingMode === 'manual') {
document.getElementById('manual-train-btn').style.display = 'block';
@@ -770,6 +811,13 @@
// Start polling for signals
startSignalPolling();
startTrainingMetricsPolling(); // Start training metrics polling
// Show online learning metrics panel
const onlineMetricsPanel = document.getElementById('online-learning-metrics');
if (onlineMetricsPanel) {
onlineMetricsPanel.style.display = 'block';
}
// Start chart auto-update
if (window.appState && window.appState.chartManager) {
@@ -822,6 +870,10 @@
document.getElementById('manual-train-btn').style.display = 'none';
document.getElementById('inference-status').style.display = 'none';
document.getElementById('inference-controls').style.display = 'none';
// Show warning, hide trading status
document.getElementById('trading-inactive-warning').style.display = 'block';
document.getElementById('trading-active-status').style.display = 'none';
// Hide live mode banner
const banner = document.getElementById('live-mode-banner');
@@ -831,6 +883,13 @@
// Stop polling
stopSignalPolling();
stopTrainingMetricsPolling(); // Stop training metrics polling
// Hide online learning metrics panel
const onlineMetricsPanel = document.getElementById('online-learning-metrics');
if (onlineMetricsPanel) {
onlineMetricsPanel.style.display = 'none';
}
// Stop chart auto-update and remove metrics overlay
if (window.appState && window.appState.chartManager) {
@@ -1298,6 +1357,99 @@
historyDiv.innerHTML = html;
}
// Training metrics polling
let trainingMetricsInterval = null;
function startTrainingMetricsPolling() {
if (trainingMetricsInterval) {
clearInterval(trainingMetricsInterval);
}
trainingMetricsInterval = setInterval(function () {
fetch('/api/training-metrics')
.then(response => response.json())
.then(data => {
if (data.success && data.metrics) {
updateTrainingMetricsDisplay(data.metrics);
}
})
.catch(error => {
console.debug('[Training Metrics] Polling error:', error);
});
}, 2000); // Poll every 2 seconds
}
function stopTrainingMetricsPolling() {
if (trainingMetricsInterval) {
clearInterval(trainingMetricsInterval);
trainingMetricsInterval = null;
}
}
function updateTrainingMetricsDisplay(metrics) {
// Update live accuracy and loss in inference status banner
const liveAccuracyEl = document.getElementById('live-accuracy');
const liveLossEl = document.getElementById('live-loss');
if (liveAccuracyEl && metrics.accuracy !== undefined) {
const accuracyPct = (metrics.accuracy * 100).toFixed(1);
liveAccuracyEl.textContent = accuracyPct + '%';
}
if (liveLossEl && metrics.loss !== undefined) {
const lossVal = metrics.loss ? metrics.loss.toFixed(4) : '--';
liveLossEl.textContent = lossVal;
}
// Update metric-accuracy if it exists
const metricAccuracyEl = document.getElementById('metric-accuracy');
if (metricAccuracyEl && metrics.accuracy !== undefined) {
const accuracyPct = (metrics.accuracy * 100).toFixed(1);
metricAccuracyEl.textContent = accuracyPct + '%';
}
// Update training status if visible
const trainingLossEl = document.getElementById('training-loss');
if (trainingLossEl && metrics.loss !== undefined) {
trainingLossEl.textContent = metrics.loss.toFixed(4);
}
// Update online learning metrics panel
if (metrics.incremental_steps !== undefined && metrics.incremental_steps > 0) {
const onlineMetricsPanel = document.getElementById('online-learning-metrics');
if (onlineMetricsPanel) {
onlineMetricsPanel.style.display = 'block';
}
const incrementalStepsEl = document.getElementById('incremental-steps');
if (incrementalStepsEl) {
incrementalStepsEl.textContent = metrics.incremental_steps;
}
const onlineLossEl = document.getElementById('online-loss');
if (onlineLossEl && metrics.loss !== undefined) {
onlineLossEl.textContent = metrics.loss.toFixed(4);
}
const onlineAccuracyEl = document.getElementById('online-accuracy');
if (onlineAccuracyEl && metrics.accuracy !== undefined) {
const accuracyPct = (metrics.accuracy * 100).toFixed(1);
onlineAccuracyEl.textContent = accuracyPct + '%';
}
const lastTrainingTimeEl = document.getElementById('last-training-time');
if (lastTrainingTimeEl && metrics.recent_history && metrics.recent_history.length > 0) {
const lastStep = metrics.recent_history[metrics.recent_history.length - 1];
if (lastStep.timestamp) {
const timeStr = new Date(lastStep.timestamp).toLocaleTimeString();
lastTrainingTimeEl.textContent = timeStr;
}
}
console.log(`[Online Learning] ${metrics.incremental_steps} incremental training steps completed`);
}
}
function startSignalPolling() {
signalPollInterval = setInterval(function () {
// Poll for signals