cleanup
This commit is contained in:
@ -214,13 +214,8 @@ class TradingOrchestrator:
|
||||
# Training tracking
|
||||
self.last_trained_symbols: Dict[str, datetime] = {}
|
||||
|
||||
# INFERENCE DATA STORAGE - Per-model storage with memory optimization
|
||||
self.inference_history: Dict[str, deque] = {} # {model_name: deque of last 5 inference records}
|
||||
self.max_memory_inferences = 5 # Keep only last 5 inferences in memory per model
|
||||
self.max_disk_files_per_model = 200 # Cap disk files per model
|
||||
|
||||
# Initialize inference history for each model (will be populated as models make predictions)
|
||||
# We'll create entries dynamically as models are used
|
||||
# SIMPLIFIED INFERENCE DATA STORAGE - Single last inference per model
|
||||
self.last_inference: Dict[str, Dict] = {} # {model_name: last_inference_record}
|
||||
|
||||
# Initialize inference logger
|
||||
self.inference_logger = get_inference_logger()
|
||||
@ -240,10 +235,16 @@ class TradingOrchestrator:
|
||||
logger.info(f"Primary symbol: {self.symbol}, Reference symbols: {self.ref_symbols}")
|
||||
logger.info("Universal Data Adapter integrated for centralized data flow")
|
||||
|
||||
# Start centralized data collection for all models and dashboard
|
||||
logger.info("Starting centralized data collection...")
|
||||
self.data_provider.start_centralized_data_collection()
|
||||
logger.info("Centralized data collection started - all models and dashboard will receive data")
|
||||
# Start data collection if available
|
||||
logger.info("Starting data collection...")
|
||||
if hasattr(self.data_provider, 'start_centralized_data_collection'):
|
||||
self.data_provider.start_centralized_data_collection()
|
||||
logger.info("Centralized data collection started - all models and dashboard will receive data")
|
||||
elif hasattr(self.data_provider, 'start_training_data_collection'):
|
||||
self.data_provider.start_training_data_collection()
|
||||
logger.info("Training data collection started")
|
||||
else:
|
||||
logger.info("Data provider does not require explicit data collection startup")
|
||||
|
||||
# Data provider is already initialized and optimized
|
||||
|
||||
@ -683,13 +684,10 @@ class TradingOrchestrator:
|
||||
self.sensitivity_learning_queue = []
|
||||
self.perfect_move_buffer = []
|
||||
|
||||
# Clear inference history (but keep recent for training)
|
||||
for model_name in list(self.inference_history.keys()):
|
||||
# Keep only the last inference for each model to maintain training capability
|
||||
if len(self.inference_history[model_name]) > 1:
|
||||
last_inference = self.inference_history[model_name][-1]
|
||||
self.inference_history[model_name].clear()
|
||||
self.inference_history[model_name].append(last_inference)
|
||||
# Clear any outcome evaluation flags for last inferences
|
||||
for model_name in self.last_inference:
|
||||
if self.last_inference[model_name]:
|
||||
self.last_inference[model_name]['outcome_evaluated'] = False
|
||||
|
||||
# Clear fusion training data
|
||||
self.fusion_training_data = []
|
||||
@ -1114,10 +1112,10 @@ class TradingOrchestrator:
|
||||
if model.name not in self.model_performance:
|
||||
self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
|
||||
|
||||
# Initialize inference history for this model
|
||||
if model.name not in self.inference_history:
|
||||
self.inference_history[model.name] = deque(maxlen=self.max_memory_inferences)
|
||||
logger.debug(f"Initialized inference history for {model.name}")
|
||||
# Initialize last inference storage for this model
|
||||
if model.name not in self.last_inference:
|
||||
self.last_inference[model.name] = None
|
||||
logger.debug(f"Initialized last inference storage for {model.name}")
|
||||
|
||||
logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}")
|
||||
self._normalize_weights()
|
||||
@ -1320,12 +1318,7 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error getting prediction from {model_name}: {e}")
|
||||
continue
|
||||
|
||||
# Debug: Log inference history status (only if low record count)
|
||||
total_records = sum(len(history) for history in self.inference_history.values())
|
||||
if total_records < 10: # Only log when we have few records
|
||||
logger.debug(f"Total inference records across all models: {total_records}")
|
||||
for model_name, history in self.inference_history.items():
|
||||
logger.debug(f" {model_name}: {len(history)} records")
|
||||
|
||||
|
||||
# Trigger training based on previous inference data
|
||||
await self._trigger_model_training(symbol)
|
||||
@ -1392,17 +1385,15 @@ class TradingOrchestrator:
|
||||
return {}
|
||||
|
||||
async def _store_inference_data_async(self, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime, symbol: str = None):
|
||||
"""Store inference data per-model with async file operations and memory optimization"""
|
||||
"""Store last inference in memory and all inferences to database for future training"""
|
||||
try:
|
||||
# Only log first few inference records to avoid spam
|
||||
if len(self.inference_history.get(model_name, [])) < 3:
|
||||
logger.debug(f"Storing inference data for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
|
||||
logger.debug(f"Storing inference for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
|
||||
|
||||
# Extract symbol from prediction if not provided
|
||||
if symbol is None:
|
||||
symbol = getattr(prediction, 'symbol', 'ETH/USDT') # Default to ETH/USDT if not available
|
||||
|
||||
# Create comprehensive inference record
|
||||
# Create inference record - store only what's needed for training
|
||||
inference_record = {
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'symbol': symbol,
|
||||
@ -1414,227 +1405,153 @@ class TradingOrchestrator:
|
||||
'probabilities': prediction.probabilities,
|
||||
'timeframe': prediction.timeframe
|
||||
},
|
||||
'metadata': prediction.metadata or {}
|
||||
'metadata': prediction.metadata or {},
|
||||
'training_outcome': None, # Will be set when training occurs
|
||||
'outcome_evaluated': False
|
||||
}
|
||||
|
||||
# Store in memory (only last 5 per model)
|
||||
if model_name not in self.inference_history:
|
||||
self.inference_history[model_name] = deque(maxlen=self.max_memory_inferences)
|
||||
# Store only the last inference per model (for immediate training)
|
||||
self.last_inference[model_name] = inference_record
|
||||
|
||||
self.inference_history[model_name].append(inference_record)
|
||||
# Also save to database using database manager for future training and analysis
|
||||
asyncio.create_task(self._save_to_database_manager_async(model_name, inference_record))
|
||||
|
||||
# Async file storage (don't wait for completion)
|
||||
asyncio.create_task(self._save_inference_to_disk_async(model_name, inference_record))
|
||||
|
||||
logger.debug(f"Stored inference data for {model_name} (memory: {len(self.inference_history[model_name])}/5)")
|
||||
logger.debug(f"Stored last inference for {model_name} and queued database save")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing inference data for {model_name}: {e}")
|
||||
|
||||
async def _save_inference_to_disk_async(self, model_name: str, inference_record: Dict):
|
||||
"""Async save inference record to SQLite database and model-specific log"""
|
||||
try:
|
||||
# Use SQLite for comprehensive storage
|
||||
await self._save_to_sqlite_db(model_name, inference_record)
|
||||
|
||||
# Also save key metrics to model-specific log for debugging
|
||||
await self._save_to_model_log(model_name, inference_record)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving inference to disk for {model_name}: {e}")
|
||||
|
||||
async def _save_to_sqlite_db(self, model_name: str, inference_record: Dict):
|
||||
"""Save inference record to SQLite database"""
|
||||
import sqlite3
|
||||
async def _save_to_database_manager_async(self, model_name: str, inference_record: Dict):
|
||||
"""Save inference record using DatabaseManager for future training"""
|
||||
import hashlib
|
||||
import asyncio
|
||||
|
||||
def save_to_db():
|
||||
try:
|
||||
# Create database directory
|
||||
db_dir = Path("training_data/inference_db")
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Connect to SQLite database
|
||||
db_path = db_dir / "inference_history.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create table if it doesn't exist
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS inference_records (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
confidence REAL NOT NULL,
|
||||
probabilities TEXT,
|
||||
timeframe TEXT,
|
||||
metadata TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
# Create index for faster queries
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_model_timestamp
|
||||
ON inference_records(model_name, timestamp)
|
||||
''')
|
||||
|
||||
# Extract data from inference record
|
||||
prediction = inference_record.get('prediction', {})
|
||||
probabilities_str = str(prediction.get('probabilities', {}))
|
||||
metadata_str = str(inference_record.get('metadata', {}))
|
||||
symbol = inference_record.get('symbol', 'ETH/USDT')
|
||||
timestamp_str = inference_record.get('timestamp', '')
|
||||
|
||||
# Insert record
|
||||
cursor.execute('''
|
||||
INSERT INTO inference_records
|
||||
(model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
model_name,
|
||||
inference_record.get('symbol', 'ETH/USDT'),
|
||||
inference_record.get('timestamp', ''),
|
||||
prediction.get('action', 'HOLD'),
|
||||
prediction.get('confidence', 0.0),
|
||||
probabilities_str,
|
||||
prediction.get('timeframe', '1m'),
|
||||
metadata_str
|
||||
))
|
||||
# Parse timestamp
|
||||
if isinstance(timestamp_str, str):
|
||||
timestamp = datetime.fromisoformat(timestamp_str)
|
||||
else:
|
||||
timestamp = timestamp_str
|
||||
|
||||
# Clean up old records (keep only last 1000 per model)
|
||||
cursor.execute('''
|
||||
DELETE FROM inference_records
|
||||
WHERE model_name = ? AND id NOT IN (
|
||||
SELECT id FROM inference_records
|
||||
WHERE model_name = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1000
|
||||
)
|
||||
''', (model_name, model_name))
|
||||
# Create hash of input features for deduplication
|
||||
model_input = inference_record.get('model_input')
|
||||
input_features_hash = "unknown"
|
||||
input_features_array = None
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
if model_input is not None:
|
||||
# Convert to numpy array if possible
|
||||
try:
|
||||
if hasattr(model_input, 'numpy'): # PyTorch tensor
|
||||
input_features_array = model_input.detach().cpu().numpy()
|
||||
elif isinstance(model_input, np.ndarray):
|
||||
input_features_array = model_input
|
||||
elif isinstance(model_input, (list, tuple)):
|
||||
input_features_array = np.array(model_input)
|
||||
|
||||
# Create hash of the input features
|
||||
if input_features_array is not None:
|
||||
input_features_hash = hashlib.md5(input_features_array.tobytes()).hexdigest()[:16]
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not process input features for hashing: {e}")
|
||||
|
||||
# Create InferenceRecord using the database manager's structure
|
||||
from utils.database_manager import InferenceRecord
|
||||
|
||||
db_record = InferenceRecord(
|
||||
model_name=model_name,
|
||||
timestamp=timestamp,
|
||||
symbol=symbol,
|
||||
action=prediction.get('action', 'HOLD'),
|
||||
confidence=prediction.get('confidence', 0.0),
|
||||
probabilities=prediction.get('probabilities', {}),
|
||||
input_features_hash=input_features_hash,
|
||||
processing_time_ms=0.0, # We don't track this in orchestrator
|
||||
memory_usage_mb=0.0, # We don't track this in orchestrator
|
||||
input_features=input_features_array,
|
||||
checkpoint_id=None,
|
||||
metadata=inference_record.get('metadata', {})
|
||||
)
|
||||
|
||||
# Log using database manager
|
||||
success = self.db_manager.log_inference(db_record)
|
||||
|
||||
if success:
|
||||
logger.debug(f"Saved inference to database for {model_name}")
|
||||
else:
|
||||
logger.warning(f"Failed to save inference to database for {model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving to SQLite database: {e}")
|
||||
logger.error(f"Error saving to database manager: {e}")
|
||||
|
||||
# Run database operation in thread pool to avoid blocking
|
||||
await asyncio.get_event_loop().run_in_executor(None, save_to_db)
|
||||
|
||||
async def _save_to_model_log(self, model_name: str, inference_record: Dict):
|
||||
"""Save key inference metrics to model-specific log file for debugging"""
|
||||
import asyncio
|
||||
|
||||
def save_to_log():
|
||||
try:
|
||||
# Create logs directory
|
||||
logs_dir = Path("logs/model_inference")
|
||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create model-specific log file
|
||||
log_file = logs_dir / f"{model_name}_inference.log"
|
||||
|
||||
# Extract key metrics
|
||||
prediction = inference_record.get('prediction', {})
|
||||
timestamp = inference_record.get('timestamp', '')
|
||||
symbol = inference_record.get('symbol', 'N/A')
|
||||
|
||||
# Format log entry with key metrics
|
||||
log_entry = (
|
||||
f"{timestamp} | "
|
||||
f"Symbol: {symbol} | "
|
||||
f"Action: {prediction.get('action', 'N/A'):4} | "
|
||||
f"Confidence: {prediction.get('confidence', 0.0):6.3f} | "
|
||||
f"Timeframe: {prediction.get('timeframe', 'N/A'):3} | "
|
||||
f"Probs: BUY={prediction.get('probabilities', {}).get('BUY', 0.0):5.3f} "
|
||||
f"SELL={prediction.get('probabilities', {}).get('SELL', 0.0):5.3f} "
|
||||
f"HOLD={prediction.get('probabilities', {}).get('HOLD', 0.0):5.3f}\n"
|
||||
)
|
||||
|
||||
# Append to log file
|
||||
with open(log_file, 'a', encoding='utf-8') as f:
|
||||
f.write(log_entry)
|
||||
|
||||
# Keep log files manageable (rotate when > 10MB)
|
||||
if log_file.stat().st_size > 10 * 1024 * 1024: # 10MB
|
||||
self._rotate_log_file(log_file)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving to model log: {e}")
|
||||
|
||||
# Run log operation in thread pool to avoid blocking
|
||||
await asyncio.get_event_loop().run_in_executor(None, save_to_log)
|
||||
|
||||
|
||||
def _rotate_log_file(self, log_file: Path):
|
||||
"""Rotate log file when it gets too large"""
|
||||
try:
|
||||
# Keep last 1000 lines
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Write back only the last 1000 lines
|
||||
with open(log_file, 'w', encoding='utf-8') as f:
|
||||
f.writelines(lines[-1000:])
|
||||
|
||||
logger.debug(f"Rotated log file {log_file.name} (kept last 1000 lines)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error rotating log file {log_file}: {e}")
|
||||
|
||||
def get_inference_records_from_db(self, model_name: str = None, limit: int = 100) -> List[Dict]:
|
||||
"""Get inference records from SQLite database"""
|
||||
import sqlite3
|
||||
|
||||
try:
|
||||
# Connect to database
|
||||
db_path = Path("training_data/inference_db/inference_history.db")
|
||||
if not db_path.exists():
|
||||
return []
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Query records
|
||||
if model_name:
|
||||
cursor.execute('''
|
||||
SELECT model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata
|
||||
FROM inference_records
|
||||
WHERE model_name = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
''', (model_name, limit))
|
||||
else:
|
||||
cursor.execute('''
|
||||
SELECT model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata
|
||||
FROM inference_records
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
''', (limit,))
|
||||
|
||||
records = []
|
||||
for row in cursor.fetchall():
|
||||
record = {
|
||||
'model_name': row[0],
|
||||
'symbol': row[1],
|
||||
'timestamp': row[2],
|
||||
'prediction': {
|
||||
'action': row[3],
|
||||
'confidence': row[4],
|
||||
'probabilities': eval(row[5]) if row[5] else {},
|
||||
'timeframe': row[6]
|
||||
},
|
||||
'metadata': eval(row[7]) if row[7] else {}
|
||||
|
||||
def get_last_inference_status(self) -> Dict[str, Any]:
|
||||
"""Get status of last inferences for all models"""
|
||||
status = {}
|
||||
for model_name, inference in self.last_inference.items():
|
||||
if inference:
|
||||
status[model_name] = {
|
||||
'timestamp': inference.get('timestamp'),
|
||||
'symbol': inference.get('symbol'),
|
||||
'action': inference.get('prediction', {}).get('action'),
|
||||
'confidence': inference.get('prediction', {}).get('confidence'),
|
||||
'outcome_evaluated': inference.get('outcome_evaluated', False),
|
||||
'training_outcome': inference.get('training_outcome')
|
||||
}
|
||||
records.append(record)
|
||||
else:
|
||||
status[model_name] = None
|
||||
return status
|
||||
|
||||
def get_training_data_from_db(self, model_name: str, symbol: str = None, hours_back: int = 24, limit: int = 1000) -> List[Dict]:
|
||||
"""Get inference records for training from database manager"""
|
||||
try:
|
||||
# Use database manager's method specifically for training data
|
||||
db_records = self.db_manager.get_inference_records_for_training(
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
hours_back=hours_back,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
conn.close()
|
||||
# Convert to our format
|
||||
records = []
|
||||
for db_record in db_records:
|
||||
try:
|
||||
record = {
|
||||
'model_name': db_record.model_name,
|
||||
'symbol': db_record.symbol,
|
||||
'timestamp': db_record.timestamp.isoformat(),
|
||||
'prediction': {
|
||||
'action': db_record.action,
|
||||
'confidence': db_record.confidence,
|
||||
'probabilities': db_record.probabilities,
|
||||
'timeframe': '1m'
|
||||
},
|
||||
'metadata': db_record.metadata or {},
|
||||
'model_input': db_record.input_features, # Full input features for training
|
||||
'input_features_hash': db_record.input_features_hash
|
||||
}
|
||||
records.append(record)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping malformed training record: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Retrieved {len(records)} training records for {model_name}")
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error querying SQLite database: {e}")
|
||||
logger.error(f"Error getting training data from database: {e}")
|
||||
return []
|
||||
|
||||
|
||||
|
||||
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor:
|
||||
@ -1763,197 +1680,58 @@ class TradingOrchestrator:
|
||||
'outcome_evaluated': False
|
||||
}
|
||||
|
||||
# Store in memory (inference history) - keyed by model_name
|
||||
if model_name not in self.inference_history:
|
||||
self.inference_history[model_name] = deque(maxlen=self.max_memory_inferences)
|
||||
# Store only the last inference per model (for immediate training)
|
||||
self.last_inference[model_name] = inference_record
|
||||
|
||||
self.inference_history[model_name].append(inference_record)
|
||||
logger.debug(f"Stored inference data for {model_name} on {symbol}")
|
||||
# Also save to database using database manager for future training (run in background)
|
||||
asyncio.create_task(self._save_to_database_manager_async(model_name, inference_record))
|
||||
|
||||
# Persistent storage to disk (for long-term training data)
|
||||
self._save_inference_to_disk(inference_record)
|
||||
logger.debug(f"Stored last inference for {model_name} on {symbol} and queued database save")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing inference data: {e}")
|
||||
|
||||
def _save_inference_to_disk(self, inference_record: Dict):
|
||||
"""Save inference record to persistent storage"""
|
||||
try:
|
||||
# Create inference data directory
|
||||
inference_dir = Path("training_data/inference_history")
|
||||
inference_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create filename with timestamp and model name
|
||||
timestamp_str = inference_record['timestamp'].strftime('%Y%m%d_%H%M%S')
|
||||
filename = f"{inference_record['symbol']}_{inference_record['model_name']}_{timestamp_str}.json"
|
||||
filepath = inference_dir / filename
|
||||
|
||||
# Convert numpy arrays to lists for JSON serialization
|
||||
serializable_record = self._make_json_serializable(inference_record)
|
||||
|
||||
# Save to JSON file
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(serializable_record, f, indent=2)
|
||||
|
||||
logger.debug(f"Saved inference record to disk: {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving inference to disk: {e}")
|
||||
|
||||
|
||||
|
||||
def _make_json_serializable(self, obj):
|
||||
"""Convert object to JSON-serializable format"""
|
||||
if isinstance(obj, dict):
|
||||
return {k: self._make_json_serializable(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [self._make_json_serializable(item) for item in obj]
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
elif isinstance(obj, (np.integer, np.floating)):
|
||||
return obj.item()
|
||||
elif isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
else:
|
||||
return obj
|
||||
|
||||
def load_inference_history_from_disk(self, symbol: str, days_back: int = 7) -> List[Dict]:
|
||||
"""Load inference history from SQLite database for training replay"""
|
||||
try:
|
||||
import sqlite3
|
||||
|
||||
# Connect to database
|
||||
db_path = Path("training_data/inference_db/inference_history.db")
|
||||
if not db_path.exists():
|
||||
return []
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get records for the symbol from the last N days
|
||||
cutoff_date = (datetime.now() - timedelta(days=days_back)).isoformat()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata
|
||||
FROM inference_records
|
||||
WHERE symbol = ? AND timestamp >= ?
|
||||
ORDER BY timestamp ASC
|
||||
''', (symbol, cutoff_date))
|
||||
|
||||
inference_records = []
|
||||
for row in cursor.fetchall():
|
||||
record = {
|
||||
'model_name': row[0],
|
||||
'symbol': row[1],
|
||||
'timestamp': row[2],
|
||||
'prediction': {
|
||||
'action': row[3],
|
||||
'confidence': row[4],
|
||||
'probabilities': eval(row[5]) if row[5] else {},
|
||||
'timeframe': row[6]
|
||||
},
|
||||
'metadata': eval(row[7]) if row[7] else {}
|
||||
}
|
||||
inference_records.append(record)
|
||||
|
||||
conn.close()
|
||||
logger.info(f"Loaded {len(inference_records)} inference records for {symbol} from SQLite database")
|
||||
|
||||
return inference_records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading inference history from database: {e}")
|
||||
return []
|
||||
|
||||
async def load_model_inference_history(self, model_name: str, limit: int = 50) -> List[Dict]:
|
||||
"""Load inference history for a specific model from SQLite database"""
|
||||
try:
|
||||
# Use the SQLite database method
|
||||
records = self.get_inference_records_from_db(model_name, limit)
|
||||
logger.info(f"Loaded {len(records)} inference records for {model_name} from database")
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model inference history for {model_name}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_model_training_data(self, model_name: str, symbol: str = None) -> List[Dict]:
|
||||
"""Get training data for a specific model"""
|
||||
try:
|
||||
training_data = []
|
||||
|
||||
# Get from memory first
|
||||
if symbol:
|
||||
symbols_to_check = [symbol]
|
||||
else:
|
||||
symbols_to_check = self.symbols
|
||||
# Use database manager to get training data
|
||||
training_data = self.get_training_data_from_db(model_name, symbol)
|
||||
|
||||
for sym in symbols_to_check:
|
||||
if sym in self.inference_history:
|
||||
for record in self.inference_history[sym]:
|
||||
if record['model_name'] == model_name:
|
||||
training_data.append(record)
|
||||
|
||||
# Also load from disk for more comprehensive training data
|
||||
for sym in symbols_to_check:
|
||||
disk_records = self.load_inference_history_from_disk(sym)
|
||||
for record in disk_records:
|
||||
if record['model_name'] == model_name:
|
||||
training_data.append(record)
|
||||
|
||||
# Remove duplicates and sort by timestamp
|
||||
seen_timestamps = set()
|
||||
unique_data = []
|
||||
for record in training_data:
|
||||
timestamp_key = f"{record['timestamp']}_{record['symbol']}"
|
||||
if timestamp_key not in seen_timestamps:
|
||||
seen_timestamps.add(timestamp_key)
|
||||
unique_data.append(record)
|
||||
|
||||
unique_data.sort(key=lambda x: x['timestamp'])
|
||||
logger.info(f"Retrieved {len(unique_data)} training records for {model_name}")
|
||||
|
||||
return unique_data
|
||||
logger.info(f"Retrieved {len(training_data)} training records for {model_name}")
|
||||
return training_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model training data: {e}")
|
||||
return []
|
||||
|
||||
async def _trigger_model_training(self, symbol: str):
|
||||
"""Trigger training for models based on previous inference data"""
|
||||
"""Trigger training for models based on their last inference"""
|
||||
try:
|
||||
if not self.training_enabled:
|
||||
logger.debug("Training disabled, skipping model training")
|
||||
return
|
||||
|
||||
# Check if we have any inference history for any model
|
||||
if not self.inference_history:
|
||||
logger.debug("No inference history available for training")
|
||||
# Check if we have any last inferences for any model
|
||||
if not self.last_inference:
|
||||
logger.debug("No inference data available for training")
|
||||
return
|
||||
|
||||
# Get recent inference records from all models (not symbol-based)
|
||||
all_recent_records = []
|
||||
for model_name, model_records in self.inference_history.items():
|
||||
all_recent_records.extend(list(model_records))
|
||||
|
||||
# Only log if we have few records (for debugging)
|
||||
if len(all_recent_records) < 5:
|
||||
logger.debug(f"Total inference records for training: {len(all_recent_records)}")
|
||||
for model_name, model_records in self.inference_history.items():
|
||||
logger.debug(f" Model {model_name} has {len(model_records)} inference records")
|
||||
|
||||
if len(all_recent_records) < 2:
|
||||
logger.debug("Not enough inference records for training")
|
||||
return # Need at least 2 records to compare
|
||||
|
||||
# Get current price for outcome evaluation
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
if current_price is None:
|
||||
return
|
||||
|
||||
# Train on the most recent inference record (last prediction made)
|
||||
if all_recent_records:
|
||||
# Get the most recent record for training
|
||||
most_recent_record = max(all_recent_records, key=lambda x: datetime.fromisoformat(x['timestamp']) if isinstance(x['timestamp'], str) else x['timestamp'])
|
||||
await self._evaluate_and_train_on_record(most_recent_record, current_price)
|
||||
# Train each model based on its last inference
|
||||
for model_name, last_inference_record in self.last_inference.items():
|
||||
if last_inference_record and not last_inference_record.get('outcome_evaluated', False):
|
||||
await self._evaluate_and_train_on_record(last_inference_record, current_price)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering model training for {symbol}: {e}")
|
||||
@ -2011,6 +1789,16 @@ class TradingOrchestrator:
|
||||
# Train the specific model based on sophisticated outcome
|
||||
await self._train_model_on_outcome(record, was_correct, price_change_pct, reward)
|
||||
|
||||
# Mark this inference as evaluated to prevent re-training
|
||||
if model_name in self.last_inference and self.last_inference[model_name] == record:
|
||||
self.last_inference[model_name]['outcome_evaluated'] = True
|
||||
self.last_inference[model_name]['training_outcome'] = {
|
||||
'was_correct': was_correct,
|
||||
'reward': reward,
|
||||
'price_change_pct': price_change_pct,
|
||||
'evaluated_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.debug(f"Evaluated {model_name} prediction: {'✓' if was_correct else '✗'} "
|
||||
f"({prediction['action']}, {price_change_pct:.2f}% change, "
|
||||
f"confidence: {prediction_confidence:.3f}, reward: {reward:.3f})")
|
||||
@ -2215,7 +2003,7 @@ class TradingOrchestrator:
|
||||
)
|
||||
predictions.append(prediction)
|
||||
|
||||
# Store prediction in SQLite database for training
|
||||
# Store prediction in database for training
|
||||
logger.debug(f"Added CNN prediction to database: {prediction}")
|
||||
|
||||
# Note: Inference data will be stored in main prediction loop to avoid duplication
|
||||
|
Reference in New Issue
Block a user