IMPLEMENTED: WIP; realtime candle predictions training
This commit is contained in:
@@ -2369,6 +2369,55 @@ class AnnotationDashboard:
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling prediction request: {e}")
|
||||
emit('prediction_error', {'error': str(e)})
|
||||
|
||||
@self.socketio.on('prediction_accuracy')
|
||||
def handle_prediction_accuracy(data):
|
||||
"""
|
||||
Handle validated prediction accuracy - trigger incremental training
|
||||
|
||||
This is called when frontend validates a prediction against actual candle.
|
||||
We use this data to incrementally train the model for continuous improvement.
|
||||
"""
|
||||
from flask_socketio import emit
|
||||
try:
|
||||
timeframe = data.get('timeframe')
|
||||
timestamp = data.get('timestamp')
|
||||
predicted = data.get('predicted') # [O, H, L, C, V]
|
||||
actual = data.get('actual') # [O, H, L, C]
|
||||
errors = data.get('errors') # {open, high, low, close}
|
||||
pct_errors = data.get('pctErrors')
|
||||
direction_correct = data.get('directionCorrect')
|
||||
accuracy = data.get('accuracy')
|
||||
|
||||
if not all([timeframe, timestamp, predicted, actual]):
|
||||
logger.warning("Incomplete prediction accuracy data received")
|
||||
return
|
||||
|
||||
logger.info(f"[{timeframe}] Prediction validated: {accuracy:.1f}% accuracy, direction: {direction_correct}")
|
||||
logger.debug(f" Errors: O={pct_errors['open']:.2f}% H={pct_errors['high']:.2f}% L={pct_errors['low']:.2f}% C={pct_errors['close']:.2f}%")
|
||||
|
||||
# Trigger incremental training on this validated prediction
|
||||
self._train_on_validated_prediction(
|
||||
timeframe=timeframe,
|
||||
timestamp=timestamp,
|
||||
predicted=predicted,
|
||||
actual=actual,
|
||||
errors=errors,
|
||||
direction_correct=direction_correct,
|
||||
accuracy=accuracy
|
||||
)
|
||||
|
||||
# Send confirmation back to frontend
|
||||
emit('training_update', {
|
||||
'status': 'training_triggered',
|
||||
'timestamp': timestamp,
|
||||
'accuracy': accuracy,
|
||||
'message': f'Incremental training triggered on validated prediction'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling prediction accuracy: {e}", exc_info=True)
|
||||
emit('training_error', {'error': str(e)})
|
||||
|
||||
def _start_live_update_thread(self):
|
||||
"""Start background thread for live updates"""
|
||||
@@ -2392,24 +2441,44 @@ class AnnotationDashboard:
|
||||
for timeframe in ['1s', '1m']:
|
||||
room = f"{symbol}_{timeframe}"
|
||||
|
||||
# Get latest candle
|
||||
# Get latest candles (need last 2 to determine confirmation status)
|
||||
try:
|
||||
candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=1)
|
||||
candles = self.data_provider.get_ohlcv(symbol, timeframe, limit=2)
|
||||
if candles and len(candles) > 0:
|
||||
latest_candle = candles[-1]
|
||||
|
||||
# Emit chart update
|
||||
# Determine if candle is confirmed (closed)
|
||||
# For 1s: candle is confirmed when next candle starts (2s delay)
|
||||
# For others: candle is confirmed when next candle starts
|
||||
is_confirmed = len(candles) >= 2 # If we have 2 candles, the first is confirmed
|
||||
|
||||
# Format timestamp consistently
|
||||
timestamp = latest_candle.get('timestamp')
|
||||
if isinstance(timestamp, str):
|
||||
# Already formatted
|
||||
formatted_timestamp = timestamp
|
||||
else:
|
||||
# Convert to ISO string then format
|
||||
from datetime import datetime
|
||||
if isinstance(timestamp, datetime):
|
||||
formatted_timestamp = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
else:
|
||||
formatted_timestamp = str(timestamp)
|
||||
|
||||
# Emit chart update with full candle data
|
||||
self.socketio.emit('chart_update', {
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe,
|
||||
'candle': {
|
||||
'timestamp': latest_candle.get('timestamp'),
|
||||
'open': latest_candle.get('open'),
|
||||
'high': latest_candle.get('high'),
|
||||
'low': latest_candle.get('low'),
|
||||
'close': latest_candle.get('close'),
|
||||
'volume': latest_candle.get('volume')
|
||||
}
|
||||
'timestamp': formatted_timestamp,
|
||||
'open': float(latest_candle.get('open', 0)),
|
||||
'high': float(latest_candle.get('high', 0)),
|
||||
'low': float(latest_candle.get('low', 0)),
|
||||
'close': float(latest_candle.get('close', 0)),
|
||||
'volume': float(latest_candle.get('volume', 0))
|
||||
},
|
||||
'is_confirmed': is_confirmed, # True if this candle is closed/confirmed
|
||||
'has_previous': len(candles) >= 2 # True if we have previous candle for validation
|
||||
}, room=room)
|
||||
|
||||
# Get prediction if model is loaded
|
||||
@@ -2430,6 +2499,144 @@ class AnnotationDashboard:
|
||||
self._live_update_thread = threading.Thread(target=live_update_worker, daemon=True)
|
||||
self._live_update_thread.start()
|
||||
|
||||
def _train_on_validated_prediction(self, timeframe: str, timestamp: str, predicted: list,
|
||||
actual: list, errors: dict, direction_correct: bool, accuracy: float):
|
||||
"""
|
||||
Incrementally train model on validated prediction
|
||||
|
||||
This implements online learning where each validated prediction becomes
|
||||
a training sample, with loss weighting based on prediction accuracy.
|
||||
"""
|
||||
try:
|
||||
if not self.training_adapter:
|
||||
logger.warning("Training adapter not available for incremental training")
|
||||
return
|
||||
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'primary_transformer'):
|
||||
logger.warning("Transformer model not available for incremental training")
|
||||
return
|
||||
|
||||
# Get the transformer trainer
|
||||
trainer = getattr(self.orchestrator, 'primary_transformer_trainer', None)
|
||||
if not trainer:
|
||||
logger.warning("Transformer trainer not available")
|
||||
return
|
||||
|
||||
# Calculate sample weight based on accuracy
|
||||
# Low accuracy predictions get higher weight (we need to learn from mistakes)
|
||||
# High accuracy predictions get lower weight (model already knows this)
|
||||
if accuracy < 50:
|
||||
sample_weight = 3.0 # Learn hard from bad predictions
|
||||
elif accuracy < 70:
|
||||
sample_weight = 2.0 # Moderate learning
|
||||
elif accuracy < 85:
|
||||
sample_weight = 1.0 # Normal learning
|
||||
else:
|
||||
sample_weight = 0.5 # Light touch-up for good predictions
|
||||
|
||||
# Also weight by direction correctness
|
||||
if not direction_correct:
|
||||
sample_weight *= 1.5 # Wrong direction is critical - learn more
|
||||
|
||||
logger.info(f"[{timeframe}] Incremental training: accuracy={accuracy:.1f}%, weight={sample_weight:.1f}x")
|
||||
|
||||
# Create training sample from validated prediction
|
||||
# We need to fetch the market state at that timestamp
|
||||
symbol = 'ETH/USDT' # TODO: Get from active trading pair
|
||||
|
||||
training_sample = {
|
||||
'symbol': symbol,
|
||||
'timestamp': timestamp,
|
||||
'predicted_candle': predicted, # [O, H, L, C, V]
|
||||
'actual_candle': actual, # [O, H, L, C]
|
||||
'errors': errors,
|
||||
'accuracy': accuracy,
|
||||
'direction_correct': direction_correct,
|
||||
'sample_weight': sample_weight
|
||||
}
|
||||
|
||||
# Get market state at that timestamp
|
||||
try:
|
||||
market_state = self._fetch_market_state_at_timestamp(symbol, timestamp, timeframe)
|
||||
training_sample['market_state'] = market_state
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fetch market state: {e}")
|
||||
return
|
||||
|
||||
# Convert to transformer batch format
|
||||
batch = self.training_adapter._convert_prediction_to_batch(training_sample, timeframe)
|
||||
if not batch:
|
||||
logger.warning("Could not convert validated prediction to training batch")
|
||||
return
|
||||
|
||||
# Train on this batch with sample weighting
|
||||
with torch.enable_grad():
|
||||
trainer.model.train()
|
||||
result = trainer.train_step(batch, accumulate_gradients=False, sample_weight=sample_weight)
|
||||
|
||||
if result:
|
||||
loss = result.get('total_loss', 0)
|
||||
candle_accuracy = result.get('candle_accuracy', 0)
|
||||
|
||||
logger.info(f"[{timeframe}] Trained on validated prediction: loss={loss:.4f}, new_acc={candle_accuracy:.2%}")
|
||||
|
||||
# Save checkpoint periodically (every 10 incremental steps)
|
||||
if not hasattr(self, '_incremental_training_steps'):
|
||||
self._incremental_training_steps = 0
|
||||
|
||||
self._incremental_training_steps += 1
|
||||
|
||||
if self._incremental_training_steps % 10 == 0:
|
||||
logger.info(f"Saving checkpoint after {self._incremental_training_steps} incremental training steps")
|
||||
trainer.save_checkpoint(
|
||||
filepath=None, # Auto-generate path
|
||||
metadata={
|
||||
'training_type': 'incremental_online',
|
||||
'steps': self._incremental_training_steps,
|
||||
'last_accuracy': accuracy
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in incremental training: {e}", exc_info=True)
|
||||
|
||||
def _fetch_market_state_at_timestamp(self, symbol: str, timestamp: str, timeframe: str) -> Dict:
|
||||
"""Fetch market state at a specific timestamp for training"""
|
||||
try:
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
|
||||
# Parse timestamp
|
||||
ts = pd.Timestamp(timestamp)
|
||||
|
||||
# Get historical data for multiple timeframes
|
||||
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
|
||||
|
||||
for tf in ['1s', '1m', '1h']:
|
||||
try:
|
||||
df = self.data_provider.get_historical_data(symbol, tf, limit=200)
|
||||
if df is not None and not df.empty:
|
||||
# Find data up to (but not including) the target timestamp
|
||||
df_before = df[df.index < ts]
|
||||
if not df_before.empty:
|
||||
recent = df_before.tail(200)
|
||||
market_state['timeframes'][tf] = {
|
||||
'timestamps': recent.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
'open': recent['open'].tolist(),
|
||||
'high': recent['high'].tolist(),
|
||||
'low': recent['low'].tolist(),
|
||||
'close': recent['close'].tolist(),
|
||||
'volume': recent['volume'].tolist()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fetch {tf} data: {e}")
|
||||
|
||||
return market_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching market state: {e}")
|
||||
return {}
|
||||
|
||||
def _get_live_prediction(self, symbol: str, timeframe: str, prediction_steps: int = 1):
|
||||
"""Get live prediction from model"""
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user