Files
gogo2/ANNOTATE/core/training_data_fetcher.py
Dobromir Popov 0225f4df58 wip wip wip
2025-10-23 18:57:07 +03:00

300 lines
12 KiB
Python

"""
Training Data Fetcher - Dynamic OHLCV data retrieval for model training
Fetches ±5 minutes of OHLCV data around annotated events from cache/database
instead of storing it in JSON files. This allows efficient training on optimal timing.
"""
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
import pandas as pd
import numpy as np
import pytz
logger = logging.getLogger(__name__)
class TrainingDataFetcher:
"""
Fetches training data dynamically from cache/database for annotated events.
Key Features:
- Fetches ±5 minutes of OHLCV data around entry/exit points
- Generates training labels for optimal timing detection
- Supports multiple timeframes (1s, 1m, 1h, 1d)
- Efficient memory usage (no JSON storage)
- Real-time data from cache/database
"""
def __init__(self, data_provider):
"""
Initialize training data fetcher
Args:
data_provider: DataProvider instance for fetching OHLCV data
"""
self.data_provider = data_provider
logger.info("TrainingDataFetcher initialized")
def fetch_training_data_for_annotation(self, annotation: Dict,
context_window_minutes: int = 5) -> Dict[str, Any]:
"""
Fetch complete training data for an annotation
Args:
annotation: Annotation metadata (from annotations_db.json)
context_window_minutes: Minutes before/after entry to include
Returns:
Dict with market_state, training_labels, and expected_outcome
"""
try:
# Parse timestamps
entry_time = datetime.fromisoformat(annotation['entry']['timestamp'].replace('Z', '+00:00'))
exit_time = datetime.fromisoformat(annotation['exit']['timestamp'].replace('Z', '+00:00'))
symbol = annotation['symbol']
direction = annotation['direction']
logger.info(f"Fetching training data for {symbol} at {entry_time}{context_window_minutes}min)")
# Fetch OHLCV data for all timeframes around entry time
market_state = self._fetch_market_state_at_time(
symbol=symbol,
timestamp=entry_time,
context_window_minutes=context_window_minutes
)
# Generate training labels for optimal timing detection
training_labels = self._generate_timing_labels(
market_state=market_state,
entry_time=entry_time,
exit_time=exit_time,
direction=direction
)
# Prepare expected outcome
expected_outcome = {
"direction": direction,
"profit_loss_pct": annotation['profit_loss_pct'],
"entry_price": annotation['entry']['price'],
"exit_price": annotation['exit']['price'],
"holding_period_seconds": (exit_time - entry_time).total_seconds()
}
return {
"test_case_id": f"annotation_{annotation['annotation_id']}",
"symbol": symbol,
"timestamp": annotation['entry']['timestamp'],
"action": "BUY" if direction == "LONG" else "SELL",
"market_state": market_state,
"training_labels": training_labels,
"expected_outcome": expected_outcome,
"annotation_metadata": {
"annotator": "manual",
"confidence": 1.0,
"notes": annotation.get('notes', ''),
"created_at": annotation.get('created_at'),
"timeframe": annotation.get('timeframe', '1m')
}
}
except Exception as e:
logger.error(f"Error fetching training data for annotation: {e}")
import traceback
traceback.print_exc()
return {}
def _fetch_market_state_at_time(self, symbol: str, timestamp: datetime,
context_window_minutes: int) -> Dict[str, Any]:
"""
Fetch market state at specific time from cache/database
Args:
symbol: Trading symbol
timestamp: Target timestamp
context_window_minutes: Minutes before/after to include
Returns:
Dict with OHLCV data for all timeframes
"""
try:
# Use data provider's method to get market state
market_state = self.data_provider.get_market_state_at_time(
symbol=symbol,
timestamp=timestamp,
context_window_minutes=context_window_minutes
)
logger.info(f"Fetched market state with {len(market_state)} timeframes")
return market_state
except Exception as e:
logger.error(f"Error fetching market state: {e}")
return {}
def _generate_timing_labels(self, market_state: Dict, entry_time: datetime,
exit_time: datetime, direction: str) -> Dict[str, Any]:
"""
Generate training labels for optimal timing detection
Labels help model learn:
- WHEN to enter (optimal entry timing)
- WHEN to exit (optimal exit timing)
- WHEN NOT to trade (avoid bad timing)
Args:
market_state: OHLCV data for all timeframes
entry_time: Entry timestamp
exit_time: Exit timestamp
direction: Trade direction (LONG/SHORT)
Returns:
Dict with training labels for each timeframe
"""
labels = {
'direction': direction,
'entry_timestamp': entry_time.strftime('%Y-%m-%d %H:%M:%S'),
'exit_timestamp': exit_time.strftime('%Y-%m-%d %H:%M:%S')
}
# Generate labels for each timeframe
timeframes = ['1s', '1m', '1h', '1d']
for tf in timeframes:
tf_key = f'ohlcv_{tf}'
if tf_key in market_state and 'timestamps' in market_state[tf_key]:
timestamps = market_state[tf_key]['timestamps']
label_list = []
entry_idx = -1
exit_idx = -1
for i, ts_str in enumerate(timestamps):
try:
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
# Make timezone-aware
if ts.tzinfo is None:
ts = pytz.UTC.localize(ts)
# Make entry_time and exit_time timezone-aware if needed
if entry_time.tzinfo is None:
entry_time = pytz.UTC.localize(entry_time)
if exit_time.tzinfo is None:
exit_time = pytz.UTC.localize(exit_time)
# Determine label based on timing
if abs((ts - entry_time).total_seconds()) < 60: # Within 1 minute of entry
label = 1 # OPTIMAL ENTRY TIMING
entry_idx = i
elif abs((ts - exit_time).total_seconds()) < 60: # Within 1 minute of exit
label = 3 # OPTIMAL EXIT TIMING
exit_idx = i
elif entry_time < ts < exit_time: # Between entry and exit
label = 2 # HOLD POSITION
else: # Before entry or after exit
label = 0 # NO ACTION (avoid trading)
label_list.append(label)
except Exception as e:
logger.error(f"Error parsing timestamp {ts_str}: {e}")
label_list.append(0)
labels[f'labels_{tf}'] = label_list
labels[f'entry_index_{tf}'] = entry_idx
labels[f'exit_index_{tf}'] = exit_idx
# Log label distribution
label_counts = {0: 0, 1: 0, 2: 0, 3: 0}
for label in label_list:
label_counts[label] += 1
logger.info(f"Generated {tf} labels: {label_counts[0]} NO_ACTION, "
f"{label_counts[1]} ENTRY, {label_counts[2]} HOLD, {label_counts[3]} EXIT")
return labels
def fetch_training_batch(self, annotations: List[Dict],
context_window_minutes: int = 5) -> List[Dict[str, Any]]:
"""
Fetch training data for multiple annotations
Args:
annotations: List of annotation metadata
context_window_minutes: Minutes before/after entry to include
Returns:
List of training data dictionaries
"""
training_data = []
logger.info(f"Fetching training batch for {len(annotations)} annotations")
for annotation in annotations:
try:
training_sample = self.fetch_training_data_for_annotation(
annotation, context_window_minutes
)
if training_sample:
training_data.append(training_sample)
else:
logger.warning(f"Failed to fetch training data for annotation {annotation.get('annotation_id')}")
except Exception as e:
logger.error(f"Error processing annotation {annotation.get('annotation_id')}: {e}")
logger.info(f"Successfully fetched training data for {len(training_data)}/{len(annotations)} annotations")
return training_data
def get_training_statistics(self, training_data: List[Dict]) -> Dict[str, Any]:
"""
Get statistics about training data
Args:
training_data: List of training data samples
Returns:
Dict with training statistics
"""
if not training_data:
return {}
stats = {
'total_samples': len(training_data),
'symbols': {},
'directions': {'LONG': 0, 'SHORT': 0},
'avg_profit_loss': 0.0,
'timeframes_available': set()
}
total_pnl = 0.0
for sample in training_data:
symbol = sample.get('symbol', 'UNKNOWN')
direction = sample.get('expected_outcome', {}).get('direction', 'UNKNOWN')
pnl = sample.get('expected_outcome', {}).get('profit_loss_pct', 0.0)
# Count symbols
stats['symbols'][symbol] = stats['symbols'].get(symbol, 0) + 1
# Count directions
if direction in stats['directions']:
stats['directions'][direction] += 1
# Accumulate P&L
total_pnl += pnl
# Check available timeframes
market_state = sample.get('market_state', {})
for key in market_state.keys():
if key.startswith('ohlcv_'):
stats['timeframes_available'].add(key.replace('ohlcv_', ''))
stats['avg_profit_loss'] = total_pnl / len(training_data)
stats['timeframes_available'] = list(stats['timeframes_available'])
return stats