prediction database

This commit is contained in:
Dobromir Popov
2025-09-02 19:25:42 +03:00
parent 226a6aa047
commit fe6763c4ba
5 changed files with 523 additions and 8 deletions

View File

@@ -1112,11 +1112,76 @@ class TradingOrchestrator:
return predictions
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
"""Get predictions from CNN model for all timeframes with enhanced COB features"""
"""Get CNN predictions for multiple timeframes"""
predictions = []
try:
for timeframe in self.config.timeframes:
# Get predictions for different timeframes
timeframes = ['1m', '5m', '1h']
for timeframe in timeframes:
try:
# Get features from data provider
features = self.data_provider.get_cnn_features_for_inference(symbol, timeframe, window_size=60)
if features is not None and len(features) > 0:
# Get prediction from model
prediction_result = await model.predict(features)
if prediction_result:
prediction = Prediction(
model_name=f"CNN_{timeframe}",
symbol=symbol,
signal=prediction_result.get('signal', 'HOLD'),
confidence=prediction_result.get('confidence', 0.0),
reasoning=f"CNN {timeframe} prediction",
features=features[:10].tolist() if len(features) > 10 else features.tolist(),
metadata={'timeframe': timeframe}
)
predictions.append(prediction)
# Store prediction in database for tracking
if (hasattr(self, 'enhanced_training_system') and
self.enhanced_training_system and
hasattr(self.enhanced_training_system, 'store_model_prediction')):
current_price = self._get_current_price_safe(symbol)
if current_price > 0:
prediction_id = self.enhanced_training_system.store_model_prediction(
model_name=f"CNN_{timeframe}",
symbol=symbol,
prediction_type=prediction.signal,
confidence=prediction.confidence,
current_price=current_price
)
logger.debug(f"Stored CNN prediction {prediction_id} for {symbol} {timeframe}")
except Exception as e:
logger.debug(f"Error getting CNN prediction for {symbol} {timeframe}: {e}")
continue
except Exception as e:
logger.error(f"Error in CNN predictions for {symbol}: {e}")
return predictions
def _get_current_price_safe(self, symbol: str) -> float:
"""Safely get current price for a symbol"""
try:
# Try to get from data provider
if hasattr(self.data_provider, 'get_latest_data'):
latest = self.data_provider.get_latest_data(symbol)
if latest and 'close' in latest:
return float(latest['close'])
# Fallback values
fallback_prices = {'ETH/USDT': 4300.0, 'BTC/USDT': 111000.0}
return fallback_prices.get(symbol, 1000.0)
except Exception as e:
logger.debug(f"Error getting current price for {symbol}: {e}")
return 0.0
# Get standard feature matrix for this timeframe
feature_matrix = self.data_provider.get_feature_matrix(
symbol=symbol,
@@ -1259,11 +1324,58 @@ class TradingOrchestrator:
action_idx, confidence = result
raw_q_values = None
else:
logger.error(f"Unexpected return format from act_with_confidence: {len(result)} values")
logger.warning(f"Unexpected result format from RL model: {result}")
return None
elif hasattr(model.model, 'act'):
action_idx = model.model.act(state, explore=False)
confidence = 0.7 # Default confidence for basic act method
else:
# Fallback to standard act method
action_idx = model.model.act(state)
confidence = 0.6 # Default confidence
raw_q_values = None
# Convert action index to action name
action_names = ['BUY', 'SELL', 'HOLD']
if 0 <= action_idx < len(action_names):
action = action_names[action_idx]
else:
logger.warning(f"Invalid action index from RL model: {action_idx}")
return None
# Store prediction in database for tracking
if (hasattr(self, 'enhanced_training_system') and
self.enhanced_training_system and
hasattr(self.enhanced_training_system, 'store_model_prediction')):
current_price = self._get_current_price_safe(symbol)
if current_price > 0:
prediction_id = self.enhanced_training_system.store_model_prediction(
model_name=f"DQN_{model.model_name}" if hasattr(model, 'model_name') else "DQN",
symbol=symbol,
prediction_type=action,
confidence=confidence,
current_price=current_price
)
logger.debug(f"Stored DQN prediction {prediction_id} for {symbol}")
# Create prediction object
prediction = Prediction(
model_name=f"DQN_{model.model_name}" if hasattr(model, 'model_name') else "DQN",
symbol=symbol,
signal=action,
confidence=confidence,
reasoning=f"DQN agent prediction with Q-values: {raw_q_values}",
features=state.tolist() if isinstance(state, np.ndarray) else [],
metadata={
'action_idx': action_idx,
'q_values': raw_q_values.tolist() if raw_q_values is not None else None,
'state_size': len(state) if state is not None else 0
}
)
return prediction
except Exception as e:
logger.error(f"Error getting RL prediction for {symbol}: {e}")
return None
raw_q_values = None # No raw q_values from simple act
else:
logger.error(f"RL model {model.name} has no act method")

205
core/prediction_database.py Normal file
View File

@@ -0,0 +1,205 @@
#!/usr/bin/env python3
"""
Prediction Database - Simple SQLite database for tracking model predictions
"""
import sqlite3
import logging
import json
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
from pathlib import Path
logger = logging.getLogger(__name__)
class PredictionDatabase:
"""Simple database for tracking model predictions and outcomes"""
def __init__(self, db_path: str = "data/predictions.db"):
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._initialize_database()
logger.info(f"PredictionDatabase initialized: {self.db_path}")
def _initialize_database(self):
"""Initialize SQLite database"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Predictions table
cursor.execute("""
CREATE TABLE IF NOT EXISTS predictions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_name TEXT NOT NULL,
symbol TEXT NOT NULL,
prediction_type TEXT NOT NULL,
confidence REAL NOT NULL,
timestamp TEXT NOT NULL,
price_at_prediction REAL NOT NULL,
-- Outcome fields
outcome_timestamp TEXT,
actual_price_change REAL,
reward REAL,
is_correct INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Performance summary table
cursor.execute("""
CREATE TABLE IF NOT EXISTS model_performance (
model_name TEXT PRIMARY KEY,
total_predictions INTEGER DEFAULT 0,
correct_predictions INTEGER DEFAULT 0,
total_reward REAL DEFAULT 0.0,
last_updated TEXT
)
""")
conn.commit()
def store_prediction(self, model_name: str, symbol: str, prediction_type: str,
confidence: float, price_at_prediction: float) -> int:
"""Store a new prediction"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
timestamp = datetime.now().isoformat()
cursor.execute("""
INSERT INTO predictions (
model_name, symbol, prediction_type, confidence,
timestamp, price_at_prediction
) VALUES (?, ?, ?, ?, ?, ?)
""", (model_name, symbol, prediction_type, confidence,
timestamp, price_at_prediction))
prediction_id = cursor.lastrowid
# Update performance count
cursor.execute("""
INSERT OR REPLACE INTO model_performance (
model_name, total_predictions, correct_predictions, total_reward, last_updated
) VALUES (
?,
COALESCE((SELECT total_predictions FROM model_performance WHERE model_name = ?), 0) + 1,
COALESCE((SELECT correct_predictions FROM model_performance WHERE model_name = ?), 0),
COALESCE((SELECT total_reward FROM model_performance WHERE model_name = ?), 0.0),
?
)
""", (model_name, model_name, model_name, model_name, timestamp))
conn.commit()
return prediction_id
def resolve_prediction(self, prediction_id: int, actual_price_change: float, reward: float) -> bool:
"""Resolve a prediction with outcome"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# Get original prediction
cursor.execute("""
SELECT model_name, prediction_type FROM predictions
WHERE id = ? AND outcome_timestamp IS NULL
""", (prediction_id,))
result = cursor.fetchone()
if not result:
return False
model_name, prediction_type = result
# Determine correctness
is_correct = self._is_prediction_correct(prediction_type, actual_price_change)
# Update prediction
outcome_timestamp = datetime.now().isoformat()
cursor.execute("""
UPDATE predictions SET
outcome_timestamp = ?, actual_price_change = ?,
reward = ?, is_correct = ?
WHERE id = ?
""", (outcome_timestamp, actual_price_change, reward, int(is_correct), prediction_id))
# Update performance
cursor.execute("""
UPDATE model_performance SET
correct_predictions = correct_predictions + ?,
total_reward = total_reward + ?,
last_updated = ?
WHERE model_name = ?
""", (int(is_correct), reward, outcome_timestamp, model_name))
conn.commit()
return True
def _is_prediction_correct(self, prediction_type: str, price_change: float) -> bool:
"""Check if prediction was correct"""
if prediction_type == "BUY":
return price_change > 0
elif prediction_type == "SELL":
return price_change < 0
elif prediction_type == "HOLD":
return abs(price_change) < 0.001
return False
def get_model_stats(self, model_name: str) -> Dict[str, Any]:
"""Get model performance statistics"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT total_predictions, correct_predictions, total_reward
FROM model_performance WHERE model_name = ?
""", (model_name,))
result = cursor.fetchone()
if not result:
return {"model_name": model_name, "total_predictions": 0, "accuracy": 0.0, "total_reward": 0.0}
total, correct, reward = result
accuracy = (correct / total) if total > 0 else 0.0
return {
"model_name": model_name,
"total_predictions": total,
"correct_predictions": correct,
"accuracy": accuracy,
"total_reward": reward
}
def get_all_model_stats(self) -> List[Dict[str, Any]]:
"""Get stats for all models"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT model_name, total_predictions, correct_predictions, total_reward
FROM model_performance ORDER BY total_predictions DESC
""")
stats = []
for row in cursor.fetchall():
model_name, total, correct, reward = row
accuracy = (correct / total) if total > 0 else 0.0
stats.append({
"model_name": model_name,
"total_predictions": total,
"correct_predictions": correct,
"accuracy": accuracy,
"total_reward": reward
})
return stats
# Global instance
_prediction_db = None
def get_prediction_db() -> PredictionDatabase:
"""Get global prediction database"""
global _prediction_db
if _prediction_db is None:
_prediction_db = PredictionDatabase()
return _prediction_db