prediction database
This commit is contained in:
@@ -9,6 +9,7 @@ This system implements effective online learning with:
|
||||
- Continuous validation and adaptation
|
||||
- Multi-timeframe feature engineering
|
||||
- Real market microstructure analysis
|
||||
- PREDICTION TRACKING: Store each prediction and track outcomes
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
@@ -26,16 +27,26 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
# Import prediction tracking
|
||||
from core.prediction_database import get_prediction_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EnhancedRealtimeTrainingSystem:
|
||||
"""Enhanced real-time training system with proper online learning"""
|
||||
"""Enhanced real-time training system with prediction tracking and database storage"""
|
||||
|
||||
def __init__(self, orchestrator, data_provider, dashboard=None):
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = data_provider
|
||||
self.dashboard = dashboard
|
||||
|
||||
# Prediction tracking database
|
||||
self.prediction_db = get_prediction_db()
|
||||
|
||||
# Active predictions waiting for resolution
|
||||
self.active_predictions = {} # {prediction_id: {"timestamp": ..., "price": ..., "model": ...}}
|
||||
self.prediction_resolution_time = 300 # 5 minutes to resolve predictions
|
||||
|
||||
# Training configuration
|
||||
self.training_config = {
|
||||
'dqn_training_interval': 5, # Train DQN every 5 seconds
|
||||
@@ -162,13 +173,185 @@ class EnhancedRealtimeTrainingSystem:
|
||||
validation_thread = threading.Thread(target=self._validation_worker, daemon=True)
|
||||
validation_thread.start()
|
||||
|
||||
logger.info("Enhanced real-time training system started")
|
||||
# Start prediction resolution worker
|
||||
prediction_thread = threading.Thread(target=self._prediction_resolution_worker, daemon=True)
|
||||
prediction_thread.start()
|
||||
|
||||
logger.info("Enhanced real-time training system started with prediction tracking")
|
||||
|
||||
def stop_training(self):
|
||||
"""Stop the training system"""
|
||||
self.is_training = False
|
||||
logger.info("Enhanced real-time training system stopped")
|
||||
|
||||
def store_model_prediction(self, model_name: str, symbol: str, prediction_type: str,
|
||||
confidence: float, current_price: float) -> int:
|
||||
"""Store a model prediction in the database for tracking"""
|
||||
try:
|
||||
prediction_id = self.prediction_db.store_prediction(
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
prediction_type=prediction_type,
|
||||
confidence=confidence,
|
||||
price_at_prediction=current_price
|
||||
)
|
||||
|
||||
# Track active prediction for later resolution
|
||||
self.active_predictions[prediction_id] = {
|
||||
"model_name": model_name,
|
||||
"symbol": symbol,
|
||||
"prediction_type": prediction_type,
|
||||
"confidence": confidence,
|
||||
"timestamp": time.time(),
|
||||
"price_at_prediction": current_price
|
||||
}
|
||||
|
||||
logger.info(f"Stored prediction {prediction_id}: {model_name} -> {prediction_type} for {symbol} (conf: {confidence:.3f})")
|
||||
return prediction_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing prediction: {e}")
|
||||
return -1
|
||||
|
||||
def resolve_predictions(self):
|
||||
"""Resolve active predictions based on price movement"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
resolved_predictions = []
|
||||
|
||||
for prediction_id, pred_data in list(self.active_predictions.items()):
|
||||
# Check if prediction is old enough to resolve
|
||||
age = current_time - pred_data["timestamp"]
|
||||
if age >= self.prediction_resolution_time:
|
||||
|
||||
# Get current price for the symbol
|
||||
symbol = pred_data["symbol"]
|
||||
current_price = self._get_current_price(symbol)
|
||||
|
||||
if current_price > 0:
|
||||
# Calculate price change
|
||||
price_change_pct = (current_price - pred_data["price_at_prediction"]) / pred_data["price_at_prediction"]
|
||||
|
||||
# Calculate reward based on prediction correctness
|
||||
reward = self._calculate_prediction_reward(
|
||||
pred_data["prediction_type"],
|
||||
price_change_pct,
|
||||
pred_data["confidence"]
|
||||
)
|
||||
|
||||
# Resolve the prediction
|
||||
success = self.prediction_db.resolve_prediction(
|
||||
prediction_id=prediction_id,
|
||||
actual_price_change=price_change_pct,
|
||||
reward=reward
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"Resolved prediction {prediction_id}: {pred_data['model_name']} -> "
|
||||
f"price change {price_change_pct:.3f}%, reward {reward:.3f}")
|
||||
resolved_predictions.append(prediction_id)
|
||||
|
||||
# Remove from active predictions
|
||||
del self.active_predictions[prediction_id]
|
||||
|
||||
return len(resolved_predictions)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error resolving predictions: {e}")
|
||||
return 0
|
||||
|
||||
def _get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price for a symbol"""
|
||||
try:
|
||||
# Try to get from data provider
|
||||
if self.data_provider and hasattr(self.data_provider, 'get_latest_data'):
|
||||
latest = self.data_provider.get_latest_data(symbol)
|
||||
if latest and 'close' in latest:
|
||||
return float(latest['close'])
|
||||
|
||||
# Try to get from orchestrator
|
||||
if self.orchestrator and hasattr(self.orchestrator, '_get_current_price'):
|
||||
return float(self.orchestrator._get_current_price(symbol))
|
||||
|
||||
# Fallback values
|
||||
fallback_prices = {'ETH/USDT': 4300.0, 'BTC/USDT': 111000.0}
|
||||
return fallback_prices.get(symbol, 1000.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_prediction_reward(self, prediction_type: str, price_change_pct: float, confidence: float) -> float:
|
||||
"""Calculate reward for a prediction based on outcome"""
|
||||
try:
|
||||
# Base reward calculation
|
||||
if prediction_type == "BUY":
|
||||
base_reward = price_change_pct * 100 # Positive if price went up
|
||||
elif prediction_type == "SELL":
|
||||
base_reward = -price_change_pct * 100 # Positive if price went down
|
||||
elif prediction_type == "HOLD":
|
||||
base_reward = max(0, 1 - abs(price_change_pct) * 100) # Positive if small movement
|
||||
else:
|
||||
base_reward = 0
|
||||
|
||||
# Confidence adjustment - reward high confidence correct predictions more
|
||||
confidence_multiplier = 0.5 + (confidence * 1.5) # Range: 0.5 to 2.0
|
||||
|
||||
# Final reward calculation
|
||||
final_reward = base_reward * confidence_multiplier
|
||||
|
||||
# Normalize to reasonable range [-10, 10]
|
||||
final_reward = max(-10, min(10, final_reward))
|
||||
|
||||
return final_reward
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating prediction reward: {e}")
|
||||
return 0.0
|
||||
|
||||
def get_model_performance_stats(self) -> Dict[str, Any]:
|
||||
"""Get performance statistics for all models"""
|
||||
try:
|
||||
stats = self.prediction_db.get_all_model_stats()
|
||||
|
||||
# Add active predictions count
|
||||
active_by_model = {}
|
||||
for pred_data in self.active_predictions.values():
|
||||
model = pred_data["model_name"]
|
||||
active_by_model[model] = active_by_model.get(model, 0) + 1
|
||||
|
||||
# Enhance stats with active predictions
|
||||
for stat in stats:
|
||||
model_name = stat["model_name"]
|
||||
stat["active_predictions"] = active_by_model.get(model_name, 0)
|
||||
|
||||
return {
|
||||
"models": stats,
|
||||
"total_active_predictions": len(self.active_predictions),
|
||||
"last_updated": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting performance stats: {e}")
|
||||
return {}
|
||||
|
||||
def _prediction_resolution_worker(self):
|
||||
"""Worker thread to resolve active predictions"""
|
||||
while self.is_training:
|
||||
try:
|
||||
# Resolve predictions every 30 seconds
|
||||
resolved_count = self.resolve_predictions()
|
||||
if resolved_count > 0:
|
||||
logger.info(f"Resolved {resolved_count} predictions")
|
||||
|
||||
time.sleep(30)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prediction resolution worker: {e}")
|
||||
time.sleep(60)
|
||||
|
||||
|
||||
|
||||
def _data_collection_worker(self):
|
||||
"""Collect and preprocess real-time market data"""
|
||||
while self.is_training:
|
||||
|
Reference in New Issue
Block a user