training on predictins - WIP?
This commit is contained in:
@@ -2791,6 +2791,41 @@ class AnnotationDashboard:
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@self.server.route('/api/train-validated-prediction', methods=['POST'])
|
||||||
|
def train_validated_prediction():
|
||||||
|
"""Train model on a validated prediction (online learning)"""
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
|
||||||
|
timeframe = data.get('timeframe')
|
||||||
|
timestamp = data.get('timestamp')
|
||||||
|
predicted = data.get('predicted')
|
||||||
|
actual = data.get('actual')
|
||||||
|
errors = data.get('errors')
|
||||||
|
direction_correct = data.get('direction_correct')
|
||||||
|
accuracy = data.get('accuracy')
|
||||||
|
|
||||||
|
logger.info(f"[ONLINE LEARNING] Received validation for {timeframe}: accuracy={accuracy:.1f}%, direction={'✓' if direction_correct else '✗'}")
|
||||||
|
|
||||||
|
# Trigger training and get metrics
|
||||||
|
metrics = self._train_on_validated_prediction(
|
||||||
|
timeframe, timestamp, predicted, actual,
|
||||||
|
errors, direction_correct, accuracy
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'message': 'Training triggered',
|
||||||
|
'metrics': metrics or {}
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error training on validated prediction: {e}", exc_info=True)
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}), 500
|
||||||
|
|
||||||
@self.server.route('/api/realtime-inference/train-manual', methods=['POST'])
|
@self.server.route('/api/realtime-inference/train-manual', methods=['POST'])
|
||||||
def train_manual():
|
def train_manual():
|
||||||
"""Manually trigger training on current candle with specified action"""
|
"""Manually trigger training on current candle with specified action"""
|
||||||
@@ -2998,21 +3033,24 @@ class AnnotationDashboard:
|
|||||||
|
|
||||||
This implements online learning where each validated prediction becomes
|
This implements online learning where each validated prediction becomes
|
||||||
a training sample, with loss weighting based on prediction accuracy.
|
a training sample, with loss weighting based on prediction accuracy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with training metrics (loss, accuracy, steps)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if not self.training_adapter:
|
if not self.training_adapter:
|
||||||
logger.warning("Training adapter not available for incremental training")
|
logger.warning("Training adapter not available for incremental training")
|
||||||
return
|
return None
|
||||||
|
|
||||||
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
|
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
|
||||||
logger.warning("Transformer model not available for incremental training")
|
logger.warning("Transformer model not available for incremental training")
|
||||||
return
|
return None
|
||||||
|
|
||||||
# Get the transformer trainer
|
# Get the transformer trainer
|
||||||
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
||||||
if not trainer:
|
if not trainer:
|
||||||
logger.warning("Transformer trainer not available")
|
logger.warning("Transformer trainer not available")
|
||||||
return
|
return None
|
||||||
|
|
||||||
# Calculate sample weight based on accuracy
|
# Calculate sample weight based on accuracy
|
||||||
# Low accuracy predictions get higher weight (we need to learn from mistakes)
|
# Low accuracy predictions get higher weight (we need to learn from mistakes)
|
||||||
@@ -3062,6 +3100,7 @@ class AnnotationDashboard:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Train on this batch with sample weighting
|
# Train on this batch with sample weighting
|
||||||
|
import torch
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
trainer.model.train()
|
trainer.model.train()
|
||||||
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
|
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
|
||||||
@@ -3070,7 +3109,7 @@ class AnnotationDashboard:
|
|||||||
loss = result.get('total_loss', 0)
|
loss = result.get('total_loss', 0)
|
||||||
candle_accuracy = result.get('candle_accuracy', 0)
|
candle_accuracy = result.get('candle_accuracy', 0)
|
||||||
|
|
||||||
logger.info(f"[{timeframe}] Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}")
|
logger.info(f"[{timeframe}] ✓ Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}")
|
||||||
|
|
||||||
# Save checkpoint periodically (every 10 incremental steps)
|
# Save checkpoint periodically (every 10 incremental steps)
|
||||||
if not hasattr(self, '_incremental_training_steps'):
|
if not hasattr(self, '_incremental_training_steps'):
|
||||||
@@ -3078,6 +3117,22 @@ class AnnotationDashboard:
|
|||||||
|
|
||||||
self._incremental_training_steps += 1
|
self._incremental_training_steps += 1
|
||||||
|
|
||||||
|
# Track metrics for display
|
||||||
|
if not hasattr(self, '_training_metrics_history'):
|
||||||
|
self._training_metrics_history = []
|
||||||
|
|
||||||
|
self._training_metrics_history.append({
|
||||||
|
'step': self._incremental_training_steps,
|
||||||
|
'loss': loss,
|
||||||
|
'accuracy': candle_accuracy,
|
||||||
|
'timeframe': timeframe,
|
||||||
|
'timestamp': timestamp
|
||||||
|
})
|
||||||
|
|
||||||
|
# Keep only last 100 metrics
|
||||||
|
if len(self._training_metrics_history) > 100:
|
||||||
|
self._training_metrics_history.pop(0)
|
||||||
|
|
||||||
if self._incremental_training_steps % 10 == 0:
|
if self._incremental_training_steps % 10 == 0:
|
||||||
logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps")
|
logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps")
|
||||||
trainer.save_checkpoint(
|
trainer.save_checkpoint(
|
||||||
@@ -3088,9 +3143,18 @@ class AnnotationDashboard:
|
|||||||
'last_accuracy': accuracy
|
'last_accuracy': accuracy
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Return metrics for display
|
||||||
|
return {
|
||||||
|
'loss': loss,
|
||||||
|
'accuracy': candle_accuracy,
|
||||||
|
'steps': self._incremental_training_steps,
|
||||||
|
'sample_weight': sample_weight
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in incremental training: {e}", exc_info=True)
|
logger.error(f"Error in incremental training: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict:
|
def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict:
|
||||||
"""Fetch market state at a specific timestamp for training"""
|
"""Fetch market state at a specific timestamp for training"""
|
||||||
|
|||||||
@@ -2800,20 +2800,34 @@ class ChartManager {
|
|||||||
predicted: prediction.candle, // [O, H, L, C, V]
|
predicted: prediction.candle, // [O, H, L, C, V]
|
||||||
actual: prediction.accuracy.actualCandle, // [O, H, L, C, V]
|
actual: prediction.accuracy.actualCandle, // [O, H, L, C, V]
|
||||||
errors: prediction.accuracy.errors, // {open, high, low, close, volume}
|
errors: prediction.accuracy.errors, // {open, high, low, close, volume}
|
||||||
pctErrors: prediction.accuracy.pctErrors, // {open, high, low, close, volume}
|
direction_correct: prediction.accuracy.directionCorrect,
|
||||||
directionCorrect: prediction.accuracy.directionCorrect,
|
|
||||||
accuracy: prediction.accuracy.accuracy
|
accuracy: prediction.accuracy.accuracy
|
||||||
};
|
};
|
||||||
|
|
||||||
console.log('[Prediction Metrics for Training]', metrics);
|
console.log('[Prediction Metrics] Triggering online learning:', metrics);
|
||||||
|
|
||||||
// Send to backend via WebSocket for incremental training
|
// Send to backend for incremental training (online learning)
|
||||||
if (window.socket && window.socket.connected) {
|
fetch('/api/train-validated-prediction', {
|
||||||
window.socket.emit('prediction_accuracy', metrics);
|
method: 'POST',
|
||||||
console.log(`[${timeframe}] Sent prediction accuracy to backend for training`);
|
headers: { 'Content-Type': 'application/json' },
|
||||||
} else {
|
body: JSON.stringify(metrics)
|
||||||
console.warn('[Training] WebSocket not connected - metrics not sent to backend');
|
})
|
||||||
}
|
.then(response => response.json())
|
||||||
|
.then(data => {
|
||||||
|
if (data.success) {
|
||||||
|
console.log(`[${timeframe}] ✓ Online learning triggered - model updated from validated prediction`);
|
||||||
|
|
||||||
|
// Update metrics display if available
|
||||||
|
if (window.updateMetricsDisplay) {
|
||||||
|
window.updateMetricsDisplay(data.metrics);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.warn(`[${timeframe}] Training failed:`, data.error);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch(error => {
|
||||||
|
console.warn('[Training] Error sending metrics:', error);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
Reference in New Issue
Block a user