""" Training Data Fetcher - Dynamic OHLCV data retrieval for model training Fetches ±5 minutes of OHLCV data around annotated events from cache/database instead of storing it in JSON files. This allows efficient training on optimal timing. """ import logging from datetime import datetime, timedelta from typing import Dict, List, Optional, Any, Tuple import pandas as pd import numpy as np import pytz logger = logging.getLogger(__name__) class TrainingDataFetcher: """ Fetches training data dynamically from cache/database for annotated events. Key Features: - Fetches ±5 minutes of OHLCV data around entry/exit points - Generates training labels for optimal timing detection - Supports multiple timeframes (1s, 1m, 1h, 1d) - Efficient memory usage (no JSON storage) - Real-time data from cache/database """ def __init__(self, data_provider): """ Initialize training data fetcher Args: data_provider: DataProvider instance for fetching OHLCV data """ self.data_provider = data_provider logger.info("TrainingDataFetcher initialized") def fetch_training_data_for_annotation(self, annotation: Dict, context_window_minutes: int = 5) -> Dict[str, Any]: """ Fetch complete training data for an annotation Args: annotation: Annotation metadata (from annotations_db.json) context_window_minutes: Minutes before/after entry to include Returns: Dict with market_state, training_labels, and expected_outcome """ try: # Parse timestamps entry_time = datetime.fromisoformat(annotation['entry']['timestamp'].replace('Z', '+00:00')) exit_time = datetime.fromisoformat(annotation['exit']['timestamp'].replace('Z', '+00:00')) symbol = annotation['symbol'] direction = annotation['direction'] logger.info(f"Fetching training data for {symbol} at {entry_time} (±{context_window_minutes}min)") # Fetch OHLCV data for all timeframes around entry time market_state = self._fetch_market_state_at_time( symbol=symbol, timestamp=entry_time, context_window_minutes=context_window_minutes ) # Generate training labels for optimal timing detection training_labels = self._generate_timing_labels( market_state=market_state, entry_time=entry_time, exit_time=exit_time, direction=direction ) # Prepare expected outcome expected_outcome = { "direction": direction, "profit_loss_pct": annotation['profit_loss_pct'], "entry_price": annotation['entry']['price'], "exit_price": annotation['exit']['price'], "holding_period_seconds": (exit_time - entry_time).total_seconds() } return { "test_case_id": f"annotation_{annotation['annotation_id']}", "symbol": symbol, "timestamp": annotation['entry']['timestamp'], "action": "BUY" if direction == "LONG" else "SELL", "market_state": market_state, "training_labels": training_labels, "expected_outcome": expected_outcome, "annotation_metadata": { "annotator": "manual", "confidence": 1.0, "notes": annotation.get('notes', ''), "created_at": annotation.get('created_at'), "timeframe": annotation.get('timeframe', '1m') } } except Exception as e: logger.error(f"Error fetching training data for annotation: {e}") import traceback traceback.print_exc() return {} def _fetch_market_state_at_time(self, symbol: str, timestamp: datetime, context_window_minutes: int) -> Dict[str, Any]: """ Fetch market state at specific time from cache/database Args: symbol: Trading symbol timestamp: Target timestamp context_window_minutes: Minutes before/after to include Returns: Dict with OHLCV data for all timeframes """ try: # Use data provider's method to get market state market_state = self.data_provider.get_market_state_at_time( symbol=symbol, timestamp=timestamp, context_window_minutes=context_window_minutes ) logger.info(f"Fetched market state with {len(market_state)} timeframes") return market_state except Exception as e: logger.error(f"Error fetching market state: {e}") return {} def _generate_timing_labels(self, market_state: Dict, entry_time: datetime, exit_time: datetime, direction: str) -> Dict[str, Any]: """ Generate training labels for optimal timing detection Labels help model learn: - WHEN to enter (optimal entry timing) - WHEN to exit (optimal exit timing) - WHEN NOT to trade (avoid bad timing) Args: market_state: OHLCV data for all timeframes entry_time: Entry timestamp exit_time: Exit timestamp direction: Trade direction (LONG/SHORT) Returns: Dict with training labels for each timeframe """ labels = { 'direction': direction, 'entry_timestamp': entry_time.strftime('%Y-%m-%d %H:%M:%S'), 'exit_timestamp': exit_time.strftime('%Y-%m-%d %H:%M:%S') } # Generate labels for each timeframe timeframes = ['1s', '1m', '1h', '1d'] for tf in timeframes: tf_key = f'ohlcv_{tf}' if tf_key in market_state and 'timestamps' in market_state[tf_key]: timestamps = market_state[tf_key]['timestamps'] label_list = [] entry_idx = -1 exit_idx = -1 for i, ts_str in enumerate(timestamps): try: ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S') # Make timezone-aware if ts.tzinfo is None: ts = pytz.UTC.localize(ts) # Make entry_time and exit_time timezone-aware if needed if entry_time.tzinfo is None: entry_time = pytz.UTC.localize(entry_time) if exit_time.tzinfo is None: exit_time = pytz.UTC.localize(exit_time) # Determine label based on timing if abs((ts - entry_time).total_seconds()) < 60: # Within 1 minute of entry label = 1 # OPTIMAL ENTRY TIMING entry_idx = i elif abs((ts - exit_time).total_seconds()) < 60: # Within 1 minute of exit label = 3 # OPTIMAL EXIT TIMING exit_idx = i elif entry_time < ts < exit_time: # Between entry and exit label = 2 # HOLD POSITION else: # Before entry or after exit label = 0 # NO ACTION (avoid trading) label_list.append(label) except Exception as e: logger.error(f"Error parsing timestamp {ts_str}: {e}") label_list.append(0) labels[f'labels_{tf}'] = label_list labels[f'entry_index_{tf}'] = entry_idx labels[f'exit_index_{tf}'] = exit_idx # Log label distribution label_counts = {0: 0, 1: 0, 2: 0, 3: 0} for label in label_list: label_counts[label] += 1 logger.info(f"Generated {tf} labels: {label_counts[0]} NO_ACTION, " f"{label_counts[1]} ENTRY, {label_counts[2]} HOLD, {label_counts[3]} EXIT") return labels def fetch_training_batch(self, annotations: List[Dict], context_window_minutes: int = 5) -> List[Dict[str, Any]]: """ Fetch training data for multiple annotations Args: annotations: List of annotation metadata context_window_minutes: Minutes before/after entry to include Returns: List of training data dictionaries """ training_data = [] logger.info(f"Fetching training batch for {len(annotations)} annotations") for annotation in annotations: try: training_sample = self.fetch_training_data_for_annotation( annotation, context_window_minutes ) if training_sample: training_data.append(training_sample) else: logger.warning(f"Failed to fetch training data for annotation {annotation.get('annotation_id')}") except Exception as e: logger.error(f"Error processing annotation {annotation.get('annotation_id')}: {e}") logger.info(f"Successfully fetched training data for {len(training_data)}/{len(annotations)} annotations") return training_data def get_training_statistics(self, training_data: List[Dict]) -> Dict[str, Any]: """ Get statistics about training data Args: training_data: List of training data samples Returns: Dict with training statistics """ if not training_data: return {} stats = { 'total_samples': len(training_data), 'symbols': {}, 'directions': {'LONG': 0, 'SHORT': 0}, 'avg_profit_loss': 0.0, 'timeframes_available': set() } total_pnl = 0.0 for sample in training_data: symbol = sample.get('symbol', 'UNKNOWN') direction = sample.get('expected_outcome', {}).get('direction', 'UNKNOWN') pnl = sample.get('expected_outcome', {}).get('profit_loss_pct', 0.0) # Count symbols stats['symbols'][symbol] = stats['symbols'].get(symbol, 0) + 1 # Count directions if direction in stats['directions']: stats['directions'][direction] += 1 # Accumulate P&L total_pnl += pnl # Check available timeframes market_state = sample.get('market_state', {}) for key in market_state.keys(): if key.startswith('ohlcv_'): stats['timeframes_available'].add(key.replace('ohlcv_', '')) stats['avg_profit_loss'] = total_pnl / len(training_data) stats['timeframes_available'] = list(stats['timeframes_available']) return stats