training on predictins - WIP?

This commit is contained in:
Dobromir Popov
2025-12-10 01:24:05 +02:00
parent e84eed2839
commit 9c59b3e0c6
2 changed files with 92 additions and 14 deletions

View File

@@ -2791,6 +2791,41 @@ class AnnotationDashboard:
} }
}) })
@self.server.route('/api/train-validated-prediction', methods=['POST'])
def train_validated_prediction():
"""Train model on a validated prediction (online learning)"""
try:
data = request.get_json()
timeframe = data.get('timeframe')
timestamp = data.get('timestamp')
predicted = data.get('predicted')
actual = data.get('actual')
errors = data.get('errors')
direction_correct = data.get('direction_correct')
accuracy = data.get('accuracy')
logger.info(f"[ONLINE LEARNING] Received validation for {timeframe}: accuracy={accuracy:.1f}%, direction={'' if direction_correct else ''}")
# Trigger training and get metrics
metrics = self._train_on_validated_prediction(
timeframe, timestamp, predicted, actual,
errors, direction_correct, accuracy
)
return jsonify({
'success': True,
'message': 'Training triggered',
'metrics': metrics or {}
})
except Exception as e:
logger.error(f"Error training on validated prediction: {e}", exc_info=True)
return jsonify({
'success': False,
'error': str(e)
}), 500
@self.server.route('/api/realtime-inference/train-manual', methods=['POST']) @self.server.route('/api/realtime-inference/train-manual', methods=['POST'])
def train_manual(): def train_manual():
"""Manually trigger training on current candle with specified action""" """Manually trigger training on current candle with specified action"""
@@ -2998,21 +3033,24 @@ class AnnotationDashboard:
This implements online learning where each validated prediction becomes This implements online learning where each validated prediction becomes
a training sample, with loss weighting based on prediction accuracy. a training sample, with loss weighting based on prediction accuracy.
Returns:
Dict with training metrics (loss, accuracy, steps)
""" """
try: try:
if not self.training_adapter: if not self.training_adapter:
logger.warning("Training adapter not available for incremental training") logger.warning("Training adapter not available for incremental training")
return return None
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'): if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
logger.warning("Transformer model not available for incremental training") logger.warning("Transformer model not available for incremental training")
return return None
# Get the transformer trainer # Get the transformer trainer
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None) trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
if not trainer: if not trainer:
logger.warning("Transformer trainer not available") logger.warning("Transformer trainer not available")
return return None
# Calculate sample weight based on accuracy # Calculate sample weight based on accuracy
# Low accuracy predictions get higher weight (we need to learn from mistakes) # Low accuracy predictions get higher weight (we need to learn from mistakes)
@@ -3062,6 +3100,7 @@ class AnnotationDashboard:
return return
# Train on this batch with sample weighting # Train on this batch with sample weighting
import torch
with torch.enable_grad(): with torch.enable_grad():
trainer.model.train() trainer.model.train()
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight) result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
@@ -3070,7 +3109,7 @@ class AnnotationDashboard:
loss = result.get('total_loss', 0) loss = result.get('total_loss', 0)
candle_accuracy = result.get('candle_accuracy', 0) candle_accuracy = result.get('candle_accuracy', 0)
logger.info(f"[{timeframe}] Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}") logger.info(f"[{timeframe}] Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}")
# Save checkpoint periodically (every 10 incremental steps) # Save checkpoint periodically (every 10 incremental steps)
if not hasattr(self, '_incremental_training_steps'): if not hasattr(self, '_incremental_training_steps'):
@@ -3078,6 +3117,22 @@ class AnnotationDashboard:
self._incremental_training_steps += 1 self._incremental_training_steps += 1
# Track metrics for display
if not hasattr(self, '_training_metrics_history'):
self._training_metrics_history = []
self._training_metrics_history.append({
'step': self._incremental_training_steps,
'loss': loss,
'accuracy': candle_accuracy,
'timeframe': timeframe,
'timestamp': timestamp
})
# Keep only last 100 metrics
if len(self._training_metrics_history) > 100:
self._training_metrics_history.pop(0)
if self._incremental_training_steps % 10 == 0: if self._incremental_training_steps % 10 == 0:
logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps") logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps")
trainer.save_checkpoint( trainer.save_checkpoint(
@@ -3088,9 +3143,18 @@ class AnnotationDashboard:
'last_accuracy': accuracy 'last_accuracy': accuracy
} }
) )
# Return metrics for display
return {
'loss': loss,
'accuracy': candle_accuracy,
'steps': self._incremental_training_steps,
'sample_weight': sample_weight
}
except Exception as e: except Exception as e:
logger.error(f"Error in incremental training: {e}", exc_info=True) logger.error(f"Error in incremental training: {e}", exc_info=True)
return None
def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict: def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict:
"""Fetch market state at a specific timestamp for training""" """Fetch market state at a specific timestamp for training"""

View File

@@ -2800,20 +2800,34 @@ class ChartManager {
predicted: prediction.candle, // [O, H, L, C, V] predicted: prediction.candle, // [O, H, L, C, V]
actual: prediction.accuracy.actualCandle, // [O, H, L, C, V] actual: prediction.accuracy.actualCandle, // [O, H, L, C, V]
errors: prediction.accuracy.errors, // {open, high, low, close, volume} errors: prediction.accuracy.errors, // {open, high, low, close, volume}
pctErrors: prediction.accuracy.pctErrors, // {open, high, low, close, volume} direction_correct: prediction.accuracy.directionCorrect,
directionCorrect: prediction.accuracy.directionCorrect,
accuracy: prediction.accuracy.accuracy accuracy: prediction.accuracy.accuracy
}; };
console.log('[Prediction Metrics for Training]', metrics); console.log('[Prediction Metrics] Triggering online learning:', metrics);
// Send to backend via WebSocket for incremental training // Send to backend for incremental training (online learning)
if (window.socket && window.socket.connected) { fetch('/api/train-validated-prediction', {
window.socket.emit('prediction_accuracy', metrics); method: 'POST',
console.log(`[${timeframe}] Sent prediction accuracy to backend for training`); headers: { 'Content-Type': 'application/json' },
} else { body: JSON.stringify(metrics)
console.warn('[Training] WebSocket not connected - metrics not sent to backend'); })
} .then(response => response.json())
.then(data => {
if (data.success) {
console.log(`[${timeframe}] ✓ Online learning triggered - model updated from validated prediction`);
// Update metrics display if available
if (window.updateMetricsDisplay) {
window.updateMetricsDisplay(data.metrics);
}
} else {
console.warn(`[${timeframe}] Training failed:`, data.error);
}
})
.catch(error => {
console.warn('[Training] Error sending metrics:', error);
});
} }
/** /**