wip old MISC fix

This commit is contained in:
Dobromir Popov
2025-12-08 16:56:37 +02:00
parent 81e7e6bfe6
commit 03888b6200
5 changed files with 719 additions and 343 deletions

View File

@@ -23,6 +23,7 @@ from pathlib import Path
import torch
import numpy as np
import pandas as pd
try:
import pytz
@@ -393,10 +394,23 @@ class RealTrainingAdapter:
# Get training config
training_config = test_case.get('training_config', {})
timeframes = training_config.get('timeframes', ['1s', '1m', '1h', '1d'])
# RESTORED: 200 candles per timeframe (memory leak fixed)
# With 5 timeframes * 200 candles = 1000 total positions
candles_per_timeframe = training_config.get('candles_per_timeframe', 200) # 200 candles per batch
# REQUIRED: All 3 timeframes (1m, 1h, 1d) with 600 candles each
timeframes = training_config.get('timeframes', ['1m', '1h', '1d'])
# REQUIRED: 600 candles per timeframe for transformer model
candles_per_timeframe = training_config.get('candles_per_timeframe', 600) # 600 candles per timeframe
# REQUIRED: 1m, 1h, 1d (all with 600 candles each)
# OPTIONAL: 1s (if available, include with 600 candles)
required_timeframes = ['1m', '1h', '1d']
optional_timeframes = ['1s'] # Include if available
# Ensure required timeframes are in the list
missing_tfs = [tf for tf in required_timeframes if tf not in timeframes]
if missing_tfs:
logger.warning(f" Missing required timeframes: {missing_tfs}, adding them...")
timeframes = list(set(timeframes + required_timeframes))
# Note: 1s is optional, don't add it if not present
# Determine secondary symbol based on primary symbol
# ETH/SOL -> BTC, BTC -> ETH
@@ -408,22 +422,30 @@ class RealTrainingAdapter:
logger.info(f" Candles per batch: {candles_per_timeframe}")
# Calculate time range based on candles needed
# For 600 candles at 1m = 600 minutes = 10 hours
# Use timeframe-specific windows for better efficiency
from datetime import timedelta
# Calculate time window for each timeframe to get 600 candles
# Calculate time window for each timeframe to get required candles
# Add 20% buffer to account for missing candles
buffer_multiplier = 1.2
time_windows = {
'1s': timedelta(seconds=candles_per_timeframe), # 600 seconds = 10 minutes
'1m': timedelta(minutes=candles_per_timeframe), # 600 minutes = 10 hours
'1h': timedelta(hours=candles_per_timeframe), # 600 hours = 25 days
'1d': timedelta(days=candles_per_timeframe) # 600 days = ~1.6 years
'1s': timedelta(seconds=int(candles_per_timeframe * buffer_multiplier)),
'1m': timedelta(minutes=int(candles_per_timeframe * buffer_multiplier)),
'1h': timedelta(hours=int(candles_per_timeframe * buffer_multiplier)),
'1d': timedelta(days=int(candles_per_timeframe * buffer_multiplier))
}
# For historical queries, we want data BEFORE the timestamp
# Use the largest window to ensure we have enough data for all timeframes
max_window = max(time_windows.values())
start_time = timestamp - max_window
end_time = timestamp
# REQUIRED: All required timeframes must have exactly 600 candles (no tolerance for missing data)
min_required_candles = candles_per_timeframe # Must have full 600 candles
required_timeframes = ['1m', '1h', '1d'] # All 3 timeframes are mandatory
optional_timeframes = ['1s'] # Include if available
# Fetch data for primary symbol (all timeframes) and secondary symbol (1m only)
market_state = {
'symbol': symbol,
@@ -439,53 +461,126 @@ class RealTrainingAdapter:
duckdb_storage = self.data_provider.duckdb_storage
# Fetch primary symbol data (all timeframes)
# REQUIRED: Must fetch all 3 timeframes (1m, 1h, 1d) with 600 candles each
logger.info(f" Fetching primary symbol data: {symbol}")
logger.info(f" REQUIRED timeframes: {required_timeframes} (each with {candles_per_timeframe} candles)")
fetched_timeframes = {} # Track which timeframes we successfully fetched
for timeframe in timeframes:
# Fetch required timeframes (1m, 1h, 1d) and optional 1s if present
if timeframe not in required_timeframes and timeframe not in optional_timeframes:
continue
df = None
limit = candles_per_timeframe # Always fetch 600 candles
limit = candles_per_timeframe
# Use timeframe-specific window for better efficiency
tf_window = time_windows.get(timeframe, max_window)
tf_start_time = timestamp - tf_window
tf_end_time = timestamp
# Try DuckDB storage first (has historical data)
# Use 'before' direction to get data BEFORE the timestamp
if duckdb_storage:
try:
df = duckdb_storage.get_ohlcv_data(
symbol=symbol,
timeframe=timeframe,
start_time=start_time,
end_time=end_time,
start_time=tf_start_time,
end_time=tf_end_time,
limit=limit,
direction='latest'
direction='before' # Get data BEFORE timestamp for historical training
)
if df is not None and not df.empty:
logger.debug(f" {timeframe}: {len(df)} candles from DuckDB (historical)")
logger.debug(f" {timeframe}: {len(df)} candles from DuckDB (historical, before {timestamp})")
except Exception as e:
logger.debug(f" {timeframe}: DuckDB query failed: {e}")
# Fallback to data_provider (might have cached data)
if df is None or df.empty:
# If DuckDB doesn't have enough data, try API with proper time range
if df is None or df.empty or len(df) < min_required_candles:
try:
# Try to fetch from API with historical time range
logger.info(f" {timeframe}: DuckDB insufficient ({len(df) if df is not None else 0} candles), fetching from API for timestamp {timestamp}...")
# Fetch historical data from API for the specific time range
api_df = self._fetch_historical_from_api(
symbol=symbol,
timeframe=timeframe,
start_time=tf_start_time,
end_time=tf_end_time,
limit=limit
)
if api_df is not None and not api_df.empty:
# Filter to data before timestamp (historical training needs data BEFORE the event)
try:
api_df = api_df[api_df.index <= tf_end_time]
# Take the most recent candles up to limit
api_df = api_df.tail(limit)
if len(api_df) >= min_required_candles:
df = api_df
logger.info(f" {timeframe}: {len(df)} candles from API (historical range: {tf_start_time} to {tf_end_time})")
# Store in DuckDB for future use
if duckdb_storage:
try:
duckdb_storage.store_ohlcv_data(symbol, timeframe, df)
logger.debug(f" {timeframe}: Stored {len(df)} candles in DuckDB for future use")
except Exception as e:
logger.debug(f" {timeframe}: Could not store in DuckDB: {e}")
else:
logger.warning(f" {timeframe}: API returned only {len(api_df)} candles after filtering (need {min_required_candles})")
except Exception as e:
logger.debug(f" {timeframe}: Could not filter API data: {e}")
# Use as-is if filtering fails
if len(api_df) >= min_required_candles:
df = api_df
logger.info(f" {timeframe}: {len(df)} candles from API (unfiltered)")
else:
logger.warning(f" {timeframe}: API fetch returned no data")
except Exception as e:
logger.warning(f" {timeframe}: API fetch failed: {e}")
import traceback
logger.debug(traceback.format_exc())
# Fallback to replay method
if df is None or df.empty or len(df) < min_required_candles:
try:
# Use get_historical_data_replay for time-specific data
replay_data = self.data_provider.get_historical_data_replay(
symbol=symbol,
start_time=start_time,
end_time=end_time,
start_time=tf_start_time,
end_time=tf_end_time,
timeframes=[timeframe]
)
df = replay_data.get(timeframe)
if df is not None and not df.empty:
logger.debug(f" {timeframe}: {len(df)} candles from replay")
replay_df = replay_data.get(timeframe)
if replay_df is not None and not replay_df.empty and len(replay_df) >= min_required_candles:
df = replay_df
logger.info(f" {timeframe}: {len(df)} candles from replay")
except Exception as e:
logger.debug(f" {timeframe}: Replay failed: {e}")
# Last resort: get latest data (not ideal but better than nothing)
if df is None or df.empty:
logger.warning(f" {timeframe}: No historical data found, using latest data as fallback")
df = self.data_provider.get_historical_data(
symbol=symbol,
timeframe=timeframe,
limit=limit # Use calculated limit
)
# Validate data quality before storing
if df is not None and not df.empty:
# Check minimum candle count
if len(df) < min_required_candles:
logger.warning(f" {symbol} {timeframe}: Only {len(df)} candles (need {min_required_candles}), skipping")
continue
# Validate data quality - check for NaN values
if df.isnull().any().any():
logger.warning(f" {symbol} {timeframe}: Contains NaN values, cleaning...")
df = df.dropna()
if len(df) < min_required_candles:
logger.warning(f" {symbol} {timeframe}: After cleaning, only {len(df)} candles, skipping")
continue
# Ensure we have required columns
required_cols = ['open', 'high', 'low', 'close', 'volume']
if not all(col in df.columns for col in required_cols):
logger.warning(f" {symbol} {timeframe}: Missing required columns, skipping")
continue
# Convert to dict format
market_state['timeframes'][timeframe] = {
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
@@ -495,67 +590,124 @@ class RealTrainingAdapter:
'close': df['close'].tolist(),
'volume': df['volume'].tolist()
}
logger.info(f" {symbol} {timeframe}: {len(df)} candles")
fetched_timeframes[timeframe] = len(df)
logger.info(f" {symbol} {timeframe}: {len(df)} candles [OK]")
else:
logger.warning(f" {symbol} {timeframe}: No data available")
logger.warning(f" {symbol} {timeframe}: No quality data available (need {min_required_candles} candles)")
# Fetch secondary symbol data (1m timeframe only, 600 candles)
# CRITICAL: Validate we have all required timeframes
missing_required = [tf for tf in required_timeframes if tf not in fetched_timeframes]
if missing_required:
logger.error(f" FAILED: Missing required timeframes: {missing_required}")
logger.error(f" Fetched: {list(fetched_timeframes.keys())}")
logger.error(f" Cannot proceed without all required timeframes")
return {} # Return empty dict to signal failure
# Fetch secondary symbol data (1m timeframe only)
logger.info(f" Fetching secondary symbol data: {secondary_symbol} (1m)")
secondary_df = None
# Try DuckDB first
# Use 1m-specific window
tf_window = time_windows.get('1m', max_window)
tf_start_time = timestamp - tf_window
tf_end_time = timestamp
# Try DuckDB first with 'before' direction
if duckdb_storage:
try:
secondary_df = duckdb_storage.get_ohlcv_data(
symbol=secondary_symbol,
timeframe='1m',
start_time=start_time,
end_time=end_time,
start_time=tf_start_time,
end_time=tf_end_time,
limit=candles_per_timeframe,
direction='latest'
direction='before' # Get data BEFORE timestamp
)
if secondary_df is not None and not secondary_df.empty:
logger.debug(f" {secondary_symbol} 1m: {len(secondary_df)} candles from DuckDB")
except Exception as e:
logger.debug(f" {secondary_symbol} 1m: DuckDB query failed: {e}")
# If DuckDB doesn't have enough, try API with historical time range
if secondary_df is None or secondary_df.empty or len(secondary_df) < min_required_candles:
try:
logger.info(f" {secondary_symbol} 1m: DuckDB insufficient ({len(secondary_df) if secondary_df is not None else 0} candles), fetching from API for timestamp {timestamp}...")
# Fetch historical data from API for the specific time range
api_df = self._fetch_historical_from_api(
symbol=secondary_symbol,
timeframe='1m',
start_time=tf_start_time,
end_time=tf_end_time,
limit=candles_per_timeframe
)
if api_df is not None and not api_df.empty:
# Filter to data before timestamp
try:
api_df = api_df[api_df.index <= tf_end_time].tail(candles_per_timeframe)
if len(api_df) >= min_required_candles:
secondary_df = api_df
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles from API (historical range)")
# Store in DuckDB for future use
if duckdb_storage:
try:
duckdb_storage.store_ohlcv_data(secondary_symbol, '1m', secondary_df)
logger.debug(f" {secondary_symbol} 1m: Stored in DuckDB for future use")
except Exception as e:
logger.debug(f" {secondary_symbol} 1m: Could not store in DuckDB: {e}")
except Exception as e:
logger.debug(f" {secondary_symbol} 1m: Could not filter API data: {e}")
if len(api_df) >= min_required_candles:
secondary_df = api_df
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles from API (unfiltered)")
else:
logger.warning(f" {secondary_symbol} 1m: API fetch returned no data")
except Exception as e:
logger.warning(f" {secondary_symbol} 1m: API fetch failed: {e}")
import traceback
logger.debug(traceback.format_exc())
# Fallback to replay
if secondary_df is None or secondary_df.empty:
if secondary_df is None or secondary_df.empty or len(secondary_df) < min_required_candles:
try:
replay_data = self.data_provider.get_historical_data_replay(
symbol=secondary_symbol,
start_time=start_time,
end_time=end_time,
start_time=tf_start_time,
end_time=tf_end_time,
timeframes=['1m']
)
secondary_df = replay_data.get('1m')
if secondary_df is not None and not secondary_df.empty:
logger.debug(f" {secondary_symbol} 1m: {len(secondary_df)} candles from replay")
replay_df = replay_data.get('1m')
if replay_df is not None and not replay_df.empty and len(replay_df) >= min_required_candles:
secondary_df = replay_df
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles from replay")
except Exception as e:
logger.debug(f" {secondary_symbol} 1m: Replay failed: {e}")
# Last resort: latest data
if secondary_df is None or secondary_df.empty:
logger.warning(f" {secondary_symbol} 1m: No historical data, using latest as fallback")
secondary_df = self.data_provider.get_historical_data(
symbol=secondary_symbol,
timeframe='1m',
limit=candles_per_timeframe
)
# Store secondary symbol data
# Validate and store secondary symbol data
if secondary_df is not None and not secondary_df.empty:
market_state['secondary_timeframes']['1m'] = {
'timestamps': secondary_df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': secondary_df['open'].tolist(),
'high': secondary_df['high'].tolist(),
'low': secondary_df['low'].tolist(),
'close': secondary_df['close'].tolist(),
'volume': secondary_df['volume'].tolist()
}
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles")
if len(secondary_df) < min_required_candles:
logger.warning(f" {secondary_symbol} 1m: Only {len(secondary_df)} candles (need {min_required_candles}), skipping")
elif secondary_df.isnull().any().any():
logger.warning(f" {secondary_symbol} 1m: Contains NaN values, skipping")
elif not all(col in secondary_df.columns for col in ['open', 'high', 'low', 'close', 'volume']):
logger.warning(f" {secondary_symbol} 1m: Missing required columns, skipping")
else:
# Store in the correct structure: secondary_timeframes[symbol][timeframe]
if secondary_symbol not in market_state['secondary_timeframes']:
market_state['secondary_timeframes'][secondary_symbol] = {}
market_state['secondary_timeframes'][secondary_symbol]['1m'] = {
'timestamps': secondary_df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
'open': secondary_df['open'].tolist(),
'high': secondary_df['high'].tolist(),
'low': secondary_df['low'].tolist(),
'close': secondary_df['close'].tolist(),
'volume': secondary_df['volume'].tolist()
}
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles [OK]")
else:
logger.warning(f" {secondary_symbol} 1m: No data available")
logger.warning(f" {secondary_symbol} 1m: No quality data available (need {min_required_candles} candles)")
# Verify we have data
if market_state['timeframes']:
@@ -1190,6 +1342,200 @@ class RealTrainingAdapter:
state_size = agent.state_size if hasattr(agent, 'state_size') else 100
return [0.0] * state_size
def _fetch_historical_from_api(self, symbol: str, timeframe: str, start_time: datetime, end_time: datetime, limit: int) -> Optional[pd.DataFrame]:
"""
Fetch historical OHLCV data from exchange APIs for a specific time range
Args:
symbol: Trading symbol (e.g., 'ETH/USDT')
timeframe: Timeframe (e.g., '1m', '1h', '1d')
start_time: Start timestamp (UTC)
end_time: End timestamp (UTC)
limit: Maximum number of candles to fetch
Returns:
DataFrame with OHLCV data or None if fetch fails
"""
import pandas as pd
import requests
import time
from datetime import datetime, timezone
try:
# Handle 1s timeframe specially (not directly supported by most APIs)
if timeframe == '1s':
logger.debug(f"1s timeframe requested - will try to generate from ticks or skip")
# For 1s, we might need to generate from tick data or skip
# This is handled by the data provider's _generate_1s_candles_from_ticks
# For now, return None and let the caller handle it
return None
# Try Binance first (supports historical queries with startTime/endTime)
try:
binance_symbol = symbol.replace('/', '').upper()
# Convert timeframe for Binance
timeframe_map = {
'1m': '1m', '5m': '5m', '15m': '15m', '30m': '30m',
'1h': '1h', '4h': '4h', '1d': '1d'
}
binance_timeframe = timeframe_map.get(timeframe)
if not binance_timeframe:
logger.warning(f"Binance doesn't support timeframe {timeframe}")
return None
# Binance API klines endpoint with startTime and endTime
url = "https://api.binance.com/api/v3/klines"
# Convert timestamps to milliseconds
start_ms = int(start_time.timestamp() * 1000)
end_ms = int(end_time.timestamp() * 1000)
# Binance max is 1000 per request, so paginate if needed
all_data = []
current_start = start_ms
max_per_request = 1000
max_requests = 10 # Safety limit
request_count = 0
while current_start < end_ms and request_count < max_requests:
params = {
'symbol': binance_symbol,
'interval': binance_timeframe,
'startTime': current_start,
'endTime': end_ms,
'limit': min(max_per_request, limit - len(all_data))
}
logger.debug(f"Fetching from Binance: {symbol} {timeframe} batch {request_count + 1} (start: {current_start}, end: {end_ms})")
response = requests.get(url, params=params, timeout=10)
if response.status_code == 200:
data = response.json()
if data:
all_data.extend(data)
# Update current_start to the last candle's close_time + 1ms
if len(data) > 0:
last_close_time = data[-1][6] # close_time is at index 6
current_start = last_close_time + 1
else:
break
# If we got less than requested, we've reached the end
if len(data) < max_per_request:
break
# If we have enough data, stop
if len(all_data) >= limit:
break
else:
break
else:
logger.debug(f"Binance API returned {response.status_code} for {symbol} {timeframe}")
break
request_count += 1
# Small delay to avoid rate limiting
time.sleep(0.1)
if all_data:
# Convert to DataFrame
df = pd.DataFrame(all_data, columns=[
'timestamp', 'open', 'high', 'low', 'close', 'volume',
'close_time', 'quote_volume', 'trades', 'taker_buy_base',
'taker_buy_quote', 'ignore'
])
# Process columns
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
for col in ['open', 'high', 'low', 'close', 'volume']:
df[col] = df[col].astype(float)
# Keep only OHLCV columns and set timestamp as index
df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
df = df.set_index('timestamp')
df = df.sort_index()
# Remove duplicates and take last 'limit' candles
df = df[~df.index.duplicated(keep='last')]
df = df.tail(limit)
logger.info(f"Binance API: Fetched {len(df)} candles for {symbol} {timeframe} (historical, {request_count} requests)")
return df
else:
logger.warning(f"Binance API returned no data for {symbol} {timeframe}")
except Exception as e:
logger.debug(f"Binance fetch failed: {e}")
# Fallback to MEXC
try:
mexc_symbol = symbol.replace('/', '').upper()
timeframe_map = {
'1m': '1m', '5m': '5m', '15m': '15m', '30m': '30m',
'1h': '1h', '4h': '4h', '1d': '1d'
}
mexc_timeframe = timeframe_map.get(timeframe)
if not mexc_timeframe:
logger.warning(f"MEXC doesn't support timeframe {timeframe}")
return None
# MEXC API klines endpoint (may not support startTime/endTime, so fetch latest and filter)
url = "https://api.mexc.com/api/v3/klines"
params = {
'symbol': mexc_symbol,
'interval': mexc_timeframe,
'limit': min(limit * 2, 1000) # Fetch more to account for filtering
}
logger.debug(f"Fetching from MEXC: {symbol} {timeframe}")
response = requests.get(url, params=params, timeout=10)
if response.status_code == 200:
data = response.json()
if data:
df = pd.DataFrame(data, columns=[
'timestamp', 'open', 'high', 'low', 'close', 'volume',
'close_time', 'quote_volume'
])
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
for col in ['open', 'high', 'low', 'close', 'volume']:
df[col] = df[col].astype(float)
df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
df = df.set_index('timestamp')
df = df.sort_index()
# Filter to time range
df = df[(df.index >= start_time) & (df.index <= end_time)]
df = df.tail(limit)
if len(df) > 0:
logger.info(f"MEXC API: Fetched {len(df)} candles for {symbol} {timeframe} (historical, filtered)")
return df
else:
logger.warning(f"MEXC API: No candles in time range for {symbol} {timeframe}")
else:
logger.warning(f"MEXC API returned empty data for {symbol} {timeframe}")
else:
logger.debug(f"MEXC API returned {response.status_code} for {symbol} {timeframe}")
except Exception as e:
logger.debug(f"MEXC fetch failed: {e}")
return None
except Exception as e:
logger.error(f"Error fetching historical data from API: {e}")
import traceback
logger.debug(traceback.format_exc())
return None
def _extract_timeframe_data(self, tf_data: Dict, target_seq_len: int = 600) -> Optional[torch.Tensor]:
"""
Extract and normalize OHLCV data from a single timeframe
@@ -1215,28 +1561,22 @@ class RealTrainingAdapter:
if len(closes) == 0:
return None
# Take last target_seq_len candles or pad if needed
if len(closes) >= target_seq_len:
# Truncate to target length
opens = opens[-target_seq_len:]
highs = highs[-target_seq_len:]
lows = lows[-target_seq_len:]
closes = closes[-target_seq_len:]
volumes = volumes[-target_seq_len:]
else:
# Pad with last candle
pad_len = target_seq_len - len(closes)
last_open = opens[-1] if len(opens) > 0 else 0.0
last_high = highs[-1] if len(highs) > 0 else 0.0
last_low = lows[-1] if len(lows) > 0 else 0.0
last_close = closes[-1] if len(closes) > 0 else 0.0
last_volume = volumes[-1] if len(volumes) > 0 else 0.0
# REQUIRED: Must have exactly target_seq_len (600) candles, no padding allowed
if len(closes) < target_seq_len:
logger.warning(f"Insufficient candles: {len(closes)} < {target_seq_len} (required)")
return None
opens = np.pad(opens, (0, pad_len), constant_values=last_open)
highs = np.pad(highs, (0, pad_len), constant_values=last_high)
lows = np.pad(lows, (0, pad_len), constant_values=last_low)
closes = np.pad(closes, (0, pad_len), constant_values=last_close)
volumes = np.pad(volumes, (0, pad_len), constant_values=last_volume)
# Take last target_seq_len candles (exactly 600)
opens = opens[-target_seq_len:]
highs = highs[-target_seq_len:]
lows = lows[-target_seq_len:]
closes = closes[-target_seq_len:]
volumes = volumes[-target_seq_len:]
# Validate we have exactly target_seq_len
if len(closes) != target_seq_len:
logger.warning(f"Extraction failed: got {len(closes)} candles, need {target_seq_len}")
return None
# Stack OHLCV [seq_len, 5]
ohlcv = np.stack([opens, highs, lows, closes, volumes], axis=-1)
@@ -1369,45 +1709,62 @@ class RealTrainingAdapter:
timeframes = market_state.get('timeframes', {})
secondary_timeframes = market_state.get('secondary_timeframes', {})
# Target sequence length - RESTORED to 200 (memory leak fixed)
# With 5 timeframes * 200 candles = 1000 sequence positions
# Memory management fixes allow full sequence length
target_seq_len = 200 # Restored to original
for tf_data in timeframes.values():
if tf_data and 'close' in tf_data and len(tf_data['close']) > 0:
target_seq_len = min(len(tf_data['close']), 200) # Cap at 200
break
# REQUIRED: 600 candles per timeframe for transformer model
target_seq_len = 600 # Must be 600 candles for each timeframe
# Validate we have enough data in required timeframes (1m, 1h, 1d)
required_tfs = ['1m', '1h', '1d']
for tf_name in required_tfs:
if tf_name in timeframes:
tf_data = timeframes[tf_name]
if tf_data and 'close' in tf_data:
if len(tf_data['close']) < 600:
logger.warning(f"Required timeframe {tf_name} has only {len(tf_data['close'])} candles (need 600)")
return None
# Validate optional 1s timeframe if present (must have 600 candles if included)
if '1s' in timeframes:
tf_data = timeframes['1s']
if tf_data and 'close' in tf_data:
if len(tf_data['close']) < 600:
logger.warning(f"Optional timeframe 1s has only {len(tf_data['close'])} candles (need 600), excluding it")
# Remove 1s from timeframes if insufficient
timeframes = {k: v for k, v in timeframes.items() if k != '1s'}
# Extract each timeframe (returns tuple: (tensor, norm_params) or None)
# Store normalization parameters for each timeframe
norm_params_dict = {}
# OPTIONAL: Extract 1s timeframe if available (must have 600 candles if included)
result_1s = self._extract_timeframe_data(timeframes.get('1s', {}), target_seq_len) if '1s' in timeframes else None
if result_1s:
price_data_1s, norm_params_dict['1s'] = result_1s
logger.debug(f"Included optional 1s timeframe with {result_1s[0].shape[1]} candles")
else:
# Don't fail on missing 1s data, it's often unavailable in annotations
# 1s is optional - don't fail if missing, but log it
price_data_1s = None
logger.debug("Optional 1s timeframe not available (this is OK)")
# REQUIRED: Extract all 3 timeframes (1m, 1h, 1d) with exactly 600 candles each
result_1m = self._extract_timeframe_data(timeframes.get('1m', {}), target_seq_len) if '1m' in timeframes else None
if result_1m:
price_data_1m, norm_params_dict['1m'] = result_1m
else:
# Warning: 1m data is critical
logger.warning(f"Missing 1m data for transformer batch (sample: {training_sample.get('test_case_id')})")
logger.warning(f"Missing or insufficient 1m data for transformer batch (sample: {training_sample.get('test_case_id')})")
return None
result_1h = self._extract_timeframe_data(timeframes.get('1h', {}), target_seq_len) if '1h' in timeframes else None
if result_1h:
price_data_1h, norm_params_dict['1h'] = result_1h
else:
price_data_1h = None
logger.warning(f"Missing or insufficient 1h data for transformer batch (sample: {training_sample.get('test_case_id')})")
return None
result_1d = self._extract_timeframe_data(timeframes.get('1d', {}), target_seq_len) if '1d' in timeframes else None
if result_1d:
price_data_1d, norm_params_dict['1d'] = result_1d
else:
price_data_1d = None
logger.warning(f"Missing or insufficient 1d data for transformer batch (sample: {training_sample.get('test_case_id')})")
return None
# Extract BTC reference data
btc_data_1m = None
@@ -1416,12 +1773,41 @@ class RealTrainingAdapter:
if result_btc:
btc_data_1m, norm_params_dict['btc'] = result_btc
# Ensure at least one timeframe is available
# Check if all are None (can't use any() with tensors)
if price_data_1s is None and price_data_1m is None and price_data_1h is None and price_data_1d is None:
logger.warning("No price data available in any timeframe")
# CRITICAL: Ensure ALL required timeframes are available (1m, 1h, 1d)
# REQUIRED: 1m, 1h, 1d (each with 600 candles)
# OPTIONAL: 1s (if available, include with 600 candles)
required_timeframes_present = (
price_data_1m is not None and
price_data_1h is not None and
price_data_1d is not None
)
if not required_timeframes_present:
missing = []
if price_data_1m is None:
missing.append('1m')
if price_data_1h is None:
missing.append('1h')
if price_data_1d is None:
missing.append('1d')
logger.warning(f"Missing required timeframes: {missing}. Need all 3: 1m, 1h, 1d")
return None
# Validate each required timeframe has correct shape
for tf_name, tf_data in [('1m', price_data_1m), ('1h', price_data_1h), ('1d', price_data_1d)]:
if tf_data is not None:
shape = tf_data.shape
if len(shape) != 3 or shape[1] < 600:
logger.warning(f"Timeframe {tf_name} has invalid shape {shape} (need [1, 600, 5])")
return None
# Validate optional 1s timeframe if present
if price_data_1s is not None:
shape = price_data_1s.shape
if len(shape) != 3 or shape[1] < 600:
logger.warning(f"Optional timeframe 1s has invalid shape {shape}, removing it")
price_data_1s = None
# Get reference timeframe for other features (prefer 1m, fallback to any available)
ref_data = price_data_1m if price_data_1m is not None else (
price_data_1h if price_data_1h is not None else (
@@ -1928,6 +2314,46 @@ class RealTrainingAdapter:
for i, data in enumerate(training_data):
batch = self._convert_annotation_to_transformer_batch(data)
if batch is not None:
# CRITICAL: Validate that ALL required timeframes are present
# REQUIRED: 1m, 1h, 1d (each with 600 candles)
# OPTIONAL: 1s (if available, include with 600 candles)
required_tf_keys = ['price_data_1m', 'price_data_1h', 'price_data_1d']
optional_tf_keys = ['price_data_1s']
missing_tfs = [tf for tf in required_tf_keys if batch.get(tf) is None]
if missing_tfs:
logger.warning(f" Skipping sample {i+1}: Missing required timeframes: {missing_tfs}")
continue
# Validate each required timeframe has correct shape [1, 600, 5]
for tf_key in required_tf_keys:
tf_data = batch.get(tf_key)
if tf_data is not None:
if not isinstance(tf_data, torch.Tensor):
logger.warning(f" Skipping sample {i+1}: {tf_key} is not a tensor")
missing_tfs.append(tf_key)
break
shape = tf_data.shape
if len(shape) != 3 or shape[1] < 600: # Must be [batch, seq_len, features] with seq_len >= 600
logger.warning(f" Skipping sample {i+1}: {tf_key} has invalid shape {shape} (need [1, 600, 5])")
missing_tfs.append(tf_key)
break
if missing_tfs:
continue
# Validate optional 1s timeframe if present
if batch.get('price_data_1s') is not None:
tf_data = batch.get('price_data_1s')
if isinstance(tf_data, torch.Tensor):
shape = tf_data.shape
if len(shape) != 3 or shape[1] < 600:
logger.warning(f" Sample {i+1}: price_data_1s has invalid shape {shape}, removing it")
batch['price_data_1s'] = None
logger.debug(f" Sample {i+1}: All required timeframes present (1m, 1h, 1d), 1s={'present' if batch.get('price_data_1s') is not None else 'not available'}")
# Move batch to GPU immediately with pinned memory for faster transfer
if use_gpu:
batch_gpu = {}

View File

@@ -832,7 +832,7 @@ class AnnotationDashboard:
# Check if the specific model is already initialized
if model_name == 'Transformer':
logger.info("Checking Transformer model...")
if self.orchestrator.primary_transformer:
if hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer:
self.loaded_models['Transformer'] = self.orchestrator.primary_transformer
logger.info("Transformer model loaded successfully")
else:
@@ -841,7 +841,7 @@ class AnnotationDashboard:
elif model_name == 'CNN':
logger.info("Checking CNN model...")
if self.orchestrator.cnn_model:
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
self.loaded_models['CNN'] = self.orchestrator.cnn_model
logger.info("CNN model loaded successfully")
else:

View File

@@ -270,9 +270,9 @@ class DQNAgent:
self.batch_size = batch_size
self.target_update = target_update
# Set device for computation (default to GPU if available)
# Set device for computation (read from config.yaml if available)
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = self._get_device_from_config()
else:
self.device = device
@@ -282,10 +282,6 @@ class DQNAgent:
self.policy_net = DQNNetwork(self.state_dim, self.n_actions, config).to(self.device)
self.target_net = DQNNetwork(self.state_dim, self.n_actions, config).to(self.device)
# Ensure models are on the correct device
self.policy_net = self.policy_net.to(self.device)
self.target_net = self.target_net.to(self.device)
# Initialize the target network with the same weights as the policy network
self.target_net.load_state_dict(self.policy_net.state_dict())
@@ -317,13 +313,92 @@ class DQNAgent:
# Market regime adaptation weights
self.market_regime_weights = {
'trending': 1.0,
'sideways': 0.8,
'volatile': 1.2,
'bullish': 1.1,
'bearish': 1.1
'trending': 1.2, # Higher confidence in trending markets
'ranging': 0.8, # Lower confidence in ranging markets
'volatile': 0.6 # Much lower confidence in volatile markets
}
# Additional initialization
self.recent_actions = deque(maxlen=10)
self.recent_prices = deque(maxlen=20)
self.recent_rewards = deque(maxlen=100)
# Price direction tracking
self.last_price_direction = {
'direction': 0.0,
'confidence': 0.0
}
self.price_movement_memory = []
self.losses = []
self.no_improvement_count = 0
self.confidence_history = []
self.avg_confidence = 0.0
self.max_confidence = 0.0
self.min_confidence = 1.0
# Enhanced training features
self.use_dueling = True
self.use_prioritized_replay = priority_memory
self.alpha = 0.6
self.beta = 0.4
self.beta_increment = 0.001
self.use_double_dqn = True
self.target_update_freq = target_update
self.training_steps = 0
self.gradient_clip_norm = 1.0
self.epsilon_history = []
self.td_errors = []
# Trade settings
self.trade_action_fee = 0.0005
self.minimum_action_confidence = 0.3
# Violent move detection
self.price_history = []
self.volatility_window = 20
self.volatility_threshold = 0.0015
self.post_violent_move = False
self.violent_move_cooldown = 0
# Feature integration
self.last_hidden_features = None
self.feature_history = []
self.realtime_tick_features = None
self.tick_feature_weight = 0.3
# Mixed precision training
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
self.use_mixed_precision = True
self.scaler = torch.amp.GradScaler('cuda')
logger.info("Mixed precision training enabled")
else:
self.use_mixed_precision = False
logger.info("Mixed precision training disabled")
self.training = True
# Compatibility
self.state_size = np.prod(state_shape)
self.action_size = n_actions
self.memory_size = buffer_size
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3]
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}")
total_params = sum(p.numel() for p in self.policy_net.parameters())
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
# Position management
self.current_position = 0.0
self.position_entry_price = 0.0
self.position_entry_time = None
self.entry_confidence_threshold = 0.35
self.exit_confidence_threshold = 0.15
self.uncertainty_threshold = 0.1
# Load best checkpoint if available
if self.enable_checkpoints:
self.load_best_checkpoint()
@@ -332,113 +407,46 @@ class DQNAgent:
if enable_checkpoints:
logger.info(f"Model name: {model_name}, Checkpoint frequency: {self.checkpoint_frequency}")
# Add this line to the __init__ method
self.recent_actions = deque(maxlen=10)
self.recent_prices = deque(maxlen=20)
self.recent_rewards = deque(maxlen=100)
def _get_device_from_config(self) -> torch.device:
"""Get device from config.yaml or auto-detect"""
try:
# Try to load config
from core.config import get_config
config = get_config()
gpu_config = config._config.get('gpu', {})
# Price direction tracking - stores direction and confidence
self.last_price_direction = {
'direction': 0.0, # Single value between -1 and 1
'confidence': 0.0 # Single value between 0 and 1
}
device_setting = gpu_config.get('device', 'auto')
fallback_to_cpu = gpu_config.get('fallback_to_cpu', True)
gpu_enabled = gpu_config.get('enabled', True)
# Store separate memory for price direction examples
self.price_movement_memory = [] # For storing examples of clear price movements
# If GPU is disabled in config, use CPU
if not gpu_enabled:
logger.info("GPU disabled in config.yaml, using CPU")
return torch.device('cpu')
# Performance tracking
self.losses = []
self.no_improvement_count = 0
# Handle device selection
if device_setting == 'cpu':
logger.info("Device set to CPU in config.yaml")
return torch.device('cpu')
elif device_setting == 'cuda' or device_setting == 'auto':
# Try GPU first
if torch.cuda.is_available():
logger.info("Using GPU (CUDA available)")
return torch.device('cuda')
else:
if fallback_to_cpu:
logger.warning("CUDA not available, falling back to CPU")
return torch.device('cpu')
else:
raise RuntimeError("CUDA not available and fallback_to_cpu is False")
else:
logger.warning(f"Unknown device setting '{device_setting}', using auto-detection")
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Confidence tracking
self.confidence_history = []
self.avg_confidence = 0.0
self.max_confidence = 0.0
self.min_confidence = 1.0
# Enhanced features from EnhancedDQNAgent
# Market adaptation capabilities
self.market_regime_weights = {
'trending': 1.2, # Higher confidence in trending markets
'ranging': 0.8, # Lower confidence in ranging markets
'volatile': 0.6 # Much lower confidence in volatile markets
}
# Dueling network support (requires enhanced network architecture)
self.use_dueling = True
# Prioritized experience replay parameters
self.use_prioritized_replay = priority_memory
self.alpha = 0.6 # Priority exponent
self.beta = 0.4 # Importance sampling exponent
self.beta_increment = 0.001
# Double DQN support
self.use_double_dqn = True
# Enhanced training features from EnhancedDQNAgent
self.target_update_freq = target_update # More descriptive name
self.training_steps = 0
self.gradient_clip_norm = 1.0 # Gradient clipping
# Enhanced statistics tracking
self.epsilon_history = []
self.td_errors = [] # Track TD errors for analysis
# Trade action fee and confidence thresholds
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
# Violent move detection
self.price_history = []
self.volatility_window = 20 # Window size for volatility calculation
self.volatility_threshold = 0.0015 # Threshold for considering a move "violent"
self.post_violent_move = False # Flag for recent violent move
self.violent_move_cooldown = 0 # Cooldown after violent move
# Feature integration
self.last_hidden_features = None # Store last extracted features
self.feature_history = [] # Store history of features for analysis
# Real-time tick features integration
self.realtime_tick_features = None # Latest tick features from tick processor
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
# Check if mixed precision training should be used
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
self.use_mixed_precision = True
self.scaler = torch.amp.GradScaler('cuda')
logger.info("Mixed precision training enabled")
else:
self.use_mixed_precision = False
logger.info("Mixed precision training disabled")
# Track if we're in training mode
self.training = True
# For compatibility with old code
self.state_size = np.prod(state_shape)
self.action_size = n_actions
self.memory_size = buffer_size
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] # Default timeframes
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}")
# Log model parameters
total_params = sum(p.numel() for p in self.policy_net.parameters())
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
# Position management for 2-action system
self.current_position = 0.0 # -1 (short), 0 (neutral), 1 (long)
self.position_entry_price = 0.0
self.position_entry_time = None
# Different thresholds for entry vs exit decisions - AGGRESSIVE for more training data
self.entry_confidence_threshold = 0.35 # Lower threshold for new positions (was 0.7)
self.exit_confidence_threshold = 0.15 # Very low threshold for closing positions (was 0.3)
self.uncertainty_threshold = 0.1 # When to stay neutral
except Exception as e:
logger.warning(f"Error reading device config: {e}, using auto-detection")
# Fallback to auto-detection
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_best_checkpoint(self):
"""Load the best checkpoint for this DQN agent"""
@@ -513,103 +521,6 @@ class DQNAgent:
logger.error(f"Error saving DQN checkpoint: {e}")
return False
# Store separate memory for price direction examples
self.price_movement_memory = [] # For storing examples of clear price movements
# Performance tracking
self.losses = []
self.no_improvement_count = 0
# Confidence tracking
self.confidence_history = []
self.avg_confidence = 0.0
self.max_confidence = 0.0
self.min_confidence = 1.0
# Enhanced features from EnhancedDQNAgent
# Market adaptation capabilities
self.market_regime_weights = {
'trending': 1.2, # Higher confidence in trending markets
'ranging': 0.8, # Lower confidence in ranging markets
'volatile': 0.6 # Much lower confidence in volatile markets
}
# Dueling network support (requires enhanced network architecture)
self.use_dueling = True
# Prioritized experience replay parameters
self.use_prioritized_replay = priority_memory
self.alpha = 0.6 # Priority exponent
self.beta = 0.4 # Importance sampling exponent
self.beta_increment = 0.001
# Double DQN support
self.use_double_dqn = True
# Enhanced training features from EnhancedDQNAgent
self.target_update_freq = target_update # More descriptive name
self.training_steps = 0
self.gradient_clip_norm = 1.0 # Gradient clipping
# Enhanced statistics tracking
self.epsilon_history = []
self.td_errors = [] # Track TD errors for analysis
# Trade action fee and confidence thresholds
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
# Violent move detection
self.price_history = []
self.volatility_window = 20 # Window size for volatility calculation
self.volatility_threshold = 0.0015 # Threshold for considering a move "violent"
self.post_violent_move = False # Flag for recent violent move
self.violent_move_cooldown = 0 # Cooldown after violent move
# Feature integration
self.last_hidden_features = None # Store last extracted features
self.feature_history = [] # Store history of features for analysis
# Real-time tick features integration
self.realtime_tick_features = None # Latest tick features from tick processor
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
# Check if mixed precision training should be used
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
self.use_mixed_precision = True
self.scaler = torch.amp.GradScaler('cuda')
logger.info("Mixed precision training enabled")
else:
self.use_mixed_precision = False
logger.info("Mixed precision training disabled")
# Track if we're in training mode
self.training = True
# For compatibility with old code
self.state_size = np.prod(state_shape)
self.action_size = n_actions
self.memory_size = buffer_size
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] # Default timeframes
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}")
# Log model parameters
total_params = sum(p.numel() for p in self.policy_net.parameters())
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
# Position management for 2-action system
self.current_position = 0.0 # -1 (short), 0 (neutral), 1 (long)
self.position_entry_price = 0.0
self.position_entry_time = None
# Different thresholds for entry vs exit decisions - AGGRESSIVE for more training data
self.entry_confidence_threshold = 0.35 # Lower threshold for new positions (was 0.7)
self.exit_confidence_threshold = 0.15 # Very low threshold for closing positions (was 0.3)
self.uncertainty_threshold = 0.1 # When to stay neutral
def move_models_to_device(self, device=None):
"""Move models to the specified device (GPU/CPU)"""
if device is not None:

View File

@@ -181,3 +181,7 @@ This complements the earlier fixes:
📊 **EXPECTED** - Better trend predictions after training

View File

@@ -318,28 +318,63 @@ class TradingOrchestrator:
# Initialize confidence threshold
self.confidence_threshold = self.config.get('confidence_threshold', 0.6)
# Determine the device to use (GPU if available, else CPU)
# Initialize device - force CPU mode to avoid CUDA errors
if torch.cuda.is_available():
try:
# Test CUDA availability with actual Linear layer operation
# This catches architecture-specific issues like gfx1151 incompatibility
test_tensor = torch.randn(2, 10).cuda()
test_linear = torch.nn.Linear(10, 5).cuda()
test_result = test_linear(test_tensor)
logger.info(f"GPU compatibility test passed: {torch.cuda.get_device_name(0)}")
self.device = torch.device("cuda")
logger.info("CUDA/ROCm device initialized successfully")
except Exception as e:
logger.warning(f"CUDA/ROCm initialization failed: {e}")
logger.warning("GPU architecture may not be supported - falling back to CPU")
logger.warning("This is common with newer AMD GPUs (gfx1151+) that require specific PyTorch builds")
self.device = torch.device("cpu")
else:
self.device = torch.device("cpu")
# Determine the device to use from config.yaml
self.device = self._get_device_from_config()
logger.info(f"Using device: {self.device}")
def _get_device_from_config(self) -> torch.device:
"""Get device from config.yaml or auto-detect"""
try:
gpu_config = self.config._config.get('gpu', {})
device_setting = gpu_config.get('device', 'auto')
fallback_to_cpu = gpu_config.get('fallback_to_cpu', True)
gpu_enabled = gpu_config.get('enabled', True)
# If GPU is disabled in config, use CPU
if not gpu_enabled:
logger.info("GPU disabled in config.yaml, using CPU")
return torch.device('cpu')
# Handle device selection
if device_setting == 'cpu':
logger.info("Device set to CPU in config.yaml")
return torch.device('cpu')
elif device_setting == 'cuda' or device_setting == 'auto':
# Try GPU first with compatibility test
if torch.cuda.is_available():
try:
# Test CUDA availability with actual Linear layer operation
# This catches architecture-specific issues like gfx1151 incompatibility
test_tensor = torch.randn(2, 10).cuda()
test_linear = torch.nn.Linear(10, 5).cuda()
test_result = test_linear(test_tensor)
logger.info(f"GPU compatibility test passed: {torch.cuda.get_device_name(0)}")
logger.info("CUDA/ROCm device initialized successfully")
return torch.device("cuda")
except Exception as e:
logger.warning(f"CUDA/ROCm initialization failed: {e}")
logger.warning("GPU architecture may not be supported - falling back to CPU")
logger.warning("This is common with newer AMD GPUs (gfx1151+) that require specific PyTorch builds")
if fallback_to_cpu:
return torch.device("cpu")
else:
raise RuntimeError("CUDA not available and fallback_to_cpu is False")
else:
if fallback_to_cpu:
logger.warning("CUDA not available, falling back to CPU")
return torch.device('cpu')
else:
raise RuntimeError("CUDA not available and fallback_to_cpu is False")
else:
logger.warning(f"Unknown device setting '{device_setting}', using auto-detection")
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
except Exception as e:
logger.warning(f"Error reading device config: {e}, using auto-detection")
# Fallback to auto-detection
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Canonical model name aliases to eliminate ambiguity across UI/DB/FS
# Canonical → accepted aliases (internal/legacy)
self.model_name_aliases: Dict[str, list] = {