logging channels; training steps storage

This commit is contained in:
Dobromir Popov
2025-11-22 12:47:43 +02:00
parent 7a219b5ebc
commit e404658dc7
9 changed files with 938 additions and 37 deletions

View File

@@ -138,9 +138,25 @@ class RealTrainingAdapter:
self.data_provider = data_provider
self.training_sessions: Dict[str, TrainingSession] = {}
# Real-time training tracking
self.realtime_training_metrics = {
'total_steps': 0,
'total_loss': 0.0,
'total_accuracy': 0.0,
'best_loss': float('inf'),
'best_accuracy': 0.0,
'last_checkpoint_step': 0,
'checkpoint_frequency': 100, # Save every N steps
'losses': [], # Rolling window
'accuracies': [] # Rolling window
}
# Import real training systems
self._import_training_systems()
# Load best realtime checkpoint if available
self._load_best_realtime_checkpoint()
logger.info("RealTrainingAdapter initialized - NO SIMULATION, REAL TRAINING ONLY")
def _import_training_systems(self):
@@ -2410,6 +2426,11 @@ class RealTrainingAdapter:
'train_every_candle': train_every_candle,
'timeframe': timeframe,
'data_provider': data_provider,
'metrics': {
'accuracy': 0.0,
'loss': 0.0,
'steps': 0
},
'last_candle_time': None
}
@@ -2711,6 +2732,13 @@ class RealTrainingAdapter:
model_name = session['model_name']
if model_name == 'Transformer':
self._train_transformer_on_sample(training_sample)
# Update session metrics with latest realtime metrics
if len(self.realtime_training_metrics['losses']) > 0:
session['metrics']['loss'] = sum(self.realtime_training_metrics['losses']) / len(self.realtime_training_metrics['losses'])
session['metrics']['accuracy'] = sum(self.realtime_training_metrics['accuracies']) / len(self.realtime_training_metrics['accuracies'])
session['metrics']['steps'] = self.realtime_training_metrics['total_steps']
logger.info(f"Trained on candle: {symbol} {timeframe} @ {completed_candle.name} (change: {price_change:+.2%})")
except Exception as e:
@@ -2740,7 +2768,7 @@ class RealTrainingAdapter:
return {}
def _train_transformer_on_sample(self, training_sample: Dict):
"""Train transformer on a single sample"""
"""Train transformer on a single sample with checkpoint saving"""
try:
if not self.orchestrator:
return
@@ -2760,12 +2788,249 @@ class RealTrainingAdapter:
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False)
if result:
logger.info(f"Per-candle training: Loss={result.get('total_loss', 0):.4f}")
loss = result.get('total_loss', 0)
accuracy = result.get('accuracy', 0)
# Update metrics tracking
self.realtime_training_metrics['total_steps'] += 1
self.realtime_training_metrics['total_loss'] += loss
self.realtime_training_metrics['total_accuracy'] += accuracy
# Maintain rolling window (last 100 steps)
self.realtime_training_metrics['losses'].append(loss)
self.realtime_training_metrics['accuracies'].append(accuracy)
if len(self.realtime_training_metrics['losses']) > 100:
self.realtime_training_metrics['losses'].pop(0)
self.realtime_training_metrics['accuracies'].pop(0)
# Calculate rolling average
avg_loss = sum(self.realtime_training_metrics['losses']) / len(self.realtime_training_metrics['losses'])
avg_accuracy = sum(self.realtime_training_metrics['accuracies']) / len(self.realtime_training_metrics['accuracies'])
logger.info(f"Per-candle training: Loss={loss:.4f} (avg: {avg_loss:.4f}), Acc={accuracy:.2%} (avg: {avg_accuracy:.2%})")
# Check if model improved (save checkpoint)
improved = False
if avg_loss < self.realtime_training_metrics['best_loss']:
self.realtime_training_metrics['best_loss'] = avg_loss
improved = True
logger.info(f" NEW BEST LOSS: {avg_loss:.4f}")
if avg_accuracy > self.realtime_training_metrics['best_accuracy']:
self.realtime_training_metrics['best_accuracy'] = avg_accuracy
improved = True
logger.info(f" NEW BEST ACCURACY: {avg_accuracy:.2%}")
# Save checkpoint if improved or every N steps
steps_since_checkpoint = self.realtime_training_metrics['total_steps'] - self.realtime_training_metrics['last_checkpoint_step']
if improved or steps_since_checkpoint >= self.realtime_training_metrics['checkpoint_frequency']:
self._save_realtime_checkpoint(
trainer=trainer,
step=self.realtime_training_metrics['total_steps'],
loss=avg_loss,
accuracy=avg_accuracy,
improved=improved
)
self.realtime_training_metrics['last_checkpoint_step'] = self.realtime_training_metrics['total_steps']
except Exception as e:
logger.warning(f"Error training transformer on sample: {e}")
def _save_realtime_checkpoint(self, trainer, step: int, loss: float, accuracy: float, improved: bool = False):
"""
Save checkpoint during real-time training
Args:
trainer: Model trainer instance
step: Current training step
loss: Current average loss
accuracy: Current average accuracy
improved: Whether this is an improvement checkpoint
"""
try:
import torch
import os
from datetime import datetime
checkpoint_dir = "models/checkpoints/transformer/realtime"
os.makedirs(checkpoint_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
checkpoint_type = "BEST" if improved else "periodic"
checkpoint_path = os.path.join(checkpoint_dir, f"realtime_{checkpoint_type}_step{step}_{timestamp}.pt")
# Save checkpoint
torch.save({
'step': step,
'model_state_dict': trainer.model.state_dict(),
'optimizer_state_dict': trainer.optimizer.state_dict(),
'scheduler_state_dict': trainer.scheduler.state_dict() if hasattr(trainer, 'scheduler') else None,
'loss': loss,
'accuracy': accuracy,
'learning_rate': trainer.scheduler.get_last_lr()[0] if hasattr(trainer, 'scheduler') else trainer.optimizer.param_groups[0]['lr'],
'training_type': 'realtime_per_candle',
'metrics': {
'total_steps': self.realtime_training_metrics['total_steps'],
'best_loss': self.realtime_training_metrics['best_loss'],
'best_accuracy': self.realtime_training_metrics['best_accuracy'],
'rolling_losses': self.realtime_training_metrics['losses'][-10:], # Last 10
'rolling_accuracies': self.realtime_training_metrics['accuracies'][-10:]
}
}, checkpoint_path)
logger.info(f" SAVED REALTIME CHECKPOINT: {checkpoint_path}")
logger.info(f" Step: {step}, Loss: {loss:.4f}, Acc: {accuracy:.2%}, Improved: {improved}")
# Save metadata to database
try:
from utils.database_manager import get_database_manager
db_manager = get_database_manager()
checkpoint_id = f"realtime_step{step}_{timestamp}"
from utils.database_manager import CheckpointMetadata
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name="transformer_realtime",
model_type="transformer",
timestamp=datetime.now(),
performance_metrics={
'loss': float(loss),
'accuracy': float(accuracy),
'step': step,
'best_loss': float(self.realtime_training_metrics['best_loss']),
'best_accuracy': float(self.realtime_training_metrics['best_accuracy'])
},
training_metadata={
'training_type': 'realtime_per_candle',
'total_steps': self.realtime_training_metrics['total_steps'],
'checkpoint_type': checkpoint_type
},
file_path=checkpoint_path,
file_size_mb=os.path.getsize(checkpoint_path) / (1024 * 1024),
is_active=True
)
if db_manager.save_checkpoint_metadata(metadata):
logger.info(f" Saved checkpoint metadata to database: {checkpoint_id}")
except Exception as meta_error:
logger.warning(f" Could not save checkpoint metadata: {meta_error}")
# Cleanup: Keep only best 10 checkpoints
if improved:
self._cleanup_realtime_checkpoints(checkpoint_dir, keep_best=10)
except Exception as e:
logger.error(f"Error saving realtime checkpoint: {e}")
def _cleanup_realtime_checkpoints(self, checkpoint_dir: str, keep_best: int = 10):
"""Keep only the best N realtime checkpoints"""
try:
if not os.path.exists(checkpoint_dir):
return
import torch
checkpoints = []
for filename in os.listdir(checkpoint_dir):
if filename.endswith('.pt') and filename.startswith('realtime_'):
filepath = os.path.join(checkpoint_dir, filename)
try:
checkpoint = torch.load(filepath, map_location='cpu')
checkpoints.append({
'path': filepath,
'loss': checkpoint.get('loss', float('inf')),
'accuracy': checkpoint.get('accuracy', 0),
'step': checkpoint.get('step', 0),
'is_best': 'BEST' in filename
})
except Exception as e:
logger.debug(f"Could not load checkpoint {filename}: {e}")
# Sort by accuracy (higher is better), then by loss (lower is better)
checkpoints.sort(key=lambda x: (x['accuracy'], -x['loss']), reverse=True)
# Keep best N checkpoints
for checkpoint in checkpoints[keep_best:]:
try:
os.remove(checkpoint['path'])
logger.debug(f"Removed old realtime checkpoint: {os.path.basename(checkpoint['path'])}")
except Exception as e:
logger.warning(f"Could not remove checkpoint: {e}")
except Exception as e:
logger.error(f"Error cleaning up realtime checkpoints: {e}")
def _load_best_realtime_checkpoint(self):
"""Load the best realtime checkpoint on startup to resume training"""
try:
import torch
import os
checkpoint_dir = "models/checkpoints/transformer/realtime"
if not os.path.exists(checkpoint_dir):
logger.info("No realtime checkpoints found, starting fresh")
return
# Find best checkpoint
checkpoints = []
for filename in os.listdir(checkpoint_dir):
if filename.endswith('.pt') and filename.startswith('realtime_'):
filepath = os.path.join(checkpoint_dir, filename)
try:
checkpoint = torch.load(filepath, map_location='cpu')
checkpoints.append({
'path': filepath,
'loss': checkpoint.get('loss', float('inf')),
'accuracy': checkpoint.get('accuracy', 0),
'step': checkpoint.get('step', 0),
'checkpoint': checkpoint
})
except Exception as e:
logger.debug(f"Could not load checkpoint {filename}: {e}")
if not checkpoints:
logger.info("No valid realtime checkpoints found")
return
# Sort by accuracy, then by loss
checkpoints.sort(key=lambda x: (x['accuracy'], -x['loss']), reverse=True)
best = checkpoints[0]
# Restore metrics from checkpoint
if 'metrics' in best['checkpoint']:
saved_metrics = best['checkpoint']['metrics']
self.realtime_training_metrics['total_steps'] = saved_metrics.get('total_steps', 0)
self.realtime_training_metrics['best_loss'] = saved_metrics.get('best_loss', float('inf'))
self.realtime_training_metrics['best_accuracy'] = saved_metrics.get('best_accuracy', 0.0)
self.realtime_training_metrics['losses'] = saved_metrics.get('rolling_losses', [])
self.realtime_training_metrics['accuracies'] = saved_metrics.get('rolling_accuracies', [])
self.realtime_training_metrics['last_checkpoint_step'] = best['step']
# Load model weights if orchestrator is available
if self.orchestrator and hasattr(self.orchestrator, 'primary_transformer_trainer'):
trainer = self.orchestrator.primary_transformer_trainer
if trainer and trainer.model:
trainer.model.load_state_dict(best['checkpoint']['model_state_dict'])
trainer.optimizer.load_state_dict(best['checkpoint']['optimizer_state_dict'])
if 'scheduler_state_dict' in best['checkpoint'] and best['checkpoint']['scheduler_state_dict']:
trainer.scheduler.load_state_dict(best['checkpoint']['scheduler_state_dict'])
logger.info(f"RESUMED REALTIME TRAINING from checkpoint:")
logger.info(f" Step: {best['step']}, Loss: {best['loss']:.4f}, Acc: {best['accuracy']:.2%}")
logger.info(f" Path: {os.path.basename(best['path'])}")
else:
logger.info(f"Found realtime checkpoint but trainer not available yet")
else:
logger.info(f"Found realtime checkpoint but orchestrator not available yet")
except Exception as e:
logger.warning(f"Error loading realtime checkpoint: {e}")
logger.info("Starting realtime training from scratch")
def _get_sleep_time_for_timeframe(self, timeframe: str) -> float:
"""Get appropriate sleep time based on timeframe"""
timeframe_seconds = {

View File

@@ -25,6 +25,7 @@ import threading
import uuid
import time
import torch
from utils.logging_config import get_channel_logger, LogChannel
# Import core components from main system
try:
@@ -98,6 +99,11 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
logger.info(f"Logging to: {log_file}")
# Create channel-specific loggers
pivot_logger = get_channel_logger(__name__, LogChannel.PIVOTS)
api_logger = get_channel_logger(__name__, LogChannel.API)
webui_logger = get_channel_logger(__name__, LogChannel.WEBUI)
class BacktestRunner:
"""Runs backtest candle-by-candle with model predictions and tracks PnL"""
@@ -941,7 +947,7 @@ class AnnotationDashboard:
ts_str, idx = last_info['low']
pivot_map[ts_str]['lows'][idx]['is_last'] = True
logger.info(f"Found {len(pivot_map)} pivot candles for {symbol} {timeframe} (from {len(df)} candles)")
pivot_logger.info(f"Found {len(pivot_map)} pivot candles for {symbol} {timeframe} (from {len(df)} candles)")
return pivot_map
except Exception as e:
@@ -1067,7 +1073,7 @@ class AnnotationDashboard:
'error': {'code': 'INVALID_REQUEST', 'message': 'Missing timeframe'}
})
logger.info(f" Recalculating pivots for {symbol} {timeframe} using backend data")
pivot_logger.info(f"Recalculating pivots for {symbol} {timeframe} using backend data")
if not self.data_loader:
return jsonify({
@@ -1094,7 +1100,7 @@ class AnnotationDashboard:
# Recalculate pivot markers
pivot_markers = self._get_pivot_markers_for_timeframe(symbol, timeframe, df)
logger.info(f" Recalculated {len(pivot_markers)} pivot candles")
pivot_logger.info(f"Recalculated {len(pivot_markers)} pivot candles")
return jsonify({
'success': True,
@@ -1120,11 +1126,11 @@ class AnnotationDashboard:
limit = data.get('limit', 2500) # Default 2500 candles for training
direction = data.get('direction', 'latest') # 'latest', 'before', or 'after'
logger.info(f"Chart data request: {symbol} {timeframes} direction={direction} limit={limit}")
webui_logger.info(f"Chart data request: {symbol} {timeframes} direction={direction} limit={limit}")
if start_time_str:
logger.info(f" start_time: {start_time_str}")
webui_logger.info(f" start_time: {start_time_str}")
if end_time_str:
logger.info(f" end_time: {end_time_str}")
webui_logger.info(f" end_time: {end_time_str}")
if not self.data_loader:
return jsonify({
@@ -1156,7 +1162,7 @@ class AnnotationDashboard:
)
if df is not None and not df.empty:
logger.info(f" {timeframe}: {len(df)} candles ({df.index[0]} to {df.index[-1]})")
webui_logger.info(f" {timeframe}: {len(df)} candles ({df.index[0]} to {df.index[-1]})")
# Get pivot points for this timeframe (only if we have enough context)
pivot_markers = {}
@@ -2386,6 +2392,10 @@ def main():
logger.info(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
logger.info("=" * 80)
# Print logging channel configuration
from utils.logging_config import print_channel_status
print_channel_status()
dashboard = AnnotationDashboard()
dashboard.run(debug=True)

View File

@@ -12,6 +12,9 @@ class ChartManager {
this.updateTimers = {}; // Track auto-update timers
this.autoUpdateEnabled = false; // Auto-update state
this.liveMetricsOverlay = null; // Live metrics display overlay
this.lastPredictionUpdate = {}; // Track last prediction update per timeframe
this.predictionUpdateThrottle = 500; // Min ms between prediction updates
this.lastPredictionHash = null; // Track if predictions actually changed
console.log('ChartManager initialized with timeframes:', timeframes);
}
@@ -172,6 +175,14 @@ class ChartManager {
});
});
// Merge pivot markers
if (newData.pivot_markers) {
if (!chart.data.pivot_markers) {
chart.data.pivot_markers = {};
}
Object.assign(chart.data.pivot_markers, newData.pivot_markers);
}
// 2. Update existing candles in place if they exist in new data
// Iterate backwards to optimize for recent updates
let updatesCount = 0;
@@ -212,7 +223,12 @@ class ChartManager {
if (updatesCount > 0 || remainingTimestamps.length > 0) {
console.log(`[${timeframe}] Chart update: ${updatesCount} updated, ${remainingTimestamps.length} new candles`);
this.recalculatePivots(timeframe, chart.data);
// Only recalculate pivots if we have NEW candles (not just updates to existing ones)
// This prevents unnecessary pivot recalculation on every live candle update
if (remainingTimestamps.length > 0) {
this.recalculatePivots(timeframe, chart.data);
}
this.updateSingleChart(timeframe, chart.data);
window.liveUpdateCount = (window.liveUpdateCount || 0) + 1;
@@ -1774,25 +1790,30 @@ class ChartManager {
});
}
// Update chart layout with new pivots
Plotly.relayout(chart.plotId, {
// Batch update: Use Plotly.update to combine layout and trace updates
// This reduces flickering by doing both operations in one call
const layoutUpdate = {
shapes: shapes,
annotations: annotations
});
};
const traceUpdate = pivotDots.x.length > 0 ? {
x: [pivotDots.x],
y: [pivotDots.y],
text: [pivotDots.text],
'marker.color': [pivotDots.marker.color],
'marker.size': [pivotDots.marker.size],
'marker.symbol': [pivotDots.marker.symbol]
} : {};
// Update pivot dots trace
// Use Plotly.update to batch both operations
if (pivotDots.x.length > 0) {
Plotly.restyle(chart.plotId, {
x: [pivotDots.x],
y: [pivotDots.y],
text: [pivotDots.text],
'marker.color': [pivotDots.marker.color],
'marker.size': [pivotDots.marker.size],
'marker.symbol': [pivotDots.marker.symbol]
}, [2]); // Trace index 2 is pivot dots
Plotly.update(chart.plotId, traceUpdate, layoutUpdate, [2]); // Trace index 2 is pivot dots
} else {
Plotly.relayout(chart.plotId, layoutUpdate);
}
console.log(`🎨 Redrawn ${timeframe} chart with updated pivots`);
console.log(`Redrawn ${timeframe} chart with updated pivots`);
}
/**
@@ -1803,6 +1824,8 @@ class ChartManager {
if (!chart) return;
const plotId = chart.plotId;
const plotElement = document.getElementById(plotId);
if (!plotElement) return;
// Create volume colors
const volumeColors = data.close.map((close, i) => {
@@ -1810,18 +1833,34 @@ class ChartManager {
return close >= data.open[i] ? '#10b981' : '#ef4444';
});
// Update traces
const update = {
x: [data.timestamps, data.timestamps],
open: [data.open],
high: [data.high],
low: [data.low],
close: [data.close],
y: [undefined, data.volume],
'marker.color': [undefined, volumeColors]
// Use Plotly.react for smoother, non-flickering updates
// It only updates what changed, unlike restyle which can cause flicker
const currentData = plotElement.data;
// Update only the first two traces (candlestick and volume)
// Keep other traces (pivots, predictions) intact
const updatedTraces = [...currentData];
// Update candlestick trace (trace 0)
updatedTraces[0] = {
...updatedTraces[0],
x: data.timestamps,
open: data.open,
high: data.high,
low: data.low,
close: data.close
};
// Update volume trace (trace 1)
updatedTraces[1] = {
...updatedTraces[1],
x: data.timestamps,
y: data.volume,
marker: { ...updatedTraces[1].marker, color: volumeColors }
};
Plotly.restyle(plotId, update, [0, 1]);
// Use react instead of restyle - it's smarter about what to update
Plotly.react(plotId, updatedTraces, plotElement.layout, plotElement.config);
console.log(`Updated ${timeframe} chart with ${data.timestamps.length} candles`);
}
@@ -1882,7 +1921,36 @@ class ChartManager {
// This ensures predictions appear on the chart the user is watching (e.g., '1s')
const timeframe = window.appState?.currentTimeframes?.[0] || '1m';
const chart = this.charts[timeframe];
if (!chart) return;
if (!chart) {
console.warn(`[updatePredictions] Chart not found for timeframe: ${timeframe}`);
return;
}
// Throttle prediction updates to avoid flickering
const now = Date.now();
const lastUpdate = this.lastPredictionUpdate[timeframe] || 0;
// Create a simple hash of prediction data to detect actual changes
const predictionHash = JSON.stringify({
action: predictions.transformer?.action,
confidence: predictions.transformer?.confidence,
predicted_price: predictions.transformer?.predicted_price,
timestamp: predictions.transformer?.timestamp
});
// Skip update if:
// 1. Too soon since last update (throttle)
// 2. Predictions haven't actually changed
if (now - lastUpdate < this.predictionUpdateThrottle && predictionHash === this.lastPredictionHash) {
console.debug(`[updatePredictions] Skipping update (throttled or unchanged)`);
return;
}
this.lastPredictionUpdate[timeframe] = now;
this.lastPredictionHash = predictionHash;
console.log(`[updatePredictions] Timeframe: ${timeframe}, Predictions:`, predictions);
const plotId = chart.plotId;
const plotElement = document.getElementById(plotId);
@@ -1918,7 +1986,9 @@ class ChartManager {
// Handle Predicted Candles
if (predictions.transformer.predicted_candle) {
console.log(`[updatePredictions] predicted_candle data:`, predictions.transformer.predicted_candle);
const candleData = predictions.transformer.predicted_candle[timeframe];
console.log(`[updatePredictions] candleData for ${timeframe}:`, candleData);
if (candleData) {
// Get the prediction timestamp from the model (when inference was made)
const predictionTimestamp = predictions.transformer.timestamp || new Date().toISOString();
@@ -2005,8 +2075,8 @@ class ChartManager {
// trendVector contains: angle_degrees, steepness, direction, price_delta
// We visualize this as a ray from current price
// Need current candle close and timestamp
const timeframe = '1m'; // Default to 1m for now
// Use the active timeframe from app state
const timeframe = window.appState?.currentTimeframes?.[0] || '1m';
const chart = this.charts[timeframe];
if (!chart || !chart.data) return;

View File

@@ -144,6 +144,7 @@
<strong class="small">🔴 LIVE</strong>
</div>
<div class="small">
<div>Timeframe: <span id="active-timeframe" class="fw-bold text-primary">--</span></div>
<div>Signal: <span id="latest-signal" class="fw-bold">--</span></div>
<div>Confidence: <span id="latest-confidence">--</span></div>
<div class="text-muted" style="font-size: 0.7rem;">Predicting <span id="active-steps">1</span> step(s) ahead</div>
@@ -572,6 +573,9 @@
document.getElementById('inference-status').style.display = 'block';
document.getElementById('inference-controls').style.display = 'block';
// Display active timeframe
document.getElementById('active-timeframe').textContent = timeframe.toUpperCase();
// Clear prediction history and reset PnL tracker
predictionHistory = [];
pnlTracker = {
@@ -1038,6 +1042,9 @@
}
const latest = data.signals[0];
console.log('[Signal Polling] Latest signal:', latest);
console.log('[Signal Polling] predicted_candle:', latest.predicted_candle);
document.getElementById('latest-signal').textContent = latest.action;
document.getElementById('latest-confidence').textContent =
(latest.confidence * 100).toFixed(1) + '%';