prediction database

This commit is contained in:
Dobromir Popov
2025-09-02 19:25:42 +03:00
parent 226a6aa047
commit fe6763c4ba
5 changed files with 523 additions and 8 deletions

View File

@@ -9,6 +9,7 @@ This system implements effective online learning with:
- Continuous validation and adaptation
- Multi-timeframe feature engineering
- Real market microstructure analysis
- PREDICTION TRACKING: Store each prediction and track outcomes
"""
import numpy as np
@@ -26,16 +27,26 @@ import torch
import torch.nn as nn
import torch.optim as optim
# Import prediction tracking
from core.prediction_database import get_prediction_db
logger = logging.getLogger(__name__)
class EnhancedRealtimeTrainingSystem:
"""Enhanced real-time training system with proper online learning"""
"""Enhanced real-time training system with prediction tracking and database storage"""
def __init__(self, orchestrator, data_provider, dashboard=None):
self.orchestrator = orchestrator
self.data_provider = data_provider
self.dashboard = dashboard
# Prediction tracking database
self.prediction_db = get_prediction_db()
# Active predictions waiting for resolution
self.active_predictions = {} # {prediction_id: {"timestamp": ..., "price": ..., "model": ...}}
self.prediction_resolution_time = 300 # 5 minutes to resolve predictions
# Training configuration
self.training_config = {
'dqn_training_interval': 5, # Train DQN every 5 seconds
@@ -162,13 +173,185 @@ class EnhancedRealtimeTrainingSystem:
validation_thread = threading.Thread(target=self._validation_worker, daemon=True)
validation_thread.start()
logger.info("Enhanced real-time training system started")
# Start prediction resolution worker
prediction_thread = threading.Thread(target=self._prediction_resolution_worker, daemon=True)
prediction_thread.start()
logger.info("Enhanced real-time training system started with prediction tracking")
def stop_training(self):
"""Stop the training system"""
self.is_training = False
logger.info("Enhanced real-time training system stopped")
def store_model_prediction(self, model_name: str, symbol: str, prediction_type: str,
confidence: float, current_price: float) -> int:
"""Store a model prediction in the database for tracking"""
try:
prediction_id = self.prediction_db.store_prediction(
model_name=model_name,
symbol=symbol,
prediction_type=prediction_type,
confidence=confidence,
price_at_prediction=current_price
)
# Track active prediction for later resolution
self.active_predictions[prediction_id] = {
"model_name": model_name,
"symbol": symbol,
"prediction_type": prediction_type,
"confidence": confidence,
"timestamp": time.time(),
"price_at_prediction": current_price
}
logger.info(f"Stored prediction {prediction_id}: {model_name} -> {prediction_type} for {symbol} (conf: {confidence:.3f})")
return prediction_id
except Exception as e:
logger.error(f"Error storing prediction: {e}")
return -1
def resolve_predictions(self):
"""Resolve active predictions based on price movement"""
try:
current_time = time.time()
resolved_predictions = []
for prediction_id, pred_data in list(self.active_predictions.items()):
# Check if prediction is old enough to resolve
age = current_time - pred_data["timestamp"]
if age >= self.prediction_resolution_time:
# Get current price for the symbol
symbol = pred_data["symbol"]
current_price = self._get_current_price(symbol)
if current_price > 0:
# Calculate price change
price_change_pct = (current_price - pred_data["price_at_prediction"]) / pred_data["price_at_prediction"]
# Calculate reward based on prediction correctness
reward = self._calculate_prediction_reward(
pred_data["prediction_type"],
price_change_pct,
pred_data["confidence"]
)
# Resolve the prediction
success = self.prediction_db.resolve_prediction(
prediction_id=prediction_id,
actual_price_change=price_change_pct,
reward=reward
)
if success:
logger.info(f"Resolved prediction {prediction_id}: {pred_data['model_name']} -> "
f"price change {price_change_pct:.3f}%, reward {reward:.3f}")
resolved_predictions.append(prediction_id)
# Remove from active predictions
del self.active_predictions[prediction_id]
return len(resolved_predictions)
except Exception as e:
logger.error(f"Error resolving predictions: {e}")
return 0
def _get_current_price(self, symbol: str) -> float:
"""Get current price for a symbol"""
try:
# Try to get from data provider
if self.data_provider and hasattr(self.data_provider, 'get_latest_data'):
latest = self.data_provider.get_latest_data(symbol)
if latest and 'close' in latest:
return float(latest['close'])
# Try to get from orchestrator
if self.orchestrator and hasattr(self.orchestrator, '_get_current_price'):
return float(self.orchestrator._get_current_price(symbol))
# Fallback values
fallback_prices = {'ETH/USDT': 4300.0, 'BTC/USDT': 111000.0}
return fallback_prices.get(symbol, 1000.0)
except Exception as e:
logger.debug(f"Error getting current price for {symbol}: {e}")
return 0.0
def _calculate_prediction_reward(self, prediction_type: str, price_change_pct: float, confidence: float) -> float:
"""Calculate reward for a prediction based on outcome"""
try:
# Base reward calculation
if prediction_type == "BUY":
base_reward = price_change_pct * 100 # Positive if price went up
elif prediction_type == "SELL":
base_reward = -price_change_pct * 100 # Positive if price went down
elif prediction_type == "HOLD":
base_reward = max(0, 1 - abs(price_change_pct) * 100) # Positive if small movement
else:
base_reward = 0
# Confidence adjustment - reward high confidence correct predictions more
confidence_multiplier = 0.5 + (confidence * 1.5) # Range: 0.5 to 2.0
# Final reward calculation
final_reward = base_reward * confidence_multiplier
# Normalize to reasonable range [-10, 10]
final_reward = max(-10, min(10, final_reward))
return final_reward
except Exception as e:
logger.error(f"Error calculating prediction reward: {e}")
return 0.0
def get_model_performance_stats(self) -> Dict[str, Any]:
"""Get performance statistics for all models"""
try:
stats = self.prediction_db.get_all_model_stats()
# Add active predictions count
active_by_model = {}
for pred_data in self.active_predictions.values():
model = pred_data["model_name"]
active_by_model[model] = active_by_model.get(model, 0) + 1
# Enhance stats with active predictions
for stat in stats:
model_name = stat["model_name"]
stat["active_predictions"] = active_by_model.get(model_name, 0)
return {
"models": stats,
"total_active_predictions": len(self.active_predictions),
"last_updated": datetime.now().isoformat()
}
except Exception as e:
logger.error(f"Error getting performance stats: {e}")
return {}
def _prediction_resolution_worker(self):
"""Worker thread to resolve active predictions"""
while self.is_training:
try:
# Resolve predictions every 30 seconds
resolved_count = self.resolve_predictions()
if resolved_count > 0:
logger.info(f"Resolved {resolved_count} predictions")
time.sleep(30)
except Exception as e:
logger.error(f"Error in prediction resolution worker: {e}")
time.sleep(60)
def _data_collection_worker(self):
"""Collect and preprocess real-time market data"""
while self.is_training:

View File

@@ -1112,11 +1112,76 @@ class TradingOrchestrator:
return predictions
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
"""Get predictions from CNN model for all timeframes with enhanced COB features"""
"""Get CNN predictions for multiple timeframes"""
predictions = []
try:
for timeframe in self.config.timeframes:
# Get predictions for different timeframes
timeframes = ['1m', '5m', '1h']
for timeframe in timeframes:
try:
# Get features from data provider
features = self.data_provider.get_cnn_features_for_inference(symbol, timeframe, window_size=60)
if features is not None and len(features) > 0:
# Get prediction from model
prediction_result = await model.predict(features)
if prediction_result:
prediction = Prediction(
model_name=f"CNN_{timeframe}",
symbol=symbol,
signal=prediction_result.get('signal', 'HOLD'),
confidence=prediction_result.get('confidence', 0.0),
reasoning=f"CNN {timeframe} prediction",
features=features[:10].tolist() if len(features) > 10 else features.tolist(),
metadata={'timeframe': timeframe}
)
predictions.append(prediction)
# Store prediction in database for tracking
if (hasattr(self, 'enhanced_training_system') and
self.enhanced_training_system and
hasattr(self.enhanced_training_system, 'store_model_prediction')):
current_price = self._get_current_price_safe(symbol)
if current_price > 0:
prediction_id = self.enhanced_training_system.store_model_prediction(
model_name=f"CNN_{timeframe}",
symbol=symbol,
prediction_type=prediction.signal,
confidence=prediction.confidence,
current_price=current_price
)
logger.debug(f"Stored CNN prediction {prediction_id} for {symbol} {timeframe}")
except Exception as e:
logger.debug(f"Error getting CNN prediction for {symbol} {timeframe}: {e}")
continue
except Exception as e:
logger.error(f"Error in CNN predictions for {symbol}: {e}")
return predictions
def _get_current_price_safe(self, symbol: str) -> float:
"""Safely get current price for a symbol"""
try:
# Try to get from data provider
if hasattr(self.data_provider, 'get_latest_data'):
latest = self.data_provider.get_latest_data(symbol)
if latest and 'close' in latest:
return float(latest['close'])
# Fallback values
fallback_prices = {'ETH/USDT': 4300.0, 'BTC/USDT': 111000.0}
return fallback_prices.get(symbol, 1000.0)
except Exception as e:
logger.debug(f"Error getting current price for {symbol}: {e}")
return 0.0
# Get standard feature matrix for this timeframe
feature_matrix = self.data_provider.get_feature_matrix(
symbol=symbol,
@@ -1259,11 +1324,58 @@ class TradingOrchestrator:
action_idx, confidence = result
raw_q_values = None
else:
logger.error(f"Unexpected return format from act_with_confidence: {len(result)} values")
logger.warning(f"Unexpected result format from RL model: {result}")
return None
elif hasattr(model.model, 'act'):
action_idx = model.model.act(state, explore=False)
confidence = 0.7 # Default confidence for basic act method
else:
# Fallback to standard act method
action_idx = model.model.act(state)
confidence = 0.6 # Default confidence
raw_q_values = None
# Convert action index to action name
action_names = ['BUY', 'SELL', 'HOLD']
if 0 <= action_idx < len(action_names):
action = action_names[action_idx]
else:
logger.warning(f"Invalid action index from RL model: {action_idx}")
return None
# Store prediction in database for tracking
if (hasattr(self, 'enhanced_training_system') and
self.enhanced_training_system and
hasattr(self.enhanced_training_system, 'store_model_prediction')):
current_price = self._get_current_price_safe(symbol)
if current_price > 0:
prediction_id = self.enhanced_training_system.store_model_prediction(
model_name=f"DQN_{model.model_name}" if hasattr(model, 'model_name') else "DQN",
symbol=symbol,
prediction_type=action,
confidence=confidence,
current_price=current_price
)
logger.debug(f"Stored DQN prediction {prediction_id} for {symbol}")
# Create prediction object
prediction = Prediction(
model_name=f"DQN_{model.model_name}" if hasattr(model, 'model_name') else "DQN",
symbol=symbol,
signal=action,
confidence=confidence,
reasoning=f"DQN agent prediction with Q-values: {raw_q_values}",
features=state.tolist() if isinstance(state, np.ndarray) else [],
metadata={
'action_idx': action_idx,
'q_values': raw_q_values.tolist() if raw_q_values is not None else None,
'state_size': len(state) if state is not None else 0
}
)
return prediction
except Exception as e:
logger.error(f"Error getting RL prediction for {symbol}: {e}")
return None
raw_q_values = None # No raw q_values from simple act
else:
logger.error(f"RL model {model.name} has no act method")

205
core/prediction_database.py Normal file
View File

@@ -0,0 +1,205 @@
#!/usr/bin/env python3
"""
Prediction Database - Simple SQLite database for tracking model predictions
"""
import sqlite3
import logging
import json
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
from pathlib import Path
logger = logging.getLogger(__name__)
class PredictionDatabase:
"""Simple database for tracking model predictions and outcomes"""
def __init__(self, db_path: str = "data/predictions.db"):
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._initialize_database()
logger.info(f"PredictionDatabase initialized: {self.db_path}")
def _initialize_database(self):
"""Initialize SQLite database"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Predictions table
cursor.execute("""
CREATE TABLE IF NOT EXISTS predictions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_name TEXT NOT NULL,
symbol TEXT NOT NULL,
prediction_type TEXT NOT NULL,
confidence REAL NOT NULL,
timestamp TEXT NOT NULL,
price_at_prediction REAL NOT NULL,
-- Outcome fields
outcome_timestamp TEXT,
actual_price_change REAL,
reward REAL,
is_correct INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Performance summary table
cursor.execute("""
CREATE TABLE IF NOT EXISTS model_performance (
model_name TEXT PRIMARY KEY,
total_predictions INTEGER DEFAULT 0,
correct_predictions INTEGER DEFAULT 0,
total_reward REAL DEFAULT 0.0,
last_updated TEXT
)
""")
conn.commit()
def store_prediction(self, model_name: str, symbol: str, prediction_type: str,
confidence: float, price_at_prediction: float) -> int:
"""Store a new prediction"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
timestamp = datetime.now().isoformat()
cursor.execute("""
INSERT INTO predictions (
model_name, symbol, prediction_type, confidence,
timestamp, price_at_prediction
) VALUES (?, ?, ?, ?, ?, ?)
""", (model_name, symbol, prediction_type, confidence,
timestamp, price_at_prediction))
prediction_id = cursor.lastrowid
# Update performance count
cursor.execute("""
INSERT OR REPLACE INTO model_performance (
model_name, total_predictions, correct_predictions, total_reward, last_updated
) VALUES (
?,
COALESCE((SELECT total_predictions FROM model_performance WHERE model_name = ?), 0) + 1,
COALESCE((SELECT correct_predictions FROM model_performance WHERE model_name = ?), 0),
COALESCE((SELECT total_reward FROM model_performance WHERE model_name = ?), 0.0),
?
)
""", (model_name, model_name, model_name, model_name, timestamp))
conn.commit()
return prediction_id
def resolve_prediction(self, prediction_id: int, actual_price_change: float, reward: float) -> bool:
"""Resolve a prediction with outcome"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Get original prediction
cursor.execute("""
SELECT model_name, prediction_type FROM predictions
WHERE id = ? AND outcome_timestamp IS NULL
""", (prediction_id,))
result = cursor.fetchone()
if not result:
return False
model_name, prediction_type = result
# Determine correctness
is_correct = self._is_prediction_correct(prediction_type, actual_price_change)
# Update prediction
outcome_timestamp = datetime.now().isoformat()
cursor.execute("""
UPDATE predictions SET
outcome_timestamp = ?, actual_price_change = ?,
reward = ?, is_correct = ?
WHERE id = ?
""", (outcome_timestamp, actual_price_change, reward, int(is_correct), prediction_id))
# Update performance
cursor.execute("""
UPDATE model_performance SET
correct_predictions = correct_predictions + ?,
total_reward = total_reward + ?,
last_updated = ?
WHERE model_name = ?
""", (int(is_correct), reward, outcome_timestamp, model_name))
conn.commit()
return True
def _is_prediction_correct(self, prediction_type: str, price_change: float) -> bool:
"""Check if prediction was correct"""
if prediction_type == "BUY":
return price_change > 0
elif prediction_type == "SELL":
return price_change < 0
elif prediction_type == "HOLD":
return abs(price_change) < 0.001
return False
def get_model_stats(self, model_name: str) -> Dict[str, Any]:
"""Get model performance statistics"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT total_predictions, correct_predictions, total_reward
FROM model_performance WHERE model_name = ?
""", (model_name,))
result = cursor.fetchone()
if not result:
return {"model_name": model_name, "total_predictions": 0, "accuracy": 0.0, "total_reward": 0.0}
total, correct, reward = result
accuracy = (correct / total) if total > 0 else 0.0
return {
"model_name": model_name,
"total_predictions": total,
"correct_predictions": correct,
"accuracy": accuracy,
"total_reward": reward
}
def get_all_model_stats(self) -> List[Dict[str, Any]]:
"""Get stats for all models"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT model_name, total_predictions, correct_predictions, total_reward
FROM model_performance ORDER BY total_predictions DESC
""")
stats = []
for row in cursor.fetchall():
model_name, total, correct, reward = row
accuracy = (correct / total) if total > 0 else 0.0
stats.append({
"model_name": model_name,
"total_predictions": total,
"correct_predictions": correct,
"accuracy": accuracy,
"total_reward": reward
})
return stats
# Global instance
_prediction_db = None
def get_prediction_db() -> PredictionDatabase:
"""Get global prediction database"""
global _prediction_db
if _prediction_db is None:
_prediction_db = PredictionDatabase()
return _prediction_db

BIN
data/predictions.db Normal file

Binary file not shown.

View File

@@ -341,8 +341,23 @@ class CleanTradingDashboard:
'status': 'healthy',
'dashboard_running': True,
'orchestrator_active': hasattr(self, 'orchestrator'),
'enhanced_training_active': hasattr(self.orchestrator, 'enhanced_training_system') and self.orchestrator.enhanced_training_system is not None,
'timestamp': datetime.now().isoformat()
})
@self.app.server.route('/api/predictions/stats', methods=['GET'])
def get_prediction_stats():
"""Get model prediction statistics"""
try:
if (hasattr(self.orchestrator, 'enhanced_training_system') and
self.orchestrator.enhanced_training_system):
stats = self.orchestrator.enhanced_training_system.get_model_performance_stats()
return jsonify(stats)
else:
return jsonify({"error": "Training system not available"}), 503
except Exception as e:
logger.error(f"Error getting prediction stats: {e}")
return jsonify({"error": str(e)}), 500
def _get_ohlcv_data_with_indicators(self, symbol: str, timeframe: str, limit: int = 300):
"""Get OHLCV data with technical indicators from data stream monitor"""