diff --git a/ANNOTATE/core/real_training_adapter.py b/ANNOTATE/core/real_training_adapter.py index 0dc5750..a95eef0 100644 --- a/ANNOTATE/core/real_training_adapter.py +++ b/ANNOTATE/core/real_training_adapter.py @@ -23,6 +23,7 @@ from pathlib import Path import torch import numpy as np +import pandas as pd try: import pytz @@ -393,10 +394,23 @@ class RealTrainingAdapter: # Get training config training_config = test_case.get('training_config', {}) - timeframes = training_config.get('timeframes', ['1s', '1m', '1h', '1d']) - # RESTORED: 200 candles per timeframe (memory leak fixed) - # With 5 timeframes * 200 candles = 1000 total positions - candles_per_timeframe = training_config.get('candles_per_timeframe', 200) # 200 candles per batch + # REQUIRED: All 3 timeframes (1m, 1h, 1d) with 600 candles each + timeframes = training_config.get('timeframes', ['1m', '1h', '1d']) + # REQUIRED: 600 candles per timeframe for transformer model + candles_per_timeframe = training_config.get('candles_per_timeframe', 600) # 600 candles per timeframe + + # REQUIRED: 1m, 1h, 1d (all with 600 candles each) + # OPTIONAL: 1s (if available, include with 600 candles) + required_timeframes = ['1m', '1h', '1d'] + optional_timeframes = ['1s'] # Include if available + + # Ensure required timeframes are in the list + missing_tfs = [tf for tf in required_timeframes if tf not in timeframes] + if missing_tfs: + logger.warning(f" Missing required timeframes: {missing_tfs}, adding them...") + timeframes = list(set(timeframes + required_timeframes)) + + # Note: 1s is optional, don't add it if not present # Determine secondary symbol based on primary symbol # ETH/SOL -> BTC, BTC -> ETH @@ -408,22 +422,30 @@ class RealTrainingAdapter: logger.info(f" Candles per batch: {candles_per_timeframe}") # Calculate time range based on candles needed - # For 600 candles at 1m = 600 minutes = 10 hours + # Use timeframe-specific windows for better efficiency from datetime import timedelta - # Calculate time window for each timeframe to get 600 candles + # Calculate time window for each timeframe to get required candles + # Add 20% buffer to account for missing candles + buffer_multiplier = 1.2 time_windows = { - '1s': timedelta(seconds=candles_per_timeframe), # 600 seconds = 10 minutes - '1m': timedelta(minutes=candles_per_timeframe), # 600 minutes = 10 hours - '1h': timedelta(hours=candles_per_timeframe), # 600 hours = 25 days - '1d': timedelta(days=candles_per_timeframe) # 600 days = ~1.6 years + '1s': timedelta(seconds=int(candles_per_timeframe * buffer_multiplier)), + '1m': timedelta(minutes=int(candles_per_timeframe * buffer_multiplier)), + '1h': timedelta(hours=int(candles_per_timeframe * buffer_multiplier)), + '1d': timedelta(days=int(candles_per_timeframe * buffer_multiplier)) } + # For historical queries, we want data BEFORE the timestamp # Use the largest window to ensure we have enough data for all timeframes max_window = max(time_windows.values()) start_time = timestamp - max_window end_time = timestamp + # REQUIRED: All required timeframes must have exactly 600 candles (no tolerance for missing data) + min_required_candles = candles_per_timeframe # Must have full 600 candles + required_timeframes = ['1m', '1h', '1d'] # All 3 timeframes are mandatory + optional_timeframes = ['1s'] # Include if available + # Fetch data for primary symbol (all timeframes) and secondary symbol (1m only) market_state = { 'symbol': symbol, @@ -439,53 +461,126 @@ class RealTrainingAdapter: duckdb_storage = self.data_provider.duckdb_storage # Fetch primary symbol data (all timeframes) + # REQUIRED: Must fetch all 3 timeframes (1m, 1h, 1d) with 600 candles each logger.info(f" Fetching primary symbol data: {symbol}") + logger.info(f" REQUIRED timeframes: {required_timeframes} (each with {candles_per_timeframe} candles)") + + fetched_timeframes = {} # Track which timeframes we successfully fetched + for timeframe in timeframes: + # Fetch required timeframes (1m, 1h, 1d) and optional 1s if present + if timeframe not in required_timeframes and timeframe not in optional_timeframes: + continue df = None - limit = candles_per_timeframe # Always fetch 600 candles + limit = candles_per_timeframe + + # Use timeframe-specific window for better efficiency + tf_window = time_windows.get(timeframe, max_window) + tf_start_time = timestamp - tf_window + tf_end_time = timestamp # Try DuckDB storage first (has historical data) + # Use 'before' direction to get data BEFORE the timestamp if duckdb_storage: try: df = duckdb_storage.get_ohlcv_data( symbol=symbol, timeframe=timeframe, - start_time=start_time, - end_time=end_time, + start_time=tf_start_time, + end_time=tf_end_time, limit=limit, - direction='latest' + direction='before' # Get data BEFORE timestamp for historical training ) if df is not None and not df.empty: - logger.debug(f" {timeframe}: {len(df)} candles from DuckDB (historical)") + logger.debug(f" {timeframe}: {len(df)} candles from DuckDB (historical, before {timestamp})") except Exception as e: logger.debug(f" {timeframe}: DuckDB query failed: {e}") - # Fallback to data_provider (might have cached data) - if df is None or df.empty: + # If DuckDB doesn't have enough data, try API with proper time range + if df is None or df.empty or len(df) < min_required_candles: + try: + # Try to fetch from API with historical time range + logger.info(f" {timeframe}: DuckDB insufficient ({len(df) if df is not None else 0} candles), fetching from API for timestamp {timestamp}...") + + # Fetch historical data from API for the specific time range + api_df = self._fetch_historical_from_api( + symbol=symbol, + timeframe=timeframe, + start_time=tf_start_time, + end_time=tf_end_time, + limit=limit + ) + + if api_df is not None and not api_df.empty: + # Filter to data before timestamp (historical training needs data BEFORE the event) + try: + api_df = api_df[api_df.index <= tf_end_time] + # Take the most recent candles up to limit + api_df = api_df.tail(limit) + + if len(api_df) >= min_required_candles: + df = api_df + logger.info(f" {timeframe}: {len(df)} candles from API (historical range: {tf_start_time} to {tf_end_time})") + + # Store in DuckDB for future use + if duckdb_storage: + try: + duckdb_storage.store_ohlcv_data(symbol, timeframe, df) + logger.debug(f" {timeframe}: Stored {len(df)} candles in DuckDB for future use") + except Exception as e: + logger.debug(f" {timeframe}: Could not store in DuckDB: {e}") + else: + logger.warning(f" {timeframe}: API returned only {len(api_df)} candles after filtering (need {min_required_candles})") + except Exception as e: + logger.debug(f" {timeframe}: Could not filter API data: {e}") + # Use as-is if filtering fails + if len(api_df) >= min_required_candles: + df = api_df + logger.info(f" {timeframe}: {len(df)} candles from API (unfiltered)") + else: + logger.warning(f" {timeframe}: API fetch returned no data") + except Exception as e: + logger.warning(f" {timeframe}: API fetch failed: {e}") + import traceback + logger.debug(traceback.format_exc()) + + # Fallback to replay method + if df is None or df.empty or len(df) < min_required_candles: try: - # Use get_historical_data_replay for time-specific data replay_data = self.data_provider.get_historical_data_replay( symbol=symbol, - start_time=start_time, - end_time=end_time, + start_time=tf_start_time, + end_time=tf_end_time, timeframes=[timeframe] ) - df = replay_data.get(timeframe) - if df is not None and not df.empty: - logger.debug(f" {timeframe}: {len(df)} candles from replay") + replay_df = replay_data.get(timeframe) + if replay_df is not None and not replay_df.empty and len(replay_df) >= min_required_candles: + df = replay_df + logger.info(f" {timeframe}: {len(df)} candles from replay") except Exception as e: logger.debug(f" {timeframe}: Replay failed: {e}") - # Last resort: get latest data (not ideal but better than nothing) - if df is None or df.empty: - logger.warning(f" {timeframe}: No historical data found, using latest data as fallback") - df = self.data_provider.get_historical_data( - symbol=symbol, - timeframe=timeframe, - limit=limit # Use calculated limit - ) - + # Validate data quality before storing if df is not None and not df.empty: + # Check minimum candle count + if len(df) < min_required_candles: + logger.warning(f" {symbol} {timeframe}: Only {len(df)} candles (need {min_required_candles}), skipping") + continue + + # Validate data quality - check for NaN values + if df.isnull().any().any(): + logger.warning(f" {symbol} {timeframe}: Contains NaN values, cleaning...") + df = df.dropna() + if len(df) < min_required_candles: + logger.warning(f" {symbol} {timeframe}: After cleaning, only {len(df)} candles, skipping") + continue + + # Ensure we have required columns + required_cols = ['open', 'high', 'low', 'close', 'volume'] + if not all(col in df.columns for col in required_cols): + logger.warning(f" {symbol} {timeframe}: Missing required columns, skipping") + continue + # Convert to dict format market_state['timeframes'][timeframe] = { 'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(), @@ -495,67 +590,124 @@ class RealTrainingAdapter: 'close': df['close'].tolist(), 'volume': df['volume'].tolist() } - logger.info(f" {symbol} {timeframe}: {len(df)} candles") + fetched_timeframes[timeframe] = len(df) + logger.info(f" {symbol} {timeframe}: {len(df)} candles [OK]") else: - logger.warning(f" {symbol} {timeframe}: No data available") + logger.warning(f" {symbol} {timeframe}: No quality data available (need {min_required_candles} candles)") - # Fetch secondary symbol data (1m timeframe only, 600 candles) + # CRITICAL: Validate we have all required timeframes + missing_required = [tf for tf in required_timeframes if tf not in fetched_timeframes] + if missing_required: + logger.error(f" FAILED: Missing required timeframes: {missing_required}") + logger.error(f" Fetched: {list(fetched_timeframes.keys())}") + logger.error(f" Cannot proceed without all required timeframes") + return {} # Return empty dict to signal failure + + # Fetch secondary symbol data (1m timeframe only) logger.info(f" Fetching secondary symbol data: {secondary_symbol} (1m)") secondary_df = None - # Try DuckDB first + # Use 1m-specific window + tf_window = time_windows.get('1m', max_window) + tf_start_time = timestamp - tf_window + tf_end_time = timestamp + + # Try DuckDB first with 'before' direction if duckdb_storage: try: secondary_df = duckdb_storage.get_ohlcv_data( symbol=secondary_symbol, timeframe='1m', - start_time=start_time, - end_time=end_time, + start_time=tf_start_time, + end_time=tf_end_time, limit=candles_per_timeframe, - direction='latest' + direction='before' # Get data BEFORE timestamp ) if secondary_df is not None and not secondary_df.empty: logger.debug(f" {secondary_symbol} 1m: {len(secondary_df)} candles from DuckDB") except Exception as e: logger.debug(f" {secondary_symbol} 1m: DuckDB query failed: {e}") + # If DuckDB doesn't have enough, try API with historical time range + if secondary_df is None or secondary_df.empty or len(secondary_df) < min_required_candles: + try: + logger.info(f" {secondary_symbol} 1m: DuckDB insufficient ({len(secondary_df) if secondary_df is not None else 0} candles), fetching from API for timestamp {timestamp}...") + + # Fetch historical data from API for the specific time range + api_df = self._fetch_historical_from_api( + symbol=secondary_symbol, + timeframe='1m', + start_time=tf_start_time, + end_time=tf_end_time, + limit=candles_per_timeframe + ) + + if api_df is not None and not api_df.empty: + # Filter to data before timestamp + try: + api_df = api_df[api_df.index <= tf_end_time].tail(candles_per_timeframe) + if len(api_df) >= min_required_candles: + secondary_df = api_df + logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles from API (historical range)") + + # Store in DuckDB for future use + if duckdb_storage: + try: + duckdb_storage.store_ohlcv_data(secondary_symbol, '1m', secondary_df) + logger.debug(f" {secondary_symbol} 1m: Stored in DuckDB for future use") + except Exception as e: + logger.debug(f" {secondary_symbol} 1m: Could not store in DuckDB: {e}") + except Exception as e: + logger.debug(f" {secondary_symbol} 1m: Could not filter API data: {e}") + if len(api_df) >= min_required_candles: + secondary_df = api_df + logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles from API (unfiltered)") + else: + logger.warning(f" {secondary_symbol} 1m: API fetch returned no data") + except Exception as e: + logger.warning(f" {secondary_symbol} 1m: API fetch failed: {e}") + import traceback + logger.debug(traceback.format_exc()) + # Fallback to replay - if secondary_df is None or secondary_df.empty: + if secondary_df is None or secondary_df.empty or len(secondary_df) < min_required_candles: try: replay_data = self.data_provider.get_historical_data_replay( symbol=secondary_symbol, - start_time=start_time, - end_time=end_time, + start_time=tf_start_time, + end_time=tf_end_time, timeframes=['1m'] ) - secondary_df = replay_data.get('1m') - if secondary_df is not None and not secondary_df.empty: - logger.debug(f" {secondary_symbol} 1m: {len(secondary_df)} candles from replay") + replay_df = replay_data.get('1m') + if replay_df is not None and not replay_df.empty and len(replay_df) >= min_required_candles: + secondary_df = replay_df + logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles from replay") except Exception as e: logger.debug(f" {secondary_symbol} 1m: Replay failed: {e}") - # Last resort: latest data - if secondary_df is None or secondary_df.empty: - logger.warning(f" {secondary_symbol} 1m: No historical data, using latest as fallback") - secondary_df = self.data_provider.get_historical_data( - symbol=secondary_symbol, - timeframe='1m', - limit=candles_per_timeframe - ) - - # Store secondary symbol data + # Validate and store secondary symbol data if secondary_df is not None and not secondary_df.empty: - market_state['secondary_timeframes']['1m'] = { - 'timestamps': secondary_df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(), - 'open': secondary_df['open'].tolist(), - 'high': secondary_df['high'].tolist(), - 'low': secondary_df['low'].tolist(), - 'close': secondary_df['close'].tolist(), - 'volume': secondary_df['volume'].tolist() - } - logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles") + if len(secondary_df) < min_required_candles: + logger.warning(f" {secondary_symbol} 1m: Only {len(secondary_df)} candles (need {min_required_candles}), skipping") + elif secondary_df.isnull().any().any(): + logger.warning(f" {secondary_symbol} 1m: Contains NaN values, skipping") + elif not all(col in secondary_df.columns for col in ['open', 'high', 'low', 'close', 'volume']): + logger.warning(f" {secondary_symbol} 1m: Missing required columns, skipping") + else: + # Store in the correct structure: secondary_timeframes[symbol][timeframe] + if secondary_symbol not in market_state['secondary_timeframes']: + market_state['secondary_timeframes'][secondary_symbol] = {} + market_state['secondary_timeframes'][secondary_symbol]['1m'] = { + 'timestamps': secondary_df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(), + 'open': secondary_df['open'].tolist(), + 'high': secondary_df['high'].tolist(), + 'low': secondary_df['low'].tolist(), + 'close': secondary_df['close'].tolist(), + 'volume': secondary_df['volume'].tolist() + } + logger.info(f" {secondary_symbol} 1m: {len(secondary_df)} candles [OK]") else: - logger.warning(f" {secondary_symbol} 1m: No data available") + logger.warning(f" {secondary_symbol} 1m: No quality data available (need {min_required_candles} candles)") # Verify we have data if market_state['timeframes']: @@ -1190,6 +1342,200 @@ class RealTrainingAdapter: state_size = agent.state_size if hasattr(agent, 'state_size') else 100 return [0.0] * state_size + def _fetch_historical_from_api(self, symbol: str, timeframe: str, start_time: datetime, end_time: datetime, limit: int) -> Optional[pd.DataFrame]: + """ + Fetch historical OHLCV data from exchange APIs for a specific time range + + Args: + symbol: Trading symbol (e.g., 'ETH/USDT') + timeframe: Timeframe (e.g., '1m', '1h', '1d') + start_time: Start timestamp (UTC) + end_time: End timestamp (UTC) + limit: Maximum number of candles to fetch + + Returns: + DataFrame with OHLCV data or None if fetch fails + """ + import pandas as pd + import requests + import time + from datetime import datetime, timezone + + try: + # Handle 1s timeframe specially (not directly supported by most APIs) + if timeframe == '1s': + logger.debug(f"1s timeframe requested - will try to generate from ticks or skip") + # For 1s, we might need to generate from tick data or skip + # This is handled by the data provider's _generate_1s_candles_from_ticks + # For now, return None and let the caller handle it + return None + + # Try Binance first (supports historical queries with startTime/endTime) + try: + binance_symbol = symbol.replace('/', '').upper() + + # Convert timeframe for Binance + timeframe_map = { + '1m': '1m', '5m': '5m', '15m': '15m', '30m': '30m', + '1h': '1h', '4h': '4h', '1d': '1d' + } + binance_timeframe = timeframe_map.get(timeframe) + + if not binance_timeframe: + logger.warning(f"Binance doesn't support timeframe {timeframe}") + return None + + # Binance API klines endpoint with startTime and endTime + url = "https://api.binance.com/api/v3/klines" + + # Convert timestamps to milliseconds + start_ms = int(start_time.timestamp() * 1000) + end_ms = int(end_time.timestamp() * 1000) + + # Binance max is 1000 per request, so paginate if needed + all_data = [] + current_start = start_ms + max_per_request = 1000 + max_requests = 10 # Safety limit + request_count = 0 + + while current_start < end_ms and request_count < max_requests: + params = { + 'symbol': binance_symbol, + 'interval': binance_timeframe, + 'startTime': current_start, + 'endTime': end_ms, + 'limit': min(max_per_request, limit - len(all_data)) + } + + logger.debug(f"Fetching from Binance: {symbol} {timeframe} batch {request_count + 1} (start: {current_start}, end: {end_ms})") + response = requests.get(url, params=params, timeout=10) + + if response.status_code == 200: + data = response.json() + if data: + all_data.extend(data) + # Update current_start to the last candle's close_time + 1ms + if len(data) > 0: + last_close_time = data[-1][6] # close_time is at index 6 + current_start = last_close_time + 1 + else: + break + + # If we got less than requested, we've reached the end + if len(data) < max_per_request: + break + + # If we have enough data, stop + if len(all_data) >= limit: + break + else: + break + else: + logger.debug(f"Binance API returned {response.status_code} for {symbol} {timeframe}") + break + + request_count += 1 + # Small delay to avoid rate limiting + time.sleep(0.1) + + if all_data: + # Convert to DataFrame + df = pd.DataFrame(all_data, columns=[ + 'timestamp', 'open', 'high', 'low', 'close', 'volume', + 'close_time', 'quote_volume', 'trades', 'taker_buy_base', + 'taker_buy_quote', 'ignore' + ]) + + # Process columns + df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True) + for col in ['open', 'high', 'low', 'close', 'volume']: + df[col] = df[col].astype(float) + + # Keep only OHLCV columns and set timestamp as index + df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']] + df = df.set_index('timestamp') + df = df.sort_index() + + # Remove duplicates and take last 'limit' candles + df = df[~df.index.duplicated(keep='last')] + df = df.tail(limit) + + logger.info(f"Binance API: Fetched {len(df)} candles for {symbol} {timeframe} (historical, {request_count} requests)") + return df + else: + logger.warning(f"Binance API returned no data for {symbol} {timeframe}") + + except Exception as e: + logger.debug(f"Binance fetch failed: {e}") + + # Fallback to MEXC + try: + mexc_symbol = symbol.replace('/', '').upper() + + timeframe_map = { + '1m': '1m', '5m': '5m', '15m': '15m', '30m': '30m', + '1h': '1h', '4h': '4h', '1d': '1d' + } + mexc_timeframe = timeframe_map.get(timeframe) + + if not mexc_timeframe: + logger.warning(f"MEXC doesn't support timeframe {timeframe}") + return None + + # MEXC API klines endpoint (may not support startTime/endTime, so fetch latest and filter) + url = "https://api.mexc.com/api/v3/klines" + params = { + 'symbol': mexc_symbol, + 'interval': mexc_timeframe, + 'limit': min(limit * 2, 1000) # Fetch more to account for filtering + } + + logger.debug(f"Fetching from MEXC: {symbol} {timeframe}") + response = requests.get(url, params=params, timeout=10) + + if response.status_code == 200: + data = response.json() + + if data: + df = pd.DataFrame(data, columns=[ + 'timestamp', 'open', 'high', 'low', 'close', 'volume', + 'close_time', 'quote_volume' + ]) + + df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True) + for col in ['open', 'high', 'low', 'close', 'volume']: + df[col] = df[col].astype(float) + + df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']] + df = df.set_index('timestamp') + df = df.sort_index() + + # Filter to time range + df = df[(df.index >= start_time) & (df.index <= end_time)] + df = df.tail(limit) + + if len(df) > 0: + logger.info(f"MEXC API: Fetched {len(df)} candles for {symbol} {timeframe} (historical, filtered)") + return df + else: + logger.warning(f"MEXC API: No candles in time range for {symbol} {timeframe}") + else: + logger.warning(f"MEXC API returned empty data for {symbol} {timeframe}") + else: + logger.debug(f"MEXC API returned {response.status_code} for {symbol} {timeframe}") + + except Exception as e: + logger.debug(f"MEXC fetch failed: {e}") + + return None + + except Exception as e: + logger.error(f"Error fetching historical data from API: {e}") + import traceback + logger.debug(traceback.format_exc()) + return None + def _extract_timeframe_data(self, tf_data: Dict, target_seq_len: int = 600) -> Optional[torch.Tensor]: """ Extract and normalize OHLCV data from a single timeframe @@ -1215,28 +1561,22 @@ class RealTrainingAdapter: if len(closes) == 0: return None - # Take last target_seq_len candles or pad if needed - if len(closes) >= target_seq_len: - # Truncate to target length - opens = opens[-target_seq_len:] - highs = highs[-target_seq_len:] - lows = lows[-target_seq_len:] - closes = closes[-target_seq_len:] - volumes = volumes[-target_seq_len:] - else: - # Pad with last candle - pad_len = target_seq_len - len(closes) - last_open = opens[-1] if len(opens) > 0 else 0.0 - last_high = highs[-1] if len(highs) > 0 else 0.0 - last_low = lows[-1] if len(lows) > 0 else 0.0 - last_close = closes[-1] if len(closes) > 0 else 0.0 - last_volume = volumes[-1] if len(volumes) > 0 else 0.0 - - opens = np.pad(opens, (0, pad_len), constant_values=last_open) - highs = np.pad(highs, (0, pad_len), constant_values=last_high) - lows = np.pad(lows, (0, pad_len), constant_values=last_low) - closes = np.pad(closes, (0, pad_len), constant_values=last_close) - volumes = np.pad(volumes, (0, pad_len), constant_values=last_volume) + # REQUIRED: Must have exactly target_seq_len (600) candles, no padding allowed + if len(closes) < target_seq_len: + logger.warning(f"Insufficient candles: {len(closes)} < {target_seq_len} (required)") + return None + + # Take last target_seq_len candles (exactly 600) + opens = opens[-target_seq_len:] + highs = highs[-target_seq_len:] + lows = lows[-target_seq_len:] + closes = closes[-target_seq_len:] + volumes = volumes[-target_seq_len:] + + # Validate we have exactly target_seq_len + if len(closes) != target_seq_len: + logger.warning(f"Extraction failed: got {len(closes)} candles, need {target_seq_len}") + return None # Stack OHLCV [seq_len, 5] ohlcv = np.stack([opens, highs, lows, closes, volumes], axis=-1) @@ -1369,45 +1709,62 @@ class RealTrainingAdapter: timeframes = market_state.get('timeframes', {}) secondary_timeframes = market_state.get('secondary_timeframes', {}) - # Target sequence length - RESTORED to 200 (memory leak fixed) - # With 5 timeframes * 200 candles = 1000 sequence positions - # Memory management fixes allow full sequence length - target_seq_len = 200 # Restored to original - for tf_data in timeframes.values(): - if tf_data and 'close' in tf_data and len(tf_data['close']) > 0: - target_seq_len = min(len(tf_data['close']), 200) # Cap at 200 - break + # REQUIRED: 600 candles per timeframe for transformer model + target_seq_len = 600 # Must be 600 candles for each timeframe + # Validate we have enough data in required timeframes (1m, 1h, 1d) + required_tfs = ['1m', '1h', '1d'] + for tf_name in required_tfs: + if tf_name in timeframes: + tf_data = timeframes[tf_name] + if tf_data and 'close' in tf_data: + if len(tf_data['close']) < 600: + logger.warning(f"Required timeframe {tf_name} has only {len(tf_data['close'])} candles (need 600)") + return None + + # Validate optional 1s timeframe if present (must have 600 candles if included) + if '1s' in timeframes: + tf_data = timeframes['1s'] + if tf_data and 'close' in tf_data: + if len(tf_data['close']) < 600: + logger.warning(f"Optional timeframe 1s has only {len(tf_data['close'])} candles (need 600), excluding it") + # Remove 1s from timeframes if insufficient + timeframes = {k: v for k, v in timeframes.items() if k != '1s'} # Extract each timeframe (returns tuple: (tensor, norm_params) or None) # Store normalization parameters for each timeframe norm_params_dict = {} + # OPTIONAL: Extract 1s timeframe if available (must have 600 candles if included) result_1s = self._extract_timeframe_data(timeframes.get('1s', {}), target_seq_len) if '1s' in timeframes else None if result_1s: price_data_1s, norm_params_dict['1s'] = result_1s + logger.debug(f"Included optional 1s timeframe with {result_1s[0].shape[1]} candles") else: - # Don't fail on missing 1s data, it's often unavailable in annotations + # 1s is optional - don't fail if missing, but log it price_data_1s = None + logger.debug("Optional 1s timeframe not available (this is OK)") + # REQUIRED: Extract all 3 timeframes (1m, 1h, 1d) with exactly 600 candles each result_1m = self._extract_timeframe_data(timeframes.get('1m', {}), target_seq_len) if '1m' in timeframes else None if result_1m: price_data_1m, norm_params_dict['1m'] = result_1m else: - # Warning: 1m data is critical - logger.warning(f"Missing 1m data for transformer batch (sample: {training_sample.get('test_case_id')})") + logger.warning(f"Missing or insufficient 1m data for transformer batch (sample: {training_sample.get('test_case_id')})") return None result_1h = self._extract_timeframe_data(timeframes.get('1h', {}), target_seq_len) if '1h' in timeframes else None if result_1h: price_data_1h, norm_params_dict['1h'] = result_1h else: - price_data_1h = None + logger.warning(f"Missing or insufficient 1h data for transformer batch (sample: {training_sample.get('test_case_id')})") + return None result_1d = self._extract_timeframe_data(timeframes.get('1d', {}), target_seq_len) if '1d' in timeframes else None if result_1d: price_data_1d, norm_params_dict['1d'] = result_1d else: - price_data_1d = None + logger.warning(f"Missing or insufficient 1d data for transformer batch (sample: {training_sample.get('test_case_id')})") + return None # Extract BTC reference data btc_data_1m = None @@ -1416,12 +1773,41 @@ class RealTrainingAdapter: if result_btc: btc_data_1m, norm_params_dict['btc'] = result_btc - # Ensure at least one timeframe is available - # Check if all are None (can't use any() with tensors) - if price_data_1s is None and price_data_1m is None and price_data_1h is None and price_data_1d is None: - logger.warning("No price data available in any timeframe") + # CRITICAL: Ensure ALL required timeframes are available (1m, 1h, 1d) + # REQUIRED: 1m, 1h, 1d (each with 600 candles) + # OPTIONAL: 1s (if available, include with 600 candles) + required_timeframes_present = ( + price_data_1m is not None and + price_data_1h is not None and + price_data_1d is not None + ) + + if not required_timeframes_present: + missing = [] + if price_data_1m is None: + missing.append('1m') + if price_data_1h is None: + missing.append('1h') + if price_data_1d is None: + missing.append('1d') + logger.warning(f"Missing required timeframes: {missing}. Need all 3: 1m, 1h, 1d") return None + # Validate each required timeframe has correct shape + for tf_name, tf_data in [('1m', price_data_1m), ('1h', price_data_1h), ('1d', price_data_1d)]: + if tf_data is not None: + shape = tf_data.shape + if len(shape) != 3 or shape[1] < 600: + logger.warning(f"Timeframe {tf_name} has invalid shape {shape} (need [1, 600, 5])") + return None + + # Validate optional 1s timeframe if present + if price_data_1s is not None: + shape = price_data_1s.shape + if len(shape) != 3 or shape[1] < 600: + logger.warning(f"Optional timeframe 1s has invalid shape {shape}, removing it") + price_data_1s = None + # Get reference timeframe for other features (prefer 1m, fallback to any available) ref_data = price_data_1m if price_data_1m is not None else ( price_data_1h if price_data_1h is not None else ( @@ -1928,6 +2314,46 @@ class RealTrainingAdapter: for i, data in enumerate(training_data): batch = self._convert_annotation_to_transformer_batch(data) if batch is not None: + # CRITICAL: Validate that ALL required timeframes are present + # REQUIRED: 1m, 1h, 1d (each with 600 candles) + # OPTIONAL: 1s (if available, include with 600 candles) + required_tf_keys = ['price_data_1m', 'price_data_1h', 'price_data_1d'] + optional_tf_keys = ['price_data_1s'] + + missing_tfs = [tf for tf in required_tf_keys if batch.get(tf) is None] + + if missing_tfs: + logger.warning(f" Skipping sample {i+1}: Missing required timeframes: {missing_tfs}") + continue + + # Validate each required timeframe has correct shape [1, 600, 5] + for tf_key in required_tf_keys: + tf_data = batch.get(tf_key) + if tf_data is not None: + if not isinstance(tf_data, torch.Tensor): + logger.warning(f" Skipping sample {i+1}: {tf_key} is not a tensor") + missing_tfs.append(tf_key) + break + shape = tf_data.shape + if len(shape) != 3 or shape[1] < 600: # Must be [batch, seq_len, features] with seq_len >= 600 + logger.warning(f" Skipping sample {i+1}: {tf_key} has invalid shape {shape} (need [1, 600, 5])") + missing_tfs.append(tf_key) + break + + if missing_tfs: + continue + + # Validate optional 1s timeframe if present + if batch.get('price_data_1s') is not None: + tf_data = batch.get('price_data_1s') + if isinstance(tf_data, torch.Tensor): + shape = tf_data.shape + if len(shape) != 3 or shape[1] < 600: + logger.warning(f" Sample {i+1}: price_data_1s has invalid shape {shape}, removing it") + batch['price_data_1s'] = None + + logger.debug(f" Sample {i+1}: All required timeframes present (1m, 1h, 1d), 1s={'present' if batch.get('price_data_1s') is not None else 'not available'}") + # Move batch to GPU immediately with pinned memory for faster transfer if use_gpu: batch_gpu = {} diff --git a/ANNOTATE/web/app.py b/ANNOTATE/web/app.py index b89acfa..795438a 100644 --- a/ANNOTATE/web/app.py +++ b/ANNOTATE/web/app.py @@ -832,7 +832,7 @@ class AnnotationDashboard: # Check if the specific model is already initialized if model_name == 'Transformer': logger.info("Checking Transformer model...") - if self.orchestrator.primary_transformer: + if hasattr(self.orchestrator, 'primary_transformer') and self.orchestrator.primary_transformer: self.loaded_models['Transformer'] = self.orchestrator.primary_transformer logger.info("Transformer model loaded successfully") else: @@ -841,13 +841,13 @@ class AnnotationDashboard: elif model_name == 'CNN': logger.info("Checking CNN model...") - if self.orchestrator.cnn_model: + if hasattr(self.orchestrator, 'cnn_model') and self.orchestrator.cnn_model: self.loaded_models['CNN'] = self.orchestrator.cnn_model logger.info("CNN model loaded successfully") else: logger.warning("CNN model not initialized in orchestrator") return - + elif model_name == 'DQN': logger.info("Checking DQN model...") if self.orchestrator.rl_agent: diff --git a/NN/models/dqn_agent.py b/NN/models/dqn_agent.py index d43e933..3302f15 100644 --- a/NN/models/dqn_agent.py +++ b/NN/models/dqn_agent.py @@ -270,9 +270,9 @@ class DQNAgent: self.batch_size = batch_size self.target_update = target_update - # Set device for computation (default to GPU if available) + # Set device for computation (read from config.yaml if available) if device is None: - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = self._get_device_from_config() else: self.device = device @@ -282,10 +282,6 @@ class DQNAgent: self.policy_net = DQNNetwork(self.state_dim, self.n_actions, config).to(self.device) self.target_net = DQNNetwork(self.state_dim, self.n_actions, config).to(self.device) - # Ensure models are on the correct device - self.policy_net = self.policy_net.to(self.device) - self.target_net = self.target_net.to(self.device) - # Initialize the target network with the same weights as the policy network self.target_net.load_state_dict(self.policy_net.state_dict()) @@ -317,13 +313,92 @@ class DQNAgent: # Market regime adaptation weights self.market_regime_weights = { - 'trending': 1.0, - 'sideways': 0.8, - 'volatile': 1.2, - 'bullish': 1.1, - 'bearish': 1.1 + 'trending': 1.2, # Higher confidence in trending markets + 'ranging': 0.8, # Lower confidence in ranging markets + 'volatile': 0.6 # Much lower confidence in volatile markets } + # Additional initialization + self.recent_actions = deque(maxlen=10) + self.recent_prices = deque(maxlen=20) + self.recent_rewards = deque(maxlen=100) + + # Price direction tracking + self.last_price_direction = { + 'direction': 0.0, + 'confidence': 0.0 + } + + self.price_movement_memory = [] + self.losses = [] + self.no_improvement_count = 0 + self.confidence_history = [] + self.avg_confidence = 0.0 + self.max_confidence = 0.0 + self.min_confidence = 1.0 + + # Enhanced training features + self.use_dueling = True + self.use_prioritized_replay = priority_memory + self.alpha = 0.6 + self.beta = 0.4 + self.beta_increment = 0.001 + self.use_double_dqn = True + self.target_update_freq = target_update + self.training_steps = 0 + self.gradient_clip_norm = 1.0 + self.epsilon_history = [] + self.td_errors = [] + + # Trade settings + self.trade_action_fee = 0.0005 + self.minimum_action_confidence = 0.3 + + # Violent move detection + self.price_history = [] + self.volatility_window = 20 + self.volatility_threshold = 0.0015 + self.post_violent_move = False + self.violent_move_cooldown = 0 + + # Feature integration + self.last_hidden_features = None + self.feature_history = [] + self.realtime_tick_features = None + self.tick_feature_weight = 0.3 + + # Mixed precision training + if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ: + self.use_mixed_precision = True + self.scaler = torch.amp.GradScaler('cuda') + logger.info("Mixed precision training enabled") + else: + self.use_mixed_precision = False + logger.info("Mixed precision training disabled") + + self.training = True + + # Compatibility + self.state_size = np.prod(state_shape) + self.action_size = n_actions + self.memory_size = buffer_size + self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] + + logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}") + logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}") + logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}") + + total_params = sum(p.numel() for p in self.policy_net.parameters()) + logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters") + + # Position management + self.current_position = 0.0 + self.position_entry_price = 0.0 + self.position_entry_time = None + self.entry_confidence_threshold = 0.35 + self.exit_confidence_threshold = 0.15 + self.uncertainty_threshold = 0.1 + # Load best checkpoint if available if self.enable_checkpoints: self.load_best_checkpoint() @@ -331,114 +406,47 @@ class DQNAgent: logger.info(f"DQN Agent initialized with checkpoint management: {enable_checkpoints}") if enable_checkpoints: logger.info(f"Model name: {model_name}, Checkpoint frequency: {self.checkpoint_frequency}") - - # Add this line to the __init__ method - self.recent_actions = deque(maxlen=10) - self.recent_prices = deque(maxlen=20) - self.recent_rewards = deque(maxlen=100) - - # Price direction tracking - stores direction and confidence - self.last_price_direction = { - 'direction': 0.0, # Single value between -1 and 1 - 'confidence': 0.0 # Single value between 0 and 1 - } - - # Store separate memory for price direction examples - self.price_movement_memory = [] # For storing examples of clear price movements - - # Performance tracking - self.losses = [] - self.no_improvement_count = 0 - - # Confidence tracking - self.confidence_history = [] - self.avg_confidence = 0.0 - self.max_confidence = 0.0 - self.min_confidence = 1.0 - - # Enhanced features from EnhancedDQNAgent - # Market adaptation capabilities - self.market_regime_weights = { - 'trending': 1.2, # Higher confidence in trending markets - 'ranging': 0.8, # Lower confidence in ranging markets - 'volatile': 0.6 # Much lower confidence in volatile markets - } - - # Dueling network support (requires enhanced network architecture) - self.use_dueling = True - - # Prioritized experience replay parameters - self.use_prioritized_replay = priority_memory - self.alpha = 0.6 # Priority exponent - self.beta = 0.4 # Importance sampling exponent - self.beta_increment = 0.001 - - # Double DQN support - self.use_double_dqn = True - - # Enhanced training features from EnhancedDQNAgent - self.target_update_freq = target_update # More descriptive name - self.training_steps = 0 - self.gradient_clip_norm = 1.0 # Gradient clipping - - # Enhanced statistics tracking - self.epsilon_history = [] - self.td_errors = [] # Track TD errors for analysis - - # Trade action fee and confidence thresholds - self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading - self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5) - - # Violent move detection - self.price_history = [] - self.volatility_window = 20 # Window size for volatility calculation - self.volatility_threshold = 0.0015 # Threshold for considering a move "violent" - self.post_violent_move = False # Flag for recent violent move - self.violent_move_cooldown = 0 # Cooldown after violent move - - # Feature integration - self.last_hidden_features = None # Store last extracted features - self.feature_history = [] # Store history of features for analysis - - # Real-time tick features integration - self.realtime_tick_features = None # Latest tick features from tick processor - self.tick_feature_weight = 0.3 # Weight for tick features in decision making - - # Check if mixed precision training should be used - if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ: - self.use_mixed_precision = True - self.scaler = torch.amp.GradScaler('cuda') - logger.info("Mixed precision training enabled") - else: - self.use_mixed_precision = False - logger.info("Mixed precision training disabled") + + def _get_device_from_config(self) -> torch.device: + """Get device from config.yaml or auto-detect""" + try: + # Try to load config + from core.config import get_config + config = get_config() + gpu_config = config._config.get('gpu', {}) - # Track if we're in training mode - self.training = True - - # For compatibility with old code - self.state_size = np.prod(state_shape) - self.action_size = n_actions - self.memory_size = buffer_size - self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] # Default timeframes - - logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}") - logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}") - logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}") - - # Log model parameters - total_params = sum(p.numel() for p in self.policy_net.parameters()) - logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters") - - # Position management for 2-action system - self.current_position = 0.0 # -1 (short), 0 (neutral), 1 (long) - self.position_entry_price = 0.0 - self.position_entry_time = None - - # Different thresholds for entry vs exit decisions - AGGRESSIVE for more training data - self.entry_confidence_threshold = 0.35 # Lower threshold for new positions (was 0.7) - self.exit_confidence_threshold = 0.15 # Very low threshold for closing positions (was 0.3) - self.uncertainty_threshold = 0.1 # When to stay neutral + device_setting = gpu_config.get('device', 'auto') + fallback_to_cpu = gpu_config.get('fallback_to_cpu', True) + gpu_enabled = gpu_config.get('enabled', True) + + # If GPU is disabled in config, use CPU + if not gpu_enabled: + logger.info("GPU disabled in config.yaml, using CPU") + return torch.device('cpu') + + # Handle device selection + if device_setting == 'cpu': + logger.info("Device set to CPU in config.yaml") + return torch.device('cpu') + elif device_setting == 'cuda' or device_setting == 'auto': + # Try GPU first + if torch.cuda.is_available(): + logger.info("Using GPU (CUDA available)") + return torch.device('cuda') + else: + if fallback_to_cpu: + logger.warning("CUDA not available, falling back to CPU") + return torch.device('cpu') + else: + raise RuntimeError("CUDA not available and fallback_to_cpu is False") + else: + logger.warning(f"Unknown device setting '{device_setting}', using auto-detection") + return torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + except Exception as e: + logger.warning(f"Error reading device config: {e}, using auto-detection") + # Fallback to auto-detection + return torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_best_checkpoint(self): """Load the best checkpoint for this DQN agent""" @@ -512,104 +520,7 @@ class DQNAgent: except Exception as e: logger.error(f"Error saving DQN checkpoint: {e}") return False - - # Store separate memory for price direction examples - self.price_movement_memory = [] # For storing examples of clear price movements - - # Performance tracking - self.losses = [] - self.no_improvement_count = 0 - - # Confidence tracking - self.confidence_history = [] - self.avg_confidence = 0.0 - self.max_confidence = 0.0 - self.min_confidence = 1.0 - - # Enhanced features from EnhancedDQNAgent - # Market adaptation capabilities - self.market_regime_weights = { - 'trending': 1.2, # Higher confidence in trending markets - 'ranging': 0.8, # Lower confidence in ranging markets - 'volatile': 0.6 # Much lower confidence in volatile markets - } - - # Dueling network support (requires enhanced network architecture) - self.use_dueling = True - - # Prioritized experience replay parameters - self.use_prioritized_replay = priority_memory - self.alpha = 0.6 # Priority exponent - self.beta = 0.4 # Importance sampling exponent - self.beta_increment = 0.001 - - # Double DQN support - self.use_double_dqn = True - - # Enhanced training features from EnhancedDQNAgent - self.target_update_freq = target_update # More descriptive name - self.training_steps = 0 - self.gradient_clip_norm = 1.0 # Gradient clipping - - # Enhanced statistics tracking - self.epsilon_history = [] - self.td_errors = [] # Track TD errors for analysis - - # Trade action fee and confidence thresholds - self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading - self.minimum_action_confidence = 0.3 # Minimum confidence to consider trading (lowered from 0.5) - - # Violent move detection - self.price_history = [] - self.volatility_window = 20 # Window size for volatility calculation - self.volatility_threshold = 0.0015 # Threshold for considering a move "violent" - self.post_violent_move = False # Flag for recent violent move - self.violent_move_cooldown = 0 # Cooldown after violent move - - # Feature integration - self.last_hidden_features = None # Store last extracted features - self.feature_history = [] # Store history of features for analysis - - # Real-time tick features integration - self.realtime_tick_features = None # Latest tick features from tick processor - self.tick_feature_weight = 0.3 # Weight for tick features in decision making - - # Check if mixed precision training should be used - if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ: - self.use_mixed_precision = True - self.scaler = torch.amp.GradScaler('cuda') - logger.info("Mixed precision training enabled") - else: - self.use_mixed_precision = False - logger.info("Mixed precision training disabled") - - # Track if we're in training mode - self.training = True - - # For compatibility with old code - self.state_size = np.prod(state_shape) - self.action_size = n_actions - self.memory_size = buffer_size - self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0] if isinstance(self.state_dim, tuple) else 3] # Default timeframes - - logger.info(f"DQN Agent using Enhanced CNN with device: {self.device}") - logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}") - logger.info(f"Real-time tick feature integration enabled with weight: {self.tick_feature_weight}") - - # Log model parameters - total_params = sum(p.numel() for p in self.policy_net.parameters()) - logger.info(f"Enhanced CNN Policy Network: {total_params:,} parameters") - - # Position management for 2-action system - self.current_position = 0.0 # -1 (short), 0 (neutral), 1 (long) - self.position_entry_price = 0.0 - self.position_entry_time = None - - # Different thresholds for entry vs exit decisions - AGGRESSIVE for more training data - self.entry_confidence_threshold = 0.35 # Lower threshold for new positions (was 0.7) - self.exit_confidence_threshold = 0.15 # Very low threshold for closing positions (was 0.3) - self.uncertainty_threshold = 0.1 # When to stay neutral - + def move_models_to_device(self, device=None): """Move models to the specified device (GPU/CPU)""" if device is not None: diff --git a/TREND_TARGET_IMPROVEMENT.md b/TREND_TARGET_IMPROVEMENT.md index 8ccc8d0..5d5993f 100644 --- a/TREND_TARGET_IMPROVEMENT.md +++ b/TREND_TARGET_IMPROVEMENT.md @@ -181,3 +181,7 @@ This complements the earlier fixes: 📊 **EXPECTED** - Better trend predictions after training + + + + diff --git a/core/orchestrator.py b/core/orchestrator.py index c5fcb3d..a59a9c7 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -318,27 +318,62 @@ class TradingOrchestrator: # Initialize confidence threshold self.confidence_threshold = self.config.get('confidence_threshold', 0.6) - # Determine the device to use (GPU if available, else CPU) - # Initialize device - force CPU mode to avoid CUDA errors - if torch.cuda.is_available(): - try: - # Test CUDA availability with actual Linear layer operation - # This catches architecture-specific issues like gfx1151 incompatibility - test_tensor = torch.randn(2, 10).cuda() - test_linear = torch.nn.Linear(10, 5).cuda() - test_result = test_linear(test_tensor) - logger.info(f"GPU compatibility test passed: {torch.cuda.get_device_name(0)}") - self.device = torch.device("cuda") - logger.info("CUDA/ROCm device initialized successfully") - except Exception as e: - logger.warning(f"CUDA/ROCm initialization failed: {e}") - logger.warning("GPU architecture may not be supported - falling back to CPU") - logger.warning("This is common with newer AMD GPUs (gfx1151+) that require specific PyTorch builds") - self.device = torch.device("cpu") - else: - self.device = torch.device("cpu") - + # Determine the device to use from config.yaml + self.device = self._get_device_from_config() logger.info(f"Using device: {self.device}") + + def _get_device_from_config(self) -> torch.device: + """Get device from config.yaml or auto-detect""" + try: + gpu_config = self.config._config.get('gpu', {}) + + device_setting = gpu_config.get('device', 'auto') + fallback_to_cpu = gpu_config.get('fallback_to_cpu', True) + gpu_enabled = gpu_config.get('enabled', True) + + # If GPU is disabled in config, use CPU + if not gpu_enabled: + logger.info("GPU disabled in config.yaml, using CPU") + return torch.device('cpu') + + # Handle device selection + if device_setting == 'cpu': + logger.info("Device set to CPU in config.yaml") + return torch.device('cpu') + elif device_setting == 'cuda' or device_setting == 'auto': + # Try GPU first with compatibility test + if torch.cuda.is_available(): + try: + # Test CUDA availability with actual Linear layer operation + # This catches architecture-specific issues like gfx1151 incompatibility + test_tensor = torch.randn(2, 10).cuda() + test_linear = torch.nn.Linear(10, 5).cuda() + test_result = test_linear(test_tensor) + logger.info(f"GPU compatibility test passed: {torch.cuda.get_device_name(0)}") + logger.info("CUDA/ROCm device initialized successfully") + return torch.device("cuda") + except Exception as e: + logger.warning(f"CUDA/ROCm initialization failed: {e}") + logger.warning("GPU architecture may not be supported - falling back to CPU") + logger.warning("This is common with newer AMD GPUs (gfx1151+) that require specific PyTorch builds") + if fallback_to_cpu: + return torch.device("cpu") + else: + raise RuntimeError("CUDA not available and fallback_to_cpu is False") + else: + if fallback_to_cpu: + logger.warning("CUDA not available, falling back to CPU") + return torch.device('cpu') + else: + raise RuntimeError("CUDA not available and fallback_to_cpu is False") + else: + logger.warning(f"Unknown device setting '{device_setting}', using auto-detection") + return torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + except Exception as e: + logger.warning(f"Error reading device config: {e}, using auto-detection") + # Fallback to auto-detection + return torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Canonical model name aliases to eliminate ambiguity across UI/DB/FS # Canonical → accepted aliases (internal/legacy)