From fe6763c4baf350715d8693df9903a7ec40f61448 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Tue, 2 Sep 2025 19:25:42 +0300 Subject: [PATCH] prediction database --- NN/training/enhanced_realtime_training.py | 187 +++++++++++++++++++- core/orchestrator.py | 124 ++++++++++++- core/prediction_database.py | 205 ++++++++++++++++++++++ data/predictions.db | Bin 0 -> 20480 bytes web/clean_dashboard.py | 15 ++ 5 files changed, 523 insertions(+), 8 deletions(-) create mode 100644 core/prediction_database.py create mode 100644 data/predictions.db diff --git a/NN/training/enhanced_realtime_training.py b/NN/training/enhanced_realtime_training.py index f2a313b..895d13e 100644 --- a/NN/training/enhanced_realtime_training.py +++ b/NN/training/enhanced_realtime_training.py @@ -9,6 +9,7 @@ This system implements effective online learning with: - Continuous validation and adaptation - Multi-timeframe feature engineering - Real market microstructure analysis +- PREDICTION TRACKING: Store each prediction and track outcomes """ import numpy as np @@ -26,16 +27,26 @@ import torch import torch.nn as nn import torch.optim as optim +# Import prediction tracking +from core.prediction_database import get_prediction_db + logger = logging.getLogger(__name__) class EnhancedRealtimeTrainingSystem: - """Enhanced real-time training system with proper online learning""" + """Enhanced real-time training system with prediction tracking and database storage""" def __init__(self, orchestrator, data_provider, dashboard=None): self.orchestrator = orchestrator self.data_provider = data_provider self.dashboard = dashboard + # Prediction tracking database + self.prediction_db = get_prediction_db() + + # Active predictions waiting for resolution + self.active_predictions = {} # {prediction_id: {"timestamp": ..., "price": ..., "model": ...}} + self.prediction_resolution_time = 300 # 5 minutes to resolve predictions + # Training configuration self.training_config = { 'dqn_training_interval': 5, # Train DQN every 5 seconds @@ -162,13 +173,185 @@ class EnhancedRealtimeTrainingSystem: validation_thread = threading.Thread(target=self._validation_worker, daemon=True) validation_thread.start() - logger.info("Enhanced real-time training system started") + # Start prediction resolution worker + prediction_thread = threading.Thread(target=self._prediction_resolution_worker, daemon=True) + prediction_thread.start() + + logger.info("Enhanced real-time training system started with prediction tracking") def stop_training(self): """Stop the training system""" self.is_training = False logger.info("Enhanced real-time training system stopped") + def store_model_prediction(self, model_name: str, symbol: str, prediction_type: str, + confidence: float, current_price: float) -> int: + """Store a model prediction in the database for tracking""" + try: + prediction_id = self.prediction_db.store_prediction( + model_name=model_name, + symbol=symbol, + prediction_type=prediction_type, + confidence=confidence, + price_at_prediction=current_price + ) + + # Track active prediction for later resolution + self.active_predictions[prediction_id] = { + "model_name": model_name, + "symbol": symbol, + "prediction_type": prediction_type, + "confidence": confidence, + "timestamp": time.time(), + "price_at_prediction": current_price + } + + logger.info(f"Stored prediction {prediction_id}: {model_name} -> {prediction_type} for {symbol} (conf: {confidence:.3f})") + return prediction_id + + except Exception as e: + logger.error(f"Error storing prediction: {e}") + return -1 + + def resolve_predictions(self): + """Resolve active predictions based on price movement""" + try: + current_time = time.time() + resolved_predictions = [] + + for prediction_id, pred_data in list(self.active_predictions.items()): + # Check if prediction is old enough to resolve + age = current_time - pred_data["timestamp"] + if age >= self.prediction_resolution_time: + + # Get current price for the symbol + symbol = pred_data["symbol"] + current_price = self._get_current_price(symbol) + + if current_price > 0: + # Calculate price change + price_change_pct = (current_price - pred_data["price_at_prediction"]) / pred_data["price_at_prediction"] + + # Calculate reward based on prediction correctness + reward = self._calculate_prediction_reward( + pred_data["prediction_type"], + price_change_pct, + pred_data["confidence"] + ) + + # Resolve the prediction + success = self.prediction_db.resolve_prediction( + prediction_id=prediction_id, + actual_price_change=price_change_pct, + reward=reward + ) + + if success: + logger.info(f"Resolved prediction {prediction_id}: {pred_data['model_name']} -> " + f"price change {price_change_pct:.3f}%, reward {reward:.3f}") + resolved_predictions.append(prediction_id) + + # Remove from active predictions + del self.active_predictions[prediction_id] + + return len(resolved_predictions) + + except Exception as e: + logger.error(f"Error resolving predictions: {e}") + return 0 + + def _get_current_price(self, symbol: str) -> float: + """Get current price for a symbol""" + try: + # Try to get from data provider + if self.data_provider and hasattr(self.data_provider, 'get_latest_data'): + latest = self.data_provider.get_latest_data(symbol) + if latest and 'close' in latest: + return float(latest['close']) + + # Try to get from orchestrator + if self.orchestrator and hasattr(self.orchestrator, '_get_current_price'): + return float(self.orchestrator._get_current_price(symbol)) + + # Fallback values + fallback_prices = {'ETH/USDT': 4300.0, 'BTC/USDT': 111000.0} + return fallback_prices.get(symbol, 1000.0) + + except Exception as e: + logger.debug(f"Error getting current price for {symbol}: {e}") + return 0.0 + + def _calculate_prediction_reward(self, prediction_type: str, price_change_pct: float, confidence: float) -> float: + """Calculate reward for a prediction based on outcome""" + try: + # Base reward calculation + if prediction_type == "BUY": + base_reward = price_change_pct * 100 # Positive if price went up + elif prediction_type == "SELL": + base_reward = -price_change_pct * 100 # Positive if price went down + elif prediction_type == "HOLD": + base_reward = max(0, 1 - abs(price_change_pct) * 100) # Positive if small movement + else: + base_reward = 0 + + # Confidence adjustment - reward high confidence correct predictions more + confidence_multiplier = 0.5 + (confidence * 1.5) # Range: 0.5 to 2.0 + + # Final reward calculation + final_reward = base_reward * confidence_multiplier + + # Normalize to reasonable range [-10, 10] + final_reward = max(-10, min(10, final_reward)) + + return final_reward + + except Exception as e: + logger.error(f"Error calculating prediction reward: {e}") + return 0.0 + + def get_model_performance_stats(self) -> Dict[str, Any]: + """Get performance statistics for all models""" + try: + stats = self.prediction_db.get_all_model_stats() + + # Add active predictions count + active_by_model = {} + for pred_data in self.active_predictions.values(): + model = pred_data["model_name"] + active_by_model[model] = active_by_model.get(model, 0) + 1 + + # Enhance stats with active predictions + for stat in stats: + model_name = stat["model_name"] + stat["active_predictions"] = active_by_model.get(model_name, 0) + + return { + "models": stats, + "total_active_predictions": len(self.active_predictions), + "last_updated": datetime.now().isoformat() + } + + except Exception as e: + logger.error(f"Error getting performance stats: {e}") + return {} + + def _prediction_resolution_worker(self): + """Worker thread to resolve active predictions""" + while self.is_training: + try: + # Resolve predictions every 30 seconds + resolved_count = self.resolve_predictions() + if resolved_count > 0: + logger.info(f"Resolved {resolved_count} predictions") + + time.sleep(30) + + except Exception as e: + logger.error(f"Error in prediction resolution worker: {e}") + time.sleep(60) + + + def _data_collection_worker(self): """Collect and preprocess real-time market data""" while self.is_training: diff --git a/core/orchestrator.py b/core/orchestrator.py index 0252f9c..3864e7a 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -1112,11 +1112,76 @@ class TradingOrchestrator: return predictions async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]: - """Get predictions from CNN model for all timeframes with enhanced COB features""" + """Get CNN predictions for multiple timeframes""" predictions = [] try: - for timeframe in self.config.timeframes: + # Get predictions for different timeframes + timeframes = ['1m', '5m', '1h'] + + for timeframe in timeframes: + try: + # Get features from data provider + features = self.data_provider.get_cnn_features_for_inference(symbol, timeframe, window_size=60) + + if features is not None and len(features) > 0: + # Get prediction from model + prediction_result = await model.predict(features) + + if prediction_result: + prediction = Prediction( + model_name=f"CNN_{timeframe}", + symbol=symbol, + signal=prediction_result.get('signal', 'HOLD'), + confidence=prediction_result.get('confidence', 0.0), + reasoning=f"CNN {timeframe} prediction", + features=features[:10].tolist() if len(features) > 10 else features.tolist(), + metadata={'timeframe': timeframe} + ) + predictions.append(prediction) + + # Store prediction in database for tracking + if (hasattr(self, 'enhanced_training_system') and + self.enhanced_training_system and + hasattr(self.enhanced_training_system, 'store_model_prediction')): + + current_price = self._get_current_price_safe(symbol) + if current_price > 0: + prediction_id = self.enhanced_training_system.store_model_prediction( + model_name=f"CNN_{timeframe}", + symbol=symbol, + prediction_type=prediction.signal, + confidence=prediction.confidence, + current_price=current_price + ) + logger.debug(f"Stored CNN prediction {prediction_id} for {symbol} {timeframe}") + + except Exception as e: + logger.debug(f"Error getting CNN prediction for {symbol} {timeframe}: {e}") + continue + + except Exception as e: + logger.error(f"Error in CNN predictions for {symbol}: {e}") + + return predictions + + def _get_current_price_safe(self, symbol: str) -> float: + """Safely get current price for a symbol""" + try: + # Try to get from data provider + if hasattr(self.data_provider, 'get_latest_data'): + latest = self.data_provider.get_latest_data(symbol) + if latest and 'close' in latest: + return float(latest['close']) + + # Fallback values + fallback_prices = {'ETH/USDT': 4300.0, 'BTC/USDT': 111000.0} + return fallback_prices.get(symbol, 1000.0) + + except Exception as e: + logger.debug(f"Error getting current price for {symbol}: {e}") + return 0.0 + # Get standard feature matrix for this timeframe feature_matrix = self.data_provider.get_feature_matrix( symbol=symbol, @@ -1259,11 +1324,58 @@ class TradingOrchestrator: action_idx, confidence = result raw_q_values = None else: - logger.error(f"Unexpected return format from act_with_confidence: {len(result)} values") + logger.warning(f"Unexpected result format from RL model: {result}") return None - elif hasattr(model.model, 'act'): - action_idx = model.model.act(state, explore=False) - confidence = 0.7 # Default confidence for basic act method + else: + # Fallback to standard act method + action_idx = model.model.act(state) + confidence = 0.6 # Default confidence + raw_q_values = None + + # Convert action index to action name + action_names = ['BUY', 'SELL', 'HOLD'] + if 0 <= action_idx < len(action_names): + action = action_names[action_idx] + else: + logger.warning(f"Invalid action index from RL model: {action_idx}") + return None + + # Store prediction in database for tracking + if (hasattr(self, 'enhanced_training_system') and + self.enhanced_training_system and + hasattr(self.enhanced_training_system, 'store_model_prediction')): + + current_price = self._get_current_price_safe(symbol) + if current_price > 0: + prediction_id = self.enhanced_training_system.store_model_prediction( + model_name=f"DQN_{model.model_name}" if hasattr(model, 'model_name') else "DQN", + symbol=symbol, + prediction_type=action, + confidence=confidence, + current_price=current_price + ) + logger.debug(f"Stored DQN prediction {prediction_id} for {symbol}") + + # Create prediction object + prediction = Prediction( + model_name=f"DQN_{model.model_name}" if hasattr(model, 'model_name') else "DQN", + symbol=symbol, + signal=action, + confidence=confidence, + reasoning=f"DQN agent prediction with Q-values: {raw_q_values}", + features=state.tolist() if isinstance(state, np.ndarray) else [], + metadata={ + 'action_idx': action_idx, + 'q_values': raw_q_values.tolist() if raw_q_values is not None else None, + 'state_size': len(state) if state is not None else 0 + } + ) + + return prediction + + except Exception as e: + logger.error(f"Error getting RL prediction for {symbol}: {e}") + return None raw_q_values = None # No raw q_values from simple act else: logger.error(f"RL model {model.name} has no act method") diff --git a/core/prediction_database.py b/core/prediction_database.py new file mode 100644 index 0000000..7a33a5b --- /dev/null +++ b/core/prediction_database.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 +""" +Prediction Database - Simple SQLite database for tracking model predictions +""" + +import sqlite3 +import logging +import json +from datetime import datetime, timedelta +from typing import Dict, List, Any, Optional +from pathlib import Path + +logger = logging.getLogger(__name__) + +class PredictionDatabase: + """Simple database for tracking model predictions and outcomes""" + + def __init__(self, db_path: str = "data/predictions.db"): + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._initialize_database() + logger.info(f"PredictionDatabase initialized: {self.db_path}") + + def _initialize_database(self): + """Initialize SQLite database""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + + # Predictions table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS predictions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + model_name TEXT NOT NULL, + symbol TEXT NOT NULL, + prediction_type TEXT NOT NULL, + confidence REAL NOT NULL, + timestamp TEXT NOT NULL, + price_at_prediction REAL NOT NULL, + + -- Outcome fields + outcome_timestamp TEXT, + actual_price_change REAL, + reward REAL, + is_correct INTEGER, + + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Performance summary table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS model_performance ( + model_name TEXT PRIMARY KEY, + total_predictions INTEGER DEFAULT 0, + correct_predictions INTEGER DEFAULT 0, + total_reward REAL DEFAULT 0.0, + last_updated TEXT + ) + """) + + conn.commit() + + def store_prediction(self, model_name: str, symbol: str, prediction_type: str, + confidence: float, price_at_prediction: float) -> int: + """Store a new prediction""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + + timestamp = datetime.now().isoformat() + + cursor.execute(""" + INSERT INTO predictions ( + model_name, symbol, prediction_type, confidence, + timestamp, price_at_prediction + ) VALUES (?, ?, ?, ?, ?, ?) + """, (model_name, symbol, prediction_type, confidence, + timestamp, price_at_prediction)) + + prediction_id = cursor.lastrowid + + # Update performance count + cursor.execute(""" + INSERT OR REPLACE INTO model_performance ( + model_name, total_predictions, correct_predictions, total_reward, last_updated + ) VALUES ( + ?, + COALESCE((SELECT total_predictions FROM model_performance WHERE model_name = ?), 0) + 1, + COALESCE((SELECT correct_predictions FROM model_performance WHERE model_name = ?), 0), + COALESCE((SELECT total_reward FROM model_performance WHERE model_name = ?), 0.0), + ? + ) + """, (model_name, model_name, model_name, model_name, timestamp)) + + conn.commit() + return prediction_id + + def resolve_prediction(self, prediction_id: int, actual_price_change: float, reward: float) -> bool: + """Resolve a prediction with outcome""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + + # Get original prediction + cursor.execute(""" + SELECT model_name, prediction_type FROM predictions + WHERE id = ? AND outcome_timestamp IS NULL + """, (prediction_id,)) + + result = cursor.fetchone() + if not result: + return False + + model_name, prediction_type = result + + # Determine correctness + is_correct = self._is_prediction_correct(prediction_type, actual_price_change) + + # Update prediction + outcome_timestamp = datetime.now().isoformat() + cursor.execute(""" + UPDATE predictions SET + outcome_timestamp = ?, actual_price_change = ?, + reward = ?, is_correct = ? + WHERE id = ? + """, (outcome_timestamp, actual_price_change, reward, int(is_correct), prediction_id)) + + # Update performance + cursor.execute(""" + UPDATE model_performance SET + correct_predictions = correct_predictions + ?, + total_reward = total_reward + ?, + last_updated = ? + WHERE model_name = ? + """, (int(is_correct), reward, outcome_timestamp, model_name)) + + conn.commit() + return True + + def _is_prediction_correct(self, prediction_type: str, price_change: float) -> bool: + """Check if prediction was correct""" + if prediction_type == "BUY": + return price_change > 0 + elif prediction_type == "SELL": + return price_change < 0 + elif prediction_type == "HOLD": + return abs(price_change) < 0.001 + return False + + def get_model_stats(self, model_name: str) -> Dict[str, Any]: + """Get model performance statistics""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + + cursor.execute(""" + SELECT total_predictions, correct_predictions, total_reward + FROM model_performance WHERE model_name = ? + """, (model_name,)) + + result = cursor.fetchone() + if not result: + return {"model_name": model_name, "total_predictions": 0, "accuracy": 0.0, "total_reward": 0.0} + + total, correct, reward = result + accuracy = (correct / total) if total > 0 else 0.0 + + return { + "model_name": model_name, + "total_predictions": total, + "correct_predictions": correct, + "accuracy": accuracy, + "total_reward": reward + } + + def get_all_model_stats(self) -> List[Dict[str, Any]]: + """Get stats for all models""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + + cursor.execute(""" + SELECT model_name, total_predictions, correct_predictions, total_reward + FROM model_performance ORDER BY total_predictions DESC + """) + + stats = [] + for row in cursor.fetchall(): + model_name, total, correct, reward = row + accuracy = (correct / total) if total > 0 else 0.0 + stats.append({ + "model_name": model_name, + "total_predictions": total, + "correct_predictions": correct, + "accuracy": accuracy, + "total_reward": reward + }) + + return stats + +# Global instance +_prediction_db = None + +def get_prediction_db() -> PredictionDatabase: + """Get global prediction database""" + global _prediction_db + if _prediction_db is None: + _prediction_db = PredictionDatabase() + return _prediction_db diff --git a/data/predictions.db b/data/predictions.db new file mode 100644 index 0000000000000000000000000000000000000000..4c6332405b41455f7d2dc154e79894820d7d2e8b GIT binary patch literal 20480 zcmeI(Jx|*}7zc1W38aChL29KejdT(V5rn)63P=?-;L@l`py47llVxnvT1w(zU)qum zl`1;0_jB}9bnMCp=-h>&AE4)$5GUqERYeE%KPAHPJ$F9;o;xWAX75$mWCl5~?3T(% zQJ4`#QFum(AP5oup5X663GoB>+@OeqgKK_9gvEC!@px2-O!tJi8c$63?ti}~9#s&4 z00bZa0SG_<0uZ>P0-v|T(PTC&9zSL3e$!}Ky3tfRhTZ=v-qwOABAYc?qB5b? zH$FIrO}?Q~)Go97pOpZVk`J76w>kIRdA1vxOwP}XjebhbyK;=fuE7&B_Axw?jo9epFUsTVQ8c;U&SR!HT8?c`0`1lI^8K%c|R4oN`r|Sp9<9- zN~-m8IT-02wf3#1H({=h8jew6N1gNeb#=IAwGT|)RrJz*btE%ehQriW=jIWZ8qc$T z?%mY)$h2|SN+v^gx=ge90y!{^rtSnHEU#J_9ad0H)mXQ`eB424Z`Jl2?@oh}zF8lv zG#%x9Nez-5j0ucUvkiA;a0^096?u534YDSu;<9yPKiDQAkP#^#S z2tWV=5P$##AOHafKmY;|xJ7}U7#C8hhfygcMa5W=@*TXgS*^