immedite training imp
This commit is contained in:
@@ -176,6 +176,17 @@ class CleanTradingDashboard:
|
||||
'BTC/USDT': {}
|
||||
}
|
||||
|
||||
# Confidence calibration tracking
|
||||
self.confidence_calibration: Dict[str, Dict] = {
|
||||
'cob_liquidity_imbalance': {
|
||||
'total_predictions': 0,
|
||||
'correct_predictions': 0,
|
||||
'accuracy_by_confidence': {}, # Track accuracy by confidence ranges
|
||||
'confidence_adjustment': 1.0, # Multiplier for future confidence levels
|
||||
'last_calibration': None
|
||||
}
|
||||
}
|
||||
|
||||
# Initialize timezone
|
||||
timezone_name = self.config.get('system', {}).get('timezone', 'Europe/Sofia')
|
||||
self.timezone = pytz.timezone(timezone_name)
|
||||
@@ -3822,6 +3833,9 @@ class CleanTradingDashboard:
|
||||
if signal['executed']:
|
||||
self._train_all_models_on_signal(signal)
|
||||
|
||||
# Immediate price feedback training (always runs if enabled, regardless of execution)
|
||||
self._immediate_price_feedback_training(signal)
|
||||
|
||||
# Log signal processing
|
||||
status = "EXECUTED" if signal['executed'] else ("BLOCKED" if signal['blocked'] else "PENDING")
|
||||
logger.info(f"[{status}] {signal['action']} signal for {signal['symbol']} "
|
||||
@@ -3830,6 +3844,248 @@ class CleanTradingDashboard:
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing dashboard signal: {e}")
|
||||
|
||||
# immediate price feedback training
|
||||
# ToDo: review/revise
|
||||
def _immediate_price_feedback_training(self, signal: Dict):
|
||||
"""Immediate training fine-tuning based on current price feedback - rewards profitable predictions"""
|
||||
try:
|
||||
# Check if any model training is enabled - immediate training is part of core training
|
||||
training_enabled = (
|
||||
getattr(self, 'dqn_training_enabled', True) or
|
||||
getattr(self, 'cnn_training_enabled', True) or
|
||||
(hasattr(self.orchestrator, 'cob_rl_agent') and self.orchestrator.cob_rl_agent is not None) or
|
||||
(hasattr(self.orchestrator, 'model_manager') and self.orchestrator.model_manager is not None)
|
||||
)
|
||||
|
||||
if not training_enabled:
|
||||
return
|
||||
|
||||
symbol = signal.get('symbol', 'ETH/USDT')
|
||||
signal_price = signal.get('price', 0)
|
||||
predicted_action = signal.get('action', 'HOLD')
|
||||
signal_confidence = signal.get('confidence', 0.5)
|
||||
signal_timestamp = signal.get('timestamp')
|
||||
|
||||
if signal_price == 0 or predicted_action == 'HOLD':
|
||||
return
|
||||
|
||||
# Get current price for immediate feedback
|
||||
current_price = self._get_current_price(symbol)
|
||||
if current_price == 0:
|
||||
return
|
||||
|
||||
# Calculate immediate price movement since signal generation
|
||||
price_change_pct = (current_price - signal_price) / signal_price
|
||||
price_change_abs = abs(price_change_pct)
|
||||
|
||||
# Determine if prediction was correct
|
||||
predicted_direction = 1 if predicted_action == 'BUY' else -1
|
||||
actual_direction = 1 if price_change_pct > 0 else -1
|
||||
prediction_correct = predicted_direction == actual_direction
|
||||
|
||||
# Calculate reward based on prediction accuracy and price movement
|
||||
base_reward = price_change_abs * 1000 # Scale by price movement
|
||||
|
||||
if prediction_correct:
|
||||
# Reward correct predictions
|
||||
reward = base_reward
|
||||
confidence_bonus = signal_confidence * base_reward * 0.5 # Bonus for high confidence correct predictions
|
||||
reward += confidence_bonus
|
||||
else:
|
||||
# Punish incorrect predictions
|
||||
reward = -base_reward
|
||||
confidence_penalty = (1 - signal_confidence) * base_reward * 0.3 # Less penalty for low confidence wrong predictions
|
||||
reward -= confidence_penalty
|
||||
|
||||
# Scale reward by time elapsed (more recent = higher weight)
|
||||
time_elapsed = (datetime.now() - signal_timestamp).total_seconds() if signal_timestamp else 0
|
||||
time_weight = max(0.1, 1.0 - (time_elapsed / 300)) # Decay over 5 minutes
|
||||
final_reward = reward * time_weight
|
||||
|
||||
# Create immediate training data
|
||||
training_data = {
|
||||
'symbol': symbol,
|
||||
'signal_price': signal_price,
|
||||
'current_price': current_price,
|
||||
'price_change_pct': price_change_pct,
|
||||
'predicted_action': predicted_action,
|
||||
'actual_direction': 'UP' if actual_direction > 0 else 'DOWN',
|
||||
'prediction_correct': prediction_correct,
|
||||
'signal_confidence': signal_confidence,
|
||||
'reward': final_reward,
|
||||
'time_elapsed': time_elapsed,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
# Train models immediately with price feedback
|
||||
self._train_models_on_immediate_feedback(signal, training_data, final_reward)
|
||||
|
||||
# Update confidence calibration
|
||||
self._update_confidence_calibration(signal, prediction_correct, price_change_abs)
|
||||
|
||||
logger.debug(f"💰 IMMEDIATE TRAINING: {symbol} {predicted_action} signal - "
|
||||
f"Price: {signal_price:.2f} → {current_price:.2f} ({price_change_pct:+.2%}) - "
|
||||
f"{'✅' if prediction_correct else '❌'} Correct - Reward: {final_reward:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in immediate price feedback training: {e}")
|
||||
|
||||
def _train_models_on_immediate_feedback(self, signal: Dict, training_data: Dict, reward: float):
|
||||
"""Train models immediately on price feedback"""
|
||||
try:
|
||||
symbol = signal.get('symbol', 'ETH/USDT')
|
||||
action = 0 if signal.get('action') == 'BUY' else 1
|
||||
|
||||
# Train COB RL model immediately if COB RL training is enabled
|
||||
if (self.orchestrator and hasattr(self.orchestrator, 'cob_rl_agent') and
|
||||
self.orchestrator.cob_rl_agent and hasattr(self.orchestrator, 'model_manager')):
|
||||
try:
|
||||
# Get COB features for immediate training
|
||||
cob_features = self._get_cob_features_for_training(symbol, signal.get('price', 0))
|
||||
if cob_features:
|
||||
# Store immediate experience
|
||||
if hasattr(self.orchestrator.cob_rl_agent, 'remember'):
|
||||
self.orchestrator.cob_rl_agent.remember(
|
||||
cob_features, action, reward, cob_features, done=False # Not done for immediate feedback
|
||||
)
|
||||
|
||||
# Immediate training if enough samples
|
||||
if hasattr(self.orchestrator.cob_rl_agent, 'memory') and len(self.orchestrator.cob_rl_agent.memory) > 16:
|
||||
if hasattr(self.orchestrator.cob_rl_agent, 'replay'):
|
||||
loss = self.orchestrator.cob_rl_agent.replay(batch_size=8) # Smaller batch for immediate training
|
||||
if loss is not None:
|
||||
logger.debug(f"COB RL immediate training - loss: {loss:.4f}, reward: {reward:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training COB RL on immediate feedback: {e}")
|
||||
|
||||
# Train DQN model immediately if DQN training is enabled
|
||||
if (self.orchestrator and hasattr(self.orchestrator, 'rl_agent') and
|
||||
self.orchestrator.rl_agent and getattr(self, 'dqn_training_enabled', True)):
|
||||
try:
|
||||
# Create immediate DQN experience
|
||||
state = self._get_rl_state_for_training(symbol, signal.get('price', 0))
|
||||
if state:
|
||||
if hasattr(self.orchestrator.rl_agent, 'remember'):
|
||||
self.orchestrator.rl_agent.remember(state, action, reward, state, done=False)
|
||||
|
||||
# Immediate training
|
||||
if hasattr(self.orchestrator.rl_agent, 'replay') and hasattr(self.orchestrator.rl_agent, 'memory'):
|
||||
if len(self.orchestrator.rl_agent.memory) > 16:
|
||||
loss = self.orchestrator.rl_agent.replay(batch_size=8)
|
||||
if loss is not None:
|
||||
logger.debug(f"DQN immediate training - loss: {loss:.4f}, reward: {reward:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training DQN on immediate feedback: {e}")
|
||||
|
||||
# Train CNN model immediately if CNN training is enabled
|
||||
if (self.orchestrator and hasattr(self.orchestrator, 'cnn_model') and
|
||||
self.orchestrator.cnn_model and getattr(self, 'cnn_training_enabled', True)):
|
||||
try:
|
||||
# Create immediate CNN training data
|
||||
cnn_features = self._create_cnn_cob_features(symbol, {
|
||||
'current_snapshot': {'price': signal.get('price', 0), 'imbalance': 0},
|
||||
'history': self.cob_data_history.get(symbol, [])[-10:],
|
||||
'timestamp': datetime.now()
|
||||
})
|
||||
|
||||
if cnn_features:
|
||||
# For CNN, we can update internal training data or use model-specific training
|
||||
if hasattr(self.orchestrator.cnn_model, 'update_training_data'):
|
||||
self.orchestrator.cnn_model.update_training_data(cnn_features, action, reward)
|
||||
|
||||
logger.debug(f"CNN immediate training data updated - action: {action}, reward: {reward:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error training CNN on immediate feedback: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in immediate model training: {e}")
|
||||
|
||||
def _update_confidence_calibration(self, signal: Dict, prediction_correct: bool, price_change_abs: float):
|
||||
"""Update confidence calibration based on prediction accuracy"""
|
||||
try:
|
||||
signal_type = signal.get('type', 'unknown')
|
||||
signal_confidence = signal.get('confidence', 0.5)
|
||||
|
||||
if signal_type not in self.confidence_calibration:
|
||||
return
|
||||
|
||||
calibration = self.confidence_calibration[signal_type]
|
||||
|
||||
# Track total predictions and accuracy
|
||||
calibration['total_predictions'] += 1
|
||||
if prediction_correct:
|
||||
calibration['correct_predictions'] += 1
|
||||
|
||||
# Track accuracy by confidence ranges
|
||||
confidence_range = f"{int(signal_confidence * 10) / 10:.1f}" # 0.0-1.0 in 0.1 increments
|
||||
|
||||
if confidence_range not in calibration['accuracy_by_confidence']:
|
||||
calibration['accuracy_by_confidence'][confidence_range] = {
|
||||
'total': 0,
|
||||
'correct': 0,
|
||||
'avg_price_change': 0.0
|
||||
}
|
||||
|
||||
range_stats = calibration['accuracy_by_confidence'][confidence_range]
|
||||
range_stats['total'] += 1
|
||||
if prediction_correct:
|
||||
range_stats['correct'] += 1
|
||||
range_stats['avg_price_change'] = (
|
||||
(range_stats['avg_price_change'] * (range_stats['total'] - 1)) + price_change_abs
|
||||
) / range_stats['total']
|
||||
|
||||
# Update confidence adjustment every 50 predictions
|
||||
if calibration['total_predictions'] % 50 == 0:
|
||||
self._recalibrate_confidence_levels(signal_type)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error updating confidence calibration: {e}")
|
||||
|
||||
def _recalibrate_confidence_levels(self, signal_type: str):
|
||||
"""Recalibrate confidence levels based on historical performance"""
|
||||
try:
|
||||
calibration = self.confidence_calibration[signal_type]
|
||||
accuracy_by_confidence = calibration['accuracy_by_confidence']
|
||||
|
||||
# Calculate expected vs actual accuracy for each confidence range
|
||||
total_adjustment = 0.0
|
||||
valid_ranges = 0
|
||||
|
||||
for conf_range, stats in accuracy_by_confidence.items():
|
||||
if stats['total'] >= 5: # Need at least 5 predictions for reliable calibration
|
||||
expected_accuracy = float(conf_range) # Confidence should match accuracy
|
||||
actual_accuracy = stats['correct'] / stats['total']
|
||||
adjustment = actual_accuracy / expected_accuracy if expected_accuracy > 0 else 1.0
|
||||
total_adjustment += adjustment
|
||||
valid_ranges += 1
|
||||
|
||||
if valid_ranges > 0:
|
||||
calibration['confidence_adjustment'] = total_adjustment / valid_ranges
|
||||
calibration['last_calibration'] = datetime.now()
|
||||
|
||||
logger.info(f"🔧 CONFIDENCE CALIBRATION: {signal_type} adjustment = {calibration['confidence_adjustment']:.3f} "
|
||||
f"(based on {valid_ranges} confidence ranges)")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error recalibrating confidence levels: {e}")
|
||||
|
||||
def _get_calibrated_confidence(self, signal_type: str, raw_confidence: float) -> float:
|
||||
"""Get calibrated confidence level based on historical performance"""
|
||||
try:
|
||||
if signal_type in self.confidence_calibration:
|
||||
adjustment = self.confidence_calibration[signal_type]['confidence_adjustment']
|
||||
calibrated = raw_confidence * adjustment
|
||||
return max(0.0, min(1.0, calibrated)) # Clamp to [0,1]
|
||||
return raw_confidence
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting calibrated confidence: {e}")
|
||||
return raw_confidence
|
||||
|
||||
# This function is used to train all models on a signal
|
||||
# ToDo: review this function and make sure it is correct
|
||||
def _train_all_models_on_signal(self, signal: Dict):
|
||||
"""Train ALL models on executed trade signal - Comprehensive training system"""
|
||||
try:
|
||||
@@ -5311,7 +5567,10 @@ class CleanTradingDashboard:
|
||||
# Generate signal if imbalance exceeds threshold
|
||||
if abs_imbalance > threshold:
|
||||
# Calculate more realistic confidence (never exactly 1.0)
|
||||
final_confidence = min(0.95, base_confidence + confidence_boost)
|
||||
raw_confidence = min(0.95, base_confidence + confidence_boost)
|
||||
|
||||
# Apply confidence calibration based on historical performance
|
||||
final_confidence = self._get_calibrated_confidence('cob_liquidity_imbalance', raw_confidence)
|
||||
|
||||
signal = {
|
||||
'timestamp': datetime.now(),
|
||||
@@ -5354,6 +5613,7 @@ class CleanTradingDashboard:
|
||||
'history': self.cob_data_history[symbol][-15:], # Last 15 seconds
|
||||
'bucketed_data': self.cob_bucketed_data[symbol],
|
||||
'cumulative_imbalance': cumulative_imbalance,
|
||||
'cob_imbalance_ma': self.cob_imbalance_ma.get(symbol, {}), # ✅ ADD MOVING AVERAGES
|
||||
'timestamp': cob_snapshot['timestamp'],
|
||||
'stats': cob_snapshot.get('stats', {}),
|
||||
'bids': cob_snapshot.get('bids', []),
|
||||
|
@@ -374,18 +374,16 @@ class DashboardComponentManager:
|
||||
html.Div(imbalance_stats_display),
|
||||
|
||||
# COB Imbalance Moving Averages
|
||||
ma_display = []
|
||||
if imbalance_ma_data:
|
||||
ma_display.append(html.H6("Imbalance MAs", className="mt-3 mb-2 small text-muted text-uppercase"))
|
||||
for timeframe, ma_value in imbalance_ma_data.items():
|
||||
ma_color = "text-success" if ma_value > 0 else "text-danger"
|
||||
ma_text = f"MA {timeframe}: {ma_value:.3f}"
|
||||
ma_display.append(html.Div([
|
||||
html.Div([
|
||||
html.H6("Imbalance MAs", className="mt-3 mb-2 small text-muted text-uppercase"),
|
||||
*[
|
||||
html.Div([
|
||||
html.Strong(f"{timeframe}: ", className="small"),
|
||||
html.Span(ma_text, className=f"small {ma_color}")
|
||||
], className="mb-1"))
|
||||
|
||||
html.Div(ma_display),
|
||||
html.Span(f"MA {timeframe}: {ma_value:.3f}", className=f"small {'text-success' if ma_value > 0 else 'text-danger'}")
|
||||
], className="mb-1")
|
||||
for timeframe, ma_value in (imbalance_ma_data or {}).items()
|
||||
]
|
||||
]) if imbalance_ma_data else html.Div(),
|
||||
|
||||
html.Hr(className="my-2"),
|
||||
|
||||
|
Reference in New Issue
Block a user