sqlite for checkpoints, cleanup
This commit is contained in:
@ -201,6 +201,9 @@ class DataProvider:
|
||||
self.last_pivot_calculation: Dict[str, datetime] = {}
|
||||
self.pivot_calculation_interval = timedelta(minutes=5) # Recalculate every 5 minutes
|
||||
|
||||
# Auto-fix corrupted cache files on startup
|
||||
self._auto_fix_corrupted_cache()
|
||||
|
||||
# Load existing pivot bounds from cache
|
||||
self._load_all_pivot_bounds()
|
||||
|
||||
@ -1231,6 +1234,36 @@ class DataProvider:
|
||||
return symbol # Return first symbol for now - can be improved
|
||||
return None
|
||||
|
||||
# === CACHE MANAGEMENT ===
|
||||
|
||||
def _auto_fix_corrupted_cache(self):
|
||||
"""Automatically fix corrupted cache files on startup"""
|
||||
try:
|
||||
from utils.cache_manager import get_cache_manager
|
||||
cache_manager = get_cache_manager()
|
||||
|
||||
# Quick health check
|
||||
health_summary = cache_manager.get_cache_summary()
|
||||
|
||||
if health_summary['corrupted_files'] > 0:
|
||||
logger.warning(f"Found {health_summary['corrupted_files']} corrupted cache files, cleaning up...")
|
||||
|
||||
# Auto-cleanup corrupted files (no confirmation needed)
|
||||
deleted_files = cache_manager.cleanup_corrupted_files(dry_run=False)
|
||||
|
||||
deleted_count = 0
|
||||
for cache_dir, files in deleted_files.items():
|
||||
for file_info in files:
|
||||
if "DELETED:" in file_info:
|
||||
deleted_count += 1
|
||||
|
||||
logger.info(f"Auto-cleaned {deleted_count} corrupted cache files")
|
||||
else:
|
||||
logger.debug("Cache health check passed - no corrupted files found")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache auto-fix failed: {e}")
|
||||
|
||||
# === PIVOT BOUNDS CACHING ===
|
||||
|
||||
def _load_all_pivot_bounds(self):
|
||||
@ -1285,13 +1318,25 @@ class DataProvider:
|
||||
logger.info(f"Loaded {len(df)} 1m candles from cache for {symbol}")
|
||||
return df
|
||||
except Exception as parquet_e:
|
||||
# Handle corrupted Parquet file
|
||||
if "Parquet magic bytes not found" in str(parquet_e) or "corrupted" in str(parquet_e).lower():
|
||||
# Handle corrupted Parquet file - expanded error detection
|
||||
error_str = str(parquet_e).lower()
|
||||
corrupted_indicators = [
|
||||
"parquet magic bytes not found",
|
||||
"corrupted",
|
||||
"couldn't deserialize thrift",
|
||||
"don't know what type",
|
||||
"invalid parquet file",
|
||||
"unexpected end of file",
|
||||
"invalid metadata"
|
||||
]
|
||||
|
||||
if any(indicator in error_str for indicator in corrupted_indicators):
|
||||
logger.warning(f"Corrupted Parquet cache file for {symbol}, removing and returning None: {parquet_e}")
|
||||
try:
|
||||
cache_file.unlink() # Delete corrupted file
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"Deleted corrupted monthly cache file: {cache_file}")
|
||||
except Exception as delete_e:
|
||||
logger.error(f"Failed to delete corrupted monthly cache file: {delete_e}")
|
||||
return None
|
||||
else:
|
||||
raise parquet_e
|
||||
@ -1393,13 +1438,25 @@ class DataProvider:
|
||||
logger.debug(f"Loaded {len(df)} rows from cache for {symbol} {timeframe} (age: {cache_age/60:.1f}min)")
|
||||
return df
|
||||
except Exception as parquet_e:
|
||||
# Handle corrupted Parquet file
|
||||
if "Parquet magic bytes not found" in str(parquet_e) or "corrupted" in str(parquet_e).lower():
|
||||
# Handle corrupted Parquet file - expanded error detection
|
||||
error_str = str(parquet_e).lower()
|
||||
corrupted_indicators = [
|
||||
"parquet magic bytes not found",
|
||||
"corrupted",
|
||||
"couldn't deserialize thrift",
|
||||
"don't know what type",
|
||||
"invalid parquet file",
|
||||
"unexpected end of file",
|
||||
"invalid metadata"
|
||||
]
|
||||
|
||||
if any(indicator in error_str for indicator in corrupted_indicators):
|
||||
logger.warning(f"Corrupted Parquet cache file for {symbol} {timeframe}, removing and returning None: {parquet_e}")
|
||||
try:
|
||||
cache_file.unlink() # Delete corrupted file
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"Deleted corrupted cache file: {cache_file}")
|
||||
except Exception as delete_e:
|
||||
logger.error(f"Failed to delete corrupted cache file: {delete_e}")
|
||||
return None
|
||||
else:
|
||||
raise parquet_e
|
||||
|
@ -38,6 +38,11 @@ from NN.models.cob_rl_model import COBRLModelInterface # Specific import for COB
|
||||
from NN.models.model_interfaces import ModelInterface as NNModelInterface, CNNModelInterface as NNCNNModelInterface, RLAgentInterface as NNRLAgentInterface, ExtremaTrainerInterface as NNExtremaTrainerInterface # Import from new file
|
||||
from core.extrema_trainer import ExtremaTrainer # Import ExtremaTrainer for its interface
|
||||
|
||||
# Import new logging and database systems
|
||||
from utils.inference_logger import get_inference_logger, log_model_inference
|
||||
from utils.database_manager import get_database_manager
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
|
||||
# Import COB integration for real-time market microstructure data
|
||||
try:
|
||||
from .cob_integration import COBIntegration
|
||||
@ -213,6 +218,10 @@ class TradingOrchestrator:
|
||||
# Initialize inference history for each model (will be populated as models make predictions)
|
||||
# We'll create entries dynamically as models are used
|
||||
|
||||
# Initialize inference logger
|
||||
self.inference_logger = get_inference_logger()
|
||||
self.db_manager = get_database_manager()
|
||||
|
||||
# ENHANCED: Real-time Training System Integration
|
||||
self.enhanced_training_system = None # Will be set to EnhancedRealtimeTrainingSystem if available
|
||||
# Enable training by default - don't depend on external training system
|
||||
@ -232,6 +241,9 @@ class TradingOrchestrator:
|
||||
self.data_provider.start_centralized_data_collection()
|
||||
logger.info("Centralized data collection started - all models and dashboard will receive data")
|
||||
|
||||
# Initialize database cleanup task
|
||||
self._schedule_database_cleanup()
|
||||
|
||||
# CRITICAL: Initialize checkpoint manager for saving training progress
|
||||
self.checkpoint_manager = None
|
||||
self.training_iterations = 0 # Track training iterations for periodic saves
|
||||
@ -265,24 +277,23 @@ class TradingOrchestrator:
|
||||
self.rl_agent = DQNAgent(state_shape=state_size, n_actions=action_size)
|
||||
self.rl_agent.to(self.device) # Move DQN agent to the determined device
|
||||
|
||||
# Load best checkpoint and capture initial state
|
||||
# Load best checkpoint and capture initial state (using database metadata)
|
||||
checkpoint_loaded = False
|
||||
if hasattr(self.rl_agent, 'load_best_checkpoint'):
|
||||
try:
|
||||
self.rl_agent.load_best_checkpoint() # This loads the state into the model
|
||||
# Check if we have checkpoints available
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("dqn_agent")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
# Check if we have checkpoints available using database metadata (fast!)
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata("dqn_agent")
|
||||
if checkpoint_metadata:
|
||||
self.model_states['dqn']['initial_loss'] = 0.412
|
||||
self.model_states['dqn']['current_loss'] = metadata.loss
|
||||
self.model_states['dqn']['best_loss'] = metadata.loss
|
||||
self.model_states['dqn']['current_loss'] = checkpoint_metadata.performance_metrics.get('loss', 0.0)
|
||||
self.model_states['dqn']['best_loss'] = checkpoint_metadata.performance_metrics.get('loss', 0.0)
|
||||
self.model_states['dqn']['checkpoint_loaded'] = True
|
||||
self.model_states['dqn']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
self.model_states['dqn']['checkpoint_filename'] = checkpoint_metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "unknown"
|
||||
logger.info(f"DQN checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
||||
logger.info(f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading DQN checkpoint: {e}")
|
||||
|
||||
@ -307,21 +318,20 @@ class TradingOrchestrator:
|
||||
self.cnn_model = self.cnn_adapter.model # Keep reference for compatibility
|
||||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN
|
||||
|
||||
# Load best checkpoint and capture initial state
|
||||
# Load best checkpoint and capture initial state (using database metadata)
|
||||
checkpoint_loaded = False
|
||||
try:
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("enhanced_cnn")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata("enhanced_cnn")
|
||||
if checkpoint_metadata:
|
||||
self.model_states['cnn']['initial_loss'] = 0.412
|
||||
self.model_states['cnn']['current_loss'] = metadata.loss or 0.0187
|
||||
self.model_states['cnn']['best_loss'] = metadata.loss or 0.0134
|
||||
self.model_states['cnn']['current_loss'] = checkpoint_metadata.performance_metrics.get('loss', 0.0187)
|
||||
self.model_states['cnn']['best_loss'] = checkpoint_metadata.performance_metrics.get('loss', 0.0134)
|
||||
self.model_states['cnn']['checkpoint_loaded'] = True
|
||||
self.model_states['cnn']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
self.model_states['cnn']['checkpoint_filename'] = checkpoint_metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "unknown"
|
||||
logger.info(f"CNN checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
||||
logger.info(f"CNN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading CNN checkpoint: {e}")
|
||||
|
||||
@ -399,23 +409,22 @@ class TradingOrchestrator:
|
||||
if hasattr(self.cob_rl_agent, 'to'):
|
||||
self.cob_rl_agent.to(self.device)
|
||||
|
||||
# Load best checkpoint and capture initial state
|
||||
# Load best checkpoint and capture initial state (using database metadata)
|
||||
checkpoint_loaded = False
|
||||
if hasattr(self.cob_rl_agent, 'load_model'):
|
||||
try:
|
||||
self.cob_rl_agent.load_model() # This loads the state into the model
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("cob_rl_model")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
self.model_states['cob_rl']['initial_loss'] = getattr(metadata, 'initial_loss', None)
|
||||
self.model_states['cob_rl']['current_loss'] = metadata.loss
|
||||
self.model_states['cob_rl']['best_loss'] = metadata.loss
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata("cob_rl")
|
||||
if checkpoint_metadata:
|
||||
self.model_states['cob_rl']['initial_loss'] = checkpoint_metadata.training_metadata.get('initial_loss', None)
|
||||
self.model_states['cob_rl']['current_loss'] = checkpoint_metadata.performance_metrics.get('loss', 0.0)
|
||||
self.model_states['cob_rl']['best_loss'] = checkpoint_metadata.performance_metrics.get('loss', 0.0)
|
||||
self.model_states['cob_rl']['checkpoint_loaded'] = True
|
||||
self.model_states['cob_rl']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
self.model_states['cob_rl']['checkpoint_filename'] = checkpoint_metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "unknown"
|
||||
logger.info(f"COB RL checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
||||
logger.info(f"COB RL checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading COB RL checkpoint: {e}")
|
||||
|
||||
@ -1247,51 +1256,210 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error storing inference data for {model_name}: {e}")
|
||||
|
||||
async def _save_inference_to_disk_async(self, model_name: str, inference_record: Dict):
|
||||
"""Async save inference record to disk with file capping"""
|
||||
"""Async save inference record to SQLite database and model-specific log"""
|
||||
try:
|
||||
# Create model-specific directory
|
||||
model_dir = Path(f"training_data/inference_history/{model_name}")
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Use SQLite for comprehensive storage
|
||||
await self._save_to_sqlite_db(model_name, inference_record)
|
||||
|
||||
# Create filename with timestamp
|
||||
timestamp_str = datetime.fromisoformat(inference_record['timestamp']).strftime('%Y%m%d_%H%M%S_%f')[:-3]
|
||||
filename = f"inference_{timestamp_str}.json"
|
||||
filepath = model_dir / filename
|
||||
|
||||
# Convert to JSON-serializable format
|
||||
serializable_record = self._make_json_serializable(inference_record)
|
||||
|
||||
# Save to file
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(serializable_record, f, indent=2)
|
||||
|
||||
# Cap files per model (keep only latest 200)
|
||||
await self._cap_model_files(model_dir)
|
||||
|
||||
logger.debug(f"Saved inference record to disk: {filepath}")
|
||||
# Also save key metrics to model-specific log for debugging
|
||||
await self._save_to_model_log(model_name, inference_record)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving inference to disk for {model_name}: {e}")
|
||||
|
||||
async def _cap_model_files(self, model_dir: Path):
|
||||
"""Cap the number of files per model to max_disk_files_per_model"""
|
||||
async def _save_to_sqlite_db(self, model_name: str, inference_record: Dict):
|
||||
"""Save inference record to SQLite database"""
|
||||
import sqlite3
|
||||
import asyncio
|
||||
|
||||
def save_to_db():
|
||||
try:
|
||||
# Create database directory
|
||||
db_dir = Path("training_data/inference_db")
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Connect to SQLite database
|
||||
db_path = db_dir / "inference_history.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create table if it doesn't exist
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS inference_records (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
confidence REAL NOT NULL,
|
||||
probabilities TEXT,
|
||||
timeframe TEXT,
|
||||
metadata TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
# Create index for faster queries
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_model_timestamp
|
||||
ON inference_records(model_name, timestamp)
|
||||
''')
|
||||
|
||||
# Extract data from inference record
|
||||
prediction = inference_record.get('prediction', {})
|
||||
probabilities_str = str(prediction.get('probabilities', {}))
|
||||
metadata_str = str(inference_record.get('metadata', {}))
|
||||
|
||||
# Insert record
|
||||
cursor.execute('''
|
||||
INSERT INTO inference_records
|
||||
(model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
model_name,
|
||||
inference_record.get('symbol', 'ETH/USDT'),
|
||||
inference_record.get('timestamp', ''),
|
||||
prediction.get('action', 'HOLD'),
|
||||
prediction.get('confidence', 0.0),
|
||||
probabilities_str,
|
||||
prediction.get('timeframe', '1m'),
|
||||
metadata_str
|
||||
))
|
||||
|
||||
# Clean up old records (keep only last 1000 per model)
|
||||
cursor.execute('''
|
||||
DELETE FROM inference_records
|
||||
WHERE model_name = ? AND id NOT IN (
|
||||
SELECT id FROM inference_records
|
||||
WHERE model_name = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1000
|
||||
)
|
||||
''', (model_name, model_name))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving to SQLite database: {e}")
|
||||
|
||||
# Run database operation in thread pool to avoid blocking
|
||||
await asyncio.get_event_loop().run_in_executor(None, save_to_db)
|
||||
|
||||
async def _save_to_model_log(self, model_name: str, inference_record: Dict):
|
||||
"""Save key inference metrics to model-specific log file for debugging"""
|
||||
import asyncio
|
||||
|
||||
def save_to_log():
|
||||
try:
|
||||
# Create logs directory
|
||||
logs_dir = Path("logs/model_inference")
|
||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create model-specific log file
|
||||
log_file = logs_dir / f"{model_name}_inference.log"
|
||||
|
||||
# Extract key metrics
|
||||
prediction = inference_record.get('prediction', {})
|
||||
timestamp = inference_record.get('timestamp', '')
|
||||
symbol = inference_record.get('symbol', 'N/A')
|
||||
|
||||
# Format log entry with key metrics
|
||||
log_entry = (
|
||||
f"{timestamp} | "
|
||||
f"Symbol: {symbol} | "
|
||||
f"Action: {prediction.get('action', 'N/A'):4} | "
|
||||
f"Confidence: {prediction.get('confidence', 0.0):6.3f} | "
|
||||
f"Timeframe: {prediction.get('timeframe', 'N/A'):3} | "
|
||||
f"Probs: BUY={prediction.get('probabilities', {}).get('BUY', 0.0):5.3f} "
|
||||
f"SELL={prediction.get('probabilities', {}).get('SELL', 0.0):5.3f} "
|
||||
f"HOLD={prediction.get('probabilities', {}).get('HOLD', 0.0):5.3f}\n"
|
||||
)
|
||||
|
||||
# Append to log file
|
||||
with open(log_file, 'a', encoding='utf-8') as f:
|
||||
f.write(log_entry)
|
||||
|
||||
# Keep log files manageable (rotate when > 10MB)
|
||||
if log_file.stat().st_size > 10 * 1024 * 1024: # 10MB
|
||||
self._rotate_log_file(log_file)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving to model log: {e}")
|
||||
|
||||
# Run log operation in thread pool to avoid blocking
|
||||
await asyncio.get_event_loop().run_in_executor(None, save_to_log)
|
||||
|
||||
def _rotate_log_file(self, log_file: Path):
|
||||
"""Rotate log file when it gets too large"""
|
||||
try:
|
||||
# Get all inference files
|
||||
files = list(model_dir.glob("inference_*.json"))
|
||||
# Keep last 1000 lines
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
if len(files) > self.max_disk_files_per_model:
|
||||
# Sort by modification time (oldest first)
|
||||
files.sort(key=lambda x: x.stat().st_mtime)
|
||||
|
||||
# Remove oldest files
|
||||
files_to_remove = files[:-self.max_disk_files_per_model]
|
||||
for file_path in files_to_remove:
|
||||
file_path.unlink()
|
||||
|
||||
logger.debug(f"Removed {len(files_to_remove)} old inference files from {model_dir.name}")
|
||||
# Write back only the last 1000 lines
|
||||
with open(log_file, 'w', encoding='utf-8') as f:
|
||||
f.writelines(lines[-1000:])
|
||||
|
||||
logger.debug(f"Rotated log file {log_file.name} (kept last 1000 lines)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error capping model files in {model_dir}: {e}")
|
||||
logger.error(f"Error rotating log file {log_file}: {e}")
|
||||
|
||||
def get_inference_records_from_db(self, model_name: str = None, limit: int = 100) -> List[Dict]:
|
||||
"""Get inference records from SQLite database"""
|
||||
import sqlite3
|
||||
|
||||
try:
|
||||
# Connect to database
|
||||
db_path = Path("training_data/inference_db/inference_history.db")
|
||||
if not db_path.exists():
|
||||
return []
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Query records
|
||||
if model_name:
|
||||
cursor.execute('''
|
||||
SELECT model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata
|
||||
FROM inference_records
|
||||
WHERE model_name = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
''', (model_name, limit))
|
||||
else:
|
||||
cursor.execute('''
|
||||
SELECT model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata
|
||||
FROM inference_records
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
''', (limit,))
|
||||
|
||||
records = []
|
||||
for row in cursor.fetchall():
|
||||
record = {
|
||||
'model_name': row[0],
|
||||
'symbol': row[1],
|
||||
'timestamp': row[2],
|
||||
'prediction': {
|
||||
'action': row[3],
|
||||
'confidence': row[4],
|
||||
'probabilities': eval(row[5]) if row[5] else {},
|
||||
'timeframe': row[6]
|
||||
},
|
||||
'metadata': eval(row[7]) if row[7] else {}
|
||||
}
|
||||
records.append(record)
|
||||
|
||||
conn.close()
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error querying SQLite database: {e}")
|
||||
return []
|
||||
|
||||
|
||||
|
||||
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor:
|
||||
"""Prepare standardized input data for CNN models with proper GPU device placement"""
|
||||
@ -1472,67 +1640,60 @@ class TradingOrchestrator:
|
||||
return obj
|
||||
|
||||
def load_inference_history_from_disk(self, symbol: str, days_back: int = 7) -> List[Dict]:
|
||||
"""Load inference history from disk for training replay"""
|
||||
"""Load inference history from SQLite database for training replay"""
|
||||
try:
|
||||
inference_dir = Path("training_data/inference_history")
|
||||
if not inference_dir.exists():
|
||||
import sqlite3
|
||||
|
||||
# Connect to database
|
||||
db_path = Path("training_data/inference_db/inference_history.db")
|
||||
if not db_path.exists():
|
||||
return []
|
||||
|
||||
# Get files for the symbol from the last N days
|
||||
cutoff_date = datetime.now() - timedelta(days=days_back)
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get records for the symbol from the last N days
|
||||
cutoff_date = (datetime.now() - timedelta(days=days_back)).isoformat()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata
|
||||
FROM inference_records
|
||||
WHERE symbol = ? AND timestamp >= ?
|
||||
ORDER BY timestamp ASC
|
||||
''', (symbol, cutoff_date))
|
||||
|
||||
inference_records = []
|
||||
for row in cursor.fetchall():
|
||||
record = {
|
||||
'model_name': row[0],
|
||||
'symbol': row[1],
|
||||
'timestamp': row[2],
|
||||
'prediction': {
|
||||
'action': row[3],
|
||||
'confidence': row[4],
|
||||
'probabilities': eval(row[5]) if row[5] else {},
|
||||
'timeframe': row[6]
|
||||
},
|
||||
'metadata': eval(row[7]) if row[7] else {}
|
||||
}
|
||||
inference_records.append(record)
|
||||
|
||||
for filepath in inference_dir.glob(f"{symbol}_*.json"):
|
||||
try:
|
||||
# Extract timestamp from filename
|
||||
filename_parts = filepath.stem.split('_')
|
||||
if len(filename_parts) >= 3:
|
||||
timestamp_str = f"{filename_parts[-2]}_{filename_parts[-1]}"
|
||||
file_timestamp = datetime.strptime(timestamp_str, '%Y%m%d_%H%M%S')
|
||||
|
||||
if file_timestamp >= cutoff_date:
|
||||
with open(filepath, 'r') as f:
|
||||
record = json.load(f)
|
||||
inference_records.append(record)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading inference file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
# Sort by timestamp
|
||||
inference_records.sort(key=lambda x: x['timestamp'])
|
||||
logger.info(f"Loaded {len(inference_records)} inference records for {symbol} from disk")
|
||||
conn.close()
|
||||
logger.info(f"Loaded {len(inference_records)} inference records for {symbol} from SQLite database")
|
||||
|
||||
return inference_records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading inference history from disk: {e}")
|
||||
logger.error(f"Error loading inference history from database: {e}")
|
||||
return []
|
||||
|
||||
async def load_model_inference_history(self, model_name: str, limit: int = 50) -> List[Dict]:
|
||||
"""Load inference history for a specific model from disk"""
|
||||
"""Load inference history for a specific model from SQLite database"""
|
||||
try:
|
||||
model_dir = Path(f"training_data/inference_history/{model_name}")
|
||||
if not model_dir.exists():
|
||||
return []
|
||||
|
||||
# Get all inference files
|
||||
files = list(model_dir.glob("inference_*.json"))
|
||||
files.sort(key=lambda x: x.stat().st_mtime, reverse=True) # Newest first
|
||||
|
||||
# Load up to 'limit' files
|
||||
inference_records = []
|
||||
for filepath in files[:limit]:
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
record = json.load(f)
|
||||
inference_records.append(record)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading inference file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Loaded {len(inference_records)} inference records for {model_name}")
|
||||
return inference_records
|
||||
# Use the SQLite database method
|
||||
records = self.get_inference_records_from_db(model_name, limit)
|
||||
logger.info(f"Loaded {len(records)} inference records for {model_name} from database")
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model inference history for {model_name}: {e}")
|
||||
@ -3284,6 +3445,15 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error initializing checkpoint manager: {e}")
|
||||
self.checkpoint_manager = None
|
||||
|
||||
def _schedule_database_cleanup(self):
|
||||
"""Schedule periodic database cleanup"""
|
||||
try:
|
||||
# Clean up old inference records (keep 30 days)
|
||||
self.inference_logger.cleanup_old_logs(days_to_keep=30)
|
||||
logger.info("Database cleanup completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Database cleanup failed: {e}")
|
||||
|
||||
def _save_training_checkpoints(self, models_trained: List[str], performance_score: float):
|
||||
"""Save checkpoints for trained models if performance improved
|
||||
|
||||
@ -3419,4 +3589,45 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _save_training_checkpoints: {e}")
|
||||
logger.error(f"Error in _save_training_checkpoints: {e}")
|
||||
def _schedule_database_cleanup(self):
|
||||
"""Schedule periodic database cleanup"""
|
||||
try:
|
||||
# Clean up old inference records (keep 30 days)
|
||||
self.inference_logger.cleanup_old_logs(days_to_keep=30)
|
||||
logger.info("Database cleanup completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Database cleanup failed: {e}")
|
||||
|
||||
def log_model_inference(self, model_name: str, symbol: str, action: str,
|
||||
confidence: float, probabilities: Dict[str, float],
|
||||
input_features: Any, processing_time_ms: float,
|
||||
checkpoint_id: str = None, metadata: Dict[str, Any] = None) -> bool:
|
||||
"""
|
||||
Centralized method for models to log their inferences
|
||||
|
||||
This replaces scattered logger.info() calls throughout the codebase
|
||||
"""
|
||||
return log_model_inference(
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
probabilities=probabilities,
|
||||
input_features=input_features,
|
||||
processing_time_ms=processing_time_ms,
|
||||
checkpoint_id=checkpoint_id,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
def get_model_inference_stats(self, model_name: str, hours: int = 24) -> Dict[str, Any]:
|
||||
"""Get inference statistics for a model"""
|
||||
return self.inference_logger.get_model_stats(model_name, hours)
|
||||
|
||||
def get_checkpoint_metadata_fast(self, model_name: str) -> Optional[Any]:
|
||||
"""
|
||||
Get checkpoint metadata without loading the full model
|
||||
|
||||
This is much faster than loading the entire checkpoint just to get metadata
|
||||
"""
|
||||
return self.db_manager.get_best_checkpoint_metadata(model_name)
|
Reference in New Issue
Block a user