training on predictins - WIP?
This commit is contained in:
@@ -2791,6 +2791,41 @@ class AnnotationDashboard:
|
||||
}
|
||||
})
|
||||
|
||||
@self.server.route('/api/train-validated-prediction', methods=['POST'])
|
||||
def train_validated_prediction():
|
||||
"""Train model on a validated prediction (online learning)"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
timeframe = data.get('timeframe')
|
||||
timestamp = data.get('timestamp')
|
||||
predicted = data.get('predicted')
|
||||
actual = data.get('actual')
|
||||
errors = data.get('errors')
|
||||
direction_correct = data.get('direction_correct')
|
||||
accuracy = data.get('accuracy')
|
||||
|
||||
logger.info(f"[ONLINE LEARNING] Received validation for {timeframe}: accuracy={accuracy:.1f}%, direction={'✓' if direction_correct else '✗'}")
|
||||
|
||||
# Trigger training and get metrics
|
||||
metrics = self._train_on_validated_prediction(
|
||||
timeframe, timestamp, predicted, actual,
|
||||
errors, direction_correct, accuracy
|
||||
)
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'message': 'Training triggered',
|
||||
'metrics': metrics or {}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training on validated prediction: {e}", exc_info=True)
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}), 500
|
||||
|
||||
@self.server.route('/api/realtime-inference/train-manual', methods=['POST'])
|
||||
def train_manual():
|
||||
"""Manually trigger training on current candle with specified action"""
|
||||
@@ -2998,21 +3033,24 @@ class AnnotationDashboard:
|
||||
|
||||
This implements online learning where each validated prediction becomes
|
||||
a training sample, with loss weighting based on prediction accuracy.
|
||||
|
||||
Returns:
|
||||
Dict with training metrics (loss, accuracy, steps)
|
||||
"""
|
||||
try:
|
||||
if not self.training_adapter:
|
||||
logger.warning("Training adapter not available for incremental training")
|
||||
return
|
||||
return None
|
||||
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
|
||||
logger.warning("Transformer model not available for incremental training")
|
||||
return
|
||||
return None
|
||||
|
||||
# Get the transformer trainer
|
||||
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
||||
if not trainer:
|
||||
logger.warning("Transformer trainer not available")
|
||||
return
|
||||
return None
|
||||
|
||||
# Calculate sample weight based on accuracy
|
||||
# Low accuracy predictions get higher weight (we need to learn from mistakes)
|
||||
@@ -3062,6 +3100,7 @@ class AnnotationDashboard:
|
||||
return
|
||||
|
||||
# Train on this batch with sample weighting
|
||||
import torch
|
||||
with torch.enable_grad():
|
||||
trainer.model.train()
|
||||
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
|
||||
@@ -3070,7 +3109,7 @@ class AnnotationDashboard:
|
||||
loss = result.get('total_loss', 0)
|
||||
candle_accuracy = result.get('candle_accuracy', 0)
|
||||
|
||||
logger.info(f"[{timeframe}] Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}")
|
||||
logger.info(f"[{timeframe}] ✓ Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}")
|
||||
|
||||
# Save checkpoint periodically (every 10 incremental steps)
|
||||
if not hasattr(self, '_incremental_training_steps'):
|
||||
@@ -3078,6 +3117,22 @@ class AnnotationDashboard:
|
||||
|
||||
self._incremental_training_steps += 1
|
||||
|
||||
# Track metrics for display
|
||||
if not hasattr(self, '_training_metrics_history'):
|
||||
self._training_metrics_history = []
|
||||
|
||||
self._training_metrics_history.append({
|
||||
'step': self._incremental_training_steps,
|
||||
'loss': loss,
|
||||
'accuracy': candle_accuracy,
|
||||
'timeframe': timeframe,
|
||||
'timestamp': timestamp
|
||||
})
|
||||
|
||||
# Keep only last 100 metrics
|
||||
if len(self._training_metrics_history) > 100:
|
||||
self._training_metrics_history.pop(0)
|
||||
|
||||
if self._incremental_training_steps % 10 == 0:
|
||||
logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps")
|
||||
trainer.save_checkpoint(
|
||||
@@ -3089,8 +3144,17 @@ class AnnotationDashboard:
|
||||
}
|
||||
)
|
||||
|
||||
# Return metrics for display
|
||||
return {
|
||||
'loss': loss,
|
||||
'accuracy': candle_accuracy,
|
||||
'steps': self._incremental_training_steps,
|
||||
'sample_weight': sample_weight
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in incremental training: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict:
|
||||
"""Fetch market state at a specific timestamp for training"""
|
||||
|
||||
@@ -2800,20 +2800,34 @@ class ChartManager {
|
||||
predicted: prediction.candle, // [O, H, L, C, V]
|
||||
actual: prediction.accuracy.actualCandle, // [O, H, L, C, V]
|
||||
errors: prediction.accuracy.errors, // {open, high, low, close, volume}
|
||||
pctErrors: prediction.accuracy.pctErrors, // {open, high, low, close, volume}
|
||||
directionCorrect: prediction.accuracy.directionCorrect,
|
||||
direction_correct: prediction.accuracy.directionCorrect,
|
||||
accuracy: prediction.accuracy.accuracy
|
||||
};
|
||||
|
||||
console.log('[Prediction Metrics for Training]', metrics);
|
||||
console.log('[Prediction Metrics] Triggering online learning:', metrics);
|
||||
|
||||
// Send to backend via WebSocket for incremental training
|
||||
if (window.socket && window.socket.connected) {
|
||||
window.socket.emit('prediction_accuracy', metrics);
|
||||
console.log(`[${timeframe}] Sent prediction accuracy to backend for training`);
|
||||
} else {
|
||||
console.warn('[Training] WebSocket not connected - metrics not sent to backend');
|
||||
// Send to backend for incremental training (online learning)
|
||||
fetch('/api/train-validated-prediction', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(metrics)
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.success) {
|
||||
console.log(`[${timeframe}] ✓ Online learning triggered - model updated from validated prediction`);
|
||||
|
||||
// Update metrics display if available
|
||||
if (window.updateMetricsDisplay) {
|
||||
window.updateMetricsDisplay(data.metrics);
|
||||
}
|
||||
} else {
|
||||
console.warn(`[${timeframe}] Training failed:`, data.error);
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
console.warn('[Training] Error sending metrics:', error);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user