prediction database
This commit is contained in:
@@ -1112,11 +1112,76 @@ class TradingOrchestrator:
|
||||
return predictions
|
||||
|
||||
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
|
||||
"""Get predictions from CNN model for all timeframes with enhanced COB features"""
|
||||
"""Get CNN predictions for multiple timeframes"""
|
||||
predictions = []
|
||||
|
||||
try:
|
||||
for timeframe in self.config.timeframes:
|
||||
# Get predictions for different timeframes
|
||||
timeframes = ['1m', '5m', '1h']
|
||||
|
||||
for timeframe in timeframes:
|
||||
try:
|
||||
# Get features from data provider
|
||||
features = self.data_provider.get_cnn_features_for_inference(symbol, timeframe, window_size=60)
|
||||
|
||||
if features is not None and len(features) > 0:
|
||||
# Get prediction from model
|
||||
prediction_result = await model.predict(features)
|
||||
|
||||
if prediction_result:
|
||||
prediction = Prediction(
|
||||
model_name=f"CNN_{timeframe}",
|
||||
symbol=symbol,
|
||||
signal=prediction_result.get('signal', 'HOLD'),
|
||||
confidence=prediction_result.get('confidence', 0.0),
|
||||
reasoning=f"CNN {timeframe} prediction",
|
||||
features=features[:10].tolist() if len(features) > 10 else features.tolist(),
|
||||
metadata={'timeframe': timeframe}
|
||||
)
|
||||
predictions.append(prediction)
|
||||
|
||||
# Store prediction in database for tracking
|
||||
if (hasattr(self, 'enhanced_training_system') and
|
||||
self.enhanced_training_system and
|
||||
hasattr(self.enhanced_training_system, 'store_model_prediction')):
|
||||
|
||||
current_price = self._get_current_price_safe(symbol)
|
||||
if current_price > 0:
|
||||
prediction_id = self.enhanced_training_system.store_model_prediction(
|
||||
model_name=f"CNN_{timeframe}",
|
||||
symbol=symbol,
|
||||
prediction_type=prediction.signal,
|
||||
confidence=prediction.confidence,
|
||||
current_price=current_price
|
||||
)
|
||||
logger.debug(f"Stored CNN prediction {prediction_id} for {symbol} {timeframe}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting CNN prediction for {symbol} {timeframe}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN predictions for {symbol}: {e}")
|
||||
|
||||
return predictions
|
||||
|
||||
def _get_current_price_safe(self, symbol: str) -> float:
|
||||
"""Safely get current price for a symbol"""
|
||||
try:
|
||||
# Try to get from data provider
|
||||
if hasattr(self.data_provider, 'get_latest_data'):
|
||||
latest = self.data_provider.get_latest_data(symbol)
|
||||
if latest and 'close' in latest:
|
||||
return float(latest['close'])
|
||||
|
||||
# Fallback values
|
||||
fallback_prices = {'ETH/USDT': 4300.0, 'BTC/USDT': 111000.0}
|
||||
return fallback_prices.get(symbol, 1000.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
return 0.0
|
||||
|
||||
# Get standard feature matrix for this timeframe
|
||||
feature_matrix = self.data_provider.get_feature_matrix(
|
||||
symbol=symbol,
|
||||
@@ -1259,11 +1324,58 @@ class TradingOrchestrator:
|
||||
action_idx, confidence = result
|
||||
raw_q_values = None
|
||||
else:
|
||||
logger.error(f"Unexpected return format from act_with_confidence: {len(result)} values")
|
||||
logger.warning(f"Unexpected result format from RL model: {result}")
|
||||
return None
|
||||
elif hasattr(model.model, 'act'):
|
||||
action_idx = model.model.act(state, explore=False)
|
||||
confidence = 0.7 # Default confidence for basic act method
|
||||
else:
|
||||
# Fallback to standard act method
|
||||
action_idx = model.model.act(state)
|
||||
confidence = 0.6 # Default confidence
|
||||
raw_q_values = None
|
||||
|
||||
# Convert action index to action name
|
||||
action_names = ['BUY', 'SELL', 'HOLD']
|
||||
if 0 <= action_idx < len(action_names):
|
||||
action = action_names[action_idx]
|
||||
else:
|
||||
logger.warning(f"Invalid action index from RL model: {action_idx}")
|
||||
return None
|
||||
|
||||
# Store prediction in database for tracking
|
||||
if (hasattr(self, 'enhanced_training_system') and
|
||||
self.enhanced_training_system and
|
||||
hasattr(self.enhanced_training_system, 'store_model_prediction')):
|
||||
|
||||
current_price = self._get_current_price_safe(symbol)
|
||||
if current_price > 0:
|
||||
prediction_id = self.enhanced_training_system.store_model_prediction(
|
||||
model_name=f"DQN_{model.model_name}" if hasattr(model, 'model_name') else "DQN",
|
||||
symbol=symbol,
|
||||
prediction_type=action,
|
||||
confidence=confidence,
|
||||
current_price=current_price
|
||||
)
|
||||
logger.debug(f"Stored DQN prediction {prediction_id} for {symbol}")
|
||||
|
||||
# Create prediction object
|
||||
prediction = Prediction(
|
||||
model_name=f"DQN_{model.model_name}" if hasattr(model, 'model_name') else "DQN",
|
||||
symbol=symbol,
|
||||
signal=action,
|
||||
confidence=confidence,
|
||||
reasoning=f"DQN agent prediction with Q-values: {raw_q_values}",
|
||||
features=state.tolist() if isinstance(state, np.ndarray) else [],
|
||||
metadata={
|
||||
'action_idx': action_idx,
|
||||
'q_values': raw_q_values.tolist() if raw_q_values is not None else None,
|
||||
'state_size': len(state) if state is not None else 0
|
||||
}
|
||||
)
|
||||
|
||||
return prediction
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting RL prediction for {symbol}: {e}")
|
||||
return None
|
||||
raw_q_values = None # No raw q_values from simple act
|
||||
else:
|
||||
logger.error(f"RL model {model.name} has no act method")
|
||||
|
205
core/prediction_database.py
Normal file
205
core/prediction_database.py
Normal file
@@ -0,0 +1,205 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Prediction Database - Simple SQLite database for tracking model predictions
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import logging
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PredictionDatabase:
|
||||
"""Simple database for tracking model predictions and outcomes"""
|
||||
|
||||
def __init__(self, db_path: str = "data/predictions.db"):
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._initialize_database()
|
||||
logger.info(f"PredictionDatabase initialized: {self.db_path}")
|
||||
|
||||
def _initialize_database(self):
|
||||
"""Initialize SQLite database"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Predictions table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS predictions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
prediction_type TEXT NOT NULL,
|
||||
confidence REAL NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
price_at_prediction REAL NOT NULL,
|
||||
|
||||
-- Outcome fields
|
||||
outcome_timestamp TEXT,
|
||||
actual_price_change REAL,
|
||||
reward REAL,
|
||||
is_correct INTEGER,
|
||||
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Performance summary table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS model_performance (
|
||||
model_name TEXT PRIMARY KEY,
|
||||
total_predictions INTEGER DEFAULT 0,
|
||||
correct_predictions INTEGER DEFAULT 0,
|
||||
total_reward REAL DEFAULT 0.0,
|
||||
last_updated TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
|
||||
def store_prediction(self, model_name: str, symbol: str, prediction_type: str,
|
||||
confidence: float, price_at_prediction: float) -> int:
|
||||
"""Store a new prediction"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
cursor.execute("""
|
||||
INSERT INTO predictions (
|
||||
model_name, symbol, prediction_type, confidence,
|
||||
timestamp, price_at_prediction
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
""", (model_name, symbol, prediction_type, confidence,
|
||||
timestamp, price_at_prediction))
|
||||
|
||||
prediction_id = cursor.lastrowid
|
||||
|
||||
# Update performance count
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO model_performance (
|
||||
model_name, total_predictions, correct_predictions, total_reward, last_updated
|
||||
) VALUES (
|
||||
?,
|
||||
COALESCE((SELECT total_predictions FROM model_performance WHERE model_name = ?), 0) + 1,
|
||||
COALESCE((SELECT correct_predictions FROM model_performance WHERE model_name = ?), 0),
|
||||
COALESCE((SELECT total_reward FROM model_performance WHERE model_name = ?), 0.0),
|
||||
?
|
||||
)
|
||||
""", (model_name, model_name, model_name, model_name, timestamp))
|
||||
|
||||
conn.commit()
|
||||
return prediction_id
|
||||
|
||||
def resolve_prediction(self, prediction_id: int, actual_price_change: float, reward: float) -> bool:
|
||||
"""Resolve a prediction with outcome"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get original prediction
|
||||
cursor.execute("""
|
||||
SELECT model_name, prediction_type FROM predictions
|
||||
WHERE id = ? AND outcome_timestamp IS NULL
|
||||
""", (prediction_id,))
|
||||
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
return False
|
||||
|
||||
model_name, prediction_type = result
|
||||
|
||||
# Determine correctness
|
||||
is_correct = self._is_prediction_correct(prediction_type, actual_price_change)
|
||||
|
||||
# Update prediction
|
||||
outcome_timestamp = datetime.now().isoformat()
|
||||
cursor.execute("""
|
||||
UPDATE predictions SET
|
||||
outcome_timestamp = ?, actual_price_change = ?,
|
||||
reward = ?, is_correct = ?
|
||||
WHERE id = ?
|
||||
""", (outcome_timestamp, actual_price_change, reward, int(is_correct), prediction_id))
|
||||
|
||||
# Update performance
|
||||
cursor.execute("""
|
||||
UPDATE model_performance SET
|
||||
correct_predictions = correct_predictions + ?,
|
||||
total_reward = total_reward + ?,
|
||||
last_updated = ?
|
||||
WHERE model_name = ?
|
||||
""", (int(is_correct), reward, outcome_timestamp, model_name))
|
||||
|
||||
conn.commit()
|
||||
return True
|
||||
|
||||
def _is_prediction_correct(self, prediction_type: str, price_change: float) -> bool:
|
||||
"""Check if prediction was correct"""
|
||||
if prediction_type == "BUY":
|
||||
return price_change > 0
|
||||
elif prediction_type == "SELL":
|
||||
return price_change < 0
|
||||
elif prediction_type == "HOLD":
|
||||
return abs(price_change) < 0.001
|
||||
return False
|
||||
|
||||
def get_model_stats(self, model_name: str) -> Dict[str, Any]:
|
||||
"""Get model performance statistics"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT total_predictions, correct_predictions, total_reward
|
||||
FROM model_performance WHERE model_name = ?
|
||||
""", (model_name,))
|
||||
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
return {"model_name": model_name, "total_predictions": 0, "accuracy": 0.0, "total_reward": 0.0}
|
||||
|
||||
total, correct, reward = result
|
||||
accuracy = (correct / total) if total > 0 else 0.0
|
||||
|
||||
return {
|
||||
"model_name": model_name,
|
||||
"total_predictions": total,
|
||||
"correct_predictions": correct,
|
||||
"accuracy": accuracy,
|
||||
"total_reward": reward
|
||||
}
|
||||
|
||||
def get_all_model_stats(self) -> List[Dict[str, Any]]:
|
||||
"""Get stats for all models"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT model_name, total_predictions, correct_predictions, total_reward
|
||||
FROM model_performance ORDER BY total_predictions DESC
|
||||
""")
|
||||
|
||||
stats = []
|
||||
for row in cursor.fetchall():
|
||||
model_name, total, correct, reward = row
|
||||
accuracy = (correct / total) if total > 0 else 0.0
|
||||
stats.append({
|
||||
"model_name": model_name,
|
||||
"total_predictions": total,
|
||||
"correct_predictions": correct,
|
||||
"accuracy": accuracy,
|
||||
"total_reward": reward
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
# Global instance
|
||||
_prediction_db = None
|
||||
|
||||
def get_prediction_db() -> PredictionDatabase:
|
||||
"""Get global prediction database"""
|
||||
global _prediction_db
|
||||
if _prediction_db is None:
|
||||
_prediction_db = PredictionDatabase()
|
||||
return _prediction_db
|
Reference in New Issue
Block a user