prediction database

This commit is contained in:
Dobromir Popov
2025-09-02 19:25:42 +03:00
parent 226a6aa047
commit fe6763c4ba
5 changed files with 523 additions and 8 deletions

View File

@@ -9,6 +9,7 @@ This system implements effective online learning with:
- Continuous validation and adaptation
- Multi-timeframe feature engineering
- Real market microstructure analysis
- PREDICTION TRACKING: Store each prediction and track outcomes
"""
import numpy as np
@@ -26,16 +27,26 @@ import torch
import torch.nn as nn
import torch.optim as optim
# Import prediction tracking
from core.prediction_database import get_prediction_db
logger = logging.getLogger(__name__)
class EnhancedRealtimeTrainingSystem:
"""Enhanced real-time training system with proper online learning"""
"""Enhanced real-time training system with prediction tracking and database storage"""
def __init__(self, orchestrator, data_provider, dashboard=None):
self.orchestrator = orchestrator
self.data_provider = data_provider
self.dashboard = dashboard
# Prediction tracking database
self.prediction_db = get_prediction_db()
# Active predictions waiting for resolution
self.active_predictions = {} # {prediction_id: {"timestamp": ..., "price": ..., "model": ...}}
self.prediction_resolution_time = 300 # 5 minutes to resolve predictions
# Training configuration
self.training_config = {
'dqn_training_interval': 5, # Train DQN every 5 seconds
@@ -162,13 +173,185 @@ class EnhancedRealtimeTrainingSystem:
validation_thread = threading.Thread(target=self._validation_worker, daemon=True)
validation_thread.start()
logger.info("Enhanced real-time training system started")
# Start prediction resolution worker
prediction_thread = threading.Thread(target=self._prediction_resolution_worker, daemon=True)
prediction_thread.start()
logger.info("Enhanced real-time training system started with prediction tracking")
def stop_training(self):
"""Stop the training system"""
self.is_training = False
logger.info("Enhanced real-time training system stopped")
def store_model_prediction(self, model_name: str, symbol: str, prediction_type: str,
confidence: float, current_price: float) -> int:
"""Store a model prediction in the database for tracking"""
try:
prediction_id = self.prediction_db.store_prediction(
model_name=model_name,
symbol=symbol,
prediction_type=prediction_type,
confidence=confidence,
price_at_prediction=current_price
)
# Track active prediction for later resolution
self.active_predictions[prediction_id] = {
"model_name": model_name,
"symbol": symbol,
"prediction_type": prediction_type,
"confidence": confidence,
"timestamp": time.time(),
"price_at_prediction": current_price
}
logger.info(f"Stored prediction {prediction_id}: {model_name} -> {prediction_type} for {symbol} (conf: {confidence:.3f})")
return prediction_id
except Exception as e:
logger.error(f"Error storing prediction: {e}")
return -1
def resolve_predictions(self):
"""Resolve active predictions based on price movement"""
try:
current_time = time.time()
resolved_predictions = []
for prediction_id, pred_data in list(self.active_predictions.items()):
# Check if prediction is old enough to resolve
age = current_time - pred_data["timestamp"]
if age >= self.prediction_resolution_time:
# Get current price for the symbol
symbol = pred_data["symbol"]
current_price = self._get_current_price(symbol)
if current_price > 0:
# Calculate price change
price_change_pct = (current_price - pred_data["price_at_prediction"]) / pred_data["price_at_prediction"]
# Calculate reward based on prediction correctness
reward = self._calculate_prediction_reward(
pred_data["prediction_type"],
price_change_pct,
pred_data["confidence"]
)
# Resolve the prediction
success = self.prediction_db.resolve_prediction(
prediction_id=prediction_id,
actual_price_change=price_change_pct,
reward=reward
)
if success:
logger.info(f"Resolved prediction {prediction_id}: {pred_data['model_name']} -> "
f"price change {price_change_pct:.3f}%, reward {reward:.3f}")
resolved_predictions.append(prediction_id)
# Remove from active predictions
del self.active_predictions[prediction_id]
return len(resolved_predictions)
except Exception as e:
logger.error(f"Error resolving predictions: {e}")
return 0
def _get_current_price(self, symbol: str) -> float:
"""Get current price for a symbol"""
try:
# Try to get from data provider
if self.data_provider and hasattr(self.data_provider, 'get_latest_data'):
latest = self.data_provider.get_latest_data(symbol)
if latest and 'close' in latest:
return float(latest['close'])
# Try to get from orchestrator
if self.orchestrator and hasattr(self.orchestrator, '_get_current_price'):
return float(self.orchestrator._get_current_price(symbol))
# Fallback values
fallback_prices = {'ETH/USDT': 4300.0, 'BTC/USDT': 111000.0}
return fallback_prices.get(symbol, 1000.0)
except Exception as e:
logger.debug(f"Error getting current price for {symbol}: {e}")
return 0.0
def _calculate_prediction_reward(self, prediction_type: str, price_change_pct: float, confidence: float) -> float:
"""Calculate reward for a prediction based on outcome"""
try:
# Base reward calculation
if prediction_type == "BUY":
base_reward = price_change_pct * 100 # Positive if price went up
elif prediction_type == "SELL":
base_reward = -price_change_pct * 100 # Positive if price went down
elif prediction_type == "HOLD":
base_reward = max(0, 1 - abs(price_change_pct) * 100) # Positive if small movement
else:
base_reward = 0
# Confidence adjustment - reward high confidence correct predictions more
confidence_multiplier = 0.5 + (confidence * 1.5) # Range: 0.5 to 2.0
# Final reward calculation
final_reward = base_reward * confidence_multiplier
# Normalize to reasonable range [-10, 10]
final_reward = max(-10, min(10, final_reward))
return final_reward
except Exception as e:
logger.error(f"Error calculating prediction reward: {e}")
return 0.0
def get_model_performance_stats(self) -> Dict[str, Any]:
"""Get performance statistics for all models"""
try:
stats = self.prediction_db.get_all_model_stats()
# Add active predictions count
active_by_model = {}
for pred_data in self.active_predictions.values():
model = pred_data["model_name"]
active_by_model[model] = active_by_model.get(model, 0) + 1
# Enhance stats with active predictions
for stat in stats:
model_name = stat["model_name"]
stat["active_predictions"] = active_by_model.get(model_name, 0)
return {
"models": stats,
"total_active_predictions": len(self.active_predictions),
"last_updated": datetime.now().isoformat()
}
except Exception as e:
logger.error(f"Error getting performance stats: {e}")
return {}
def _prediction_resolution_worker(self):
"""Worker thread to resolve active predictions"""
while self.is_training:
try:
# Resolve predictions every 30 seconds
resolved_count = self.resolve_predictions()
if resolved_count > 0:
logger.info(f"Resolved {resolved_count} predictions")
time.sleep(30)
except Exception as e:
logger.error(f"Error in prediction resolution worker: {e}")
time.sleep(60)
def _data_collection_worker(self):
"""Collect and preprocess real-time market data"""
while self.is_training: