training on predictins - WIP?
This commit is contained in:
@@ -2791,6 +2791,41 @@ class AnnotationDashboard:
|
||||
}
|
||||
})
|
||||
|
||||
@self.server.route('/api/train-validated-prediction', methods=['POST'])
|
||||
def train_validated_prediction():
|
||||
"""Train model on a validated prediction (online learning)"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
timeframe = data.get('timeframe')
|
||||
timestamp = data.get('timestamp')
|
||||
predicted = data.get('predicted')
|
||||
actual = data.get('actual')
|
||||
errors = data.get('errors')
|
||||
direction_correct = data.get('direction_correct')
|
||||
accuracy = data.get('accuracy')
|
||||
|
||||
logger.info(f"[ONLINE LEARNING] Received validation for {timeframe}: accuracy={accuracy:.1f}%, direction={'✓' if direction_correct else '✗'}")
|
||||
|
||||
# Trigger training and get metrics
|
||||
metrics = self._train_on_validated_prediction(
|
||||
timeframe, timestamp, predicted, actual,
|
||||
errors, direction_correct, accuracy
|
||||
)
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'message': 'Training triggered',
|
||||
'metrics': metrics or {}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training on validated prediction: {e}", exc_info=True)
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}), 500
|
||||
|
||||
@self.server.route('/api/realtime-inference/train-manual', methods=['POST'])
|
||||
def train_manual():
|
||||
"""Manually trigger training on current candle with specified action"""
|
||||
@@ -2998,21 +3033,24 @@ class AnnotationDashboard:
|
||||
|
||||
This implements online learning where each validated prediction becomes
|
||||
a training sample, with loss weighting based on prediction accuracy.
|
||||
|
||||
Returns:
|
||||
Dict with training metrics (loss, accuracy, steps)
|
||||
"""
|
||||
try:
|
||||
if not self.training_adapter:
|
||||
logger.warning("Training adapter not available for incremental training")
|
||||
return
|
||||
return None
|
||||
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
|
||||
logger.warning("Transformer model not available for incremental training")
|
||||
return
|
||||
return None
|
||||
|
||||
# Get the transformer trainer
|
||||
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
||||
if not trainer:
|
||||
logger.warning("Transformer trainer not available")
|
||||
return
|
||||
return None
|
||||
|
||||
# Calculate sample weight based on accuracy
|
||||
# Low accuracy predictions get higher weight (we need to learn from mistakes)
|
||||
@@ -3062,6 +3100,7 @@ class AnnotationDashboard:
|
||||
return
|
||||
|
||||
# Train on this batch with sample weighting
|
||||
import torch
|
||||
with torch.enable_grad():
|
||||
trainer.model.train()
|
||||
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
|
||||
@@ -3070,7 +3109,7 @@ class AnnotationDashboard:
|
||||
loss = result.get('total_loss', 0)
|
||||
candle_accuracy = result.get('candle_accuracy', 0)
|
||||
|
||||
logger.info(f"[{timeframe}] Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}")
|
||||
logger.info(f"[{timeframe}] ✓ Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}")
|
||||
|
||||
# Save checkpoint periodically (every 10 incremental steps)
|
||||
if not hasattr(self, '_incremental_training_steps'):
|
||||
@@ -3078,6 +3117,22 @@ class AnnotationDashboard:
|
||||
|
||||
self._incremental_training_steps += 1
|
||||
|
||||
# Track metrics for display
|
||||
if not hasattr(self, '_training_metrics_history'):
|
||||
self._training_metrics_history = []
|
||||
|
||||
self._training_metrics_history.append({
|
||||
'step': self._incremental_training_steps,
|
||||
'loss': loss,
|
||||
'accuracy': candle_accuracy,
|
||||
'timeframe': timeframe,
|
||||
'timestamp': timestamp
|
||||
})
|
||||
|
||||
# Keep only last 100 metrics
|
||||
if len(self._training_metrics_history) > 100:
|
||||
self._training_metrics_history.pop(0)
|
||||
|
||||
if self._incremental_training_steps % 10 == 0:
|
||||
logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps")
|
||||
trainer.save_checkpoint(
|
||||
@@ -3088,9 +3143,18 @@ class AnnotationDashboard:
|
||||
'last_accuracy': accuracy
|
||||
}
|
||||
)
|
||||
|
||||
# Return metrics for display
|
||||
return {
|
||||
'loss': loss,
|
||||
'accuracy': candle_accuracy,
|
||||
'steps': self._incremental_training_steps,
|
||||
'sample_weight': sample_weight
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in incremental training: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict:
|
||||
"""Fetch market state at a specific timestamp for training"""
|
||||
|
||||
Reference in New Issue
Block a user