training on predictins - WIP?

This commit is contained in:
Dobromir Popov
2025-12-10 01:24:05 +02:00
parent e84eed2839
commit 9c59b3e0c6
2 changed files with 92 additions and 14 deletions

View File

@@ -2791,6 +2791,41 @@ class AnnotationDashboard:
}
})
@self.server.route('/api/train-validated-prediction', methods=['POST'])
def train_validated_prediction():
"""Train model on a validated prediction (online learning)"""
try:
data = request.get_json()
timeframe = data.get('timeframe')
timestamp = data.get('timestamp')
predicted = data.get('predicted')
actual = data.get('actual')
errors = data.get('errors')
direction_correct = data.get('direction_correct')
accuracy = data.get('accuracy')
logger.info(f"[ONLINE LEARNING] Received validation for {timeframe}: accuracy={accuracy:.1f}%, direction={'' if direction_correct else ''}")
# Trigger training and get metrics
metrics = self._train_on_validated_prediction(
timeframe, timestamp, predicted, actual,
errors, direction_correct, accuracy
)
return jsonify({
'success': True,
'message': 'Training triggered',
'metrics': metrics or {}
})
except Exception as e:
logger.error(f"Error training on validated prediction: {e}", exc_info=True)
return jsonify({
'success': False,
'error': str(e)
}), 500
@self.server.route('/api/realtime-inference/train-manual', methods=['POST'])
def train_manual():
"""Manually trigger training on current candle with specified action"""
@@ -2998,21 +3033,24 @@ class AnnotationDashboard:
This implements online learning where each validated prediction becomes
a training sample, with loss weighting based on prediction accuracy.
Returns:
Dict with training metrics (loss, accuracy, steps)
"""
try:
if not self.training_adapter:
logger.warning("Training adapter not available for incremental training")
return
return None
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
logger.warning("Transformer model not available for incremental training")
return
return None
# Get the transformer trainer
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if not trainer:
logger.warning("Transformer trainer not available")
return
return None
# Calculate sample weight based on accuracy
# Low accuracy predictions get higher weight (we need to learn from mistakes)
@@ -3062,6 +3100,7 @@ class AnnotationDashboard:
return
# Train on this batch with sample weighting
import torch
with torch.enable_grad():
trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
@@ -3070,7 +3109,7 @@ class AnnotationDashboard:
loss = result.get('total_loss', 0)
candle_accuracy = result.get('candle_accuracy', 0)
logger.info(f"[{timeframe}] Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}")
logger.info(f"[{timeframe}] Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}")
# Save checkpoint periodically (every 10 incremental steps)
if not hasattr(self, '_incremental_training_steps'):
@@ -3078,6 +3117,22 @@ class AnnotationDashboard:
self._incremental_training_steps += 1
# Track metrics for display
if not hasattr(self, '_training_metrics_history'):
self._training_metrics_history = []
self._training_metrics_history.append({
'step': self._incremental_training_steps,
'loss': loss,
'accuracy': candle_accuracy,
'timeframe': timeframe,
'timestamp': timestamp
})
# Keep only last 100 metrics
if len(self._training_metrics_history) > 100:
self._training_metrics_history.pop(0)
if self._incremental_training_steps % 10 == 0:
logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps")
trainer.save_checkpoint(
@@ -3089,8 +3144,17 @@ class AnnotationDashboard:
}
)
# Return metrics for display
return {
'loss': loss,
'accuracy': candle_accuracy,
'steps': self._incremental_training_steps,
'sample_weight': sample_weight
}
except Exception as e:
logger.error(f"Error in incremental training: {e}", exc_info=True)
return None
def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict:
"""Fetch market state at a specific timestamp for training"""

View File

@@ -2800,20 +2800,34 @@ class ChartManager {
predicted: prediction.candle, // [O, H, L, C, V]
actual: prediction.accuracy.actualCandle, // [O, H, L, C, V]
errors: prediction.accuracy.errors, // {open, high, low, close, volume}
pctErrors: prediction.accuracy.pctErrors, // {open, high, low, close, volume}
directionCorrect: prediction.accuracy.directionCorrect,
direction_correct: prediction.accuracy.directionCorrect,
accuracy: prediction.accuracy.accuracy
};
console.log('[Prediction Metrics for Training]', metrics);
console.log('[Prediction Metrics] Triggering online learning:', metrics);
// Send to backend via WebSocket for incremental training
if (window.socket && window.socket.connected) {
window.socket.emit('prediction_accuracy', metrics);
console.log(`[${timeframe}] Sent prediction accuracy to backend for training`);
} else {
console.warn('[Training] WebSocket not connected - metrics not sent to backend');
// Send to backend for incremental training (online learning)
fetch('/api/train-validated-prediction', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(metrics)
})
.then(response => response.json())
.then(data => {
if (data.success) {
console.log(`[${timeframe}] ✓ Online learning triggered - model updated from validated prediction`);
// Update metrics display if available
if (window.updateMetricsDisplay) {
window.updateMetricsDisplay(data.metrics);
}
} else {
console.warn(`[${timeframe}] Training failed:`, data.error);
}
})
.catch(error => {
console.warn('[Training] Error sending metrics:', error);
});
}
/**