wip old MISC fix
This commit is contained in:
@@ -23,6 +23,7 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
try:
|
||||
import pytz
|
||||
@@ -393,10 +394,23 @@ class RealTrainingAdapter:
|
||||
|
||||
# Get training config
|
||||
training_config = test_case.get('training_config', {})
|
||||
timeframes = training_config.get('timeframes', ['1s', '1m', '1h', '1d'])
|
||||
# RESTORED: 200 candles per timeframe (memory leak fixed)
|
||||
# With 5 timeframes * 200 candles = 1000 total positions
|
||||
candles_per_timeframe = training_config.get('candles_per_timeframe', 200) # 200 candles per batch
|
||||
# REQUIRED: All 3 timeframes (1m, 1h, 1d) with 600 candles each
|
||||
timeframes = training_config.get('timeframes', ['1m', '1h', '1d'])
|
||||
# REQUIRED: 600 candles per timeframe for transformer model
|
||||
candles_per_timeframe = training_config.get('candles_per_timeframe', 600) # 600 candles per timeframe
|
||||
|
||||
# REQUIRED: 1m, 1h, 1d (all with 600 candles each)
|
||||
# OPTIONAL: 1s (if available, include with 600 candles)
|
||||
required_timeframes = ['1m', '1h', '1d']
|
||||
optional_timeframes = ['1s'] # Include if available
|
||||
|
||||
# Ensure required timeframes are in the list
|
||||
missing_tfs = [tf for tf in required_timeframes if tf not in timeframes]
|
||||
if missing_tfs:
|
||||
logger.warning(f" Missing required timeframes: {missing_tfs}, adding them...")
|
||||
timeframes = list(set(timeframes + required_timeframes))
|
||||
|
||||
# Note: 1s is optional, don't add it if not present
|
||||
|
||||
# Determine secondary symbol based on primary symbol
|
||||
# ETH/SOL -> BTC, BTC -> ETH
|
||||
@@ -408,22 +422,30 @@ class RealTrainingAdapter:
|
||||
logger.info(f" Candles per batch: {candles_per_timeframe}")
|
||||
|
||||
# Calculate time range based on candles needed
|
||||
# For 600 candles at 1m = 600 minutes = 10 hours
|
||||
# Use timeframe-specific windows for better efficiency
|
||||
from datetime import timedelta
|
||||
|
||||
# Calculate time window for each timeframe to get 600 candles
|
||||
# Calculate time window for each timeframe to get required candles
|
||||
# Add 20% buffer to account for missing candles
|
||||
buffer_multiplier = 1.2
|
||||
time_windows = {
|
||||
'1s': timedelta(seconds=candles_per_timeframe), # 600 seconds = 10 minutes
|
||||
'1m': timedelta(minutes=candles_per_timeframe), # 600 minutes = 10 hours
|
||||
'1h': timedelta(hours=candles_per_timeframe), # 600 hours = 25 days
|
||||
'1d': timedelta(days=candles_per_timeframe) # 600 days = ~1.6 years
|
||||
'1s': timedelta(seconds=int(candles_per_timeframe * buffer_multiplier)),
|
||||
'1m': timedelta(minutes=int(candles_per_timeframe * buffer_multiplier)),
|
||||
'1h': timedelta(hours=int(candles_per_timeframe * buffer_multiplier)),
|
||||
'1d': timedelta(days=int(candles_per_timeframe * buffer_multiplier))
|
||||
}
|
||||
|
||||
# For historical queries, we want data BEFORE the timestamp
|
||||
# Use the largest window to ensure we have enough data for all timeframes
|
||||
max_window = max(time_windows.values())
|
||||
start_time = timestamp - max_window
|
||||
end_time = timestamp
|
||||
|
||||
# REQUIRED: All required timeframes must have exactly 600 candles (no tolerance for missing data)
|
||||
min_required_candles = candles_per_timeframe # Must have full 600 candles
|
||||
required_timeframes = ['1m', '1h', '1d'] # All 3 timeframes are mandatory
|
||||
optional_timeframes = ['1s'] # Include if available
|
||||
|
||||
# Fetch data for primary symbol (all timeframes) and secondary symbol (1m only)
|
||||
market_state = {
|
||||
'symbol': symbol,
|
||||
@@ -439,53 +461,126 @@ class RealTrainingAdapter:
|
||||
duckdb_storage = self.data_provider.duckdb_storage
|
||||
|
||||
# Fetch primary symbol data (all timeframes)
|
||||
# REQUIRED: Must fetch all 3 timeframes (1m, 1h, 1d) with 600 candles each
|
||||
logger.info(f" Fetching primary symbol data: {symbol}")
|
||||
logger.info(f" REQUIRED timeframes: {required_timeframes} (each with {candles_per_timeframe} candles)")
|
||||
|
||||
fetched_timeframes = {} # Track which timeframes we successfully fetched
|
||||
|
||||
for timeframe in timeframes:
|
||||
# Fetch required timeframes (1m, 1h, 1d) and optional 1s if present
|
||||
if timeframe not in required_timeframes and timeframe not in optional_timeframes:
|
||||
continue
|
||||
df = None
|
||||
limit = candles_per_timeframe # Always fetch 600 candles
|
||||
limit = candles_per_timeframe
|
||||
|
||||
# Use timeframe-specific window for better efficiency
|
||||
tf_window = time_windows.get(timeframe, max_window)
|
||||
tf_start_time = timestamp - tf_window
|
||||
tf_end_time = timestamp
|
||||
|
||||
# Try DuckDB storage first (has historical data)
|
||||
# Use 'before' direction to get data BEFORE the timestamp
|
||||
if duckdb_storage:
|
||||
try:
|
||||
df = duckdb_storage.get_ohlcv_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
start_time=tf_start_time,
|
||||
end_time=tf_end_time,
|
||||
limit=limit,
|
||||
direction='latest'
|
||||
direction='before' # Get data BEFORE timestamp for historical training
|
||||
)
|
||||
if df is not None and not df.empty:
|
||||
logger.debug(f" {timeframe}: {len(df)} candles from DuckDB (historical)")
|
||||
logger.debug(f" {timeframe}: {len(df)} candles from DuckDB (historical, before {timestamp})")
|
||||
except Exception as e:
|
||||
logger.debug(f" {timeframe}: DuckDB query failed: {e}")
|
||||
|
||||
# Fallback to data_provider (might have cached data)
|
||||
if df is None or df.empty:
|
||||
# If DuckDB doesn't have enough data, try API with proper time range
|
||||
if df is None or df.empty or len(df) < min_required_candles:
|
||||
try:
|
||||
# Try to fetch from API with historical time range
|
||||
logger.info(f" {timeframe}: DuckDB insufficient ({len(df) if df is not None else 0} candles), fetching from API for timestamp {timestamp}...")
|
||||
|
||||
# Fetch historical data from API for the specific time range
|
||||
api_df = self._fetch_historical_from_api(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
start_time=tf_start_time,
|
||||
end_time=tf_end_time,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
if api_df is not None and not api_df.empty:
|
||||
# Filter to data before timestamp (historical training needs data BEFORE the event)
|
||||
try:
|
||||
api_df = api_df[api_df.index <= tf_end_time]
|
||||
# Take the most recent candles up to limit
|
||||
api_df = api_df.tail(limit)
|
||||
|
||||
if len(api_df) >= min_required_candles:
|
||||
df = api_df
|
||||
logger.info(f" {timeframe}: {len(df)} candles from API (historical range: {tf_start_time} to {tf_end_time})")
|
||||
|
||||
# Store in DuckDB for future use
|
||||
if duckdb_storage:
|
||||
try:
|
||||
duckdb_storage.store_ohlcv_data(symbol, timeframe, df)
|
||||
logger.debug(f" {timeframe}: Stored {len(df)} candles in DuckDB for future use")
|
||||
except Exception as e:
|
||||
logger.debug(f" {timeframe}: Could not store in DuckDB: {e}")
|
||||
else:
|
||||
logger.warning(f" {timeframe}: API returned only {len(api_df)} candles after filtering (need {min_required_candles})")
|
||||
except Exception as e:
|
||||
logger.debug(f" {timeframe}: Could not filter API data: {e}")
|
||||
# Use as-is if filtering fails
|
||||
if len(api_df) >= min_required_candles:
|
||||
df = api_df
|
||||
logger.info(f" {timeframe}: {len(df)} candles from API (unfiltered)")
|
||||
else:
|
||||
logger.warning(f" {timeframe}: API fetch returned no data")
|
||||
except Exception as e:
|
||||
logger.warning(f" {timeframe}: API fetch failed: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
|
||||
# Fallback to replay method
|
||||
if df is None or df.empty or len(df) < min_required_candles:
|
||||
try:
|
||||
# Use get_historical_data_replay for time-specific data
|
||||
replay_data = self.data_provider.get_historical_data_replay(
|
||||
symbol=symbol,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
start_time=tf_start_time,
|
||||
end_time=tf_end_time,
|
||||
timeframes=[timeframe]
|
||||
)
|
||||
df = replay_data.get(timeframe)
|
||||
if df is not None and not df.empty:
|
||||
logger.debug(f" {timeframe}: {len(df)} candles from replay")
|
||||
replay_df = replay_data.get(timeframe)
|
||||
if replay_df is not None and not replay_df.empty and len(replay_df) >= min_required_candles:
|
||||
df = replay_df
|
||||
logger.info(f" {timeframe}: {len(df)} candles from replay")
|
||||
except Exception as e:
|
||||
logger.debug(f" {timeframe}: Replay failed: {e}")
|
||||
|
||||
# Last resort: get latest data (not ideal but better than nothing)
|
||||
if df is None or df.empty:
|
||||
logger.warning(f" {timeframe}: No historical data found, using latest data as fallback")
|
||||
df = self.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
limit=limit # Use calculated limit
|
||||
)
|
||||
|
||||
# Validate data quality before storing
|
||||
if df is not None and not df.empty:
|
||||
# Check minimum candle count
|
||||
if len(df) < min_required_candles:
|
||||
logger.warning(f" {symbol} {timeframe}: Only {len(df)} candles (need {min_required_candles}), skipping")
|
||||
continue
|
||||
|
||||
# Validate data quality - check for NaN values
|
||||
if df.isnull().any().any():
|
||||
logger.warning(f" {symbol} {timeframe}: Contains NaN values, cleaning...")
|
||||
df = df.dropna()
|
||||
if len(df) < min_required_candles:
|
||||
logger.warning(f" {symbol} {timeframe}: After cleaning, only {len(df)} candles, skipping")
|
||||
continue
|
||||
|
||||
# Ensure we have required columns
|
||||
required_cols = ['open', 'high', 'low', 'close', 'volume']
|
||||
if not all(col in df.columns for col in required_cols):
|
||||
logger.warning(f" {symbol} {timeframe}: Missing required columns, skipping")
|
||||
continue
|
||||
|
||||
# Convert to dict format
|
||||
market_state['timeframes'][timeframe] = {
|
||||
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
@@ -495,67 +590,124 @@ class RealTrainingAdapter:
|
||||
'close': df['close'].tolist(),
|
||||
'volume': df['volume'].tolist()
|
||||
}
|
||||
logger.info(f" {symbol} {timeframe}: {len(df)} candles")
|
||||
fetched_timeframes[timeframe] = len(df)
|
||||
logger.info(f" {symbol} {timeframe}: {len(df)} candles [OK]")
|
||||
else:
|
||||
logger.warning(f" {symbol} {timeframe}: No data available")
|
||||
logger.warning(f" {symbol} {timeframe}: No quality data available (need {min_required_candles} candles)")
|
||||
|
||||
# Fetch secondary symbol data (1m timeframe only, 600 candles)
|
||||
# CRITICAL: Validate we have all required timeframes
|
||||
missing_required = [tf for tf in required_timeframes if tf not in fetched_timeframes]
|
||||
if missing_required:
|
||||
logger.error(f" FAILED: Missing required timeframes: {missing_required}")
|
||||
logger.error(f" Fetched: {list(fetched_timeframes.keys())}")
|
||||
logger.error(f" Cannot proceed without all required timeframes")
|
||||
return {} # Return empty dict to signal failure
|
||||
|
||||
# Fetch secondary symbol data (1m timeframe only)
|
||||
logger.info(f" Fetching secondary symbol data: {secondary_symbol} (1m)")
|
||||
secondary_df = None
|
||||
|
||||
# Try DuckDB first
|
||||
# Use 1m-specific window
|
||||
tf_window = time_windows.get('1m', max_window)
|
||||
tf_start_time = timestamp - tf_window
|
||||
tf_end_time = timestamp
|
||||
|
||||
# Try DuckDB first with 'before' direction
|
||||
if duckdb_storage:
|
||||
try:
|
||||
secondary_df = duckdb_storage.get_ohlcv_data(
|
||||
symbol=secondary_symbol,
|
||||
timeframe='1m',
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
start_time=tf_start_time,
|
||||
end_time=tf_end_time,
|
||||
limit=candles_per_timeframe,
|
||||
direction='latest'
|
||||
direction='before' # Get data BEFORE timestamp
|
||||
)
|
||||
if secondary_df is not None and not secondary_df.empty:
|
||||
logger.debug(f" {secondary_symbol} 1m: {len(secondary_df)} candles from DuckDB")
|
||||
except Exception as e:
|
||||
logger.debug(f" {secondary_symbol} 1m: DuckDB query failed: {e}")
|
||||
|
||||
# If DuckDB doesn't have enough, try API with historical time range
|
||||
if secondary_df is None or secondary_df.empty or len(secondary_df) < min_required_candles:
|
||||
try:
|
||||
logger.info(f" {secondary_symbol} 1m: DuckDB insufficient ({len(secondary_df) if secondary_df is not None else 0} candles), fetching from API for timestamp {timestamp}...")
|
||||
|
||||
# Fetch historical data from API for the specific time range
|
||||
api_df = self._fetch_historical_from_api(
|
||||
symbol=secondary_symbol,
|
||||
timeframe='1m',
|
||||
start_time=tf_start_time,
|
||||
end_time=tf_end_time,
|
||||
limit=candles_per_timeframe
|
||||
)
|
||||
|
||||
if api_df is not None and not api_df.empty:
|
||||
# Filter to data before timestamp
|
||||
try:
|
||||
api_df = api_df[api_df.index <= tf_end_time].tail(candles_per_timeframe)
|
||||
if len(api_df) >= min_required_candles:
|
||||
secondary_df = api_df
|
||||
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles from API (historical range)")
|
||||
|
||||
# Store in DuckDB for future use
|
||||
if duckdb_storage:
|
||||
try:
|
||||
duckdb_storage.store_ohlcv_data(secondary_symbol, '1m', secondary_df)
|
||||
logger.debug(f" {secondary_symbol} 1m: Stored in DuckDB for future use")
|
||||
except Exception as e:
|
||||
logger.debug(f" {secondary_symbol} 1m: Could not store in DuckDB: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f" {secondary_symbol} 1m: Could not filter API data: {e}")
|
||||
if len(api_df) >= min_required_candles:
|
||||
secondary_df = api_df
|
||||
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles from API (unfiltered)")
|
||||
else:
|
||||
logger.warning(f" {secondary_symbol} 1m: API fetch returned no data")
|
||||
except Exception as e:
|
||||
logger.warning(f" {secondary_symbol} 1m: API fetch failed: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
|
||||
# Fallback to replay
|
||||
if secondary_df is None or secondary_df.empty:
|
||||
if secondary_df is None or secondary_df.empty or len(secondary_df) < min_required_candles:
|
||||
try:
|
||||
replay_data = self.data_provider.get_historical_data_replay(
|
||||
symbol=secondary_symbol,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
start_time=tf_start_time,
|
||||
end_time=tf_end_time,
|
||||
timeframes=['1m']
|
||||
)
|
||||
secondary_df = replay_data.get('1m')
|
||||
if secondary_df is not None and not secondary_df.empty:
|
||||
logger.debug(f" {secondary_symbol} 1m: {len(secondary_df)} candles from replay")
|
||||
replay_df = replay_data.get('1m')
|
||||
if replay_df is not None and not replay_df.empty and len(replay_df) >= min_required_candles:
|
||||
secondary_df = replay_df
|
||||
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles from replay")
|
||||
except Exception as e:
|
||||
logger.debug(f" {secondary_symbol} 1m: Replay failed: {e}")
|
||||
|
||||
# Last resort: latest data
|
||||
if secondary_df is None or secondary_df.empty:
|
||||
logger.warning(f" {secondary_symbol} 1m: No historical data, using latest as fallback")
|
||||
secondary_df = self.data_provider.get_historical_data(
|
||||
symbol=secondary_symbol,
|
||||
timeframe='1m',
|
||||
limit=candles_per_timeframe
|
||||
)
|
||||
|
||||
# Store secondary symbol data
|
||||
# Validate and store secondary symbol data
|
||||
if secondary_df is not None and not secondary_df.empty:
|
||||
market_state['secondary_timeframes']['1m'] = {
|
||||
'timestamps': secondary_df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
'open': secondary_df['open'].tolist(),
|
||||
'high': secondary_df['high'].tolist(),
|
||||
'low': secondary_df['low'].tolist(),
|
||||
'close': secondary_df['close'].tolist(),
|
||||
'volume': secondary_df['volume'].tolist()
|
||||
}
|
||||
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles")
|
||||
if len(secondary_df) < min_required_candles:
|
||||
logger.warning(f" {secondary_symbol} 1m: Only {len(secondary_df)} candles (need {min_required_candles}), skipping")
|
||||
elif secondary_df.isnull().any().any():
|
||||
logger.warning(f" {secondary_symbol} 1m: Contains NaN values, skipping")
|
||||
elif not all(col in secondary_df.columns for col in ['open', 'high', 'low', 'close', 'volume']):
|
||||
logger.warning(f" {secondary_symbol} 1m: Missing required columns, skipping")
|
||||
else:
|
||||
# Store in the correct structure: secondary_timeframes[symbol][timeframe]
|
||||
if secondary_symbol not in market_state['secondary_timeframes']:
|
||||
market_state['secondary_timeframes'][secondary_symbol] = {}
|
||||
market_state['secondary_timeframes'][secondary_symbol]['1m'] = {
|
||||
'timestamps': secondary_df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
'open': secondary_df['open'].tolist(),
|
||||
'high': secondary_df['high'].tolist(),
|
||||
'low': secondary_df['low'].tolist(),
|
||||
'close': secondary_df['close'].tolist(),
|
||||
'volume': secondary_df['volume'].tolist()
|
||||
}
|
||||
logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles [OK]")
|
||||
else:
|
||||
logger.warning(f" {secondary_symbol} 1m: No data available")
|
||||
logger.warning(f" {secondary_symbol} 1m: No quality data available (need {min_required_candles} candles)")
|
||||
|
||||
# Verify we have data
|
||||
if market_state['timeframes']:
|
||||
@@ -1190,6 +1342,200 @@ class RealTrainingAdapter:
|
||||
state_size = agent.state_size if hasattr(agent, 'state_size') else 100
|
||||
return [0.0] * state_size
|
||||
|
||||
def _fetch_historical_from_api(self, symbol: str, timeframe: str, start_time: datetime, end_time: datetime, limit: int) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Fetch historical OHLCV data from exchange APIs for a specific time range
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe (e.g., '1m', '1h', '1d')
|
||||
start_time: Start timestamp (UTC)
|
||||
end_time: End timestamp (UTC)
|
||||
limit: Maximum number of candles to fetch
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data or None if fetch fails
|
||||
"""
|
||||
import pandas as pd
|
||||
import requests
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
try:
|
||||
# Handle 1s timeframe specially (not directly supported by most APIs)
|
||||
if timeframe == '1s':
|
||||
logger.debug(f"1s timeframe requested - will try to generate from ticks or skip")
|
||||
# For 1s, we might need to generate from tick data or skip
|
||||
# This is handled by the data provider's _generate_1s_candles_from_ticks
|
||||
# For now, return None and let the caller handle it
|
||||
return None
|
||||
|
||||
# Try Binance first (supports historical queries with startTime/endTime)
|
||||
try:
|
||||
binance_symbol = symbol.replace('/', '').upper()
|
||||
|
||||
# Convert timeframe for Binance
|
||||
timeframe_map = {
|
||||
'1m': '1m', '5m': '5m', '15m': '15m', '30m': '30m',
|
||||
'1h': '1h', '4h': '4h', '1d': '1d'
|
||||
}
|
||||
binance_timeframe = timeframe_map.get(timeframe)
|
||||
|
||||
if not binance_timeframe:
|
||||
logger.warning(f"Binance doesn't support timeframe {timeframe}")
|
||||
return None
|
||||
|
||||
# Binance API klines endpoint with startTime and endTime
|
||||
url = "https://api.binance.com/api/v3/klines"
|
||||
|
||||
# Convert timestamps to milliseconds
|
||||
start_ms = int(start_time.timestamp() * 1000)
|
||||
end_ms = int(end_time.timestamp() * 1000)
|
||||
|
||||
# Binance max is 1000 per request, so paginate if needed
|
||||
all_data = []
|
||||
current_start = start_ms
|
||||
max_per_request = 1000
|
||||
max_requests = 10 # Safety limit
|
||||
request_count = 0
|
||||
|
||||
while current_start < end_ms and request_count < max_requests:
|
||||
params = {
|
||||
'symbol': binance_symbol,
|
||||
'interval': binance_timeframe,
|
||||
'startTime': current_start,
|
||||
'endTime': end_ms,
|
||||
'limit': min(max_per_request, limit - len(all_data))
|
||||
}
|
||||
|
||||
logger.debug(f"Fetching from Binance: {symbol} {timeframe} batch {request_count + 1} (start: {current_start}, end: {end_ms})")
|
||||
response = requests.get(url, params=params, timeout=10)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data:
|
||||
all_data.extend(data)
|
||||
# Update current_start to the last candle's close_time + 1ms
|
||||
if len(data) > 0:
|
||||
last_close_time = data[-1][6] # close_time is at index 6
|
||||
current_start = last_close_time + 1
|
||||
else:
|
||||
break
|
||||
|
||||
# If we got less than requested, we've reached the end
|
||||
if len(data) < max_per_request:
|
||||
break
|
||||
|
||||
# If we have enough data, stop
|
||||
if len(all_data) >= limit:
|
||||
break
|
||||
else:
|
||||
break
|
||||
else:
|
||||
logger.debug(f"Binance API returned {response.status_code} for {symbol} {timeframe}")
|
||||
break
|
||||
|
||||
request_count += 1
|
||||
# Small delay to avoid rate limiting
|
||||
time.sleep(0.1)
|
||||
|
||||
if all_data:
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(all_data, columns=[
|
||||
'timestamp', 'open', 'high', 'low', 'close', 'volume',
|
||||
'close_time', 'quote_volume', 'trades', 'taker_buy_base',
|
||||
'taker_buy_quote', 'ignore'
|
||||
])
|
||||
|
||||
# Process columns
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||
df[col] = df[col].astype(float)
|
||||
|
||||
# Keep only OHLCV columns and set timestamp as index
|
||||
df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
|
||||
df = df.set_index('timestamp')
|
||||
df = df.sort_index()
|
||||
|
||||
# Remove duplicates and take last 'limit' candles
|
||||
df = df[~df.index.duplicated(keep='last')]
|
||||
df = df.tail(limit)
|
||||
|
||||
logger.info(f"Binance API: Fetched {len(df)} candles for {symbol} {timeframe} (historical, {request_count} requests)")
|
||||
return df
|
||||
else:
|
||||
logger.warning(f"Binance API returned no data for {symbol} {timeframe}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Binance fetch failed: {e}")
|
||||
|
||||
# Fallback to MEXC
|
||||
try:
|
||||
mexc_symbol = symbol.replace('/', '').upper()
|
||||
|
||||
timeframe_map = {
|
||||
'1m': '1m', '5m': '5m', '15m': '15m', '30m': '30m',
|
||||
'1h': '1h', '4h': '4h', '1d': '1d'
|
||||
}
|
||||
mexc_timeframe = timeframe_map.get(timeframe)
|
||||
|
||||
if not mexc_timeframe:
|
||||
logger.warning(f"MEXC doesn't support timeframe {timeframe}")
|
||||
return None
|
||||
|
||||
# MEXC API klines endpoint (may not support startTime/endTime, so fetch latest and filter)
|
||||
url = "https://api.mexc.com/api/v3/klines"
|
||||
params = {
|
||||
'symbol': mexc_symbol,
|
||||
'interval': mexc_timeframe,
|
||||
'limit': min(limit * 2, 1000) # Fetch more to account for filtering
|
||||
}
|
||||
|
||||
logger.debug(f"Fetching from MEXC: {symbol} {timeframe}")
|
||||
response = requests.get(url, params=params, timeout=10)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
|
||||
if data:
|
||||
df = pd.DataFrame(data, columns=[
|
||||
'timestamp', 'open', 'high', 'low', 'close', 'volume',
|
||||
'close_time', 'quote_volume'
|
||||
])
|
||||
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
|
||||
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||||
df[col] = df[col].astype(float)
|
||||
|
||||
df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
|
||||
df = df.set_index('timestamp')
|
||||
df = df.sort_index()
|
||||
|
||||
# Filter to time range
|
||||
df = df[(df.index >= start_time) & (df.index <= end_time)]
|
||||
df = df.tail(limit)
|
||||
|
||||
if len(df) > 0:
|
||||
logger.info(f"MEXC API: Fetched {len(df)} candles for {symbol} {timeframe} (historical, filtered)")
|
||||
return df
|
||||
else:
|
||||
logger.warning(f"MEXC API: No candles in time range for {symbol} {timeframe}")
|
||||
else:
|
||||
logger.warning(f"MEXC API returned empty data for {symbol} {timeframe}")
|
||||
else:
|
||||
logger.debug(f"MEXC API returned {response.status_code} for {symbol} {timeframe}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"MEXC fetch failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching historical data from API: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def _extract_timeframe_data(self, tf_data: Dict, target_seq_len: int = 600) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Extract and normalize OHLCV data from a single timeframe
|
||||
@@ -1215,28 +1561,22 @@ class RealTrainingAdapter:
|
||||
if len(closes) == 0:
|
||||
return None
|
||||
|
||||
# Take last target_seq_len candles or pad if needed
|
||||
if len(closes) >= target_seq_len:
|
||||
# Truncate to target length
|
||||
opens = opens[-target_seq_len:]
|
||||
highs = highs[-target_seq_len:]
|
||||
lows = lows[-target_seq_len:]
|
||||
closes = closes[-target_seq_len:]
|
||||
volumes = volumes[-target_seq_len:]
|
||||
else:
|
||||
# Pad with last candle
|
||||
pad_len = target_seq_len - len(closes)
|
||||
last_open = opens[-1] if len(opens) > 0 else 0.0
|
||||
last_high = highs[-1] if len(highs) > 0 else 0.0
|
||||
last_low = lows[-1] if len(lows) > 0 else 0.0
|
||||
last_close = closes[-1] if len(closes) > 0 else 0.0
|
||||
last_volume = volumes[-1] if len(volumes) > 0 else 0.0
|
||||
# REQUIRED: Must have exactly target_seq_len (600) candles, no padding allowed
|
||||
if len(closes) < target_seq_len:
|
||||
logger.warning(f"Insufficient candles: {len(closes)} < {target_seq_len} (required)")
|
||||
return None
|
||||
|
||||
opens = np.pad(opens, (0, pad_len), constant_values=last_open)
|
||||
highs = np.pad(highs, (0, pad_len), constant_values=last_high)
|
||||
lows = np.pad(lows, (0, pad_len), constant_values=last_low)
|
||||
closes = np.pad(closes, (0, pad_len), constant_values=last_close)
|
||||
volumes = np.pad(volumes, (0, pad_len), constant_values=last_volume)
|
||||
# Take last target_seq_len candles (exactly 600)
|
||||
opens = opens[-target_seq_len:]
|
||||
highs = highs[-target_seq_len:]
|
||||
lows = lows[-target_seq_len:]
|
||||
closes = closes[-target_seq_len:]
|
||||
volumes = volumes[-target_seq_len:]
|
||||
|
||||
# Validate we have exactly target_seq_len
|
||||
if len(closes) != target_seq_len:
|
||||
logger.warning(f"Extraction failed: got {len(closes)} candles, need {target_seq_len}")
|
||||
return None
|
||||
|
||||
# Stack OHLCV [seq_len, 5]
|
||||
ohlcv = np.stack([opens, highs, lows, closes, volumes], axis=-1)
|
||||
@@ -1369,45 +1709,62 @@ class RealTrainingAdapter:
|
||||
timeframes = market_state.get('timeframes', {})
|
||||
secondary_timeframes = market_state.get('secondary_timeframes', {})
|
||||
|
||||
# Target sequence length - RESTORED to 200 (memory leak fixed)
|
||||
# With 5 timeframes * 200 candles = 1000 sequence positions
|
||||
# Memory management fixes allow full sequence length
|
||||
target_seq_len = 200 # Restored to original
|
||||
for tf_data in timeframes.values():
|
||||
if tf_data and 'close' in tf_data and len(tf_data['close']) > 0:
|
||||
target_seq_len = min(len(tf_data['close']), 200) # Cap at 200
|
||||
break
|
||||
# REQUIRED: 600 candles per timeframe for transformer model
|
||||
target_seq_len = 600 # Must be 600 candles for each timeframe
|
||||
# Validate we have enough data in required timeframes (1m, 1h, 1d)
|
||||
required_tfs = ['1m', '1h', '1d']
|
||||
for tf_name in required_tfs:
|
||||
if tf_name in timeframes:
|
||||
tf_data = timeframes[tf_name]
|
||||
if tf_data and 'close' in tf_data:
|
||||
if len(tf_data['close']) < 600:
|
||||
logger.warning(f"Required timeframe {tf_name} has only {len(tf_data['close'])} candles (need 600)")
|
||||
return None
|
||||
|
||||
# Validate optional 1s timeframe if present (must have 600 candles if included)
|
||||
if '1s' in timeframes:
|
||||
tf_data = timeframes['1s']
|
||||
if tf_data and 'close' in tf_data:
|
||||
if len(tf_data['close']) < 600:
|
||||
logger.warning(f"Optional timeframe 1s has only {len(tf_data['close'])} candles (need 600), excluding it")
|
||||
# Remove 1s from timeframes if insufficient
|
||||
timeframes = {k: v for k, v in timeframes.items() if k != '1s'}
|
||||
|
||||
# Extract each timeframe (returns tuple: (tensor, norm_params) or None)
|
||||
# Store normalization parameters for each timeframe
|
||||
norm_params_dict = {}
|
||||
|
||||
# OPTIONAL: Extract 1s timeframe if available (must have 600 candles if included)
|
||||
result_1s = self._extract_timeframe_data(timeframes.get('1s', {}), target_seq_len) if '1s' in timeframes else None
|
||||
if result_1s:
|
||||
price_data_1s, norm_params_dict['1s'] = result_1s
|
||||
logger.debug(f"Included optional 1s timeframe with {result_1s[0].shape[1]} candles")
|
||||
else:
|
||||
# Don't fail on missing 1s data, it's often unavailable in annotations
|
||||
# 1s is optional - don't fail if missing, but log it
|
||||
price_data_1s = None
|
||||
logger.debug("Optional 1s timeframe not available (this is OK)")
|
||||
|
||||
# REQUIRED: Extract all 3 timeframes (1m, 1h, 1d) with exactly 600 candles each
|
||||
result_1m = self._extract_timeframe_data(timeframes.get('1m', {}), target_seq_len) if '1m' in timeframes else None
|
||||
if result_1m:
|
||||
price_data_1m, norm_params_dict['1m'] = result_1m
|
||||
else:
|
||||
# Warning: 1m data is critical
|
||||
logger.warning(f"Missing 1m data for transformer batch (sample: {training_sample.get('test_case_id')})")
|
||||
logger.warning(f"Missing or insufficient 1m data for transformer batch (sample: {training_sample.get('test_case_id')})")
|
||||
return None
|
||||
|
||||
result_1h = self._extract_timeframe_data(timeframes.get('1h', {}), target_seq_len) if '1h' in timeframes else None
|
||||
if result_1h:
|
||||
price_data_1h, norm_params_dict['1h'] = result_1h
|
||||
else:
|
||||
price_data_1h = None
|
||||
logger.warning(f"Missing or insufficient 1h data for transformer batch (sample: {training_sample.get('test_case_id')})")
|
||||
return None
|
||||
|
||||
result_1d = self._extract_timeframe_data(timeframes.get('1d', {}), target_seq_len) if '1d' in timeframes else None
|
||||
if result_1d:
|
||||
price_data_1d, norm_params_dict['1d'] = result_1d
|
||||
else:
|
||||
price_data_1d = None
|
||||
logger.warning(f"Missing or insufficient 1d data for transformer batch (sample: {training_sample.get('test_case_id')})")
|
||||
return None
|
||||
|
||||
# Extract BTC reference data
|
||||
btc_data_1m = None
|
||||
@@ -1416,12 +1773,41 @@ class RealTrainingAdapter:
|
||||
if result_btc:
|
||||
btc_data_1m, norm_params_dict['btc'] = result_btc
|
||||
|
||||
# Ensure at least one timeframe is available
|
||||
# Check if all are None (can't use any() with tensors)
|
||||
if price_data_1s is None and price_data_1m is None and price_data_1h is None and price_data_1d is None:
|
||||
logger.warning("No price data available in any timeframe")
|
||||
# CRITICAL: Ensure ALL required timeframes are available (1m, 1h, 1d)
|
||||
# REQUIRED: 1m, 1h, 1d (each with 600 candles)
|
||||
# OPTIONAL: 1s (if available, include with 600 candles)
|
||||
required_timeframes_present = (
|
||||
price_data_1m is not None and
|
||||
price_data_1h is not None and
|
||||
price_data_1d is not None
|
||||
)
|
||||
|
||||
if not required_timeframes_present:
|
||||
missing = []
|
||||
if price_data_1m is None:
|
||||
missing.append('1m')
|
||||
if price_data_1h is None:
|
||||
missing.append('1h')
|
||||
if price_data_1d is None:
|
||||
missing.append('1d')
|
||||
logger.warning(f"Missing required timeframes: {missing}. Need all 3: 1m, 1h, 1d")
|
||||
return None
|
||||
|
||||
# Validate each required timeframe has correct shape
|
||||
for tf_name, tf_data in [('1m', price_data_1m), ('1h', price_data_1h), ('1d', price_data_1d)]:
|
||||
if tf_data is not None:
|
||||
shape = tf_data.shape
|
||||
if len(shape) != 3 or shape[1] < 600:
|
||||
logger.warning(f"Timeframe {tf_name} has invalid shape {shape} (need [1, 600, 5])")
|
||||
return None
|
||||
|
||||
# Validate optional 1s timeframe if present
|
||||
if price_data_1s is not None:
|
||||
shape = price_data_1s.shape
|
||||
if len(shape) != 3 or shape[1] < 600:
|
||||
logger.warning(f"Optional timeframe 1s has invalid shape {shape}, removing it")
|
||||
price_data_1s = None
|
||||
|
||||
# Get reference timeframe for other features (prefer 1m, fallback to any available)
|
||||
ref_data = price_data_1m if price_data_1m is not None else (
|
||||
price_data_1h if price_data_1h is not None else (
|
||||
@@ -1928,6 +2314,46 @@ class RealTrainingAdapter:
|
||||
for i, data in enumerate(training_data):
|
||||
batch = self._convert_annotation_to_transformer_batch(data)
|
||||
if batch is not None:
|
||||
# CRITICAL: Validate that ALL required timeframes are present
|
||||
# REQUIRED: 1m, 1h, 1d (each with 600 candles)
|
||||
# OPTIONAL: 1s (if available, include with 600 candles)
|
||||
required_tf_keys = ['price_data_1m', 'price_data_1h', 'price_data_1d']
|
||||
optional_tf_keys = ['price_data_1s']
|
||||
|
||||
missing_tfs = [tf for tf in required_tf_keys if batch.get(tf) is None]
|
||||
|
||||
if missing_tfs:
|
||||
logger.warning(f" Skipping sample {i+1}: Missing required timeframes: {missing_tfs}")
|
||||
continue
|
||||
|
||||
# Validate each required timeframe has correct shape [1, 600, 5]
|
||||
for tf_key in required_tf_keys:
|
||||
tf_data = batch.get(tf_key)
|
||||
if tf_data is not None:
|
||||
if not isinstance(tf_data, torch.Tensor):
|
||||
logger.warning(f" Skipping sample {i+1}: {tf_key} is not a tensor")
|
||||
missing_tfs.append(tf_key)
|
||||
break
|
||||
shape = tf_data.shape
|
||||
if len(shape) != 3 or shape[1] < 600: # Must be [batch, seq_len, features] with seq_len >= 600
|
||||
logger.warning(f" Skipping sample {i+1}: {tf_key} has invalid shape {shape} (need [1, 600, 5])")
|
||||
missing_tfs.append(tf_key)
|
||||
break
|
||||
|
||||
if missing_tfs:
|
||||
continue
|
||||
|
||||
# Validate optional 1s timeframe if present
|
||||
if batch.get('price_data_1s') is not None:
|
||||
tf_data = batch.get('price_data_1s')
|
||||
if isinstance(tf_data, torch.Tensor):
|
||||
shape = tf_data.shape
|
||||
if len(shape) != 3 or shape[1] < 600:
|
||||
logger.warning(f" Sample {i+1}: price_data_1s has invalid shape {shape}, removing it")
|
||||
batch['price_data_1s'] = None
|
||||
|
||||
logger.debug(f" Sample {i+1}: All required timeframes present (1m, 1h, 1d), 1s={'present' if batch.get('price_data_1s') is not None else 'not available'}")
|
||||
|
||||
# Move batch to GPU immediately with pinned memory for faster transfer
|
||||
if use_gpu:
|
||||
batch_gpu = {}
|
||||
|
||||
@@ -832,7 +832,7 @@ class AnnotationDashboard:
|
||||
# Check if the specific model is already initialized
|
||||
if model_name == 'Transformer':
|
||||
logger.info("Checking Transformer model...")
|
||||
if self.orchestrator.primary_transformer:
|
||||
if hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer:
|
||||
self.loaded_models['Transformer'] = self.orchestrator.primary_transformer
|
||||
logger.info("Transformer model loaded successfully")
|
||||
else:
|
||||
@@ -841,7 +841,7 @@ class AnnotationDashboard:
|
||||
|
||||
elif model_name == 'CNN':
|
||||
logger.info("Checking CNN model...")
|
||||
if self.orchestrator.cnn_model:
|
||||
if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model:
|
||||
self.loaded_models['CNN'] = self.orchestrator.cnn_model
|
||||
logger.info("CNN model loaded successfully")
|
||||
else:
|
||||
|
||||
@@ -270,9 +270,9 @@ class DQNAgent:
|
||||
self.batch_size = batch_size
|
||||
self.target_update = target_update
|
||||
|
||||
# Set device for computation (default to GPU if available)
|
||||
# Set device for computation (read from config.yaml if available)
|
||||
if device is None:
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.device = self._get_device_from_config()
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
@@ -282,10 +282,6 @@ class DQNAgent:
|
||||
self.policy_net = DQNNetwork(self.state_dim, self.n_actions, config).to(self.device)
|
||||
self.target_net = DQNNetwork(self.state_dim, self.n_actions, config).to(self.device)
|
||||
|
||||
# Ensure models are on the correct device
|
||||
self.policy_net = self.policy_net.to(self.device)
|
||||
self.target_net = self.target_net.to(self.device)
|
||||
|
||||
# Initialize the target network with the same weights as the policy network
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
@@ -317,13 +313,92 @@ class DQNAgent:
|
||||
|
||||
# Market regime adaptation weights
|
||||
self.market_regime_weights = {
|
||||
'trending': 1.0,
|
||||
'sideways': 0.8,
|
||||
'volatile': 1.2,
|
||||
'bullish': 1.1,
|
||||
'bearish': 1.1
|
||||
'trending': 1.2, # Higher confidence in trending markets
|
||||
'ranging': 0.8, # Lower confidence in ranging markets
|
||||
'volatile': 0.6 # Much lower confidence in volatile markets
|
||||
}
|
||||
|
||||
# Additional initialization
|
||||
self.recent_actions = deque(maxlen=10)
|
||||
self.recent_prices = deque(maxlen=20)
|
||||
self.recent_rewards = deque(maxlen=100)
|
||||
|
||||
# Price direction tracking
|
||||
self.last_price_direction = {
|
||||
'direction': 0.0,
|
||||
'confidence': 0.0
|
||||
}
|
||||
|
||||
self.price_movement_memory = []
|
||||
self.losses = []
|
||||
self.no_improvement_count = 0
|
||||
self.confidence_history = []
|
||||
self.avg_confidence = 0.0
|
||||
self.max_confidence = 0.0
|
||||
self.min_confidence = 1.0
|
||||
|
||||
# Enhanced training features
|
||||
self.use_dueling = True
|
||||
self.use_prioritized_replay = priority_memory
|
||||
self.alpha = 0.6
|
||||
self.beta = 0.4
|
||||
self.beta_increment = 0.001
|
||||
self.use_double_dqn = True
|
||||
self.target_update_freq = target_update
|
||||
self.training_steps = 0
|
||||
self.gradient_clip_norm = 1.0
|
||||
self.epsilon_history = []
|
||||
self.td_errors = []
|
||||
|
||||
# Trade settings
|
||||
self.trade_action_fee = 0.0005
|
||||
self.minimum_action_confidence = 0.3
|
||||
|
||||
# Violent move detection
|
||||
self.price_history = []
|
||||
self.volatility_window = 20
|
||||
self.volatility_threshold = 0.0015
|
||||
self.post_violent_move = False
|
||||
self.violent_move_cooldown = 0
|
||||
|
||||
# Feature integration
|
||||
self.last_hidden_features = None
|
||||
self.feature_history = []
|
||||
self.realtime_tick_features = None
|
||||
self.tick_feature_weight = 0.3
|
||||
|
||||
# Mixed precision training
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
||||
self.use_mixed_precision = True
|
||||
self.scaler = torch.amp.GradScaler('cuda')
|
||||
logger.info("Mixed precision training enabled")
|
||||
else:
|
||||
self.use_mixed_precision = False
|
||||
logger.info("Mixed precision training disabled")
|
||||
|
||||
self.training = True
|
||||
|
||||
# Compatibility
|
||||
self.state_size = np.prod(state_shape)
|
||||
self.action_size = n_actions
|
||||
self.memory_size = buffer_size
|
||||
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3]
|
||||
|
||||
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
|
||||
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
|
||||
logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}")
|
||||
|
||||
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
||||
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
|
||||
|
||||
# Position management
|
||||
self.current_position = 0.0
|
||||
self.position_entry_price = 0.0
|
||||
self.position_entry_time = None
|
||||
self.entry_confidence_threshold = 0.35
|
||||
self.exit_confidence_threshold = 0.15
|
||||
self.uncertainty_threshold = 0.1
|
||||
|
||||
# Load best checkpoint if available
|
||||
if self.enable_checkpoints:
|
||||
self.load_best_checkpoint()
|
||||
@@ -332,113 +407,46 @@ class DQNAgent:
|
||||
if enable_checkpoints:
|
||||
logger.info(f"Model name: {model_name}, Checkpoint frequency: {self.checkpoint_frequency}")
|
||||
|
||||
# Add this line to the __init__ method
|
||||
self.recent_actions = deque(maxlen=10)
|
||||
self.recent_prices = deque(maxlen=20)
|
||||
self.recent_rewards = deque(maxlen=100)
|
||||
def _get_device_from_config(self) -> torch.device:
|
||||
"""Get device from config.yaml or auto-detect"""
|
||||
try:
|
||||
# Try to load config
|
||||
from core.config import get_config
|
||||
config = get_config()
|
||||
gpu_config = config._config.get('gpu', {})
|
||||
|
||||
# Price direction tracking - stores direction and confidence
|
||||
self.last_price_direction = {
|
||||
'direction': 0.0, # Single value between -1 and 1
|
||||
'confidence': 0.0 # Single value between 0 and 1
|
||||
}
|
||||
device_setting = gpu_config.get('device', 'auto')
|
||||
fallback_to_cpu = gpu_config.get('fallback_to_cpu', True)
|
||||
gpu_enabled = gpu_config.get('enabled', True)
|
||||
|
||||
# Store separate memory for price direction examples
|
||||
self.price_movement_memory = [] # For storing examples of clear price movements
|
||||
# If GPU is disabled in config, use CPU
|
||||
if not gpu_enabled:
|
||||
logger.info("GPU disabled in config.yaml, using CPU")
|
||||
return torch.device('cpu')
|
||||
|
||||
# Performance tracking
|
||||
self.losses = []
|
||||
self.no_improvement_count = 0
|
||||
# Handle device selection
|
||||
if device_setting == 'cpu':
|
||||
logger.info("Device set to CPU in config.yaml")
|
||||
return torch.device('cpu')
|
||||
elif device_setting == 'cuda' or device_setting == 'auto':
|
||||
# Try GPU first
|
||||
if torch.cuda.is_available():
|
||||
logger.info("Using GPU (CUDA available)")
|
||||
return torch.device('cuda')
|
||||
else:
|
||||
if fallback_to_cpu:
|
||||
logger.warning("CUDA not available, falling back to CPU")
|
||||
return torch.device('cpu')
|
||||
else:
|
||||
raise RuntimeError("CUDA not available and fallback_to_cpu is False")
|
||||
else:
|
||||
logger.warning(f"Unknown device setting '{device_setting}', using auto-detection")
|
||||
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Confidence tracking
|
||||
self.confidence_history = []
|
||||
self.avg_confidence = 0.0
|
||||
self.max_confidence = 0.0
|
||||
self.min_confidence = 1.0
|
||||
|
||||
# Enhanced features from EnhancedDQNAgent
|
||||
# Market adaptation capabilities
|
||||
self.market_regime_weights = {
|
||||
'trending': 1.2, # Higher confidence in trending markets
|
||||
'ranging': 0.8, # Lower confidence in ranging markets
|
||||
'volatile': 0.6 # Much lower confidence in volatile markets
|
||||
}
|
||||
|
||||
# Dueling network support (requires enhanced network architecture)
|
||||
self.use_dueling = True
|
||||
|
||||
# Prioritized experience replay parameters
|
||||
self.use_prioritized_replay = priority_memory
|
||||
self.alpha = 0.6 # Priority exponent
|
||||
self.beta = 0.4 # Importance sampling exponent
|
||||
self.beta_increment = 0.001
|
||||
|
||||
# Double DQN support
|
||||
self.use_double_dqn = True
|
||||
|
||||
# Enhanced training features from EnhancedDQNAgent
|
||||
self.target_update_freq = target_update # More descriptive name
|
||||
self.training_steps = 0
|
||||
self.gradient_clip_norm = 1.0 # Gradient clipping
|
||||
|
||||
# Enhanced statistics tracking
|
||||
self.epsilon_history = []
|
||||
self.td_errors = [] # Track TD errors for analysis
|
||||
|
||||
# Trade action fee and confidence thresholds
|
||||
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
|
||||
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
|
||||
|
||||
# Violent move detection
|
||||
self.price_history = []
|
||||
self.volatility_window = 20 # Window size for volatility calculation
|
||||
self.volatility_threshold = 0.0015 # Threshold for considering a move "violent"
|
||||
self.post_violent_move = False # Flag for recent violent move
|
||||
self.violent_move_cooldown = 0 # Cooldown after violent move
|
||||
|
||||
# Feature integration
|
||||
self.last_hidden_features = None # Store last extracted features
|
||||
self.feature_history = [] # Store history of features for analysis
|
||||
|
||||
# Real-time tick features integration
|
||||
self.realtime_tick_features = None # Latest tick features from tick processor
|
||||
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
|
||||
|
||||
# Check if mixed precision training should be used
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
||||
self.use_mixed_precision = True
|
||||
self.scaler = torch.amp.GradScaler('cuda')
|
||||
logger.info("Mixed precision training enabled")
|
||||
else:
|
||||
self.use_mixed_precision = False
|
||||
logger.info("Mixed precision training disabled")
|
||||
|
||||
# Track if we're in training mode
|
||||
self.training = True
|
||||
|
||||
# For compatibility with old code
|
||||
self.state_size = np.prod(state_shape)
|
||||
self.action_size = n_actions
|
||||
self.memory_size = buffer_size
|
||||
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] # Default timeframes
|
||||
|
||||
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
|
||||
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
|
||||
logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}")
|
||||
|
||||
# Log model parameters
|
||||
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
||||
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
|
||||
|
||||
# Position management for 2-action system
|
||||
self.current_position = 0.0 # -1 (short), 0 (neutral), 1 (long)
|
||||
self.position_entry_price = 0.0
|
||||
self.position_entry_time = None
|
||||
|
||||
# Different thresholds for entry vs exit decisions - AGGRESSIVE for more training data
|
||||
self.entry_confidence_threshold = 0.35 # Lower threshold for new positions (was 0.7)
|
||||
self.exit_confidence_threshold = 0.15 # Very low threshold for closing positions (was 0.3)
|
||||
self.uncertainty_threshold = 0.1 # When to stay neutral
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading device config: {e}, using auto-detection")
|
||||
# Fallback to auto-detection
|
||||
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def load_best_checkpoint(self):
|
||||
"""Load the best checkpoint for this DQN agent"""
|
||||
@@ -513,103 +521,6 @@ class DQNAgent:
|
||||
logger.error(f"Error saving DQN checkpoint: {e}")
|
||||
return False
|
||||
|
||||
# Store separate memory for price direction examples
|
||||
self.price_movement_memory = [] # For storing examples of clear price movements
|
||||
|
||||
# Performance tracking
|
||||
self.losses = []
|
||||
self.no_improvement_count = 0
|
||||
|
||||
# Confidence tracking
|
||||
self.confidence_history = []
|
||||
self.avg_confidence = 0.0
|
||||
self.max_confidence = 0.0
|
||||
self.min_confidence = 1.0
|
||||
|
||||
# Enhanced features from EnhancedDQNAgent
|
||||
# Market adaptation capabilities
|
||||
self.market_regime_weights = {
|
||||
'trending': 1.2, # Higher confidence in trending markets
|
||||
'ranging': 0.8, # Lower confidence in ranging markets
|
||||
'volatile': 0.6 # Much lower confidence in volatile markets
|
||||
}
|
||||
|
||||
# Dueling network support (requires enhanced network architecture)
|
||||
self.use_dueling = True
|
||||
|
||||
# Prioritized experience replay parameters
|
||||
self.use_prioritized_replay = priority_memory
|
||||
self.alpha = 0.6 # Priority exponent
|
||||
self.beta = 0.4 # Importance sampling exponent
|
||||
self.beta_increment = 0.001
|
||||
|
||||
# Double DQN support
|
||||
self.use_double_dqn = True
|
||||
|
||||
# Enhanced training features from EnhancedDQNAgent
|
||||
self.target_update_freq = target_update # More descriptive name
|
||||
self.training_steps = 0
|
||||
self.gradient_clip_norm = 1.0 # Gradient clipping
|
||||
|
||||
# Enhanced statistics tracking
|
||||
self.epsilon_history = []
|
||||
self.td_errors = [] # Track TD errors for analysis
|
||||
|
||||
# Trade action fee and confidence thresholds
|
||||
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
|
||||
self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5)
|
||||
|
||||
# Violent move detection
|
||||
self.price_history = []
|
||||
self.volatility_window = 20 # Window size for volatility calculation
|
||||
self.volatility_threshold = 0.0015 # Threshold for considering a move "violent"
|
||||
self.post_violent_move = False # Flag for recent violent move
|
||||
self.violent_move_cooldown = 0 # Cooldown after violent move
|
||||
|
||||
# Feature integration
|
||||
self.last_hidden_features = None # Store last extracted features
|
||||
self.feature_history = [] # Store history of features for analysis
|
||||
|
||||
# Real-time tick features integration
|
||||
self.realtime_tick_features = None # Latest tick features from tick processor
|
||||
self.tick_feature_weight = 0.3 # Weight for tick features in decision making
|
||||
|
||||
# Check if mixed precision training should be used
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
||||
self.use_mixed_precision = True
|
||||
self.scaler = torch.amp.GradScaler('cuda')
|
||||
logger.info("Mixed precision training enabled")
|
||||
else:
|
||||
self.use_mixed_precision = False
|
||||
logger.info("Mixed precision training disabled")
|
||||
|
||||
# Track if we're in training mode
|
||||
self.training = True
|
||||
|
||||
# For compatibility with old code
|
||||
self.state_size = np.prod(state_shape)
|
||||
self.action_size = n_actions
|
||||
self.memory_size = buffer_size
|
||||
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] # Default timeframes
|
||||
|
||||
logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}")
|
||||
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
|
||||
logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}")
|
||||
|
||||
# Log model parameters
|
||||
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
||||
logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters")
|
||||
|
||||
# Position management for 2-action system
|
||||
self.current_position = 0.0 # -1 (short), 0 (neutral), 1 (long)
|
||||
self.position_entry_price = 0.0
|
||||
self.position_entry_time = None
|
||||
|
||||
# Different thresholds for entry vs exit decisions - AGGRESSIVE for more training data
|
||||
self.entry_confidence_threshold = 0.35 # Lower threshold for new positions (was 0.7)
|
||||
self.exit_confidence_threshold = 0.15 # Very low threshold for closing positions (was 0.3)
|
||||
self.uncertainty_threshold = 0.1 # When to stay neutral
|
||||
|
||||
def move_models_to_device(self, device=None):
|
||||
"""Move models to the specified device (GPU/CPU)"""
|
||||
if device is not None:
|
||||
|
||||
@@ -181,3 +181,7 @@ This complements the earlier fixes:
|
||||
📊 **EXPECTED** - Better trend predictions after training
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -318,28 +318,63 @@ class TradingOrchestrator:
|
||||
# Initialize confidence threshold
|
||||
self.confidence_threshold = self.config.get('confidence_threshold', 0.6)
|
||||
|
||||
# Determine the device to use (GPU if available, else CPU)
|
||||
# Initialize device - force CPU mode to avoid CUDA errors
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
# Test CUDA availability with actual Linear layer operation
|
||||
# This catches architecture-specific issues like gfx1151 incompatibility
|
||||
test_tensor = torch.randn(2, 10).cuda()
|
||||
test_linear = torch.nn.Linear(10, 5).cuda()
|
||||
test_result = test_linear(test_tensor)
|
||||
logger.info(f"GPU compatibility test passed: {torch.cuda.get_device_name(0)}")
|
||||
self.device = torch.device("cuda")
|
||||
logger.info("CUDA/ROCm device initialized successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"CUDA/ROCm initialization failed: {e}")
|
||||
logger.warning("GPU architecture may not be supported - falling back to CPU")
|
||||
logger.warning("This is common with newer AMD GPUs (gfx1151+) that require specific PyTorch builds")
|
||||
self.device = torch.device("cpu")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
# Determine the device to use from config.yaml
|
||||
self.device = self._get_device_from_config()
|
||||
logger.info(f"Using device: {self.device}")
|
||||
|
||||
def _get_device_from_config(self) -> torch.device:
|
||||
"""Get device from config.yaml or auto-detect"""
|
||||
try:
|
||||
gpu_config = self.config._config.get('gpu', {})
|
||||
|
||||
device_setting = gpu_config.get('device', 'auto')
|
||||
fallback_to_cpu = gpu_config.get('fallback_to_cpu', True)
|
||||
gpu_enabled = gpu_config.get('enabled', True)
|
||||
|
||||
# If GPU is disabled in config, use CPU
|
||||
if not gpu_enabled:
|
||||
logger.info("GPU disabled in config.yaml, using CPU")
|
||||
return torch.device('cpu')
|
||||
|
||||
# Handle device selection
|
||||
if device_setting == 'cpu':
|
||||
logger.info("Device set to CPU in config.yaml")
|
||||
return torch.device('cpu')
|
||||
elif device_setting == 'cuda' or device_setting == 'auto':
|
||||
# Try GPU first with compatibility test
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
# Test CUDA availability with actual Linear layer operation
|
||||
# This catches architecture-specific issues like gfx1151 incompatibility
|
||||
test_tensor = torch.randn(2, 10).cuda()
|
||||
test_linear = torch.nn.Linear(10, 5).cuda()
|
||||
test_result = test_linear(test_tensor)
|
||||
logger.info(f"GPU compatibility test passed: {torch.cuda.get_device_name(0)}")
|
||||
logger.info("CUDA/ROCm device initialized successfully")
|
||||
return torch.device("cuda")
|
||||
except Exception as e:
|
||||
logger.warning(f"CUDA/ROCm initialization failed: {e}")
|
||||
logger.warning("GPU architecture may not be supported - falling back to CPU")
|
||||
logger.warning("This is common with newer AMD GPUs (gfx1151+) that require specific PyTorch builds")
|
||||
if fallback_to_cpu:
|
||||
return torch.device("cpu")
|
||||
else:
|
||||
raise RuntimeError("CUDA not available and fallback_to_cpu is False")
|
||||
else:
|
||||
if fallback_to_cpu:
|
||||
logger.warning("CUDA not available, falling back to CPU")
|
||||
return torch.device('cpu')
|
||||
else:
|
||||
raise RuntimeError("CUDA not available and fallback_to_cpu is False")
|
||||
else:
|
||||
logger.warning(f"Unknown device setting '{device_setting}', using auto-detection")
|
||||
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading device config: {e}, using auto-detection")
|
||||
# Fallback to auto-detection
|
||||
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Canonical model name aliases to eliminate ambiguity across UI/DB/FS
|
||||
# Canonical → accepted aliases (internal/legacy)
|
||||
self.model_name_aliases: Dict[str, list] = {
|
||||
|
||||
Reference in New Issue
Block a user