wip wip wip
This commit is contained in:
299
ANNOTATE/core/training_data_fetcher.py
Normal file
299
ANNOTATE/core/training_data_fetcher.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Training Data Fetcher - Dynamic OHLCV data retrieval for model training
|
||||
|
||||
Fetches ±5 minutes of OHLCV data around annotated events from cache/database
|
||||
instead of storing it in JSON files. This allows efficient training on optimal timing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytz
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingDataFetcher:
|
||||
"""
|
||||
Fetches training data dynamically from cache/database for annotated events.
|
||||
|
||||
Key Features:
|
||||
- Fetches ±5 minutes of OHLCV data around entry/exit points
|
||||
- Generates training labels for optimal timing detection
|
||||
- Supports multiple timeframes (1s, 1m, 1h, 1d)
|
||||
- Efficient memory usage (no JSON storage)
|
||||
- Real-time data from cache/database
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider):
|
||||
"""
|
||||
Initialize training data fetcher
|
||||
|
||||
Args:
|
||||
data_provider: DataProvider instance for fetching OHLCV data
|
||||
"""
|
||||
self.data_provider = data_provider
|
||||
logger.info("TrainingDataFetcher initialized")
|
||||
|
||||
def fetch_training_data_for_annotation(self, annotation: Dict,
|
||||
context_window_minutes: int = 5) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch complete training data for an annotation
|
||||
|
||||
Args:
|
||||
annotation: Annotation metadata (from annotations_db.json)
|
||||
context_window_minutes: Minutes before/after entry to include
|
||||
|
||||
Returns:
|
||||
Dict with market_state, training_labels, and expected_outcome
|
||||
"""
|
||||
try:
|
||||
# Parse timestamps
|
||||
entry_time = datetime.fromisoformat(annotation['entry']['timestamp'].replace('Z', '+00:00'))
|
||||
exit_time = datetime.fromisoformat(annotation['exit']['timestamp'].replace('Z', '+00:00'))
|
||||
|
||||
symbol = annotation['symbol']
|
||||
direction = annotation['direction']
|
||||
|
||||
logger.info(f"Fetching training data for {symbol} at {entry_time} (±{context_window_minutes}min)")
|
||||
|
||||
# Fetch OHLCV data for all timeframes around entry time
|
||||
market_state = self._fetch_market_state_at_time(
|
||||
symbol=symbol,
|
||||
timestamp=entry_time,
|
||||
context_window_minutes=context_window_minutes
|
||||
)
|
||||
|
||||
# Generate training labels for optimal timing detection
|
||||
training_labels = self._generate_timing_labels(
|
||||
market_state=market_state,
|
||||
entry_time=entry_time,
|
||||
exit_time=exit_time,
|
||||
direction=direction
|
||||
)
|
||||
|
||||
# Prepare expected outcome
|
||||
expected_outcome = {
|
||||
"direction": direction,
|
||||
"profit_loss_pct": annotation['profit_loss_pct'],
|
||||
"entry_price": annotation['entry']['price'],
|
||||
"exit_price": annotation['exit']['price'],
|
||||
"holding_period_seconds": (exit_time - entry_time).total_seconds()
|
||||
}
|
||||
|
||||
return {
|
||||
"test_case_id": f"annotation_{annotation['annotation_id']}",
|
||||
"symbol": symbol,
|
||||
"timestamp": annotation['entry']['timestamp'],
|
||||
"action": "BUY" if direction == "LONG" else "SELL",
|
||||
"market_state": market_state,
|
||||
"training_labels": training_labels,
|
||||
"expected_outcome": expected_outcome,
|
||||
"annotation_metadata": {
|
||||
"annotator": "manual",
|
||||
"confidence": 1.0,
|
||||
"notes": annotation.get('notes', ''),
|
||||
"created_at": annotation.get('created_at'),
|
||||
"timeframe": annotation.get('timeframe', '1m')
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching training data for annotation: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {}
|
||||
|
||||
def _fetch_market_state_at_time(self, symbol: str, timestamp: datetime,
|
||||
context_window_minutes: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch market state at specific time from cache/database
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timestamp: Target timestamp
|
||||
context_window_minutes: Minutes before/after to include
|
||||
|
||||
Returns:
|
||||
Dict with OHLCV data for all timeframes
|
||||
"""
|
||||
try:
|
||||
# Use data provider's method to get market state
|
||||
market_state = self.data_provider.get_market_state_at_time(
|
||||
symbol=symbol,
|
||||
timestamp=timestamp,
|
||||
context_window_minutes=context_window_minutes
|
||||
)
|
||||
|
||||
logger.info(f"Fetched market state with {len(market_state)} timeframes")
|
||||
return market_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching market state: {e}")
|
||||
return {}
|
||||
|
||||
def _generate_timing_labels(self, market_state: Dict, entry_time: datetime,
|
||||
exit_time: datetime, direction: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate training labels for optimal timing detection
|
||||
|
||||
Labels help model learn:
|
||||
- WHEN to enter (optimal entry timing)
|
||||
- WHEN to exit (optimal exit timing)
|
||||
- WHEN NOT to trade (avoid bad timing)
|
||||
|
||||
Args:
|
||||
market_state: OHLCV data for all timeframes
|
||||
entry_time: Entry timestamp
|
||||
exit_time: Exit timestamp
|
||||
direction: Trade direction (LONG/SHORT)
|
||||
|
||||
Returns:
|
||||
Dict with training labels for each timeframe
|
||||
"""
|
||||
labels = {
|
||||
'direction': direction,
|
||||
'entry_timestamp': entry_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'exit_timestamp': exit_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
|
||||
# Generate labels for each timeframe
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
for tf in timeframes:
|
||||
tf_key = f'ohlcv_{tf}'
|
||||
if tf_key in market_state and 'timestamps' in market_state[tf_key]:
|
||||
timestamps = market_state[tf_key]['timestamps']
|
||||
|
||||
label_list = []
|
||||
entry_idx = -1
|
||||
exit_idx = -1
|
||||
|
||||
for i, ts_str in enumerate(timestamps):
|
||||
try:
|
||||
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
|
||||
# Make timezone-aware
|
||||
if ts.tzinfo is None:
|
||||
ts = pytz.UTC.localize(ts)
|
||||
|
||||
# Make entry_time and exit_time timezone-aware if needed
|
||||
if entry_time.tzinfo is None:
|
||||
entry_time = pytz.UTC.localize(entry_time)
|
||||
if exit_time.tzinfo is None:
|
||||
exit_time = pytz.UTC.localize(exit_time)
|
||||
|
||||
# Determine label based on timing
|
||||
if abs((ts - entry_time).total_seconds()) < 60: # Within 1 minute of entry
|
||||
label = 1 # OPTIMAL ENTRY TIMING
|
||||
entry_idx = i
|
||||
elif abs((ts - exit_time).total_seconds()) < 60: # Within 1 minute of exit
|
||||
label = 3 # OPTIMAL EXIT TIMING
|
||||
exit_idx = i
|
||||
elif entry_time < ts < exit_time: # Between entry and exit
|
||||
label = 2 # HOLD POSITION
|
||||
else: # Before entry or after exit
|
||||
label = 0 # NO ACTION (avoid trading)
|
||||
|
||||
label_list.append(label)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing timestamp {ts_str}: {e}")
|
||||
label_list.append(0)
|
||||
|
||||
labels[f'labels_{tf}'] = label_list
|
||||
labels[f'entry_index_{tf}'] = entry_idx
|
||||
labels[f'exit_index_{tf}'] = exit_idx
|
||||
|
||||
# Log label distribution
|
||||
label_counts = {0: 0, 1: 0, 2: 0, 3: 0}
|
||||
for label in label_list:
|
||||
label_counts[label] += 1
|
||||
|
||||
logger.info(f"Generated {tf} labels: {label_counts[0]} NO_ACTION, "
|
||||
f"{label_counts[1]} ENTRY, {label_counts[2]} HOLD, {label_counts[3]} EXIT")
|
||||
|
||||
return labels
|
||||
|
||||
def fetch_training_batch(self, annotations: List[Dict],
|
||||
context_window_minutes: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch training data for multiple annotations
|
||||
|
||||
Args:
|
||||
annotations: List of annotation metadata
|
||||
context_window_minutes: Minutes before/after entry to include
|
||||
|
||||
Returns:
|
||||
List of training data dictionaries
|
||||
"""
|
||||
training_data = []
|
||||
|
||||
logger.info(f"Fetching training batch for {len(annotations)} annotations")
|
||||
|
||||
for annotation in annotations:
|
||||
try:
|
||||
training_sample = self.fetch_training_data_for_annotation(
|
||||
annotation, context_window_minutes
|
||||
)
|
||||
|
||||
if training_sample:
|
||||
training_data.append(training_sample)
|
||||
else:
|
||||
logger.warning(f"Failed to fetch training data for annotation {annotation.get('annotation_id')}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing annotation {annotation.get('annotation_id')}: {e}")
|
||||
|
||||
logger.info(f"Successfully fetched training data for {len(training_data)}/{len(annotations)} annotations")
|
||||
return training_data
|
||||
|
||||
def get_training_statistics(self, training_data: List[Dict]) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics about training data
|
||||
|
||||
Args:
|
||||
training_data: List of training data samples
|
||||
|
||||
Returns:
|
||||
Dict with training statistics
|
||||
"""
|
||||
if not training_data:
|
||||
return {}
|
||||
|
||||
stats = {
|
||||
'total_samples': len(training_data),
|
||||
'symbols': {},
|
||||
'directions': {'LONG': 0, 'SHORT': 0},
|
||||
'avg_profit_loss': 0.0,
|
||||
'timeframes_available': set()
|
||||
}
|
||||
|
||||
total_pnl = 0.0
|
||||
|
||||
for sample in training_data:
|
||||
symbol = sample.get('symbol', 'UNKNOWN')
|
||||
direction = sample.get('expected_outcome', {}).get('direction', 'UNKNOWN')
|
||||
pnl = sample.get('expected_outcome', {}).get('profit_loss_pct', 0.0)
|
||||
|
||||
# Count symbols
|
||||
stats['symbols'][symbol] = stats['symbols'].get(symbol, 0) + 1
|
||||
|
||||
# Count directions
|
||||
if direction in stats['directions']:
|
||||
stats['directions'][direction] += 1
|
||||
|
||||
# Accumulate P&L
|
||||
total_pnl += pnl
|
||||
|
||||
# Check available timeframes
|
||||
market_state = sample.get('market_state', {})
|
||||
for key in market_state.keys():
|
||||
if key.startswith('ohlcv_'):
|
||||
stats['timeframes_available'].add(key.replace('ohlcv_', ''))
|
||||
|
||||
stats['avg_profit_loss'] = total_pnl / len(training_data)
|
||||
stats['timeframes_available'] = list(stats['timeframes_available'])
|
||||
|
||||
return stats
|
||||
Reference in New Issue
Block a user