This commit is contained in:
Dobromir Popov
2025-07-27 18:31:30 +03:00
parent a94b80c1f4
commit e2c495d83c
15 changed files with 3883 additions and 3335 deletions

View File

@ -1,276 +0,0 @@
"""
CNN Dashboard Integration
This module integrates the EnhancedCNN model with the dashboard, providing real-time
training and visualization of model predictions.
"""
import logging
import threading
import time
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
import os
import json
from .enhanced_cnn_adapter import EnhancedCNNAdapter
from .data_models import BaseDataInput, ModelOutput, create_model_output
from utils.training_integration import get_training_integration
logger = logging.getLogger(__name__)
class CNNDashboardIntegration:
"""
Integrates the EnhancedCNN model with the dashboard
This class:
1. Loads and initializes the CNN model
2. Processes real-time data for model inference
3. Manages continuous training of the model
4. Provides visualization data for the dashboard
"""
def __init__(self, data_provider=None, checkpoint_dir: str = "models/enhanced_cnn"):
"""
Initialize the CNN dashboard integration
Args:
data_provider: Data provider instance
checkpoint_dir: Directory to save checkpoints to
"""
self.data_provider = data_provider
self.checkpoint_dir = checkpoint_dir
self.cnn_adapter = None
self.training_thread = None
self.training_active = False
self.training_interval = 60 # Train every 60 seconds
self.training_samples = []
self.max_training_samples = 1000
self.last_training_time = 0
self.last_predictions = {}
self.performance_metrics = {}
self.model_name = "enhanced_cnn_v1"
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
# Initialize CNN adapter
self._initialize_cnn_adapter()
logger.info(f"CNNDashboardIntegration initialized with checkpoint_dir: {checkpoint_dir}")
def _initialize_cnn_adapter(self):
"""Initialize the CNN adapter"""
try:
# Import here to avoid circular imports
from .enhanced_cnn_adapter import EnhancedCNNAdapter
# Create CNN adapter
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir=self.checkpoint_dir)
# Load best checkpoint if available
self.cnn_adapter.load_best_checkpoint()
logger.info("CNN adapter initialized successfully")
except Exception as e:
logger.error(f"Error initializing CNN adapter: {e}")
self.cnn_adapter = None
def start_training_thread(self):
"""Start the training thread"""
if self.training_thread is not None and self.training_thread.is_alive():
logger.info("Training thread already running")
return
self.training_active = True
self.training_thread = threading.Thread(target=self._training_loop, daemon=True)
self.training_thread.start()
logger.info("CNN training thread started")
def stop_training_thread(self):
"""Stop the training thread"""
self.training_active = False
if self.training_thread is not None:
self.training_thread.join(timeout=5)
self.training_thread = None
logger.info("CNN training thread stopped")
def _training_loop(self):
"""Training loop for continuous model training"""
while self.training_active:
try:
# Check if it's time to train
current_time = time.time()
if current_time - self.last_training_time >= self.training_interval and len(self.training_samples) >= 10:
logger.info(f"Training CNN model with {len(self.training_samples)} samples")
# Train model
if self.cnn_adapter is not None:
metrics = self.cnn_adapter.train(epochs=1)
# Update performance metrics
self.performance_metrics = {
'loss': metrics.get('loss', 0.0),
'accuracy': metrics.get('accuracy', 0.0),
'samples': metrics.get('samples', 0),
'last_training': datetime.now().isoformat()
}
# Log training metrics
logger.info(f"CNN training metrics: loss={metrics.get('loss', 0.0):.4f}, accuracy={metrics.get('accuracy', 0.0):.4f}")
# Update last training time
self.last_training_time = current_time
# Sleep to avoid high CPU usage
time.sleep(1)
except Exception as e:
logger.error(f"Error in CNN training loop: {e}")
time.sleep(5) # Sleep longer on error
def process_data(self, symbol: str, base_data: BaseDataInput) -> Optional[ModelOutput]:
"""
Process data for model inference and training
Args:
symbol: Trading symbol
base_data: Standardized input data
Returns:
Optional[ModelOutput]: Model output, or None if processing failed
"""
try:
if self.cnn_adapter is None:
logger.warning("CNN adapter not initialized")
return None
# Make prediction
model_output = self.cnn_adapter.predict(base_data)
# Store prediction
self.last_predictions[symbol] = model_output
# Store model output in data provider
if self.data_provider is not None:
self.data_provider.store_model_output(model_output)
return model_output
except Exception as e:
logger.error(f"Error processing data for CNN model: {e}")
return None
def add_training_sample(self, base_data: BaseDataInput, actual_action: str, reward: float):
"""
Add a training sample
Args:
base_data: Standardized input data
actual_action: Actual action taken ('BUY', 'SELL', 'HOLD')
reward: Reward received for the action
"""
try:
if self.cnn_adapter is None:
logger.warning("CNN adapter not initialized")
return
# Add training sample to CNN adapter
self.cnn_adapter.add_training_sample(base_data, actual_action, reward)
# Add to local training samples
self.training_samples.append((base_data.symbol, actual_action, reward))
# Limit training samples
if len(self.training_samples) > self.max_training_samples:
self.training_samples = self.training_samples[-self.max_training_samples:]
logger.debug(f"Added training sample for {base_data.symbol}, action: {actual_action}, reward: {reward:.4f}")
except Exception as e:
logger.error(f"Error adding training sample: {e}")
def get_performance_metrics(self) -> Dict[str, Any]:
"""
Get performance metrics
Returns:
Dict[str, Any]: Performance metrics
"""
metrics = self.performance_metrics.copy()
# Add additional metrics
metrics['training_samples'] = len(self.training_samples)
metrics['model_name'] = self.model_name
# Add last prediction metrics
if self.last_predictions:
for symbol, prediction in self.last_predictions.items():
metrics[f'{symbol}_last_action'] = prediction.predictions.get('action', 'UNKNOWN')
metrics[f'{symbol}_last_confidence'] = prediction.confidence
return metrics
def get_visualization_data(self, symbol: str) -> Dict[str, Any]:
"""
Get visualization data for the dashboard
Args:
symbol: Trading symbol
Returns:
Dict[str, Any]: Visualization data
"""
data = {
'model_name': self.model_name,
'symbol': symbol,
'timestamp': datetime.now().isoformat(),
'performance_metrics': self.get_performance_metrics()
}
# Add last prediction
if symbol in self.last_predictions:
prediction = self.last_predictions[symbol]
data['last_prediction'] = {
'action': prediction.predictions.get('action', 'UNKNOWN'),
'confidence': prediction.confidence,
'timestamp': prediction.timestamp.isoformat(),
'buy_probability': prediction.predictions.get('buy_probability', 0.0),
'sell_probability': prediction.predictions.get('sell_probability', 0.0),
'hold_probability': prediction.predictions.get('hold_probability', 0.0)
}
# Add training samples summary
symbol_samples = [s for s in self.training_samples if s[0] == symbol]
data['training_samples'] = {
'total': len(symbol_samples),
'buy': len([s for s in symbol_samples if s[1] == 'BUY']),
'sell': len([s for s in symbol_samples if s[1] == 'SELL']),
'hold': len([s for s in symbol_samples if s[1] == 'HOLD']),
'avg_reward': sum(s[2] for s in symbol_samples) / len(symbol_samples) if symbol_samples else 0.0
}
return data
# Global CNN dashboard integration instance
_cnn_dashboard_integration = None
def get_cnn_dashboard_integration(data_provider=None) -> CNNDashboardIntegration:
"""
Get the global CNN dashboard integration instance
Args:
data_provider: Data provider instance
Returns:
CNNDashboardIntegration: Global CNN dashboard integration instance
"""
global _cnn_dashboard_integration
if _cnn_dashboard_integration is None:
_cnn_dashboard_integration = CNNDashboardIntegration(data_provider=data_provider)
return _cnn_dashboard_integration

View File

@ -101,9 +101,20 @@ class COBIntegration:
# Initialize COB provider as fallback
try:
# Create default exchange configs
exchange_configs = {
'binance': {
'name': 'binance',
'enabled': True,
'websocket_url': 'wss://stream.binance.com:9443/ws/',
'rest_api_url': 'https://api.binance.com/api/v3/',
'rate_limits': {'requests_per_minute': 1200}
}
}
self.cob_provider = MultiExchangeCOBProvider(
symbols=self.symbols,
bucket_size_bps=1.0 # 1 basis point granularity
exchange_configs=exchange_configs
)
# Register callbacks

File diff suppressed because it is too large Load Diff

View File

@ -214,13 +214,8 @@ class TradingOrchestrator:
# Training tracking
self.last_trained_symbols: Dict[str, datetime] = {}
# INFERENCE DATA STORAGE - Per-model storage with memory optimization
self.inference_history: Dict[str, deque] = {} # {model_name: deque of last 5 inference records}
self.max_memory_inferences = 5 # Keep only last 5 inferences in memory per model
self.max_disk_files_per_model = 200 # Cap disk files per model
# Initialize inference history for each model (will be populated as models make predictions)
# We'll create entries dynamically as models are used
# SIMPLIFIED INFERENCE DATA STORAGE - Single last inference per model
self.last_inference: Dict[str, Dict] = {} # {model_name: last_inference_record}
# Initialize inference logger
self.inference_logger = get_inference_logger()
@ -240,10 +235,16 @@ class TradingOrchestrator:
logger.info(f"Primary symbol: {self.symbol}, Reference symbols: {self.ref_symbols}")
logger.info("Universal Data Adapter integrated for centralized data flow")
# Start centralized data collection for all models and dashboard
logger.info("Starting centralized data collection...")
self.data_provider.start_centralized_data_collection()
logger.info("Centralized data collection started - all models and dashboard will receive data")
# Start data collection if available
logger.info("Starting data collection...")
if hasattr(self.data_provider, 'start_centralized_data_collection'):
self.data_provider.start_centralized_data_collection()
logger.info("Centralized data collection started - all models and dashboard will receive data")
elif hasattr(self.data_provider, 'start_training_data_collection'):
self.data_provider.start_training_data_collection()
logger.info("Training data collection started")
else:
logger.info("Data provider does not require explicit data collection startup")
# Data provider is already initialized and optimized
@ -683,13 +684,10 @@ class TradingOrchestrator:
self.sensitivity_learning_queue = []
self.perfect_move_buffer = []
# Clear inference history (but keep recent for training)
for model_name in list(self.inference_history.keys()):
# Keep only the last inference for each model to maintain training capability
if len(self.inference_history[model_name]) > 1:
last_inference = self.inference_history[model_name][-1]
self.inference_history[model_name].clear()
self.inference_history[model_name].append(last_inference)
# Clear any outcome evaluation flags for last inferences
for model_name in self.last_inference:
if self.last_inference[model_name]:
self.last_inference[model_name]['outcome_evaluated'] = False
# Clear fusion training data
self.fusion_training_data = []
@ -1114,10 +1112,10 @@ class TradingOrchestrator:
if model.name not in self.model_performance:
self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
# Initialize inference history for this model
if model.name not in self.inference_history:
self.inference_history[model.name] = deque(maxlen=self.max_memory_inferences)
logger.debug(f"Initialized inference history for {model.name}")
# Initialize last inference storage for this model
if model.name not in self.last_inference:
self.last_inference[model.name] = None
logger.debug(f"Initialized last inference storage for {model.name}")
logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}")
self._normalize_weights()
@ -1320,12 +1318,7 @@ class TradingOrchestrator:
logger.error(f"Error getting prediction from {model_name}: {e}")
continue
# Debug: Log inference history status (only if low record count)
total_records = sum(len(history) for history in self.inference_history.values())
if total_records < 10: # Only log when we have few records
logger.debug(f"Total inference records across all models: {total_records}")
for model_name, history in self.inference_history.items():
logger.debug(f" {model_name}: {len(history)} records")
# Trigger training based on previous inference data
await self._trigger_model_training(symbol)
@ -1392,17 +1385,15 @@ class TradingOrchestrator:
return {}
async def _store_inference_data_async(self, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime, symbol: str = None):
"""Store inference data per-model with async file operations and memory optimization"""
"""Store last inference in memory and all inferences to database for future training"""
try:
# Only log first few inference records to avoid spam
if len(self.inference_history.get(model_name, [])) < 3:
logger.debug(f"Storing inference data for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
logger.debug(f"Storing inference for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
# Extract symbol from prediction if not provided
if symbol is None:
symbol = getattr(prediction, 'symbol', 'ETH/USDT') # Default to ETH/USDT if not available
# Create comprehensive inference record
# Create inference record - store only what's needed for training
inference_record = {
'timestamp': timestamp.isoformat(),
'symbol': symbol,
@ -1414,227 +1405,153 @@ class TradingOrchestrator:
'probabilities': prediction.probabilities,
'timeframe': prediction.timeframe
},
'metadata': prediction.metadata or {}
'metadata': prediction.metadata or {},
'training_outcome': None, # Will be set when training occurs
'outcome_evaluated': False
}
# Store in memory (only last 5 per model)
if model_name not in self.inference_history:
self.inference_history[model_name] = deque(maxlen=self.max_memory_inferences)
# Store only the last inference per model (for immediate training)
self.last_inference[model_name] = inference_record
self.inference_history[model_name].append(inference_record)
# Also save to database using database manager for future training and analysis
asyncio.create_task(self._save_to_database_manager_async(model_name, inference_record))
# Async file storage (don't wait for completion)
asyncio.create_task(self._save_inference_to_disk_async(model_name, inference_record))
logger.debug(f"Stored inference data for {model_name} (memory: {len(self.inference_history[model_name])}/5)")
logger.debug(f"Stored last inference for {model_name} and queued database save")
except Exception as e:
logger.error(f"Error storing inference data for {model_name}: {e}")
async def _save_inference_to_disk_async(self, model_name: str, inference_record: Dict):
"""Async save inference record to SQLite database and model-specific log"""
try:
# Use SQLite for comprehensive storage
await self._save_to_sqlite_db(model_name, inference_record)
# Also save key metrics to model-specific log for debugging
await self._save_to_model_log(model_name, inference_record)
except Exception as e:
logger.error(f"Error saving inference to disk for {model_name}: {e}")
async def _save_to_sqlite_db(self, model_name: str, inference_record: Dict):
"""Save inference record to SQLite database"""
import sqlite3
async def _save_to_database_manager_async(self, model_name: str, inference_record: Dict):
"""Save inference record using DatabaseManager for future training"""
import hashlib
import asyncio
def save_to_db():
try:
# Create database directory
db_dir = Path("training_data/inference_db")
db_dir.mkdir(parents=True, exist_ok=True)
# Connect to SQLite database
db_path = db_dir / "inference_history.db"
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
# Create table if it doesn't exist
cursor.execute('''
CREATE TABLE IF NOT EXISTS inference_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_name TEXT NOT NULL,
symbol TEXT NOT NULL,
timestamp TEXT NOT NULL,
action TEXT NOT NULL,
confidence REAL NOT NULL,
probabilities TEXT,
timeframe TEXT,
metadata TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
# Create index for faster queries
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_model_timestamp
ON inference_records(model_name, timestamp)
''')
# Extract data from inference record
prediction = inference_record.get('prediction', {})
probabilities_str = str(prediction.get('probabilities', {}))
metadata_str = str(inference_record.get('metadata', {}))
symbol = inference_record.get('symbol', 'ETH/USDT')
timestamp_str = inference_record.get('timestamp', '')
# Insert record
cursor.execute('''
INSERT INTO inference_records
(model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
model_name,
inference_record.get('symbol', 'ETH/USDT'),
inference_record.get('timestamp', ''),
prediction.get('action', 'HOLD'),
prediction.get('confidence', 0.0),
probabilities_str,
prediction.get('timeframe', '1m'),
metadata_str
))
# Parse timestamp
if isinstance(timestamp_str, str):
timestamp = datetime.fromisoformat(timestamp_str)
else:
timestamp = timestamp_str
# Clean up old records (keep only last 1000 per model)
cursor.execute('''
DELETE FROM inference_records
WHERE model_name = ? AND id NOT IN (
SELECT id FROM inference_records
WHERE model_name = ?
ORDER BY timestamp DESC
LIMIT 1000
)
''', (model_name, model_name))
# Create hash of input features for deduplication
model_input = inference_record.get('model_input')
input_features_hash = "unknown"
input_features_array = None
conn.commit()
conn.close()
if model_input is not None:
# Convert to numpy array if possible
try:
if hasattr(model_input, 'numpy'): # PyTorch tensor
input_features_array = model_input.detach().cpu().numpy()
elif isinstance(model_input, np.ndarray):
input_features_array = model_input
elif isinstance(model_input, (list, tuple)):
input_features_array = np.array(model_input)
# Create hash of the input features
if input_features_array is not None:
input_features_hash = hashlib.md5(input_features_array.tobytes()).hexdigest()[:16]
except Exception as e:
logger.debug(f"Could not process input features for hashing: {e}")
# Create InferenceRecord using the database manager's structure
from utils.database_manager import InferenceRecord
db_record = InferenceRecord(
model_name=model_name,
timestamp=timestamp,
symbol=symbol,
action=prediction.get('action', 'HOLD'),
confidence=prediction.get('confidence', 0.0),
probabilities=prediction.get('probabilities', {}),
input_features_hash=input_features_hash,
processing_time_ms=0.0, # We don't track this in orchestrator
memory_usage_mb=0.0, # We don't track this in orchestrator
input_features=input_features_array,
checkpoint_id=None,
metadata=inference_record.get('metadata', {})
)
# Log using database manager
success = self.db_manager.log_inference(db_record)
if success:
logger.debug(f"Saved inference to database for {model_name}")
else:
logger.warning(f"Failed to save inference to database for {model_name}")
except Exception as e:
logger.error(f"Error saving to SQLite database: {e}")
logger.error(f"Error saving to database manager: {e}")
# Run database operation in thread pool to avoid blocking
await asyncio.get_event_loop().run_in_executor(None, save_to_db)
async def _save_to_model_log(self, model_name: str, inference_record: Dict):
"""Save key inference metrics to model-specific log file for debugging"""
import asyncio
def save_to_log():
try:
# Create logs directory
logs_dir = Path("logs/model_inference")
logs_dir.mkdir(parents=True, exist_ok=True)
# Create model-specific log file
log_file = logs_dir / f"{model_name}_inference.log"
# Extract key metrics
prediction = inference_record.get('prediction', {})
timestamp = inference_record.get('timestamp', '')
symbol = inference_record.get('symbol', 'N/A')
# Format log entry with key metrics
log_entry = (
f"{timestamp} | "
f"Symbol: {symbol} | "
f"Action: {prediction.get('action', 'N/A'):4} | "
f"Confidence: {prediction.get('confidence', 0.0):6.3f} | "
f"Timeframe: {prediction.get('timeframe', 'N/A'):3} | "
f"Probs: BUY={prediction.get('probabilities', {}).get('BUY', 0.0):5.3f} "
f"SELL={prediction.get('probabilities', {}).get('SELL', 0.0):5.3f} "
f"HOLD={prediction.get('probabilities', {}).get('HOLD', 0.0):5.3f}\n"
)
# Append to log file
with open(log_file, 'a', encoding='utf-8') as f:
f.write(log_entry)
# Keep log files manageable (rotate when > 10MB)
if log_file.stat().st_size > 10 * 1024 * 1024: # 10MB
self._rotate_log_file(log_file)
except Exception as e:
logger.error(f"Error saving to model log: {e}")
# Run log operation in thread pool to avoid blocking
await asyncio.get_event_loop().run_in_executor(None, save_to_log)
def _rotate_log_file(self, log_file: Path):
"""Rotate log file when it gets too large"""
try:
# Keep last 1000 lines
with open(log_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
# Write back only the last 1000 lines
with open(log_file, 'w', encoding='utf-8') as f:
f.writelines(lines[-1000:])
logger.debug(f"Rotated log file {log_file.name} (kept last 1000 lines)")
except Exception as e:
logger.error(f"Error rotating log file {log_file}: {e}")
def get_inference_records_from_db(self, model_name: str = None, limit: int = 100) -> List[Dict]:
"""Get inference records from SQLite database"""
import sqlite3
try:
# Connect to database
db_path = Path("training_data/inference_db/inference_history.db")
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
# Query records
if model_name:
cursor.execute('''
SELECT model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata
FROM inference_records
WHERE model_name = ?
ORDER BY timestamp DESC
LIMIT ?
''', (model_name, limit))
else:
cursor.execute('''
SELECT model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata
FROM inference_records
ORDER BY timestamp DESC
LIMIT ?
''', (limit,))
records = []
for row in cursor.fetchall():
record = {
'model_name': row[0],
'symbol': row[1],
'timestamp': row[2],
'prediction': {
'action': row[3],
'confidence': row[4],
'probabilities': eval(row[5]) if row[5] else {},
'timeframe': row[6]
},
'metadata': eval(row[7]) if row[7] else {}
def get_last_inference_status(self) -> Dict[str, Any]:
"""Get status of last inferences for all models"""
status = {}
for model_name, inference in self.last_inference.items():
if inference:
status[model_name] = {
'timestamp': inference.get('timestamp'),
'symbol': inference.get('symbol'),
'action': inference.get('prediction', {}).get('action'),
'confidence': inference.get('prediction', {}).get('confidence'),
'outcome_evaluated': inference.get('outcome_evaluated', False),
'training_outcome': inference.get('training_outcome')
}
records.append(record)
else:
status[model_name] = None
return status
def get_training_data_from_db(self, model_name: str, symbol: str = None, hours_back: int = 24, limit: int = 1000) -> List[Dict]:
"""Get inference records for training from database manager"""
try:
# Use database manager's method specifically for training data
db_records = self.db_manager.get_inference_records_for_training(
model_name=model_name,
symbol=symbol,
hours_back=hours_back,
limit=limit
)
conn.close()
# Convert to our format
records = []
for db_record in db_records:
try:
record = {
'model_name': db_record.model_name,
'symbol': db_record.symbol,
'timestamp': db_record.timestamp.isoformat(),
'prediction': {
'action': db_record.action,
'confidence': db_record.confidence,
'probabilities': db_record.probabilities,
'timeframe': '1m'
},
'metadata': db_record.metadata or {},
'model_input': db_record.input_features, # Full input features for training
'input_features_hash': db_record.input_features_hash
}
records.append(record)
except Exception as e:
logger.warning(f"Skipping malformed training record: {e}")
continue
logger.info(f"Retrieved {len(records)} training records for {model_name}")
return records
except Exception as e:
logger.error(f"Error querying SQLite database: {e}")
logger.error(f"Error getting training data from database: {e}")
return []
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor:
@ -1763,197 +1680,58 @@ class TradingOrchestrator:
'outcome_evaluated': False
}
# Store in memory (inference history) - keyed by model_name
if model_name not in self.inference_history:
self.inference_history[model_name] = deque(maxlen=self.max_memory_inferences)
# Store only the last inference per model (for immediate training)
self.last_inference[model_name] = inference_record
self.inference_history[model_name].append(inference_record)
logger.debug(f"Stored inference data for {model_name} on {symbol}")
# Also save to database using database manager for future training (run in background)
asyncio.create_task(self._save_to_database_manager_async(model_name, inference_record))
# Persistent storage to disk (for long-term training data)
self._save_inference_to_disk(inference_record)
logger.debug(f"Stored last inference for {model_name} on {symbol} and queued database save")
except Exception as e:
logger.error(f"Error storing inference data: {e}")
def _save_inference_to_disk(self, inference_record: Dict):
"""Save inference record to persistent storage"""
try:
# Create inference data directory
inference_dir = Path("training_data/inference_history")
inference_dir.mkdir(parents=True, exist_ok=True)
# Create filename with timestamp and model name
timestamp_str = inference_record['timestamp'].strftime('%Y%m%d_%H%M%S')
filename = f"{inference_record['symbol']}_{inference_record['model_name']}_{timestamp_str}.json"
filepath = inference_dir / filename
# Convert numpy arrays to lists for JSON serialization
serializable_record = self._make_json_serializable(inference_record)
# Save to JSON file
with open(filepath, 'w') as f:
json.dump(serializable_record, f, indent=2)
logger.debug(f"Saved inference record to disk: {filepath}")
except Exception as e:
logger.error(f"Error saving inference to disk: {e}")
def _make_json_serializable(self, obj):
"""Convert object to JSON-serializable format"""
if isinstance(obj, dict):
return {k: self._make_json_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._make_json_serializable(item) for item in obj]
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, (np.integer, np.floating)):
return obj.item()
elif isinstance(obj, datetime):
return obj.isoformat()
else:
return obj
def load_inference_history_from_disk(self, symbol: str, days_back: int = 7) -> List[Dict]:
"""Load inference history from SQLite database for training replay"""
try:
import sqlite3
# Connect to database
db_path = Path("training_data/inference_db/inference_history.db")
if not db_path.exists():
return []
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
# Get records for the symbol from the last N days
cutoff_date = (datetime.now() - timedelta(days=days_back)).isoformat()
cursor.execute('''
SELECT model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata
FROM inference_records
WHERE symbol = ? AND timestamp >= ?
ORDER BY timestamp ASC
''', (symbol, cutoff_date))
inference_records = []
for row in cursor.fetchall():
record = {
'model_name': row[0],
'symbol': row[1],
'timestamp': row[2],
'prediction': {
'action': row[3],
'confidence': row[4],
'probabilities': eval(row[5]) if row[5] else {},
'timeframe': row[6]
},
'metadata': eval(row[7]) if row[7] else {}
}
inference_records.append(record)
conn.close()
logger.info(f"Loaded {len(inference_records)} inference records for {symbol} from SQLite database")
return inference_records
except Exception as e:
logger.error(f"Error loading inference history from database: {e}")
return []
async def load_model_inference_history(self, model_name: str, limit: int = 50) -> List[Dict]:
"""Load inference history for a specific model from SQLite database"""
try:
# Use the SQLite database method
records = self.get_inference_records_from_db(model_name, limit)
logger.info(f"Loaded {len(records)} inference records for {model_name} from database")
return records
except Exception as e:
logger.error(f"Error loading model inference history for {model_name}: {e}")
return []
def get_model_training_data(self, model_name: str, symbol: str = None) -> List[Dict]:
"""Get training data for a specific model"""
try:
training_data = []
# Get from memory first
if symbol:
symbols_to_check = [symbol]
else:
symbols_to_check = self.symbols
# Use database manager to get training data
training_data = self.get_training_data_from_db(model_name, symbol)
for sym in symbols_to_check:
if sym in self.inference_history:
for record in self.inference_history[sym]:
if record['model_name'] == model_name:
training_data.append(record)
# Also load from disk for more comprehensive training data
for sym in symbols_to_check:
disk_records = self.load_inference_history_from_disk(sym)
for record in disk_records:
if record['model_name'] == model_name:
training_data.append(record)
# Remove duplicates and sort by timestamp
seen_timestamps = set()
unique_data = []
for record in training_data:
timestamp_key = f"{record['timestamp']}_{record['symbol']}"
if timestamp_key not in seen_timestamps:
seen_timestamps.add(timestamp_key)
unique_data.append(record)
unique_data.sort(key=lambda x: x['timestamp'])
logger.info(f"Retrieved {len(unique_data)} training records for {model_name}")
return unique_data
logger.info(f"Retrieved {len(training_data)} training records for {model_name}")
return training_data
except Exception as e:
logger.error(f"Error getting model training data: {e}")
return []
async def _trigger_model_training(self, symbol: str):
"""Trigger training for models based on previous inference data"""
"""Trigger training for models based on their last inference"""
try:
if not self.training_enabled:
logger.debug("Training disabled, skipping model training")
return
# Check if we have any inference history for any model
if not self.inference_history:
logger.debug("No inference history available for training")
# Check if we have any last inferences for any model
if not self.last_inference:
logger.debug("No inference data available for training")
return
# Get recent inference records from all models (not symbol-based)
all_recent_records = []
for model_name, model_records in self.inference_history.items():
all_recent_records.extend(list(model_records))
# Only log if we have few records (for debugging)
if len(all_recent_records) < 5:
logger.debug(f"Total inference records for training: {len(all_recent_records)}")
for model_name, model_records in self.inference_history.items():
logger.debug(f" Model {model_name} has {len(model_records)} inference records")
if len(all_recent_records) < 2:
logger.debug("Not enough inference records for training")
return # Need at least 2 records to compare
# Get current price for outcome evaluation
current_price = self.data_provider.get_current_price(symbol)
if current_price is None:
return
# Train on the most recent inference record (last prediction made)
if all_recent_records:
# Get the most recent record for training
most_recent_record = max(all_recent_records, key=lambda x: datetime.fromisoformat(x['timestamp']) if isinstance(x['timestamp'], str) else x['timestamp'])
await self._evaluate_and_train_on_record(most_recent_record, current_price)
# Train each model based on its last inference
for model_name, last_inference_record in self.last_inference.items():
if last_inference_record and not last_inference_record.get('outcome_evaluated', False):
await self._evaluate_and_train_on_record(last_inference_record, current_price)
except Exception as e:
logger.error(f"Error triggering model training for {symbol}: {e}")
@ -2011,6 +1789,16 @@ class TradingOrchestrator:
# Train the specific model based on sophisticated outcome
await self._train_model_on_outcome(record, was_correct, price_change_pct, reward)
# Mark this inference as evaluated to prevent re-training
if model_name in self.last_inference and self.last_inference[model_name] == record:
self.last_inference[model_name]['outcome_evaluated'] = True
self.last_inference[model_name]['training_outcome'] = {
'was_correct': was_correct,
'reward': reward,
'price_change_pct': price_change_pct,
'evaluated_at': datetime.now().isoformat()
}
logger.debug(f"Evaluated {model_name} prediction: {'' if was_correct else ''} "
f"({prediction['action']}, {price_change_pct:.2f}% change, "
f"confidence: {prediction_confidence:.3f}, reward: {reward:.3f})")
@ -2215,7 +2003,7 @@ class TradingOrchestrator:
)
predictions.append(prediction)
# Store prediction in SQLite database for training
# Store prediction in database for training
logger.debug(f"Added CNN prediction to database: {prediction}")
# Note: Inference data will be stored in main prediction loop to avoid duplication