stored
This commit is contained in:
@@ -7,10 +7,11 @@ Handles storage, retrieval, and test case generation from manual trade annotatio
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
import logging
|
||||
import pytz
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -205,49 +206,20 @@ class AnnotationManager:
|
||||
}
|
||||
}
|
||||
|
||||
# Populate market state with ±5 minutes of data for negative examples
|
||||
# Populate market state with ±5 minutes of data for training
|
||||
if data_provider:
|
||||
try:
|
||||
entry_time = datetime.fromisoformat(annotation.entry['timestamp'].replace('Z', '+00:00'))
|
||||
exit_time = datetime.fromisoformat(annotation.exit['timestamp'].replace('Z', '+00:00'))
|
||||
|
||||
# Calculate time window: ±5 minutes around entry
|
||||
time_window_before = timedelta(minutes=5)
|
||||
time_window_after = timedelta(minutes=5)
|
||||
logger.info(f"Fetching market state for {annotation.symbol} at {entry_time} (±5min around entry)")
|
||||
|
||||
start_time = entry_time - time_window_before
|
||||
end_time = entry_time + time_window_after
|
||||
|
||||
logger.info(f"Fetching market data from {start_time} to {end_time} (±5min around entry)")
|
||||
|
||||
# Fetch OHLCV data for all timeframes
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
market_state = {}
|
||||
|
||||
for tf in timeframes:
|
||||
# Get data for the time window
|
||||
df = data_provider.get_historical_data(
|
||||
symbol=annotation.symbol,
|
||||
timeframe=tf,
|
||||
limit=1000 # Get enough data to cover ±5 minutes
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Filter to time window
|
||||
df_window = df[(df.index >= start_time) & (df.index <= end_time)]
|
||||
|
||||
if not df_window.empty:
|
||||
# Convert to list format
|
||||
market_state[f'ohlcv_{tf}'] = {
|
||||
'timestamps': df_window.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
'open': df_window['open'].tolist(),
|
||||
'high': df_window['high'].tolist(),
|
||||
'low': df_window['low'].tolist(),
|
||||
'close': df_window['close'].tolist(),
|
||||
'volume': df_window['volume'].tolist()
|
||||
}
|
||||
|
||||
logger.info(f" {tf}: {len(df_window)} candles in ±5min window")
|
||||
# Use the new data provider method to get market state at the entry time
|
||||
market_state = data_provider.get_market_state_at_time(
|
||||
symbol=annotation.symbol,
|
||||
timestamp=entry_time,
|
||||
context_window_minutes=5
|
||||
)
|
||||
|
||||
# Add training labels for each timestamp
|
||||
# This helps model learn WHERE to signal and WHERE NOT to signal
|
||||
@@ -330,6 +302,9 @@ class AnnotationManager:
|
||||
for ts_str in timestamps:
|
||||
try:
|
||||
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
|
||||
# Make timezone-aware to match entry_time
|
||||
if ts.tzinfo is None:
|
||||
ts = pytz.UTC.localize(ts)
|
||||
|
||||
# Determine label based on position relative to entry/exit
|
||||
if abs((ts - entry_time).total_seconds()) < 60: # Within 1 minute of entry
|
||||
|
||||
Reference in New Issue
Block a user