inference data storage
This commit is contained in:
@ -181,7 +181,8 @@
|
||||
|
||||
## Model Inference Data Validation and Storage
|
||||
|
||||
- [ ] 5. Implement comprehensive inference data validation system
|
||||
- [x] 5. Implement comprehensive inference data validation system
|
||||
|
||||
- Create InferenceDataValidator class for input validation
|
||||
- Validate complete OHLCV dataframes for all required timeframes
|
||||
- Check input data dimensions against model requirements
|
||||
|
@ -28,6 +28,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
from .config import get_config
|
||||
from .data_provider import DataProvider
|
||||
@ -1311,6 +1312,86 @@ class TradingOrchestrator:
|
||||
else:
|
||||
return obj
|
||||
|
||||
def load_inference_history_from_disk(self, symbol: str, days_back: int = 7) -> List[Dict]:
|
||||
"""Load inference history from disk for training replay"""
|
||||
try:
|
||||
inference_dir = Path("training_data/inference_history")
|
||||
if not inference_dir.exists():
|
||||
return []
|
||||
|
||||
# Get files for the symbol from the last N days
|
||||
cutoff_date = datetime.now() - timedelta(days=days_back)
|
||||
inference_records = []
|
||||
|
||||
for filepath in inference_dir.glob(f"{symbol}_*.json"):
|
||||
try:
|
||||
# Extract timestamp from filename
|
||||
filename_parts = filepath.stem.split('_')
|
||||
if len(filename_parts) >= 3:
|
||||
timestamp_str = f"{filename_parts[-2]}_{filename_parts[-1]}"
|
||||
file_timestamp = datetime.strptime(timestamp_str, '%Y%m%d_%H%M%S')
|
||||
|
||||
if file_timestamp >= cutoff_date:
|
||||
with open(filepath, 'r') as f:
|
||||
record = json.load(f)
|
||||
inference_records.append(record)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading inference file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
# Sort by timestamp
|
||||
inference_records.sort(key=lambda x: x['timestamp'])
|
||||
logger.info(f"Loaded {len(inference_records)} inference records for {symbol} from disk")
|
||||
|
||||
return inference_records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading inference history from disk: {e}")
|
||||
return []
|
||||
|
||||
def get_model_training_data(self, model_name: str, symbol: str = None) -> List[Dict]:
|
||||
"""Get training data for a specific model"""
|
||||
try:
|
||||
training_data = []
|
||||
|
||||
# Get from memory first
|
||||
if symbol:
|
||||
symbols_to_check = [symbol]
|
||||
else:
|
||||
symbols_to_check = self.symbols
|
||||
|
||||
for sym in symbols_to_check:
|
||||
if sym in self.inference_history:
|
||||
for record in self.inference_history[sym]:
|
||||
if record['model_name'] == model_name:
|
||||
training_data.append(record)
|
||||
|
||||
# Also load from disk for more comprehensive training data
|
||||
for sym in symbols_to_check:
|
||||
disk_records = self.load_inference_history_from_disk(sym)
|
||||
for record in disk_records:
|
||||
if record['model_name'] == model_name:
|
||||
training_data.append(record)
|
||||
|
||||
# Remove duplicates and sort by timestamp
|
||||
seen_timestamps = set()
|
||||
unique_data = []
|
||||
for record in training_data:
|
||||
timestamp_key = f"{record['timestamp']}_{record['symbol']}"
|
||||
if timestamp_key not in seen_timestamps:
|
||||
seen_timestamps.add(timestamp_key)
|
||||
unique_data.append(record)
|
||||
|
||||
unique_data.sort(key=lambda x: x['timestamp'])
|
||||
logger.info(f"Retrieved {len(unique_data)} training records for {model_name}")
|
||||
|
||||
return unique_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model training data: {e}")
|
||||
return []
|
||||
|
||||
async def _trigger_model_training(self, symbol: str):
|
||||
"""Trigger training for models based on previous inference data"""
|
||||
try:
|
||||
|
Reference in New Issue
Block a user