300 lines
12 KiB
Python
300 lines
12 KiB
Python
"""
|
|
Training Data Fetcher - Dynamic OHLCV data retrieval for model training
|
|
|
|
Fetches ±5 minutes of OHLCV data around annotated events from cache/database
|
|
instead of storing it in JSON files. This allows efficient training on optimal timing.
|
|
"""
|
|
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any, Tuple
|
|
import pandas as pd
|
|
import numpy as np
|
|
import pytz
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TrainingDataFetcher:
|
|
"""
|
|
Fetches training data dynamically from cache/database for annotated events.
|
|
|
|
Key Features:
|
|
- Fetches ±5 minutes of OHLCV data around entry/exit points
|
|
- Generates training labels for optimal timing detection
|
|
- Supports multiple timeframes (1s, 1m, 1h, 1d)
|
|
- Efficient memory usage (no JSON storage)
|
|
- Real-time data from cache/database
|
|
"""
|
|
|
|
def __init__(self, data_provider):
|
|
"""
|
|
Initialize training data fetcher
|
|
|
|
Args:
|
|
data_provider: DataProvider instance for fetching OHLCV data
|
|
"""
|
|
self.data_provider = data_provider
|
|
logger.info("TrainingDataFetcher initialized")
|
|
|
|
def fetch_training_data_for_annotation(self, annotation: Dict,
|
|
context_window_minutes: int = 5) -> Dict[str, Any]:
|
|
"""
|
|
Fetch complete training data for an annotation
|
|
|
|
Args:
|
|
annotation: Annotation metadata (from annotations_db.json)
|
|
context_window_minutes: Minutes before/after entry to include
|
|
|
|
Returns:
|
|
Dict with market_state, training_labels, and expected_outcome
|
|
"""
|
|
try:
|
|
# Parse timestamps
|
|
entry_time = datetime.fromisoformat(annotation['entry']['timestamp'].replace('Z', '+00:00'))
|
|
exit_time = datetime.fromisoformat(annotation['exit']['timestamp'].replace('Z', '+00:00'))
|
|
|
|
symbol = annotation['symbol']
|
|
direction = annotation['direction']
|
|
|
|
logger.info(f"Fetching training data for {symbol} at {entry_time} (±{context_window_minutes}min)")
|
|
|
|
# Fetch OHLCV data for all timeframes around entry time
|
|
market_state = self._fetch_market_state_at_time(
|
|
symbol=symbol,
|
|
timestamp=entry_time,
|
|
context_window_minutes=context_window_minutes
|
|
)
|
|
|
|
# Generate training labels for optimal timing detection
|
|
training_labels = self._generate_timing_labels(
|
|
market_state=market_state,
|
|
entry_time=entry_time,
|
|
exit_time=exit_time,
|
|
direction=direction
|
|
)
|
|
|
|
# Prepare expected outcome
|
|
expected_outcome = {
|
|
"direction": direction,
|
|
"profit_loss_pct": annotation['profit_loss_pct'],
|
|
"entry_price": annotation['entry']['price'],
|
|
"exit_price": annotation['exit']['price'],
|
|
"holding_period_seconds": (exit_time - entry_time).total_seconds()
|
|
}
|
|
|
|
return {
|
|
"test_case_id": f"annotation_{annotation['annotation_id']}",
|
|
"symbol": symbol,
|
|
"timestamp": annotation['entry']['timestamp'],
|
|
"action": "BUY" if direction == "LONG" else "SELL",
|
|
"market_state": market_state,
|
|
"training_labels": training_labels,
|
|
"expected_outcome": expected_outcome,
|
|
"annotation_metadata": {
|
|
"annotator": "manual",
|
|
"confidence": 1.0,
|
|
"notes": annotation.get('notes', ''),
|
|
"created_at": annotation.get('created_at'),
|
|
"timeframe": annotation.get('timeframe', '1m')
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching training data for annotation: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return {}
|
|
|
|
def _fetch_market_state_at_time(self, symbol: str, timestamp: datetime,
|
|
context_window_minutes: int) -> Dict[str, Any]:
|
|
"""
|
|
Fetch market state at specific time from cache/database
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
timestamp: Target timestamp
|
|
context_window_minutes: Minutes before/after to include
|
|
|
|
Returns:
|
|
Dict with OHLCV data for all timeframes
|
|
"""
|
|
try:
|
|
# Use data provider's method to get market state
|
|
market_state = self.data_provider.get_market_state_at_time(
|
|
symbol=symbol,
|
|
timestamp=timestamp,
|
|
context_window_minutes=context_window_minutes
|
|
)
|
|
|
|
logger.info(f"Fetched market state with {len(market_state)} timeframes")
|
|
return market_state
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error fetching market state: {e}")
|
|
return {}
|
|
|
|
def _generate_timing_labels(self, market_state: Dict, entry_time: datetime,
|
|
exit_time: datetime, direction: str) -> Dict[str, Any]:
|
|
"""
|
|
Generate training labels for optimal timing detection
|
|
|
|
Labels help model learn:
|
|
- WHEN to enter (optimal entry timing)
|
|
- WHEN to exit (optimal exit timing)
|
|
- WHEN NOT to trade (avoid bad timing)
|
|
|
|
Args:
|
|
market_state: OHLCV data for all timeframes
|
|
entry_time: Entry timestamp
|
|
exit_time: Exit timestamp
|
|
direction: Trade direction (LONG/SHORT)
|
|
|
|
Returns:
|
|
Dict with training labels for each timeframe
|
|
"""
|
|
labels = {
|
|
'direction': direction,
|
|
'entry_timestamp': entry_time.strftime('%Y-%m-%d %H:%M:%S'),
|
|
'exit_timestamp': exit_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
}
|
|
|
|
# Generate labels for each timeframe
|
|
timeframes = ['1s', '1m', '1h', '1d']
|
|
|
|
for tf in timeframes:
|
|
tf_key = f'ohlcv_{tf}'
|
|
if tf_key in market_state and 'timestamps' in market_state[tf_key]:
|
|
timestamps = market_state[tf_key]['timestamps']
|
|
|
|
label_list = []
|
|
entry_idx = -1
|
|
exit_idx = -1
|
|
|
|
for i, ts_str in enumerate(timestamps):
|
|
try:
|
|
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
|
|
# Make timezone-aware
|
|
if ts.tzinfo is None:
|
|
ts = pytz.UTC.localize(ts)
|
|
|
|
# Make entry_time and exit_time timezone-aware if needed
|
|
if entry_time.tzinfo is None:
|
|
entry_time = pytz.UTC.localize(entry_time)
|
|
if exit_time.tzinfo is None:
|
|
exit_time = pytz.UTC.localize(exit_time)
|
|
|
|
# Determine label based on timing
|
|
if abs((ts - entry_time).total_seconds()) < 60: # Within 1 minute of entry
|
|
label = 1 # OPTIMAL ENTRY TIMING
|
|
entry_idx = i
|
|
elif abs((ts - exit_time).total_seconds()) < 60: # Within 1 minute of exit
|
|
label = 3 # OPTIMAL EXIT TIMING
|
|
exit_idx = i
|
|
elif entry_time < ts < exit_time: # Between entry and exit
|
|
label = 2 # HOLD POSITION
|
|
else: # Before entry or after exit
|
|
label = 0 # NO ACTION (avoid trading)
|
|
|
|
label_list.append(label)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error parsing timestamp {ts_str}: {e}")
|
|
label_list.append(0)
|
|
|
|
labels[f'labels_{tf}'] = label_list
|
|
labels[f'entry_index_{tf}'] = entry_idx
|
|
labels[f'exit_index_{tf}'] = exit_idx
|
|
|
|
# Log label distribution
|
|
label_counts = {0: 0, 1: 0, 2: 0, 3: 0}
|
|
for label in label_list:
|
|
label_counts[label] += 1
|
|
|
|
logger.info(f"Generated {tf} labels: {label_counts[0]} NO_ACTION, "
|
|
f"{label_counts[1]} ENTRY, {label_counts[2]} HOLD, {label_counts[3]} EXIT")
|
|
|
|
return labels
|
|
|
|
def fetch_training_batch(self, annotations: List[Dict],
|
|
context_window_minutes: int = 5) -> List[Dict[str, Any]]:
|
|
"""
|
|
Fetch training data for multiple annotations
|
|
|
|
Args:
|
|
annotations: List of annotation metadata
|
|
context_window_minutes: Minutes before/after entry to include
|
|
|
|
Returns:
|
|
List of training data dictionaries
|
|
"""
|
|
training_data = []
|
|
|
|
logger.info(f"Fetching training batch for {len(annotations)} annotations")
|
|
|
|
for annotation in annotations:
|
|
try:
|
|
training_sample = self.fetch_training_data_for_annotation(
|
|
annotation, context_window_minutes
|
|
)
|
|
|
|
if training_sample:
|
|
training_data.append(training_sample)
|
|
else:
|
|
logger.warning(f"Failed to fetch training data for annotation {annotation.get('annotation_id')}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing annotation {annotation.get('annotation_id')}: {e}")
|
|
|
|
logger.info(f"Successfully fetched training data for {len(training_data)}/{len(annotations)} annotations")
|
|
return training_data
|
|
|
|
def get_training_statistics(self, training_data: List[Dict]) -> Dict[str, Any]:
|
|
"""
|
|
Get statistics about training data
|
|
|
|
Args:
|
|
training_data: List of training data samples
|
|
|
|
Returns:
|
|
Dict with training statistics
|
|
"""
|
|
if not training_data:
|
|
return {}
|
|
|
|
stats = {
|
|
'total_samples': len(training_data),
|
|
'symbols': {},
|
|
'directions': {'LONG': 0, 'SHORT': 0},
|
|
'avg_profit_loss': 0.0,
|
|
'timeframes_available': set()
|
|
}
|
|
|
|
total_pnl = 0.0
|
|
|
|
for sample in training_data:
|
|
symbol = sample.get('symbol', 'UNKNOWN')
|
|
direction = sample.get('expected_outcome', {}).get('direction', 'UNKNOWN')
|
|
pnl = sample.get('expected_outcome', {}).get('profit_loss_pct', 0.0)
|
|
|
|
# Count symbols
|
|
stats['symbols'][symbol] = stats['symbols'].get(symbol, 0) + 1
|
|
|
|
# Count directions
|
|
if direction in stats['directions']:
|
|
stats['directions'][direction] += 1
|
|
|
|
# Accumulate P&L
|
|
total_pnl += pnl
|
|
|
|
# Check available timeframes
|
|
market_state = sample.get('market_state', {})
|
|
for key in market_state.keys():
|
|
if key.startswith('ohlcv_'):
|
|
stats['timeframes_available'].add(key.replace('ohlcv_', ''))
|
|
|
|
stats['avg_profit_loss'] = total_pnl / len(training_data)
|
|
stats['timeframes_available'] = list(stats['timeframes_available'])
|
|
|
|
return stats
|