cleanup
This commit is contained in:
@ -1,276 +0,0 @@
|
||||
"""
|
||||
CNN Dashboard Integration
|
||||
|
||||
This module integrates the EnhancedCNN model with the dashboard, providing real-time
|
||||
training and visualization of model predictions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
import os
|
||||
import json
|
||||
|
||||
from .enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
from .data_models import BaseDataInput, ModelOutput, create_model_output
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CNNDashboardIntegration:
|
||||
"""
|
||||
Integrates the EnhancedCNN model with the dashboard
|
||||
|
||||
This class:
|
||||
1. Loads and initializes the CNN model
|
||||
2. Processes real-time data for model inference
|
||||
3. Manages continuous training of the model
|
||||
4. Provides visualization data for the dashboard
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider=None, checkpoint_dir: str = "models/enhanced_cnn"):
|
||||
"""
|
||||
Initialize the CNN dashboard integration
|
||||
|
||||
Args:
|
||||
data_provider: Data provider instance
|
||||
checkpoint_dir: Directory to save checkpoints to
|
||||
"""
|
||||
self.data_provider = data_provider
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.cnn_adapter = None
|
||||
self.training_thread = None
|
||||
self.training_active = False
|
||||
self.training_interval = 60 # Train every 60 seconds
|
||||
self.training_samples = []
|
||||
self.max_training_samples = 1000
|
||||
self.last_training_time = 0
|
||||
self.last_predictions = {}
|
||||
self.performance_metrics = {}
|
||||
self.model_name = "enhanced_cnn_v1"
|
||||
|
||||
# Create checkpoint directory if it doesn't exist
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Initialize CNN adapter
|
||||
self._initialize_cnn_adapter()
|
||||
|
||||
logger.info(f"CNNDashboardIntegration initialized with checkpoint_dir: {checkpoint_dir}")
|
||||
|
||||
def _initialize_cnn_adapter(self):
|
||||
"""Initialize the CNN adapter"""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from .enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
|
||||
# Create CNN adapter
|
||||
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=self.checkpoint_dir)
|
||||
|
||||
# Load best checkpoint if available
|
||||
self.cnn_adapter.load_best_checkpoint()
|
||||
|
||||
logger.info("CNN adapter initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing CNN adapter: {e}")
|
||||
self.cnn_adapter = None
|
||||
|
||||
def start_training_thread(self):
|
||||
"""Start the training thread"""
|
||||
if self.training_thread is not None and self.training_thread.is_alive():
|
||||
logger.info("Training thread already running")
|
||||
return
|
||||
|
||||
self.training_active = True
|
||||
self.training_thread = threading.Thread(target=self._training_loop, daemon=True)
|
||||
self.training_thread.start()
|
||||
|
||||
logger.info("CNN training thread started")
|
||||
|
||||
def stop_training_thread(self):
|
||||
"""Stop the training thread"""
|
||||
self.training_active = False
|
||||
if self.training_thread is not None:
|
||||
self.training_thread.join(timeout=5)
|
||||
self.training_thread = None
|
||||
|
||||
logger.info("CNN training thread stopped")
|
||||
|
||||
def _training_loop(self):
|
||||
"""Training loop for continuous model training"""
|
||||
while self.training_active:
|
||||
try:
|
||||
# Check if it's time to train
|
||||
current_time = time.time()
|
||||
if current_time - self.last_training_time >= self.training_interval and len(self.training_samples) >= 10:
|
||||
logger.info(f"Training CNN model with {len(self.training_samples)} samples")
|
||||
|
||||
# Train model
|
||||
if self.cnn_adapter is not None:
|
||||
metrics = self.cnn_adapter.train(epochs=1)
|
||||
|
||||
# Update performance metrics
|
||||
self.performance_metrics = {
|
||||
'loss': metrics.get('loss', 0.0),
|
||||
'accuracy': metrics.get('accuracy', 0.0),
|
||||
'samples': metrics.get('samples', 0),
|
||||
'last_training': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Log training metrics
|
||||
logger.info(f"CNN training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}")
|
||||
|
||||
# Update last training time
|
||||
self.last_training_time = current_time
|
||||
|
||||
# Sleep to avoid high CPU usage
|
||||
time.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN training loop: {e}")
|
||||
time.sleep(5) # Sleep longer on error
|
||||
|
||||
def process_data(self, symbol: str, base_data: BaseDataInput) -> Optional[ModelOutput]:
|
||||
"""
|
||||
Process data for model inference and training
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
base_data: Standardized input data
|
||||
|
||||
Returns:
|
||||
Optional[ModelOutput]: Model output, or None if processing failed
|
||||
"""
|
||||
try:
|
||||
if self.cnn_adapter is None:
|
||||
logger.warning("CNN adapter not initialized")
|
||||
return None
|
||||
|
||||
# Make prediction
|
||||
model_output = self.cnn_adapter.predict(base_data)
|
||||
|
||||
# Store prediction
|
||||
self.last_predictions[symbol] = model_output
|
||||
|
||||
# Store model output in data provider
|
||||
if self.data_provider is not None:
|
||||
self.data_provider.store_model_output(model_output)
|
||||
|
||||
return model_output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing data for CNN model: {e}")
|
||||
return None
|
||||
|
||||
def add_training_sample(self, base_data: BaseDataInput, actual_action: str, reward: float):
|
||||
"""
|
||||
Add a training sample
|
||||
|
||||
Args:
|
||||
base_data: Standardized input data
|
||||
actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
|
||||
reward: Reward received for the action
|
||||
"""
|
||||
try:
|
||||
if self.cnn_adapter is None:
|
||||
logger.warning("CNN adapter not initialized")
|
||||
return
|
||||
|
||||
# Add training sample to CNN adapter
|
||||
self.cnn_adapter.add_training_sample(base_data, actual_action, reward)
|
||||
|
||||
# Add to local training samples
|
||||
self.training_samples.append((base_data.symbol, actual_action, reward))
|
||||
|
||||
# Limit training samples
|
||||
if len(self.training_samples) > self.max_training_samples:
|
||||
self.training_samples = self.training_samples[-self.max_training_samples:]
|
||||
|
||||
logger.debug(f"Added training sample for {base_data.symbol}, action: {actual_action}, reward: {reward:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding training sample: {e}")
|
||||
|
||||
def get_performance_metrics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get performance metrics
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Performance metrics
|
||||
"""
|
||||
metrics = self.performance_metrics.copy()
|
||||
|
||||
# Add additional metrics
|
||||
metrics['training_samples'] = len(self.training_samples)
|
||||
metrics['model_name'] = self.model_name
|
||||
|
||||
# Add last prediction metrics
|
||||
if self.last_predictions:
|
||||
for symbol, prediction in self.last_predictions.items():
|
||||
metrics[f'{symbol}_last_action'] = prediction.predictions.get('action', 'UNKNOWN')
|
||||
metrics[f'{symbol}_last_confidence'] = prediction.confidence
|
||||
|
||||
return metrics
|
||||
|
||||
def get_visualization_data(self, symbol: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get visualization data for the dashboard
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Visualization data
|
||||
"""
|
||||
data = {
|
||||
'model_name': self.model_name,
|
||||
'symbol': symbol,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'performance_metrics': self.get_performance_metrics()
|
||||
}
|
||||
|
||||
# Add last prediction
|
||||
if symbol in self.last_predictions:
|
||||
prediction = self.last_predictions[symbol]
|
||||
data['last_prediction'] = {
|
||||
'action': prediction.predictions.get('action', 'UNKNOWN'),
|
||||
'confidence': prediction.confidence,
|
||||
'timestamp': prediction.timestamp.isoformat(),
|
||||
'buy_probability': prediction.predictions.get('buy_probability', 0.0),
|
||||
'sell_probability': prediction.predictions.get('sell_probability', 0.0),
|
||||
'hold_probability': prediction.predictions.get('hold_probability', 0.0)
|
||||
}
|
||||
|
||||
# Add training samples summary
|
||||
symbol_samples = [s for s in self.training_samples if s[0] == symbol]
|
||||
data['training_samples'] = {
|
||||
'total': len(symbol_samples),
|
||||
'buy': len([s for s in symbol_samples if s[1] == 'BUY']),
|
||||
'sell': len([s for s in symbol_samples if s[1] == 'SELL']),
|
||||
'hold': len([s for s in symbol_samples if s[1] == 'HOLD']),
|
||||
'avg_reward': sum(s[2] for s in symbol_samples) / len(symbol_samples) if symbol_samples else 0.0
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
# Global CNN dashboard integration instance
|
||||
_cnn_dashboard_integration = None
|
||||
|
||||
def get_cnn_dashboard_integration(data_provider=None) -> CNNDashboardIntegration:
|
||||
"""
|
||||
Get the global CNN dashboard integration instance
|
||||
|
||||
Args:
|
||||
data_provider: Data provider instance
|
||||
|
||||
Returns:
|
||||
CNNDashboardIntegration: Global CNN dashboard integration instance
|
||||
"""
|
||||
global _cnn_dashboard_integration
|
||||
|
||||
if _cnn_dashboard_integration is None:
|
||||
_cnn_dashboard_integration = CNNDashboardIntegration(data_provider=data_provider)
|
||||
|
||||
return _cnn_dashboard_integration
|
@ -101,9 +101,20 @@ class COBIntegration:
|
||||
|
||||
# Initialize COB provider as fallback
|
||||
try:
|
||||
# Create default exchange configs
|
||||
exchange_configs = {
|
||||
'binance': {
|
||||
'name': 'binance',
|
||||
'enabled': True,
|
||||
'websocket_url': 'wss://stream.binance.com:9443/ws/',
|
||||
'rest_api_url': 'https://api.binance.com/api/v3/',
|
||||
'rate_limits': {'requests_per_minute': 1200}
|
||||
}
|
||||
}
|
||||
|
||||
self.cob_provider = MultiExchangeCOBProvider(
|
||||
symbols=self.symbols,
|
||||
bucket_size_bps=1.0 # 1 basis point granularity
|
||||
exchange_configs=exchange_configs
|
||||
)
|
||||
|
||||
# Register callbacks
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -214,13 +214,8 @@ class TradingOrchestrator:
|
||||
# Training tracking
|
||||
self.last_trained_symbols: Dict[str, datetime] = {}
|
||||
|
||||
# INFERENCE DATA STORAGE - Per-model storage with memory optimization
|
||||
self.inference_history: Dict[str, deque] = {} # {model_name: deque of last 5 inference records}
|
||||
self.max_memory_inferences = 5 # Keep only last 5 inferences in memory per model
|
||||
self.max_disk_files_per_model = 200 # Cap disk files per model
|
||||
|
||||
# Initialize inference history for each model (will be populated as models make predictions)
|
||||
# We'll create entries dynamically as models are used
|
||||
# SIMPLIFIED INFERENCE DATA STORAGE - Single last inference per model
|
||||
self.last_inference: Dict[str, Dict] = {} # {model_name: last_inference_record}
|
||||
|
||||
# Initialize inference logger
|
||||
self.inference_logger = get_inference_logger()
|
||||
@ -240,10 +235,16 @@ class TradingOrchestrator:
|
||||
logger.info(f"Primary symbol: {self.symbol}, Reference symbols: {self.ref_symbols}")
|
||||
logger.info("Universal Data Adapter integrated for centralized data flow")
|
||||
|
||||
# Start centralized data collection for all models and dashboard
|
||||
logger.info("Starting centralized data collection...")
|
||||
self.data_provider.start_centralized_data_collection()
|
||||
logger.info("Centralized data collection started - all models and dashboard will receive data")
|
||||
# Start data collection if available
|
||||
logger.info("Starting data collection...")
|
||||
if hasattr(self.data_provider, 'start_centralized_data_collection'):
|
||||
self.data_provider.start_centralized_data_collection()
|
||||
logger.info("Centralized data collection started - all models and dashboard will receive data")
|
||||
elif hasattr(self.data_provider, 'start_training_data_collection'):
|
||||
self.data_provider.start_training_data_collection()
|
||||
logger.info("Training data collection started")
|
||||
else:
|
||||
logger.info("Data provider does not require explicit data collection startup")
|
||||
|
||||
# Data provider is already initialized and optimized
|
||||
|
||||
@ -683,13 +684,10 @@ class TradingOrchestrator:
|
||||
self.sensitivity_learning_queue = []
|
||||
self.perfect_move_buffer = []
|
||||
|
||||
# Clear inference history (but keep recent for training)
|
||||
for model_name in list(self.inference_history.keys()):
|
||||
# Keep only the last inference for each model to maintain training capability
|
||||
if len(self.inference_history[model_name]) > 1:
|
||||
last_inference = self.inference_history[model_name][-1]
|
||||
self.inference_history[model_name].clear()
|
||||
self.inference_history[model_name].append(last_inference)
|
||||
# Clear any outcome evaluation flags for last inferences
|
||||
for model_name in self.last_inference:
|
||||
if self.last_inference[model_name]:
|
||||
self.last_inference[model_name]['outcome_evaluated'] = False
|
||||
|
||||
# Clear fusion training data
|
||||
self.fusion_training_data = []
|
||||
@ -1114,10 +1112,10 @@ class TradingOrchestrator:
|
||||
if model.name not in self.model_performance:
|
||||
self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
|
||||
|
||||
# Initialize inference history for this model
|
||||
if model.name not in self.inference_history:
|
||||
self.inference_history[model.name] = deque(maxlen=self.max_memory_inferences)
|
||||
logger.debug(f"Initialized inference history for {model.name}")
|
||||
# Initialize last inference storage for this model
|
||||
if model.name not in self.last_inference:
|
||||
self.last_inference[model.name] = None
|
||||
logger.debug(f"Initialized last inference storage for {model.name}")
|
||||
|
||||
logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}")
|
||||
self._normalize_weights()
|
||||
@ -1320,12 +1318,7 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error getting prediction from {model_name}: {e}")
|
||||
continue
|
||||
|
||||
# Debug: Log inference history status (only if low record count)
|
||||
total_records = sum(len(history) for history in self.inference_history.values())
|
||||
if total_records < 10: # Only log when we have few records
|
||||
logger.debug(f"Total inference records across all models: {total_records}")
|
||||
for model_name, history in self.inference_history.items():
|
||||
logger.debug(f" {model_name}: {len(history)} records")
|
||||
|
||||
|
||||
# Trigger training based on previous inference data
|
||||
await self._trigger_model_training(symbol)
|
||||
@ -1392,17 +1385,15 @@ class TradingOrchestrator:
|
||||
return {}
|
||||
|
||||
async def _store_inference_data_async(self, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime, symbol: str = None):
|
||||
"""Store inference data per-model with async file operations and memory optimization"""
|
||||
"""Store last inference in memory and all inferences to database for future training"""
|
||||
try:
|
||||
# Only log first few inference records to avoid spam
|
||||
if len(self.inference_history.get(model_name, [])) < 3:
|
||||
logger.debug(f"Storing inference data for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
|
||||
logger.debug(f"Storing inference for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
|
||||
|
||||
# Extract symbol from prediction if not provided
|
||||
if symbol is None:
|
||||
symbol = getattr(prediction, 'symbol', 'ETH/USDT') # Default to ETH/USDT if not available
|
||||
|
||||
# Create comprehensive inference record
|
||||
# Create inference record - store only what's needed for training
|
||||
inference_record = {
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'symbol': symbol,
|
||||
@ -1414,227 +1405,153 @@ class TradingOrchestrator:
|
||||
'probabilities': prediction.probabilities,
|
||||
'timeframe': prediction.timeframe
|
||||
},
|
||||
'metadata': prediction.metadata or {}
|
||||
'metadata': prediction.metadata or {},
|
||||
'training_outcome': None, # Will be set when training occurs
|
||||
'outcome_evaluated': False
|
||||
}
|
||||
|
||||
# Store in memory (only last 5 per model)
|
||||
if model_name not in self.inference_history:
|
||||
self.inference_history[model_name] = deque(maxlen=self.max_memory_inferences)
|
||||
# Store only the last inference per model (for immediate training)
|
||||
self.last_inference[model_name] = inference_record
|
||||
|
||||
self.inference_history[model_name].append(inference_record)
|
||||
# Also save to database using database manager for future training and analysis
|
||||
asyncio.create_task(self._save_to_database_manager_async(model_name, inference_record))
|
||||
|
||||
# Async file storage (don't wait for completion)
|
||||
asyncio.create_task(self._save_inference_to_disk_async(model_name, inference_record))
|
||||
|
||||
logger.debug(f"Stored inference data for {model_name} (memory: {len(self.inference_history[model_name])}/5)")
|
||||
logger.debug(f"Stored last inference for {model_name} and queued database save")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing inference data for {model_name}: {e}")
|
||||
|
||||
async def _save_inference_to_disk_async(self, model_name: str, inference_record: Dict):
|
||||
"""Async save inference record to SQLite database and model-specific log"""
|
||||
try:
|
||||
# Use SQLite for comprehensive storage
|
||||
await self._save_to_sqlite_db(model_name, inference_record)
|
||||
|
||||
# Also save key metrics to model-specific log for debugging
|
||||
await self._save_to_model_log(model_name, inference_record)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving inference to disk for {model_name}: {e}")
|
||||
|
||||
async def _save_to_sqlite_db(self, model_name: str, inference_record: Dict):
|
||||
"""Save inference record to SQLite database"""
|
||||
import sqlite3
|
||||
async def _save_to_database_manager_async(self, model_name: str, inference_record: Dict):
|
||||
"""Save inference record using DatabaseManager for future training"""
|
||||
import hashlib
|
||||
import asyncio
|
||||
|
||||
def save_to_db():
|
||||
try:
|
||||
# Create database directory
|
||||
db_dir = Path("training_data/inference_db")
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Connect to SQLite database
|
||||
db_path = db_dir / "inference_history.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create table if it doesn't exist
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS inference_records (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
confidence REAL NOT NULL,
|
||||
probabilities TEXT,
|
||||
timeframe TEXT,
|
||||
metadata TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
# Create index for faster queries
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_model_timestamp
|
||||
ON inference_records(model_name, timestamp)
|
||||
''')
|
||||
|
||||
# Extract data from inference record
|
||||
prediction = inference_record.get('prediction', {})
|
||||
probabilities_str = str(prediction.get('probabilities', {}))
|
||||
metadata_str = str(inference_record.get('metadata', {}))
|
||||
symbol = inference_record.get('symbol', 'ETH/USDT')
|
||||
timestamp_str = inference_record.get('timestamp', '')
|
||||
|
||||
# Insert record
|
||||
cursor.execute('''
|
||||
INSERT INTO inference_records
|
||||
(model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
model_name,
|
||||
inference_record.get('symbol', 'ETH/USDT'),
|
||||
inference_record.get('timestamp', ''),
|
||||
prediction.get('action', 'HOLD'),
|
||||
prediction.get('confidence', 0.0),
|
||||
probabilities_str,
|
||||
prediction.get('timeframe', '1m'),
|
||||
metadata_str
|
||||
))
|
||||
# Parse timestamp
|
||||
if isinstance(timestamp_str, str):
|
||||
timestamp = datetime.fromisoformat(timestamp_str)
|
||||
else:
|
||||
timestamp = timestamp_str
|
||||
|
||||
# Clean up old records (keep only last 1000 per model)
|
||||
cursor.execute('''
|
||||
DELETE FROM inference_records
|
||||
WHERE model_name = ? AND id NOT IN (
|
||||
SELECT id FROM inference_records
|
||||
WHERE model_name = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1000
|
||||
)
|
||||
''', (model_name, model_name))
|
||||
# Create hash of input features for deduplication
|
||||
model_input = inference_record.get('model_input')
|
||||
input_features_hash = "unknown"
|
||||
input_features_array = None
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
if model_input is not None:
|
||||
# Convert to numpy array if possible
|
||||
try:
|
||||
if hasattr(model_input, 'numpy'): # PyTorch tensor
|
||||
input_features_array = model_input.detach().cpu().numpy()
|
||||
elif isinstance(model_input, np.ndarray):
|
||||
input_features_array = model_input
|
||||
elif isinstance(model_input, (list, tuple)):
|
||||
input_features_array = np.array(model_input)
|
||||
|
||||
# Create hash of the input features
|
||||
if input_features_array is not None:
|
||||
input_features_hash = hashlib.md5(input_features_array.tobytes()).hexdigest()[:16]
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not process input features for hashing: {e}")
|
||||
|
||||
# Create InferenceRecord using the database manager's structure
|
||||
from utils.database_manager import InferenceRecord
|
||||
|
||||
db_record = InferenceRecord(
|
||||
model_name=model_name,
|
||||
timestamp=timestamp,
|
||||
symbol=symbol,
|
||||
action=prediction.get('action', 'HOLD'),
|
||||
confidence=prediction.get('confidence', 0.0),
|
||||
probabilities=prediction.get('probabilities', {}),
|
||||
input_features_hash=input_features_hash,
|
||||
processing_time_ms=0.0, # We don't track this in orchestrator
|
||||
memory_usage_mb=0.0, # We don't track this in orchestrator
|
||||
input_features=input_features_array,
|
||||
checkpoint_id=None,
|
||||
metadata=inference_record.get('metadata', {})
|
||||
)
|
||||
|
||||
# Log using database manager
|
||||
success = self.db_manager.log_inference(db_record)
|
||||
|
||||
if success:
|
||||
logger.debug(f"Saved inference to database for {model_name}")
|
||||
else:
|
||||
logger.warning(f"Failed to save inference to database for {model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving to SQLite database: {e}")
|
||||
logger.error(f"Error saving to database manager: {e}")
|
||||
|
||||
# Run database operation in thread pool to avoid blocking
|
||||
await asyncio.get_event_loop().run_in_executor(None, save_to_db)
|
||||
|
||||
async def _save_to_model_log(self, model_name: str, inference_record: Dict):
|
||||
"""Save key inference metrics to model-specific log file for debugging"""
|
||||
import asyncio
|
||||
|
||||
def save_to_log():
|
||||
try:
|
||||
# Create logs directory
|
||||
logs_dir = Path("logs/model_inference")
|
||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create model-specific log file
|
||||
log_file = logs_dir / f"{model_name}_inference.log"
|
||||
|
||||
# Extract key metrics
|
||||
prediction = inference_record.get('prediction', {})
|
||||
timestamp = inference_record.get('timestamp', '')
|
||||
symbol = inference_record.get('symbol', 'N/A')
|
||||
|
||||
# Format log entry with key metrics
|
||||
log_entry = (
|
||||
f"{timestamp} | "
|
||||
f"Symbol: {symbol} | "
|
||||
f"Action: {prediction.get('action', 'N/A'):4} | "
|
||||
f"Confidence: {prediction.get('confidence', 0.0):6.3f} | "
|
||||
f"Timeframe: {prediction.get('timeframe', 'N/A'):3} | "
|
||||
f"Probs: BUY={prediction.get('probabilities', {}).get('BUY', 0.0):5.3f} "
|
||||
f"SELL={prediction.get('probabilities', {}).get('SELL', 0.0):5.3f} "
|
||||
f"HOLD={prediction.get('probabilities', {}).get('HOLD', 0.0):5.3f}\n"
|
||||
)
|
||||
|
||||
# Append to log file
|
||||
with open(log_file, 'a', encoding='utf-8') as f:
|
||||
f.write(log_entry)
|
||||
|
||||
# Keep log files manageable (rotate when > 10MB)
|
||||
if log_file.stat().st_size > 10 * 1024 * 1024: # 10MB
|
||||
self._rotate_log_file(log_file)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving to model log: {e}")
|
||||
|
||||
# Run log operation in thread pool to avoid blocking
|
||||
await asyncio.get_event_loop().run_in_executor(None, save_to_log)
|
||||
|
||||
|
||||
def _rotate_log_file(self, log_file: Path):
|
||||
"""Rotate log file when it gets too large"""
|
||||
try:
|
||||
# Keep last 1000 lines
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Write back only the last 1000 lines
|
||||
with open(log_file, 'w', encoding='utf-8') as f:
|
||||
f.writelines(lines[-1000:])
|
||||
|
||||
logger.debug(f"Rotated log file {log_file.name} (kept last 1000 lines)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error rotating log file {log_file}: {e}")
|
||||
|
||||
def get_inference_records_from_db(self, model_name: str = None, limit: int = 100) -> List[Dict]:
|
||||
"""Get inference records from SQLite database"""
|
||||
import sqlite3
|
||||
|
||||
try:
|
||||
# Connect to database
|
||||
db_path = Path("training_data/inference_db/inference_history.db")
|
||||
if not db_path.exists():
|
||||
return []
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Query records
|
||||
if model_name:
|
||||
cursor.execute('''
|
||||
SELECT model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata
|
||||
FROM inference_records
|
||||
WHERE model_name = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
''', (model_name, limit))
|
||||
else:
|
||||
cursor.execute('''
|
||||
SELECT model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata
|
||||
FROM inference_records
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
''', (limit,))
|
||||
|
||||
records = []
|
||||
for row in cursor.fetchall():
|
||||
record = {
|
||||
'model_name': row[0],
|
||||
'symbol': row[1],
|
||||
'timestamp': row[2],
|
||||
'prediction': {
|
||||
'action': row[3],
|
||||
'confidence': row[4],
|
||||
'probabilities': eval(row[5]) if row[5] else {},
|
||||
'timeframe': row[6]
|
||||
},
|
||||
'metadata': eval(row[7]) if row[7] else {}
|
||||
|
||||
def get_last_inference_status(self) -> Dict[str, Any]:
|
||||
"""Get status of last inferences for all models"""
|
||||
status = {}
|
||||
for model_name, inference in self.last_inference.items():
|
||||
if inference:
|
||||
status[model_name] = {
|
||||
'timestamp': inference.get('timestamp'),
|
||||
'symbol': inference.get('symbol'),
|
||||
'action': inference.get('prediction', {}).get('action'),
|
||||
'confidence': inference.get('prediction', {}).get('confidence'),
|
||||
'outcome_evaluated': inference.get('outcome_evaluated', False),
|
||||
'training_outcome': inference.get('training_outcome')
|
||||
}
|
||||
records.append(record)
|
||||
else:
|
||||
status[model_name] = None
|
||||
return status
|
||||
|
||||
def get_training_data_from_db(self, model_name: str, symbol: str = None, hours_back: int = 24, limit: int = 1000) -> List[Dict]:
|
||||
"""Get inference records for training from database manager"""
|
||||
try:
|
||||
# Use database manager's method specifically for training data
|
||||
db_records = self.db_manager.get_inference_records_for_training(
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
hours_back=hours_back,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
conn.close()
|
||||
# Convert to our format
|
||||
records = []
|
||||
for db_record in db_records:
|
||||
try:
|
||||
record = {
|
||||
'model_name': db_record.model_name,
|
||||
'symbol': db_record.symbol,
|
||||
'timestamp': db_record.timestamp.isoformat(),
|
||||
'prediction': {
|
||||
'action': db_record.action,
|
||||
'confidence': db_record.confidence,
|
||||
'probabilities': db_record.probabilities,
|
||||
'timeframe': '1m'
|
||||
},
|
||||
'metadata': db_record.metadata or {},
|
||||
'model_input': db_record.input_features, # Full input features for training
|
||||
'input_features_hash': db_record.input_features_hash
|
||||
}
|
||||
records.append(record)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping malformed training record: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Retrieved {len(records)} training records for {model_name}")
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error querying SQLite database: {e}")
|
||||
logger.error(f"Error getting training data from database: {e}")
|
||||
return []
|
||||
|
||||
|
||||
|
||||
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor:
|
||||
@ -1763,197 +1680,58 @@ class TradingOrchestrator:
|
||||
'outcome_evaluated': False
|
||||
}
|
||||
|
||||
# Store in memory (inference history) - keyed by model_name
|
||||
if model_name not in self.inference_history:
|
||||
self.inference_history[model_name] = deque(maxlen=self.max_memory_inferences)
|
||||
# Store only the last inference per model (for immediate training)
|
||||
self.last_inference[model_name] = inference_record
|
||||
|
||||
self.inference_history[model_name].append(inference_record)
|
||||
logger.debug(f"Stored inference data for {model_name} on {symbol}")
|
||||
# Also save to database using database manager for future training (run in background)
|
||||
asyncio.create_task(self._save_to_database_manager_async(model_name, inference_record))
|
||||
|
||||
# Persistent storage to disk (for long-term training data)
|
||||
self._save_inference_to_disk(inference_record)
|
||||
logger.debug(f"Stored last inference for {model_name} on {symbol} and queued database save")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing inference data: {e}")
|
||||
|
||||
def _save_inference_to_disk(self, inference_record: Dict):
|
||||
"""Save inference record to persistent storage"""
|
||||
try:
|
||||
# Create inference data directory
|
||||
inference_dir = Path("training_data/inference_history")
|
||||
inference_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create filename with timestamp and model name
|
||||
timestamp_str = inference_record['timestamp'].strftime('%Y%m%d_%H%M%S')
|
||||
filename = f"{inference_record['symbol']}_{inference_record['model_name']}_{timestamp_str}.json"
|
||||
filepath = inference_dir / filename
|
||||
|
||||
# Convert numpy arrays to lists for JSON serialization
|
||||
serializable_record = self._make_json_serializable(inference_record)
|
||||
|
||||
# Save to JSON file
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(serializable_record, f, indent=2)
|
||||
|
||||
logger.debug(f"Saved inference record to disk: {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving inference to disk: {e}")
|
||||
|
||||
|
||||
|
||||
def _make_json_serializable(self, obj):
|
||||
"""Convert object to JSON-serializable format"""
|
||||
if isinstance(obj, dict):
|
||||
return {k: self._make_json_serializable(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [self._make_json_serializable(item) for item in obj]
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
elif isinstance(obj, (np.integer, np.floating)):
|
||||
return obj.item()
|
||||
elif isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
else:
|
||||
return obj
|
||||
|
||||
def load_inference_history_from_disk(self, symbol: str, days_back: int = 7) -> List[Dict]:
|
||||
"""Load inference history from SQLite database for training replay"""
|
||||
try:
|
||||
import sqlite3
|
||||
|
||||
# Connect to database
|
||||
db_path = Path("training_data/inference_db/inference_history.db")
|
||||
if not db_path.exists():
|
||||
return []
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get records for the symbol from the last N days
|
||||
cutoff_date = (datetime.now() - timedelta(days=days_back)).isoformat()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata
|
||||
FROM inference_records
|
||||
WHERE symbol = ? AND timestamp >= ?
|
||||
ORDER BY timestamp ASC
|
||||
''', (symbol, cutoff_date))
|
||||
|
||||
inference_records = []
|
||||
for row in cursor.fetchall():
|
||||
record = {
|
||||
'model_name': row[0],
|
||||
'symbol': row[1],
|
||||
'timestamp': row[2],
|
||||
'prediction': {
|
||||
'action': row[3],
|
||||
'confidence': row[4],
|
||||
'probabilities': eval(row[5]) if row[5] else {},
|
||||
'timeframe': row[6]
|
||||
},
|
||||
'metadata': eval(row[7]) if row[7] else {}
|
||||
}
|
||||
inference_records.append(record)
|
||||
|
||||
conn.close()
|
||||
logger.info(f"Loaded {len(inference_records)} inference records for {symbol} from SQLite database")
|
||||
|
||||
return inference_records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading inference history from database: {e}")
|
||||
return []
|
||||
|
||||
async def load_model_inference_history(self, model_name: str, limit: int = 50) -> List[Dict]:
|
||||
"""Load inference history for a specific model from SQLite database"""
|
||||
try:
|
||||
# Use the SQLite database method
|
||||
records = self.get_inference_records_from_db(model_name, limit)
|
||||
logger.info(f"Loaded {len(records)} inference records for {model_name} from database")
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model inference history for {model_name}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_model_training_data(self, model_name: str, symbol: str = None) -> List[Dict]:
|
||||
"""Get training data for a specific model"""
|
||||
try:
|
||||
training_data = []
|
||||
|
||||
# Get from memory first
|
||||
if symbol:
|
||||
symbols_to_check = [symbol]
|
||||
else:
|
||||
symbols_to_check = self.symbols
|
||||
# Use database manager to get training data
|
||||
training_data = self.get_training_data_from_db(model_name, symbol)
|
||||
|
||||
for sym in symbols_to_check:
|
||||
if sym in self.inference_history:
|
||||
for record in self.inference_history[sym]:
|
||||
if record['model_name'] == model_name:
|
||||
training_data.append(record)
|
||||
|
||||
# Also load from disk for more comprehensive training data
|
||||
for sym in symbols_to_check:
|
||||
disk_records = self.load_inference_history_from_disk(sym)
|
||||
for record in disk_records:
|
||||
if record['model_name'] == model_name:
|
||||
training_data.append(record)
|
||||
|
||||
# Remove duplicates and sort by timestamp
|
||||
seen_timestamps = set()
|
||||
unique_data = []
|
||||
for record in training_data:
|
||||
timestamp_key = f"{record['timestamp']}_{record['symbol']}"
|
||||
if timestamp_key not in seen_timestamps:
|
||||
seen_timestamps.add(timestamp_key)
|
||||
unique_data.append(record)
|
||||
|
||||
unique_data.sort(key=lambda x: x['timestamp'])
|
||||
logger.info(f"Retrieved {len(unique_data)} training records for {model_name}")
|
||||
|
||||
return unique_data
|
||||
logger.info(f"Retrieved {len(training_data)} training records for {model_name}")
|
||||
return training_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model training data: {e}")
|
||||
return []
|
||||
|
||||
async def _trigger_model_training(self, symbol: str):
|
||||
"""Trigger training for models based on previous inference data"""
|
||||
"""Trigger training for models based on their last inference"""
|
||||
try:
|
||||
if not self.training_enabled:
|
||||
logger.debug("Training disabled, skipping model training")
|
||||
return
|
||||
|
||||
# Check if we have any inference history for any model
|
||||
if not self.inference_history:
|
||||
logger.debug("No inference history available for training")
|
||||
# Check if we have any last inferences for any model
|
||||
if not self.last_inference:
|
||||
logger.debug("No inference data available for training")
|
||||
return
|
||||
|
||||
# Get recent inference records from all models (not symbol-based)
|
||||
all_recent_records = []
|
||||
for model_name, model_records in self.inference_history.items():
|
||||
all_recent_records.extend(list(model_records))
|
||||
|
||||
# Only log if we have few records (for debugging)
|
||||
if len(all_recent_records) < 5:
|
||||
logger.debug(f"Total inference records for training: {len(all_recent_records)}")
|
||||
for model_name, model_records in self.inference_history.items():
|
||||
logger.debug(f" Model {model_name} has {len(model_records)} inference records")
|
||||
|
||||
if len(all_recent_records) < 2:
|
||||
logger.debug("Not enough inference records for training")
|
||||
return # Need at least 2 records to compare
|
||||
|
||||
# Get current price for outcome evaluation
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
if current_price is None:
|
||||
return
|
||||
|
||||
# Train on the most recent inference record (last prediction made)
|
||||
if all_recent_records:
|
||||
# Get the most recent record for training
|
||||
most_recent_record = max(all_recent_records, key=lambda x: datetime.fromisoformat(x['timestamp']) if isinstance(x['timestamp'], str) else x['timestamp'])
|
||||
await self._evaluate_and_train_on_record(most_recent_record, current_price)
|
||||
# Train each model based on its last inference
|
||||
for model_name, last_inference_record in self.last_inference.items():
|
||||
if last_inference_record and not last_inference_record.get('outcome_evaluated', False):
|
||||
await self._evaluate_and_train_on_record(last_inference_record, current_price)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering model training for {symbol}: {e}")
|
||||
@ -2011,6 +1789,16 @@ class TradingOrchestrator:
|
||||
# Train the specific model based on sophisticated outcome
|
||||
await self._train_model_on_outcome(record, was_correct, price_change_pct, reward)
|
||||
|
||||
# Mark this inference as evaluated to prevent re-training
|
||||
if model_name in self.last_inference and self.last_inference[model_name] == record:
|
||||
self.last_inference[model_name]['outcome_evaluated'] = True
|
||||
self.last_inference[model_name]['training_outcome'] = {
|
||||
'was_correct': was_correct,
|
||||
'reward': reward,
|
||||
'price_change_pct': price_change_pct,
|
||||
'evaluated_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.debug(f"Evaluated {model_name} prediction: {'✓' if was_correct else '✗'} "
|
||||
f"({prediction['action']}, {price_change_pct:.2f}% change, "
|
||||
f"confidence: {prediction_confidence:.3f}, reward: {reward:.3f})")
|
||||
@ -2215,7 +2003,7 @@ class TradingOrchestrator:
|
||||
)
|
||||
predictions.append(prediction)
|
||||
|
||||
# Store prediction in SQLite database for training
|
||||
# Store prediction in database for training
|
||||
logger.debug(f"Added CNN prediction to database: {prediction}")
|
||||
|
||||
# Note: Inference data will be stored in main prediction loop to avoid duplication
|
||||
|
Reference in New Issue
Block a user