206 lines
7.6 KiB
Python
206 lines
7.6 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Prediction Database - Simple SQLite database for tracking model predictions
|
|
"""
|
|
|
|
import sqlite3
|
|
import logging
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Any, Optional
|
|
from pathlib import Path
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class PredictionDatabase:
|
|
"""Simple database for tracking model predictions and outcomes"""
|
|
|
|
def __init__(self, db_path: str = "data/predictions.db"):
|
|
self.db_path = Path(db_path)
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._initialize_database()
|
|
logger.info(f"PredictionDatabase initialized: {self.db_path}")
|
|
|
|
def _initialize_database(self):
|
|
"""Initialize SQLite database"""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
# Predictions table
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS predictions (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
model_name TEXT NOT NULL,
|
|
symbol TEXT NOT NULL,
|
|
prediction_type TEXT NOT NULL,
|
|
confidence REAL NOT NULL,
|
|
timestamp TEXT NOT NULL,
|
|
price_at_prediction REAL NOT NULL,
|
|
|
|
-- Outcome fields
|
|
outcome_timestamp TEXT,
|
|
actual_price_change REAL,
|
|
reward REAL,
|
|
is_correct INTEGER,
|
|
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
|
|
# Performance summary table
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS model_performance (
|
|
model_name TEXT PRIMARY KEY,
|
|
total_predictions INTEGER DEFAULT 0,
|
|
correct_predictions INTEGER DEFAULT 0,
|
|
total_reward REAL DEFAULT 0.0,
|
|
last_updated TEXT
|
|
)
|
|
""")
|
|
|
|
conn.commit()
|
|
|
|
def store_prediction(self, model_name: str, symbol: str, prediction_type: str,
|
|
confidence: float, price_at_prediction: float) -> int:
|
|
"""Store a new prediction"""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
timestamp = datetime.now().isoformat()
|
|
|
|
cursor.execute("""
|
|
INSERT INTO predictions (
|
|
model_name, symbol, prediction_type, confidence,
|
|
timestamp, price_at_prediction
|
|
) VALUES (?, ?, ?, ?, ?, ?)
|
|
""", (model_name, symbol, prediction_type, confidence,
|
|
timestamp, price_at_prediction))
|
|
|
|
prediction_id = cursor.lastrowid
|
|
|
|
# Update performance count
|
|
cursor.execute("""
|
|
INSERT OR REPLACE INTO model_performance (
|
|
model_name, total_predictions, correct_predictions, total_reward, last_updated
|
|
) VALUES (
|
|
?,
|
|
COALESCE((SELECT total_predictions FROM model_performance WHERE model_name = ?), 0) + 1,
|
|
COALESCE((SELECT correct_predictions FROM model_performance WHERE model_name = ?), 0),
|
|
COALESCE((SELECT total_reward FROM model_performance WHERE model_name = ?), 0.0),
|
|
?
|
|
)
|
|
""", (model_name, model_name, model_name, model_name, timestamp))
|
|
|
|
conn.commit()
|
|
return prediction_id
|
|
|
|
def resolve_prediction(self, prediction_id: int, actual_price_change: float, reward: float) -> bool:
|
|
"""Resolve a prediction with outcome"""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
# Get original prediction
|
|
cursor.execute("""
|
|
SELECT model_name, prediction_type FROM predictions
|
|
WHERE id = ? AND outcome_timestamp IS NULL
|
|
""", (prediction_id,))
|
|
|
|
result = cursor.fetchone()
|
|
if not result:
|
|
return False
|
|
|
|
model_name, prediction_type = result
|
|
|
|
# Determine correctness
|
|
is_correct = self._is_prediction_correct(prediction_type, actual_price_change)
|
|
|
|
# Update prediction
|
|
outcome_timestamp = datetime.now().isoformat()
|
|
cursor.execute("""
|
|
UPDATE predictions SET
|
|
outcome_timestamp = ?, actual_price_change = ?,
|
|
reward = ?, is_correct = ?
|
|
WHERE id = ?
|
|
""", (outcome_timestamp, actual_price_change, reward, int(is_correct), prediction_id))
|
|
|
|
# Update performance
|
|
cursor.execute("""
|
|
UPDATE model_performance SET
|
|
correct_predictions = correct_predictions + ?,
|
|
total_reward = total_reward + ?,
|
|
last_updated = ?
|
|
WHERE model_name = ?
|
|
""", (int(is_correct), reward, outcome_timestamp, model_name))
|
|
|
|
conn.commit()
|
|
return True
|
|
|
|
def _is_prediction_correct(self, prediction_type: str, price_change: float) -> bool:
|
|
"""Check if prediction was correct"""
|
|
if prediction_type == "BUY":
|
|
return price_change > 0
|
|
elif prediction_type == "SELL":
|
|
return price_change < 0
|
|
elif prediction_type == "HOLD":
|
|
return abs(price_change) < 0.001
|
|
return False
|
|
|
|
def get_model_stats(self, model_name: str) -> Dict[str, Any]:
|
|
"""Get model performance statistics"""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
SELECT total_predictions, correct_predictions, total_reward
|
|
FROM model_performance WHERE model_name = ?
|
|
""", (model_name,))
|
|
|
|
result = cursor.fetchone()
|
|
if not result:
|
|
return {"model_name": model_name, "total_predictions": 0, "accuracy": 0.0, "total_reward": 0.0}
|
|
|
|
total, correct, reward = result
|
|
accuracy = (correct / total) if total > 0 else 0.0
|
|
|
|
return {
|
|
"model_name": model_name,
|
|
"total_predictions": total,
|
|
"correct_predictions": correct,
|
|
"accuracy": accuracy,
|
|
"total_reward": reward
|
|
}
|
|
|
|
def get_all_model_stats(self) -> List[Dict[str, Any]]:
|
|
"""Get stats for all models"""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
SELECT model_name, total_predictions, correct_predictions, total_reward
|
|
FROM model_performance ORDER BY total_predictions DESC
|
|
""")
|
|
|
|
stats = []
|
|
for row in cursor.fetchall():
|
|
model_name, total, correct, reward = row
|
|
accuracy = (correct / total) if total > 0 else 0.0
|
|
stats.append({
|
|
"model_name": model_name,
|
|
"total_predictions": total,
|
|
"correct_predictions": correct,
|
|
"accuracy": accuracy,
|
|
"total_reward": reward
|
|
})
|
|
|
|
return stats
|
|
|
|
# Global instance
|
|
_prediction_db = None
|
|
|
|
def get_prediction_db() -> PredictionDatabase:
|
|
"""Get global prediction database"""
|
|
global _prediction_db
|
|
if _prediction_db is None:
|
|
_prediction_db = PredictionDatabase()
|
|
return _prediction_db
|