wip old MISC fix

This commit is contained in:
Dobromir Popov
2025-12-08 16:56:37 +02:00
parent 81e7e6bfe6
commit 03888b6200
5 changed files with 719 additions and 343 deletions

View File

@@ -23,6 +23,7 @@ from pathlib import Path
import torch
import numpy as np
import pandas as pd
try:
import pytz
@@ -393,10 +394,23 @@ class RealTrainingAdapter:
# Get training config
training_config = test_case.get('training_config', {})
timeframes = training_config.get('timeframes', ['1s', '1m', '1h', '1d'])
# RESTORED: 200 candles per timeframe (memory leak fixed)
# With 5 timeframes * 200 candles = 1000 total positions
candles_per_timeframe = training_config.get('candles_per_timeframe', 200) # 200 candles per batch
# REQUIRED: All 3 timeframes (1m, 1h, 1d) with 600 candles each
timeframes = training_config.get('timeframes', ['1m', '1h', '1d'])
# REQUIRED: 600 candles per timeframe for transformer model
candles_per_timeframe = training_config.get('candles_per_timeframe', 600) # 600 candles per timeframe
# REQUIRED: 1m, 1h, 1d (all with 600 candles each)
# OPTIONAL: 1s (if available, include with 600 candles)
required_timeframes = ['1m', '1h', '1d']
optional_timeframes = ['1s'] # Include if available
# Ensure required timeframes are in the list
missing_tfs = [tf for tf in required_timeframes if tf not in timeframes]
if missing_tfs:
logger.warning(f" Missing required timeframes: {missing_tfs}, adding them...")
timeframes = list(set(timeframes + required_timeframes))
# Note: 1s is optional, don't add it if not present
# Determine secondary symbol based on primary symbol
# ETH/SOL -> BTC, BTC -> ETH
@@ -408,22 +422,30 @@ class RealTrainingAdapter:
logger.info(f" Candles per batch: {candles_per_timeframe}")
# Calculate time range based on candles needed
# For 600 candles at 1m = 600 minutes = 10 hours
# Use timeframe-specific windows for better efficiency
from datetime import timedelta
# Calculate time window for each timeframe to get 600 candles
# Calculate time window for each timeframe to get required candles
# Add 20% buffer to account for missing candles
buffer_multiplier = 1.2
time_windows = {
'1s': timedelta(seconds=candles_per_timeframe), # 600 seconds = 10 minutes
'1m': timedelta(minutes=candles_per_timeframe), # 600 minutes = 10 hours
'1h': timedelta(hours=candles_per_timeframe), # 600 hours = 25 days
'1d': timedelta(days=candles_per_timeframe) # 600 days = ~1.6 years
'1s': timedelta(seconds=int(candles_per_timeframe * buffer_multiplier)),
'1m': timedelta(minutes=int(candles_per_timeframe * buffer_multiplier)),
'1h': timedelta(hours=int(candles_per_timeframe * buffer_multiplier)),
'1d': timedelta(days=int(candles_per_timeframe * buffer_multiplier))
}
# For historical queries, we want data BEFORE the timestamp
# Use the largest window to ensure we have enough data for all timeframes
max_window = max(time_windows.values())
start_time = timestamp - max_window
end_time = timestamp
# REQUIRED: All required timeframes must have exactly 600 candles (no tolerance for missing data)
min_required_candles = candles_per_timeframe # Must have full 600 candles
required_timeframes = ['1m', '1h', '1d'] # All 3 timeframes are mandatory
optional_timeframes = ['1s'] # Include if available
# Fetch data for primary symbol (all timeframes) and secondary symbol (1m only)
market_state = {
'symbol': symbol,
@@ -439,53 +461,126 @@ class RealTrainingAdapter:
duckdb_storage = self.data_provider.duckdb_storage
# Fetch primary symbol data (all timeframes)
# REQUIRED: Must fetch all 3 timeframes (1m, 1h, 1d) with 600 candles each
logger.info(f" Fetching primary symbol data: {symbol}")
logger.info(f" REQUIRED timeframes: {required_timeframes} (each with {candles_per_timeframe} candles)")
fetched_timeframes = {} # Track which timeframes we successfully fetched
for timeframe in timeframes:
# Fetch required timeframes (1m, 1h, 1d) and optional 1s if present
if timeframe not in required_timeframes and timeframe not in optional_timeframes:
continue
df = None
limit = candles_per_timeframe # Always fetch 600 candles
limit = candles_per_timeframe
# Use timeframe-specific window for better efficiency
tf_window = time_windows.get(timeframe, max_window)
tf_start_time = timestamp - tf_window
tf_end_time = timestamp
# Try DuckDB storage first (has historical data)
# Use 'before' direction to get data BEFORE the timestamp
if duckdb_storage:
try:
df = duckdb_storage.get_ohlcv_data(
symbol=symbol,
timeframe=timeframe,
start_time=start_time,
end_time=end_time,
start_time=tf_start_time,
end_time=tf_end_time,
limit=limit,
direction='latest'
direction='before' # Get data BEFORE timestamp for historical training
)
if df is not None and not df.empty:
logger.debug(f" {timeframe}: {len(df)} candles from DuckDB (historical)")
logger.debug(f" {timeframe}: {len(df)} candles from DuckDB (historical, before {timestamp})")
except Exception as e:
logger.debug(f" {timeframe}: DuckDB query failed: {e}")
# Fallback to data_provider (might have cached data)
if df is None or df.empty:
# If DuckDB doesn't have enough data, try API with proper time range
if df is None or df.empty or len(df) < min_required_candles:
try:
# Try to fetch from API with historical time range
logger.info(f" {timeframe}: DuckDB insufficient ({len(df) if df is not None else 0} candles), fetching from API for timestamp {timestamp}...")
# Fetch historical data from API for the specific time range
api_df = self._fetch_historical_from_api(
symbol=symbol,
timeframe=timeframe,
start_time=tf_start_time,
end_time=tf_end_time,
limit=limit
)
if api_df is not None and not api_df.empty:
# Filter to data before timestamp (historical training needs data BEFORE the event)
try:
api_df = api_df[api_df.index <= tf_end_time]
# Take the most recent candles up to limit
api_df = api_df.tail(limit)
if len(api_df) >= min_required_candles:
df = api_df
logger.info(f" {timeframe}: {len(df)} candles from API (historical range: {tf_start_time} to {tf_end_time})")
# Store in DuckDB for future use
if duckdb_storage:
try:
duckdb_storage.store_ohlcv_data(symbol, timeframe, df)
logger.debug(f" {timeframe}: Stored {len(df)} candles in DuckDB for future use")
except Exception as e:
logger.debug(f" {timeframe}: Could not store in DuckDB: {e}")
else:
logger.warning(f" {timeframe}: API returned only {len(api_df)} candles after filtering (need {min_required_candles})")
except Exception as e:
logger.debug(f" {timeframe}: Could not filter API data: {e}")
# Use as-is if filtering fails
if len(api_df) >= min_required_candles:
df = api_df
logger.info(f" {timeframe}: {len(df)} candles from API (unfiltered)")
else:
logger.warning(f" {timeframe}: API fetch returned no data")
except Exception as e:
logger.warning(f" {timeframe}: API fetch failed: {e}")
import traceback
logger.debug(traceback.format_exc())
# Fallback to replay method
if df is None or df.empty or len(df) < min_required_candles:
try:
# Use get_historical_data_replay for time-specific data
replay_data = self.data_provider.get_historical_data_replay(
symbol=symbol,
start_time=start_time,
end_time=end_time,
start_time=tf_start_time,
end_time=tf_end_time,
timeframes=[timeframe]
)
df = replay_data.get(timeframe)
if df is not None and not df.empty:
logger.debug(f" {timeframe}: {len(df)} candles from replay")
replay_df = replay_data.get(timeframe)
if replay_df is not None and not replay_df.empty and len(replay_df) >= min_required_candles:
df = replay_df
logger.info(f" {timeframe}: {len(df)} candles from replay")
except Exception as e:
logger.debug(f" {timeframe}: Replay failed: {e}")
# Last resort: get latest data (not ideal but better than nothing)
if df is None or df.empty:
logger.warning(f" {timeframe}: No historical data found, using latest data as fallback")
df = self.data_provider.get_historical_data(
symbol=symbol,
timeframe=timeframe,
limit=limit # Use calculated limit
)
# Validate data quality before storing
if df is not None and not df.empty:
# Check minimum candle count
if len(df) < min_required_candles:
logger.warning(f" {symbol} {timeframe}: Only {len(df)} candles (need {min_required_candles}), skipping")
continue
# Validate data quality - check for NaN values
if df.isnull().any().any():
logger.warning(f" {symbol} {timeframe}: Contains NaN values, cleaning...")
df = df.dropna()
if len(df) < min_required_candles:
logger.warning(f" {symbol} {timeframe}: After cleaning, only {len(df)} candles, skipping")
continue
# Ensure we have required columns
required_cols = ['open', 'high', 'low', 'close', 'volume']
if not all(col in df.columns for col in required_cols):
logger.warning(f" {symbol} {timeframe}: Missing required columns, skipping")
continue
# Convert to dict format
market_state['timeframes'][timeframe] = {
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
@@ -495,67 +590,124 @@ class RealTrainingAdapter:
'close': df['close'].tolist(),
'volume': df['volume'].tolist()
}
logger.info(f" {symbol} {timeframe}: {len(df)} candles")
fetched_timeframes[timeframe] = len(df)
logger.info(f" {symbol} {timeframe}: {len(df)} candles [OK]")
else:
logger.warning(f" {symbol} {timeframe}: No data available")
logger.warning(f" {symbol} {timeframe}: No quality data available (need {min_required_candles} candles)")
# Fetch secondary symbol data (1m timeframe only, 600 candles)
# CRITICAL: Validate we have all required timeframes
missing_required = [tf for tf in required_timeframes if tf not in fetched_timeframes]
if missing_required:
logger.error(f" FAILED: Missing required timeframes: {missing_required}")
logger.error(f" Fetched: {list(fetched_timeframes.keys())}")
logger.error(f" Cannot proceed without all required timeframes")
return {} # Return empty dict to signal failure
# Fetch secondary symbol data (1m timeframe only)
logger.info(f" Fetching secondary symbol data: {secondary_symbol} (1m)")
secondary_df = None
# Try DuckDB first
# Use 1m-specific window
tf_window = time_windows.get('1m', max_window)
tf_start_time = timestamp - tf_window
tf_end_time = timestamp
# Try DuckDB first with 'before' direction
if duckdb_storage:
try:
secondary_df = duckdb_storage.get_ohlcv_data(
symbol=secondary_symbol,
timeframe='1m',
start_time=start_time,
end_time=end_time,
start_time=tf_start_time,
end_time=tf_end_time,
limit=candles_per_timeframe,
direction='latest'
direction='before' # Get data BEFORE timestamp
)
if secondary_df is not None and not secondary_df.empty:
logger.debug(f" {secondary_symbol} 1m: {len(secondary_df)} candles from DuckDB")
except Exception as e:
logger.debug(f" {secondary_symbol} 1m: DuckDB query failed: {e}")
# If DuckDB doesn't have enough, try API with historical time range
if secondary_df is None or secondary_df.empty or len(secondary_df) < min_required_candles:
try:
logger.info(f" {secondary_symbol} 1m: DuckDB insufficient ({len(secondary_df) if secondary_df is not None else 0} candles), fetching from API for timestamp {timestamp}...")
# Fetch historical data from API for the specific time range
api_df = self._fetch_historical_from_api(
symbol=secondary_symbol,
timeframe='1m',
start_time=tf_start_time,
end_time=tf_end_time,
limit=candles_per_timeframe
)
if api_df is not None and not api_df.empty:
# Filter to data before timestamp
try:
api_df = api_df[api_df.index <= tf_end_time].tail(candles_per_timeframe)
if len(api_df) >= min_required_candles:
secondary_df = api_df
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles from API (historical range)")
# Store in DuckDB for future use
if duckdb_storage:
try:
duckdb_storage.store_ohlcv_data(secondary_symbol, '1m', secondary_df)
logger.debug(f" {secondary_symbol} 1m: Stored in DuckDB for future use")
except Exception as e:
logger.debug(f" {secondary_symbol} 1m: Could not store in DuckDB: {e}")
except Exception as e:
logger.debug(f" {secondary_symbol} 1m: Could not filter API data: {e}")
if len(api_df) >= min_required_candles:
secondary_df = api_df
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles from API (unfiltered)")
else:
logger.warning(f" {secondary_symbol} 1m: API fetch returned no data")
except Exception as e:
logger.warning(f" {secondary_symbol} 1m: API fetch failed: {e}")
import traceback
logger.debug(traceback.format_exc())
# Fallback to replay
if secondary_df is None or secondary_df.empty:
if secondary_df is None or secondary_df.empty or len(secondary_df) < min_required_candles:
try:
replay_data = self.data_provider.get_historical_data_replay(
symbol=secondary_symbol,
start_time=start_time,
end_time=end_time,
start_time=tf_start_time,
end_time=tf_end_time,
timeframes=['1m']
)
secondary_df = replay_data.get('1m')
if secondary_df is not None and not secondary_df.empty:
logger.debug(f" {secondary_symbol} 1m: {len(secondary_df)} candles from replay")
replay_df = replay_data.get('1m')
if replay_df is not None and not replay_df.empty and len(replay_df) >= min_required_candles:
secondary_df = replay_df
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles from replay")
except Exception as e:
logger.debug(f" {secondary_symbol} 1m: Replay failed: {e}")
# Last resort: latest data
if secondary_df is None or secondary_df.empty:
logger.warning(f" {secondary_symbol} 1m: No historical data, using latest as fallback")
secondary_df = self.data_provider.get_historical_data(
symbol=secondary_symbol,
timeframe='1m',
limit=candles_per_timeframe
)
# Store secondary symbol data
# Validate and store secondary symbol data
if secondary_df is not None and not secondary_df.empty:
market_state['secondary_timeframes']['1m'] = {
'timestamps': secondary_df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': secondary_df['open'].tolist(),
'high': secondary_df['high'].tolist(),
'low': secondary_df['low'].tolist(),
'close': secondary_df['close'].tolist(),
'volume': secondary_df['volume'].tolist()
}
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles")
if len(secondary_df) < min_required_candles:
logger.warning(f" {secondary_symbol} 1m: Only {len(secondary_df)} candles (need {min_required_candles}), skipping")
elif secondary_df.isnull().any().any():
logger.warning(f" {secondary_symbol} 1m: Contains NaN values, skipping")
elif not all(col in secondary_df.columns for col in ['open', 'high', 'low', 'close', 'volume']):
logger.warning(f" {secondary_symbol} 1m: Missing required columns, skipping")
else:
# Store in the correct structure: secondary_timeframes[symbol][timeframe]
if secondary_symbol not in market_state['secondary_timeframes']:
market_state['secondary_timeframes'][secondary_symbol] = {}
market_state['secondary_timeframes'][secondary_symbol]['1m'] = {
'timestamps': secondary_df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': secondary_df['open'].tolist(),
'high': secondary_df['high'].tolist(),
'low': secondary_df['low'].tolist(),
'close': secondary_df['close'].tolist(),
'volume': secondary_df['volume'].tolist()
}
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles [OK]")
else:
logger.warning(f" {secondary_symbol} 1m: No data available")
logger.warning(f" {secondary_symbol} 1m: No quality data available (need {min_required_candles} candles)")
# Verify we have data
if market_state['timeframes']:
@@ -1190,6 +1342,200 @@ class RealTrainingAdapter:
state_size = agent.state_size if hasattr(agent, 'state_size') else 100
return [0.0] * state_size
def _fetch_historical_from_api(self, symbol: str, timeframe: str, start_time: datetime, end_time: datetime, limit: int) -> Optional[pd.DataFrame]:
"""
Fetch historical OHLCV data from exchange APIs for a specific time range
Args:
symbol: Trading symbol (e.g., 'ETH/USDT')
timeframe: Timeframe (e.g., '1m', '1h', '1d')
start_time: Start timestamp (UTC)
end_time: End timestamp (UTC)
limit: Maximum number of candles to fetch
Returns:
DataFrame with OHLCV data or None if fetch fails
"""
import pandas as pd
import requests
import time
from datetime import datetime, timezone
try:
# Handle 1s timeframe specially (not directly supported by most APIs)
if timeframe == '1s':
logger.debug(f"1s timeframe requested - will try to generate from ticks or skip")
# For 1s, we might need to generate from tick data or skip
# This is handled by the data provider's _generate_1s_candles_from_ticks
# For now, return None and let the caller handle it
return None
# Try Binance first (supports historical queries with startTime/endTime)
try:
binance_symbol = symbol.replace('/', '').upper()
# Convert timeframe for Binance
timeframe_map = {
'1m': '1m', '5m': '5m', '15m': '15m', '30m': '30m',
'1h': '1h', '4h': '4h', '1d': '1d'
}
binance_timeframe = timeframe_map.get(timeframe)
if not binance_timeframe:
logger.warning(f"Binance doesn't support timeframe {timeframe}")
return None
# Binance API klines endpoint with startTime and endTime
url = "https://api.binance.com/api/v3/klines"
# Convert timestamps to milliseconds
start_ms = int(start_time.timestamp() * 1000)
end_ms = int(end_time.timestamp() * 1000)
# Binance max is 1000 per request, so paginate if needed
all_data = []
current_start = start_ms
max_per_request = 1000
max_requests = 10 # Safety limit
request_count = 0
while current_start < end_ms and request_count < max_requests:
params = {
'symbol': binance_symbol,
'interval': binance_timeframe,
'startTime': current_start,
'endTime': end_ms,
'limit': min(max_per_request, limit - len(all_data))
}
logger.debug(f"Fetching from Binance: {symbol} {timeframe} batch {request_count + 1} (start: {current_start}, end: {end_ms})")
response = requests.get(url, params=params, timeout=10)
if response.status_code == 200:
data = response.json()
if data:
all_data.extend(data)
# Update current_start to the last candle's close_time + 1ms
if len(data) > 0:
last_close_time = data[-1][6] # close_time is at index 6
current_start = last_close_time + 1
else:
break
# If we got less than requested, we've reached the end
if len(data) < max_per_request:
break
# If we have enough data, stop
if len(all_data) >= limit:
break
else:
break
else:
logger.debug(f"Binance API returned {response.status_code} for {symbol} {timeframe}")
break
request_count += 1
# Small delay to avoid rate limiting
time.sleep(0.1)
if all_data:
# Convert to DataFrame
df = pd.DataFrame(all_data, columns=[
'timestamp', 'open', 'high', 'low', 'close', 'volume',
'close_time', 'quote_volume', 'trades', 'taker_buy_base',
'taker_buy_quote', 'ignore'
])
# Process columns
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
for col in ['open', 'high', 'low', 'close', 'volume']:
df[col] = df[col].astype(float)
# Keep only OHLCV columns and set timestamp as index
df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
df = df.set_index('timestamp')
df = df.sort_index()
# Remove duplicates and take last 'limit' candles
df = df[~df.index.duplicated(keep='last')]
df = df.tail(limit)
logger.info(f"Binance API: Fetched {len(df)} candles for {symbol} {timeframe} (historical, {request_count} requests)")
return df
else:
logger.warning(f"Binance API returned no data for {symbol} {timeframe}")
except Exception as e:
logger.debug(f"Binance fetch failed: {e}")
# Fallback to MEXC
try:
mexc_symbol = symbol.replace('/', '').upper()
timeframe_map = {
'1m': '1m', '5m': '5m', '15m': '15m', '30m': '30m',
'1h': '1h', '4h': '4h', '1d': '1d'
}
mexc_timeframe = timeframe_map.get(timeframe)
if not mexc_timeframe:
logger.warning(f"MEXC doesn't support timeframe {timeframe}")
return None
# MEXC API klines endpoint (may not support startTime/endTime, so fetch latest and filter)
url = "https://api.mexc.com/api/v3/klines"
params = {
'symbol': mexc_symbol,
'interval': mexc_timeframe,
'limit': min(limit * 2, 1000) # Fetch more to account for filtering
}
logger.debug(f"Fetching from MEXC: {symbol} {timeframe}")
response = requests.get(url, params=params, timeout=10)
if response.status_code == 200:
data = response.json()
if data:
df = pd.DataFrame(data, columns=[
'timestamp', 'open', 'high', 'low', 'close', 'volume',
'close_time', 'quote_volume'
])
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
for col in ['open', 'high', 'low', 'close', 'volume']:
df[col] = df[col].astype(float)
df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
df = df.set_index('timestamp')
df = df.sort_index()
# Filter to time range
df = df[(df.index >= start_time) & (df.index <= end_time)]
df = df.tail(limit)
if len(df) > 0:
logger.info(f"MEXC API: Fetched {len(df)} candles for {symbol} {timeframe} (historical, filtered)")
return df
else:
logger.warning(f"MEXC API: No candles in time range for {symbol} {timeframe}")
else:
logger.warning(f"MEXC API returned empty data for {symbol} {timeframe}")
else:
logger.debug(f"MEXC API returned {response.status_code} for {symbol} {timeframe}")
except Exception as e:
logger.debug(f"MEXC fetch failed: {e}")
return None
except Exception as e:
logger.error(f"Error fetching historical data from API: {e}")
import traceback
logger.debug(traceback.format_exc())
return None
def _extract_timeframe_data(self, tf_data: Dict, target_seq_len: int = 600) -> Optional[torch.Tensor]:
"""
Extract and normalize OHLCV data from a single timeframe
@@ -1215,28 +1561,22 @@ class RealTrainingAdapter:
if len(closes) == 0:
return None
# Take last target_seq_len candles or pad if needed
if len(closes) >= target_seq_len:
# Truncate to target length
opens = opens[-target_seq_len:]
highs = highs[-target_seq_len:]
lows = lows[-target_seq_len:]
closes = closes[-target_seq_len:]
volumes = volumes[-target_seq_len:]
else:
# Pad with last candle
pad_len = target_seq_len - len(closes)
last_open = opens[-1] if len(opens) > 0 else 0.0
last_high = highs[-1] if len(highs) > 0 else 0.0
last_low = lows[-1] if len(lows) > 0 else 0.0
last_close = closes[-1] if len(closes) > 0 else 0.0
last_volume = volumes[-1] if len(volumes) > 0 else 0.0
opens = np.pad(opens, (0, pad_len), constant_values=last_open)
highs = np.pad(highs, (0, pad_len), constant_values=last_high)
lows = np.pad(lows, (0, pad_len), constant_values=last_low)
closes = np.pad(closes, (0, pad_len), constant_values=last_close)
volumes = np.pad(volumes, (0, pad_len), constant_values=last_volume)
# REQUIRED: Must have exactly target_seq_len (600) candles, no padding allowed
if len(closes) < target_seq_len:
logger.warning(f"Insufficient candles: {len(closes)} < {target_seq_len} (required)")
return None
# Take last target_seq_len candles (exactly 600)
opens = opens[-target_seq_len:]
highs = highs[-target_seq_len:]
lows = lows[-target_seq_len:]
closes = closes[-target_seq_len:]
volumes = volumes[-target_seq_len:]
# Validate we have exactly target_seq_len
if len(closes) != target_seq_len:
logger.warning(f"Extraction failed: got {len(closes)} candles, need {target_seq_len}")
return None
# Stack OHLCV [seq_len, 5]
ohlcv = np.stack([opens, highs, lows, closes, volumes], axis=-1)
@@ -1369,45 +1709,62 @@ class RealTrainingAdapter:
timeframes = market_state.get('timeframes', {})
secondary_timeframes = market_state.get('secondary_timeframes', {})
# Target sequence length - RESTORED to 200 (memory leak fixed)
# With 5 timeframes * 200 candles = 1000 sequence positions
# Memory management fixes allow full sequence length
target_seq_len = 200 # Restored to original
for tf_data in timeframes.values():
if tf_data and 'close' in tf_data and len(tf_data['close']) > 0:
target_seq_len = min(len(tf_data['close']), 200) # Cap at 200
break
# REQUIRED: 600 candles per timeframe for transformer model
target_seq_len = 600 # Must be 600 candles for each timeframe
# Validate we have enough data in required timeframes (1m, 1h, 1d)
required_tfs = ['1m', '1h', '1d']
for tf_name in required_tfs:
if tf_name in timeframes:
tf_data = timeframes[tf_name]
if tf_data and 'close' in tf_data:
if len(tf_data['close']) < 600:
logger.warning(f"Required timeframe {tf_name} has only {len(tf_data['close'])} candles (need 600)")
return None
# Validate optional 1s timeframe if present (must have 600 candles if included)
if '1s' in timeframes:
tf_data = timeframes['1s']
if tf_data and 'close' in tf_data:
if len(tf_data['close']) < 600:
logger.warning(f"Optional timeframe 1s has only {len(tf_data['close'])} candles (need 600), excluding it")
# Remove 1s from timeframes if insufficient
timeframes = {k: v for k, v in timeframes.items() if k != '1s'}
# Extract each timeframe (returns tuple: (tensor, norm_params) or None)
# Store normalization parameters for each timeframe
norm_params_dict = {}
# OPTIONAL: Extract 1s timeframe if available (must have 600 candles if included)
result_1s = self._extract_timeframe_data(timeframes.get('1s', {}), target_seq_len) if '1s' in timeframes else None
if result_1s:
price_data_1s, norm_params_dict['1s'] = result_1s
logger.debug(f"Included optional 1s timeframe with {result_1s[0].shape[1]} candles")
else:
# Don't fail on missing 1s data, it's often unavailable in annotations
# 1s is optional - don't fail if missing, but log it
price_data_1s = None
logger.debug("Optional 1s timeframe not available (this is OK)")
# REQUIRED: Extract all 3 timeframes (1m, 1h, 1d) with exactly 600 candles each
result_1m = self._extract_timeframe_data(timeframes.get('1m', {}), target_seq_len) if '1m' in timeframes else None
if result_1m:
price_data_1m, norm_params_dict['1m'] = result_1m
else:
# Warning: 1m data is critical
logger.warning(f"Missing 1m data for transformer batch (sample: {training_sample.get('test_case_id')})")
logger.warning(f"Missing or insufficient 1m data for transformer batch (sample: {training_sample.get('test_case_id')})")
return None
result_1h = self._extract_timeframe_data(timeframes.get('1h', {}), target_seq_len) if '1h' in timeframes else None
if result_1h:
price_data_1h, norm_params_dict['1h'] = result_1h
else:
price_data_1h = None
logger.warning(f"Missing or insufficient 1h data for transformer batch (sample: {training_sample.get('test_case_id')})")
return None
result_1d = self._extract_timeframe_data(timeframes.get('1d', {}), target_seq_len) if '1d' in timeframes else None
if result_1d:
price_data_1d, norm_params_dict['1d'] = result_1d
else:
price_data_1d = None
logger.warning(f"Missing or insufficient 1d data for transformer batch (sample: {training_sample.get('test_case_id')})")
return None
# Extract BTC reference data
btc_data_1m = None
@@ -1416,12 +1773,41 @@ class RealTrainingAdapter:
if result_btc:
btc_data_1m, norm_params_dict['btc'] = result_btc
# Ensure at least one timeframe is available
# Check if all are None (can't use any() with tensors)
if price_data_1s is None and price_data_1m is None and price_data_1h is None and price_data_1d is None:
logger.warning("No price data available in any timeframe")
# CRITICAL: Ensure ALL required timeframes are available (1m, 1h, 1d)
# REQUIRED: 1m, 1h, 1d (each with 600 candles)
# OPTIONAL: 1s (if available, include with 600 candles)
required_timeframes_present = (
price_data_1m is not None and
price_data_1h is not None and
price_data_1d is not None
)
if not required_timeframes_present:
missing = []
if price_data_1m is None:
missing.append('1m')
if price_data_1h is None:
missing.append('1h')
if price_data_1d is None:
missing.append('1d')
logger.warning(f"Missing required timeframes: {missing}. Need all 3: 1m, 1h, 1d")
return None
# Validate each required timeframe has correct shape
for tf_name, tf_data in [('1m', price_data_1m), ('1h', price_data_1h), ('1d', price_data_1d)]:
if tf_data is not None:
shape = tf_data.shape
if len(shape) != 3 or shape[1] < 600:
logger.warning(f"Timeframe {tf_name} has invalid shape {shape} (need [1, 600, 5])")
return None
# Validate optional 1s timeframe if present
if price_data_1s is not None:
shape = price_data_1s.shape
if len(shape) != 3 or shape[1] < 600:
logger.warning(f"Optional timeframe 1s has invalid shape {shape}, removing it")
price_data_1s = None
# Get reference timeframe for other features (prefer 1m, fallback to any available)
ref_data = price_data_1m if price_data_1m is not None else (
price_data_1h if price_data_1h is not None else (
@@ -1928,6 +2314,46 @@ class RealTrainingAdapter:
for i, data in enumerate(training_data):
batch = self._convert_annotation_to_transformer_batch(data)
if batch is not None:
# CRITICAL: Validate that ALL required timeframes are present
# REQUIRED: 1m, 1h, 1d (each with 600 candles)
# OPTIONAL: 1s (if available, include with 600 candles)
required_tf_keys = ['price_data_1m', 'price_data_1h', 'price_data_1d']
optional_tf_keys = ['price_data_1s']
missing_tfs = [tf for tf in required_tf_keys if batch.get(tf) is None]
if missing_tfs:
logger.warning(f" Skipping sample {i+1}: Missing required timeframes: {missing_tfs}")
continue
# Validate each required timeframe has correct shape [1, 600, 5]
for tf_key in required_tf_keys:
tf_data = batch.get(tf_key)
if tf_data is not None:
if not isinstance(tf_data, torch.Tensor):
logger.warning(f" Skipping sample {i+1}: {tf_key} is not a tensor")
missing_tfs.append(tf_key)
break
shape = tf_data.shape
if len(shape) != 3 or shape[1] < 600: # Must be [batch, seq_len, features] with seq_len >= 600
logger.warning(f" Skipping sample {i+1}: {tf_key} has invalid shape {shape} (need [1, 600, 5])")
missing_tfs.append(tf_key)
break
if missing_tfs:
continue
# Validate optional 1s timeframe if present
if batch.get('price_data_1s') is not None:
tf_data = batch.get('price_data_1s')
if isinstance(tf_data, torch.Tensor):
shape = tf_data.shape
if len(shape) != 3 or shape[1] < 600:
logger.warning(f" Sample {i+1}: price_data_1s has invalid shape {shape}, removing it")
batch['price_data_1s'] = None
logger.debug(f" Sample {i+1}: All required timeframes present (1m, 1h, 1d), 1s={'present' if batch.get('price_data_1s') is not None else 'not available'}")
# Move batch to GPU immediately with pinned memory for faster transfer
if use_gpu:
batch_gpu = {}