prediction database
This commit is contained in:
205
core/prediction_database.py
Normal file
205
core/prediction_database.py
Normal file
@@ -0,0 +1,205 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Prediction Database - Simple SQLite database for tracking model predictions
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import logging
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PredictionDatabase:
|
||||
"""Simple database for tracking model predictions and outcomes"""
|
||||
|
||||
def __init__(self, db_path: str = "data/predictions.db"):
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._initialize_database()
|
||||
logger.info(f"PredictionDatabase initialized: {self.db_path}")
|
||||
|
||||
def _initialize_database(self):
|
||||
"""Initialize SQLite database"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Predictions table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS predictions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
prediction_type TEXT NOT NULL,
|
||||
confidence REAL NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
price_at_prediction REAL NOT NULL,
|
||||
|
||||
-- Outcome fields
|
||||
outcome_timestamp TEXT,
|
||||
actual_price_change REAL,
|
||||
reward REAL,
|
||||
is_correct INTEGER,
|
||||
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Performance summary table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS model_performance (
|
||||
model_name TEXT PRIMARY KEY,
|
||||
total_predictions INTEGER DEFAULT 0,
|
||||
correct_predictions INTEGER DEFAULT 0,
|
||||
total_reward REAL DEFAULT 0.0,
|
||||
last_updated TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
|
||||
def store_prediction(self, model_name: str, symbol: str, prediction_type: str,
|
||||
confidence: float, price_at_prediction: float) -> int:
|
||||
"""Store a new prediction"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO predictions (
|
||||
model_name, symbol, prediction_type, confidence,
|
||||
timestamp, price_at_prediction
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (model_name, symbol, prediction_type, confidence,
|
||||
timestamp, price_at_prediction))
|
||||
|
||||
prediction_id = cursor.lastrowid
|
||||
|
||||
# Update performance count
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO model_performance (
|
||||
model_name, total_predictions, correct_predictions, total_reward, last_updated
|
||||
) VALUES (
|
||||
?,
|
||||
COALESCE((SELECT total_predictions FROM model_performance WHERE model_name = ?), 0) + 1,
|
||||
COALESCE((SELECT correct_predictions FROM model_performance WHERE model_name = ?), 0),
|
||||
COALESCE((SELECT total_reward FROM model_performance WHERE model_name = ?), 0.0),
|
||||
?
|
||||
)
|
||||
""", (model_name, model_name, model_name, model_name, timestamp))
|
||||
|
||||
conn.commit()
|
||||
return prediction_id
|
||||
|
||||
def resolve_prediction(self, prediction_id: int, actual_price_change: float, reward: float) -> bool:
|
||||
"""Resolve a prediction with outcome"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get original prediction
|
||||
cursor.execute("""
|
||||
SELECT model_name, prediction_type FROM predictions
|
||||
WHERE id = ? AND outcome_timestamp IS NULL
|
||||
""", (prediction_id,))
|
||||
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
return False
|
||||
|
||||
model_name, prediction_type = result
|
||||
|
||||
# Determine correctness
|
||||
is_correct = self._is_prediction_correct(prediction_type, actual_price_change)
|
||||
|
||||
# Update prediction
|
||||
outcome_timestamp = datetime.now().isoformat()
|
||||
cursor.execute("""
|
||||
UPDATE predictions SET
|
||||
outcome_timestamp = ?, actual_price_change = ?,
|
||||
reward = ?, is_correct = ?
|
||||
WHERE id = ?
|
||||
""", (outcome_timestamp, actual_price_change, reward, int(is_correct), prediction_id))
|
||||
|
||||
# Update performance
|
||||
cursor.execute("""
|
||||
UPDATE model_performance SET
|
||||
correct_predictions = correct_predictions + ?,
|
||||
total_reward = total_reward + ?,
|
||||
last_updated = ?
|
||||
WHERE model_name = ?
|
||||
""", (int(is_correct), reward, outcome_timestamp, model_name))
|
||||
|
||||
conn.commit()
|
||||
return True
|
||||
|
||||
def _is_prediction_correct(self, prediction_type: str, price_change: float) -> bool:
|
||||
"""Check if prediction was correct"""
|
||||
if prediction_type == "BUY":
|
||||
return price_change > 0
|
||||
elif prediction_type == "SELL":
|
||||
return price_change < 0
|
||||
elif prediction_type == "HOLD":
|
||||
return abs(price_change) < 0.001
|
||||
return False
|
||||
|
||||
def get_model_stats(self, model_name: str) -> Dict[str, Any]:
|
||||
"""Get model performance statistics"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT total_predictions, correct_predictions, total_reward
|
||||
FROM model_performance WHERE model_name = ?
|
||||
""", (model_name,))
|
||||
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
return {"model_name": model_name, "total_predictions": 0, "accuracy": 0.0, "total_reward": 0.0}
|
||||
|
||||
total, correct, reward = result
|
||||
accuracy = (correct / total) if total > 0 else 0.0
|
||||
|
||||
return {
|
||||
"model_name": model_name,
|
||||
"total_predictions": total,
|
||||
"correct_predictions": correct,
|
||||
"accuracy": accuracy,
|
||||
"total_reward": reward
|
||||
}
|
||||
|
||||
def get_all_model_stats(self) -> List[Dict[str, Any]]:
|
||||
"""Get stats for all models"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT model_name, total_predictions, correct_predictions, total_reward
|
||||
FROM model_performance ORDER BY total_predictions DESC
|
||||
""")
|
||||
|
||||
stats = []
|
||||
for row in cursor.fetchall():
|
||||
model_name, total, correct, reward = row
|
||||
accuracy = (correct / total) if total > 0 else 0.0
|
||||
stats.append({
|
||||
"model_name": model_name,
|
||||
"total_predictions": total,
|
||||
"correct_predictions": correct,
|
||||
"accuracy": accuracy,
|
||||
"total_reward": reward
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
# Global instance
|
||||
_prediction_db = None
|
||||
|
||||
def get_prediction_db() -> PredictionDatabase:
|
||||
"""Get global prediction database"""
|
||||
global _prediction_db
|
||||
if _prediction_db is None:
|
||||
_prediction_db = PredictionDatabase()
|
||||
return _prediction_db
|
Reference in New Issue
Block a user