#!/usr/bin/env python3 """ Prediction Database - Simple SQLite database for tracking model predictions """ import sqlite3 import logging import json from datetime import datetime, timedelta from typing import Dict, List, Any, Optional from pathlib import Path logger = logging.getLogger(__name__) class PredictionDatabase: """Simple database for tracking model predictions and outcomes""" def __init__(self, db_path: str = "data/predictions.db"): self.db_path = Path(db_path) self.db_path.parent.mkdir(parents=True, exist_ok=True) self._initialize_database() logger.info(f"PredictionDatabase initialized: {self.db_path}") def _initialize_database(self): """Initialize SQLite database""" with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() # Predictions table cursor.execute(""" CREATE TABLE IF NOT EXISTS predictions ( id INTEGER PRIMARY KEY AUTOINCREMENT, model_name TEXT NOT NULL, symbol TEXT NOT NULL, prediction_type TEXT NOT NULL, confidence REAL NOT NULL, timestamp TEXT NOT NULL, price_at_prediction REAL NOT NULL, -- Outcome fields outcome_timestamp TEXT, actual_price_change REAL, reward REAL, is_correct INTEGER, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) # Performance summary table cursor.execute(""" CREATE TABLE IF NOT EXISTS model_performance ( model_name TEXT PRIMARY KEY, total_predictions INTEGER DEFAULT 0, correct_predictions INTEGER DEFAULT 0, total_reward REAL DEFAULT 0.0, last_updated TEXT ) """) conn.commit() def store_prediction(self, model_name: str, symbol: str, prediction_type: str, confidence: float, price_at_prediction: float) -> int: """Store a new prediction""" with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() timestamp = datetime.now().isoformat() cursor.execute(""" INSERT INTO predictions ( model_name, symbol, prediction_type, confidence, timestamp, price_at_prediction ) VALUES (?, ?, ?, ?, ?, ?) """, (model_name, symbol, prediction_type, confidence, timestamp, price_at_prediction)) prediction_id = cursor.lastrowid # Update performance count cursor.execute(""" INSERT OR REPLACE INTO model_performance ( model_name, total_predictions, correct_predictions, total_reward, last_updated ) VALUES ( ?, COALESCE((SELECT total_predictions FROM model_performance WHERE model_name = ?), 0) + 1, COALESCE((SELECT correct_predictions FROM model_performance WHERE model_name = ?), 0), COALESCE((SELECT total_reward FROM model_performance WHERE model_name = ?), 0.0), ? ) """, (model_name, model_name, model_name, model_name, timestamp)) conn.commit() return prediction_id def resolve_prediction(self, prediction_id: int, actual_price_change: float, reward: float) -> bool: """Resolve a prediction with outcome""" with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() # Get original prediction cursor.execute(""" SELECT model_name, prediction_type FROM predictions WHERE id = ? AND outcome_timestamp IS NULL """, (prediction_id,)) result = cursor.fetchone() if not result: return False model_name, prediction_type = result # Determine correctness is_correct = self._is_prediction_correct(prediction_type, actual_price_change) # Update prediction outcome_timestamp = datetime.now().isoformat() cursor.execute(""" UPDATE predictions SET outcome_timestamp = ?, actual_price_change = ?, reward = ?, is_correct = ? WHERE id = ? """, (outcome_timestamp, actual_price_change, reward, int(is_correct), prediction_id)) # Update performance cursor.execute(""" UPDATE model_performance SET correct_predictions = correct_predictions + ?, total_reward = total_reward + ?, last_updated = ? WHERE model_name = ? """, (int(is_correct), reward, outcome_timestamp, model_name)) conn.commit() return True def _is_prediction_correct(self, prediction_type: str, price_change: float) -> bool: """Check if prediction was correct""" if prediction_type == "BUY": return price_change > 0 elif prediction_type == "SELL": return price_change < 0 elif prediction_type == "HOLD": return abs(price_change) < 0.001 return False def get_model_stats(self, model_name: str) -> Dict[str, Any]: """Get model performance statistics""" with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute(""" SELECT total_predictions, correct_predictions, total_reward FROM model_performance WHERE model_name = ? """, (model_name,)) result = cursor.fetchone() if not result: return {"model_name": model_name, "total_predictions": 0, "accuracy": 0.0, "total_reward": 0.0} total, correct, reward = result accuracy = (correct / total) if total > 0 else 0.0 return { "model_name": model_name, "total_predictions": total, "correct_predictions": correct, "accuracy": accuracy, "total_reward": reward } def get_all_model_stats(self) -> List[Dict[str, Any]]: """Get stats for all models""" with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute(""" SELECT model_name, total_predictions, correct_predictions, total_reward FROM model_performance ORDER BY total_predictions DESC """) stats = [] for row in cursor.fetchall(): model_name, total, correct, reward = row accuracy = (correct / total) if total > 0 else 0.0 stats.append({ "model_name": model_name, "total_predictions": total, "correct_predictions": correct, "accuracy": accuracy, "total_reward": reward }) return stats # Global instance _prediction_db = None def get_prediction_db() -> PredictionDatabase: """Get global prediction database""" global _prediction_db if _prediction_db is None: _prediction_db = PredictionDatabase() return _prediction_db