load market data for training/inference
This commit is contained in:
@@ -28,6 +28,73 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_timestamp_to_utc(timestamp_str: str) -> datetime:
|
||||
"""
|
||||
Unified timestamp parser that handles all formats and ensures UTC timezone.
|
||||
|
||||
Handles:
|
||||
- ISO format with timezone: '2025-10-27T14:00:00+00:00'
|
||||
- ISO format with Z: '2025-10-27T14:00:00Z'
|
||||
- Space-separated with seconds: '2025-10-27 14:00:00'
|
||||
- Space-separated without seconds: '2025-10-27 14:00'
|
||||
|
||||
Args:
|
||||
timestamp_str: Timestamp string in various formats
|
||||
|
||||
Returns:
|
||||
Timezone-aware datetime object in UTC
|
||||
|
||||
Raises:
|
||||
ValueError: If timestamp cannot be parsed
|
||||
"""
|
||||
if not timestamp_str:
|
||||
raise ValueError("Empty timestamp string")
|
||||
|
||||
# Try ISO format first (handles T separator and timezone info)
|
||||
if 'T' in timestamp_str or '+' in timestamp_str:
|
||||
try:
|
||||
# Handle 'Z' suffix (Zulu time = UTC)
|
||||
if timestamp_str.endswith('Z'):
|
||||
timestamp_str = timestamp_str[:-1] + '+00:00'
|
||||
return datetime.fromisoformat(timestamp_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try space-separated formats
|
||||
# Replace space with T for fromisoformat compatibility
|
||||
if ' ' in timestamp_str:
|
||||
try:
|
||||
# Try parsing with fromisoformat after converting space to T
|
||||
dt = datetime.fromisoformat(timestamp_str.replace(' ', 'T'))
|
||||
# Make timezone-aware if naive
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try explicit format parsing as fallback
|
||||
formats = [
|
||||
'%Y-%m-%d %H:%M:%S', # With seconds
|
||||
'%Y-%m-%d %H:%M', # Without seconds
|
||||
'%Y-%m-%dT%H:%M:%S', # ISO without timezone
|
||||
'%Y-%m-%dT%H:%M', # ISO without seconds or timezone
|
||||
]
|
||||
|
||||
for fmt in formats:
|
||||
try:
|
||||
dt = datetime.strptime(timestamp_str, fmt)
|
||||
# Make timezone-aware
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# If all parsing attempts fail
|
||||
raise ValueError(f"Could not parse timestamp: '{timestamp_str}'")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingSession:
|
||||
"""Real training session tracking"""
|
||||
@@ -214,7 +281,10 @@ class RealTrainingAdapter:
|
||||
|
||||
def _fetch_market_state_for_test_case(self, test_case: Dict) -> Dict:
|
||||
"""
|
||||
Fetch market state dynamically for a test case
|
||||
Fetch market state dynamically for a test case from DuckDB storage
|
||||
|
||||
This fetches HISTORICAL data at the specific timestamp from the annotation,
|
||||
not current/latest data.
|
||||
|
||||
Args:
|
||||
test_case: Test case dictionary with timestamp, symbol, etc.
|
||||
@@ -234,17 +304,32 @@ class RealTrainingAdapter:
|
||||
logger.warning("No timestamp in test case")
|
||||
return {}
|
||||
|
||||
# Parse timestamp
|
||||
from datetime import datetime
|
||||
timestamp = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00'))
|
||||
# Parse timestamp using unified parser
|
||||
try:
|
||||
timestamp = parse_timestamp_to_utc(timestamp_str)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse timestamp '{timestamp_str}': {e}")
|
||||
return {}
|
||||
|
||||
# Get training config
|
||||
training_config = test_case.get('training_config', {})
|
||||
timeframes = training_config.get('timeframes', ['1s', '1m', '1h', '1d'])
|
||||
context_window = training_config.get('context_window_minutes', 5)
|
||||
negative_samples_window = training_config.get('negative_samples_window', 15) # ±15 candles
|
||||
|
||||
logger.info(f" Fetching market state for {symbol} at {timestamp}")
|
||||
logger.info(f" Timeframes: {timeframes}, Context window: {context_window} minutes")
|
||||
# Calculate extended time range to include negative sampling window
|
||||
# For 1m timeframe: ±15 candles = ±15 minutes
|
||||
# Add buffer to ensure we have enough data
|
||||
extended_window_minutes = max(context_window, negative_samples_window + 10)
|
||||
|
||||
logger.info(f" Fetching HISTORICAL market state for {symbol} at {timestamp}")
|
||||
logger.info(f" Timeframes: {timeframes}, Extended window: ±{extended_window_minutes} minutes")
|
||||
logger.info(f" (Includes ±{negative_samples_window} candles for negative sampling)")
|
||||
|
||||
# Calculate time range for extended context window
|
||||
from datetime import timedelta
|
||||
start_time = timestamp - timedelta(minutes=extended_window_minutes)
|
||||
end_time = timestamp + timedelta(minutes=extended_window_minutes)
|
||||
|
||||
# Fetch data for each timeframe
|
||||
market_state = {
|
||||
@@ -253,14 +338,67 @@ class RealTrainingAdapter:
|
||||
'timeframes': {}
|
||||
}
|
||||
|
||||
# Try to get data from DuckDB storage first (historical data)
|
||||
duckdb_storage = None
|
||||
if hasattr(self.data_provider, 'duckdb_storage'):
|
||||
duckdb_storage = self.data_provider.duckdb_storage
|
||||
|
||||
for timeframe in timeframes:
|
||||
# Get historical data around the timestamp
|
||||
# For now, just get the latest data (we can improve this later)
|
||||
df = self.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=100 # Get 100 candles for context
|
||||
)
|
||||
df = None
|
||||
|
||||
# Calculate appropriate limit based on timeframe and window
|
||||
# We want enough candles to cover the extended window plus negative samples
|
||||
if timeframe == '1s':
|
||||
limit = extended_window_minutes * 60 * 2 + 100 # 2x for safety + buffer
|
||||
elif timeframe == '1m':
|
||||
limit = extended_window_minutes * 2 + 50 # 2x for safety + buffer
|
||||
elif timeframe == '1h':
|
||||
limit = max(200, extended_window_minutes // 30) # At least 200 candles
|
||||
elif timeframe == '1d':
|
||||
limit = 200 # Fixed for daily
|
||||
else:
|
||||
limit = 300
|
||||
|
||||
# Try DuckDB storage first (has historical data)
|
||||
if duckdb_storage:
|
||||
try:
|
||||
df = duckdb_storage.get_ohlcv_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
direction='latest'
|
||||
)
|
||||
if df is not None and not df.empty:
|
||||
logger.debug(f" {timeframe}: {len(df)} candles from DuckDB (historical)")
|
||||
except Exception as e:
|
||||
logger.debug(f" {timeframe}: DuckDB query failed: {e}")
|
||||
|
||||
# Fallback to data_provider (might have cached data)
|
||||
if df is None or df.empty:
|
||||
try:
|
||||
# Use get_historical_data_replay for time-specific data
|
||||
replay_data = self.data_provider.get_historical_data_replay(
|
||||
symbol=symbol,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
timeframes=[timeframe]
|
||||
)
|
||||
df = replay_data.get(timeframe)
|
||||
if df is not None and not df.empty:
|
||||
logger.debug(f" {timeframe}: {len(df)} candles from replay")
|
||||
except Exception as e:
|
||||
logger.debug(f" {timeframe}: Replay failed: {e}")
|
||||
|
||||
# Last resort: get latest data (not ideal but better than nothing)
|
||||
if df is None or df.empty:
|
||||
logger.warning(f" {timeframe}: No historical data found, using latest data as fallback")
|
||||
df = self.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=limit # Use calculated limit
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Convert to dict format
|
||||
@@ -272,15 +410,15 @@ class RealTrainingAdapter:
|
||||
'close': df['close'].tolist(),
|
||||
'volume': df['volume'].tolist()
|
||||
}
|
||||
logger.debug(f" {timeframe}: {len(df)} candles")
|
||||
logger.debug(f" {timeframe}: {len(df)} candles stored")
|
||||
else:
|
||||
logger.warning(f" {timeframe}: No data")
|
||||
logger.warning(f" {timeframe}: No data available")
|
||||
|
||||
if market_state['timeframes']:
|
||||
logger.info(f" Fetched market state with {len(market_state['timeframes'])} timeframes")
|
||||
return market_state
|
||||
else:
|
||||
logger.warning(f" No market data fetched")
|
||||
logger.warning(f" No market data fetched for any timeframe")
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
@@ -441,23 +579,9 @@ class RealTrainingAdapter:
|
||||
logger.debug(" No holding period, skipping HOLD samples")
|
||||
return hold_samples
|
||||
|
||||
# Parse entry timestamp - handle multiple formats
|
||||
# Parse entry timestamp using unified parser
|
||||
try:
|
||||
if 'T' in entry_timestamp:
|
||||
entry_time = datetime.fromisoformat(entry_timestamp.replace('Z', '+00:00'))
|
||||
else:
|
||||
# Try with seconds first, then without
|
||||
try:
|
||||
entry_time = datetime.strptime(entry_timestamp, '%Y-%m-%d %H:%M:%S')
|
||||
except ValueError:
|
||||
# Try without seconds
|
||||
entry_time = datetime.strptime(entry_timestamp, '%Y-%m-%d %H:%M')
|
||||
|
||||
# Make timezone-aware
|
||||
if pytz:
|
||||
entry_time = entry_time.replace(tzinfo=pytz.UTC)
|
||||
else:
|
||||
entry_time = entry_time.replace(tzinfo=timezone.utc)
|
||||
entry_time = parse_timestamp_to_utc(entry_timestamp)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse entry timestamp '{entry_timestamp}': {e}")
|
||||
return hold_samples
|
||||
@@ -473,18 +597,9 @@ class RealTrainingAdapter:
|
||||
|
||||
# Find all candles between entry and exit
|
||||
for idx, ts_str in enumerate(timestamps):
|
||||
# Parse timestamp and ensure it's timezone-aware
|
||||
# Parse timestamp using unified parser
|
||||
try:
|
||||
if 'T' in ts_str:
|
||||
ts = datetime.fromisoformat(ts_str.replace('Z', '+00:00'))
|
||||
else:
|
||||
ts = datetime.fromisoformat(ts_str.replace(' ', 'T'))
|
||||
# Make timezone-aware if it's naive
|
||||
if ts.tzinfo is None:
|
||||
if pytz:
|
||||
ts = ts.replace(tzinfo=pytz.UTC)
|
||||
else:
|
||||
ts = ts.replace(tzinfo=timezone.utc)
|
||||
ts = parse_timestamp_to_utc(ts_str)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not parse timestamp '{ts_str}': {e}")
|
||||
continue
|
||||
@@ -550,23 +665,9 @@ class RealTrainingAdapter:
|
||||
# Find the index of the signal timestamp
|
||||
from datetime import datetime
|
||||
|
||||
# Parse signal timestamp - handle different formats
|
||||
# Parse signal timestamp using unified parser
|
||||
try:
|
||||
if 'T' in signal_timestamp:
|
||||
signal_time = datetime.fromisoformat(signal_timestamp.replace('Z', '+00:00'))
|
||||
else:
|
||||
# Try with seconds first, then without
|
||||
try:
|
||||
signal_time = datetime.strptime(signal_timestamp, '%Y-%m-%d %H:%M:%S')
|
||||
except ValueError:
|
||||
# Try without seconds
|
||||
signal_time = datetime.strptime(signal_timestamp, '%Y-%m-%d %H:%M')
|
||||
|
||||
# Make timezone-aware
|
||||
if pytz:
|
||||
signal_time = signal_time.replace(tzinfo=pytz.UTC)
|
||||
else:
|
||||
signal_time = signal_time.replace(tzinfo=timezone.utc)
|
||||
signal_time = parse_timestamp_to_utc(signal_timestamp)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse signal timestamp '{signal_timestamp}': {e}")
|
||||
return negative_samples
|
||||
@@ -574,22 +675,8 @@ class RealTrainingAdapter:
|
||||
signal_index = None
|
||||
for idx, ts_str in enumerate(timestamps):
|
||||
try:
|
||||
# Parse timestamp from market data - handle multiple formats
|
||||
if 'T' in ts_str:
|
||||
ts = datetime.fromisoformat(ts_str.replace('Z', '+00:00'))
|
||||
else:
|
||||
# Try with seconds first, then without
|
||||
try:
|
||||
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M:%S')
|
||||
except ValueError:
|
||||
ts = datetime.strptime(ts_str, '%Y-%m-%d %H:%M')
|
||||
|
||||
# Make timezone-aware if naive
|
||||
if ts.tzinfo is None:
|
||||
if pytz:
|
||||
ts = ts.replace(tzinfo=pytz.UTC)
|
||||
else:
|
||||
ts = ts.replace(tzinfo=timezone.utc)
|
||||
# Parse timestamp using unified parser
|
||||
ts = parse_timestamp_to_utc(ts_str)
|
||||
|
||||
# Match within 1 minute
|
||||
if abs((ts - signal_time).total_seconds()) < 60:
|
||||
@@ -1147,8 +1234,9 @@ class RealTrainingAdapter:
|
||||
future_prices = torch.tensor([future_price], dtype=torch.float32)
|
||||
|
||||
# Trade success (1.0 if profitable, 0.0 otherwise)
|
||||
# Shape must be [batch_size, 1] to match confidence head output
|
||||
profit_loss_pct = training_sample.get('profit_loss_pct', 0.0)
|
||||
trade_success = torch.tensor([1.0 if profit_loss_pct > 0 else 0.0], dtype=torch.float32)
|
||||
trade_success = torch.tensor([[1.0 if profit_loss_pct > 0 else 0.0]], dtype=torch.float32)
|
||||
|
||||
# Return batch dictionary
|
||||
batch = {
|
||||
|
||||
Reference in New Issue
Block a user