fixes
This commit is contained in:
@@ -294,9 +294,13 @@ class RealTrainingAdapter:
|
||||
# Clear previous predictions for clean visualization
|
||||
# Get symbol from first test case
|
||||
symbol = test_cases[0].get('symbol', 'ETH/USDT') if test_cases else 'ETH/USDT'
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'clear_predictions'):
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'clear_predictions') and
|
||||
hasattr(self.orchestrator, 'recent_transformer_predictions')):
|
||||
self.orchestrator.clear_predictions(symbol)
|
||||
logger.info(f" Cleared previous predictions for {symbol}")
|
||||
else:
|
||||
logger.info(f" Orchestrator not ready, skipping prediction clearing for {symbol}")
|
||||
|
||||
# Prepare training data from test cases
|
||||
training_data = self._prepare_training_data(test_cases)
|
||||
@@ -595,7 +599,7 @@ class RealTrainingAdapter:
|
||||
else:
|
||||
logger.warning(f" {symbol} {timeframe}: No quality data available (need {min_required_candles} candles)")
|
||||
|
||||
# CRITICAL: Validate we have all required timeframes
|
||||
# CRITICAL: Validate we have all required timeframes (1s is optional, don't check it)
|
||||
missing_required = [tf for tf in required_timeframes if tf not in fetched_timeframes]
|
||||
if missing_required:
|
||||
logger.error(f" FAILED: Missing required timeframes: {missing_required}")
|
||||
@@ -603,6 +607,10 @@ class RealTrainingAdapter:
|
||||
logger.error(f" Cannot proceed without all required timeframes")
|
||||
return {} # Return empty dict to signal failure
|
||||
|
||||
# Log optional timeframes status
|
||||
if '1s' not in fetched_timeframes:
|
||||
logger.debug(f" Optional timeframe 1s not available (this is OK - 1s historical data is often unavailable)")
|
||||
|
||||
# Fetch secondary symbol data (1m timeframe only)
|
||||
logger.info(f" Fetching secondary symbol data: {secondary_symbol} (1m)")
|
||||
secondary_df = None
|
||||
@@ -1080,7 +1088,8 @@ class RealTrainingAdapter:
|
||||
Create a market state snapshot at a specific candle index
|
||||
|
||||
This creates a "view" of the market as it was at that specific candle,
|
||||
which is used for negative sampling.
|
||||
which is used for negative sampling. CRITICAL: Ensures 600 candles are available
|
||||
by taking the last 600 candles BEFORE the target point.
|
||||
"""
|
||||
snapshot = {
|
||||
'symbol': market_state.get('symbol'),
|
||||
@@ -1088,19 +1097,26 @@ class RealTrainingAdapter:
|
||||
'timeframes': {}
|
||||
}
|
||||
|
||||
# For each timeframe, create a snapshot up to the candle_index
|
||||
# CRITICAL: Training requires 600 candles BEFORE the target point
|
||||
required_candles = 600
|
||||
|
||||
# For each timeframe, create a snapshot with 600 candles BEFORE the candle_index
|
||||
for tf, tf_data in market_state.get('timeframes', {}).items():
|
||||
timestamps = tf_data.get('timestamps', [])
|
||||
|
||||
if candle_index < len(timestamps):
|
||||
# Include data up to and including this candle
|
||||
# Take the last 600 candles BEFORE and INCLUDING this candle
|
||||
# If we don't have 600 candles, we'll pad later in extraction
|
||||
start_idx = max(0, candle_index + 1 - required_candles)
|
||||
end_idx = candle_index + 1
|
||||
|
||||
snapshot['timeframes'][tf] = {
|
||||
'timestamps': timestamps[:candle_index + 1],
|
||||
'open': tf_data.get('open', [])[:candle_index + 1],
|
||||
'high': tf_data.get('high', [])[:candle_index + 1],
|
||||
'low': tf_data.get('low', [])[:candle_index + 1],
|
||||
'close': tf_data.get('close', [])[:candle_index + 1],
|
||||
'volume': tf_data.get('volume', [])[:candle_index + 1]
|
||||
'timestamps': timestamps[start_idx:end_idx],
|
||||
'open': tf_data.get('open', [])[start_idx:end_idx],
|
||||
'high': tf_data.get('high', [])[start_idx:end_idx],
|
||||
'low': tf_data.get('low', [])[start_idx:end_idx],
|
||||
'close': tf_data.get('close', [])[start_idx:end_idx],
|
||||
'volume': tf_data.get('volume', [])[start_idx:end_idx]
|
||||
}
|
||||
|
||||
if tf == '1m':
|
||||
@@ -1561,21 +1577,28 @@ class RealTrainingAdapter:
|
||||
if len(closes) == 0:
|
||||
return None
|
||||
|
||||
# REQUIRED: Must have exactly target_seq_len (600) candles, no padding allowed
|
||||
# ALLOW PADDING: If we have fewer than target_seq_len, pad with the first available value
|
||||
if len(closes) < target_seq_len:
|
||||
logger.warning(f"Insufficient candles: {len(closes)} < {target_seq_len} (required)")
|
||||
return None
|
||||
|
||||
# Take last target_seq_len candles (exactly 600)
|
||||
opens = opens[-target_seq_len:]
|
||||
highs = highs[-target_seq_len:]
|
||||
lows = lows[-target_seq_len:]
|
||||
closes = closes[-target_seq_len:]
|
||||
volumes = volumes[-target_seq_len:]
|
||||
|
||||
# Validate we have exactly target_seq_len
|
||||
logger.debug(f"Padding {target_seq_len - len(closes)} candles for timeframe (have {len(closes)}, need {target_seq_len})")
|
||||
pad_len = target_seq_len - len(closes)
|
||||
|
||||
# Pad at the beginning with the first available value (edge padding)
|
||||
opens = np.pad(opens, (pad_len, 0), mode='edge')
|
||||
highs = np.pad(highs, (pad_len, 0), mode='edge')
|
||||
lows = np.pad(lows, (pad_len, 0), mode='edge')
|
||||
closes = np.pad(closes, (pad_len, 0), mode='edge')
|
||||
volumes = np.pad(volumes, (pad_len, 0), mode='edge')
|
||||
else:
|
||||
# Take last target_seq_len candles if we have more than needed
|
||||
opens = opens[-target_seq_len:]
|
||||
highs = highs[-target_seq_len:]
|
||||
lows = lows[-target_seq_len:]
|
||||
closes = closes[-target_seq_len:]
|
||||
volumes = volumes[-target_seq_len:]
|
||||
|
||||
# Validate we now have exactly target_seq_len
|
||||
if len(closes) != target_seq_len:
|
||||
logger.warning(f"Extraction failed: got {len(closes)} candles, need {target_seq_len}")
|
||||
logger.warning(f"Extraction failed: got {len(closes)} candles after padding, need {target_seq_len}")
|
||||
return None
|
||||
|
||||
# Stack OHLCV [seq_len, 5]
|
||||
@@ -1709,26 +1732,40 @@ class RealTrainingAdapter:
|
||||
timeframes = market_state.get('timeframes', {})
|
||||
secondary_timeframes = market_state.get('secondary_timeframes', {})
|
||||
|
||||
# REQUIRED: 600 candles per timeframe for transformer model
|
||||
target_seq_len = 600 # Must be 600 candles for each timeframe
|
||||
# Validate we have enough data in required timeframes (1m, 1h, 1d)
|
||||
# REQUIRED: At least some candles per timeframe for transformer model (will pad if needed)
|
||||
target_seq_len = 600 # Target 600 candles for each timeframe, but allow less and pad
|
||||
min_required_candles = 50 # Minimum candles needed to attempt training
|
||||
# Validate we have minimum data in required timeframes (1m, 1h, 1d)
|
||||
required_tfs = ['1m', '1h', '1d']
|
||||
for tf_name in required_tfs:
|
||||
if tf_name in timeframes:
|
||||
tf_data = timeframes[tf_name]
|
||||
if tf_data and 'close' in tf_data:
|
||||
if len(tf_data['close']) < 600:
|
||||
logger.warning(f"Required timeframe {tf_name} has only {len(tf_data['close'])} candles (need 600)")
|
||||
if len(tf_data['close']) < min_required_candles:
|
||||
logger.warning(f"Required timeframe {tf_name} has only {len(tf_data['close'])} candles (need at least {min_required_candles})")
|
||||
return None
|
||||
elif len(tf_data['close']) < target_seq_len:
|
||||
logger.debug(f"Timeframe {tf_name} has {len(tf_data['close'])} candles, will pad to {target_seq_len}")
|
||||
else:
|
||||
logger.warning(f"Required timeframe {tf_name} missing data")
|
||||
return None
|
||||
else:
|
||||
logger.warning(f"Required timeframe {tf_name} not found in data")
|
||||
return None
|
||||
|
||||
# Validate optional 1s timeframe if present (must have 600 candles if included)
|
||||
# Validate optional 1s timeframe if present (must have minimum candles if included)
|
||||
if '1s' in timeframes:
|
||||
tf_data = timeframes['1s']
|
||||
if tf_data and 'close' in tf_data:
|
||||
if len(tf_data['close']) < 600:
|
||||
logger.warning(f"Optional timeframe 1s has only {len(tf_data['close'])} candles (need 600), excluding it")
|
||||
if len(tf_data['close']) < min_required_candles:
|
||||
logger.warning(f"Optional timeframe 1s has only {len(tf_data['close'])} candles (need at least {min_required_candles}), excluding it")
|
||||
# Remove 1s from timeframes if insufficient
|
||||
timeframes = {k: v for k, v in timeframes.items() if k != '1s'}
|
||||
elif len(tf_data['close']) < target_seq_len:
|
||||
logger.debug(f"Timeframe 1s has {len(tf_data['close'])} candles, will pad to {target_seq_len}")
|
||||
else:
|
||||
logger.warning(f"Optional timeframe 1s has invalid data, excluding it")
|
||||
timeframes = {k: v for k, v in timeframes.items() if k != '1s'}
|
||||
|
||||
# Extract each timeframe (returns tuple: (tensor, norm_params) or None)
|
||||
# Store normalization parameters for each timeframe
|
||||
@@ -1793,18 +1830,18 @@ class RealTrainingAdapter:
|
||||
logger.warning(f"Missing required timeframes: {missing}. Need all 3: 1m, 1h, 1d")
|
||||
return None
|
||||
|
||||
# Validate each required timeframe has correct shape
|
||||
# Validate each required timeframe has correct shape (should be exactly 600 after padding)
|
||||
for tf_name, tf_data in [('1m', price_data_1m), ('1h', price_data_1h), ('1d', price_data_1d)]:
|
||||
if tf_data is not None:
|
||||
shape = tf_data.shape
|
||||
if len(shape) != 3 or shape[1] < 600:
|
||||
if len(shape) != 3 or shape[1] != 600:
|
||||
logger.warning(f"Timeframe {tf_name} has invalid shape {shape} (need [1, 600, 5])")
|
||||
return None
|
||||
|
||||
# Validate optional 1s timeframe if present
|
||||
# Validate optional 1s timeframe if present (should be exactly 600 if included)
|
||||
if price_data_1s is not None:
|
||||
shape = price_data_1s.shape
|
||||
if len(shape) != 3 or shape[1] < 600:
|
||||
if len(shape) != 3 or shape[1] != 600:
|
||||
logger.warning(f"Optional timeframe 1s has invalid shape {shape}, removing it")
|
||||
price_data_1s = None
|
||||
|
||||
@@ -2335,10 +2372,14 @@ class RealTrainingAdapter:
|
||||
missing_tfs.append(tf_key)
|
||||
break
|
||||
shape = tf_data.shape
|
||||
if len(shape) != 3 or shape[1] < 600: # Must be [batch, seq_len, features] with seq_len >= 600
|
||||
if len(shape) != 3 or shape[1] != 600: # Must be [batch, seq_len, features] with seq_len == 600 (padded if needed)
|
||||
logger.warning(f" Skipping sample {i+1}: {tf_key} has invalid shape {shape} (need [1, 600, 5])")
|
||||
missing_tfs.append(tf_key)
|
||||
break
|
||||
else:
|
||||
logger.warning(f" Skipping sample {i+1}: {tf_key} is None")
|
||||
missing_tfs.append(tf_key)
|
||||
break
|
||||
|
||||
if missing_tfs:
|
||||
continue
|
||||
@@ -2348,7 +2389,7 @@ class RealTrainingAdapter:
|
||||
tf_data = batch.get('price_data_1s')
|
||||
if isinstance(tf_data, torch.Tensor):
|
||||
shape = tf_data.shape
|
||||
if len(shape) != 3 or shape[1] < 600:
|
||||
if len(shape) != 3 or shape[1] != 600:
|
||||
logger.warning(f" Sample {i+1}: price_data_1s has invalid shape {shape}, removing it")
|
||||
batch['price_data_1s'] = None
|
||||
|
||||
@@ -3306,14 +3347,37 @@ class RealTrainingAdapter:
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
def _fetch_market_state_for_candle(self, symbol: str, timestamp, data_provider) -> Dict:
|
||||
"""Fetch market state with OHLCV data for model training"""
|
||||
"""Fetch market state with OHLCV data for model training - ENSURES 600 CANDLES ARE AVAILABLE"""
|
||||
try:
|
||||
# Get market state with OHLCV data only (NO business logic)
|
||||
market_state = {'timeframes': {}, 'secondary_timeframes': {}}
|
||||
|
||||
|
||||
# CRITICAL: Training requires exactly 600 candles per timeframe
|
||||
required_limit = 600
|
||||
|
||||
for tf in ['1s', '1m', '1h', '1d']:
|
||||
df = data_provider.get_historical_data(symbol, tf, limit=200)
|
||||
if df is not None and not df.empty:
|
||||
# First try to get data from cache
|
||||
df = data_provider.get_historical_data(symbol, tf, limit=required_limit)
|
||||
|
||||
# If insufficient data, force a refresh from API and cache it
|
||||
if df is None or df.empty or len(df) < required_limit:
|
||||
logger.info(f"Fetching {required_limit} candles for {symbol} {tf} from API (insufficient cached data)")
|
||||
try:
|
||||
# Force refresh from API and persist to cache
|
||||
df = data_provider.get_historical_data(
|
||||
symbol, tf, limit=required_limit,
|
||||
refresh=True, persist=True
|
||||
)
|
||||
logger.info(f"Successfully cached {len(df) if df is not None else 0} candles for {symbol} {tf}")
|
||||
except Exception as api_error:
|
||||
logger.warning(f"Failed to fetch {symbol} {tf} from API: {api_error}")
|
||||
continue
|
||||
|
||||
# Verify we have enough data
|
||||
if df is not None and not df.empty and len(df) >= required_limit:
|
||||
# Take the most recent required_limit candles
|
||||
df = df.tail(required_limit)
|
||||
|
||||
market_state['timeframes'][tf] = {
|
||||
'timestamps': df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
'open': df['open'].tolist(),
|
||||
@@ -3322,10 +3386,42 @@ class RealTrainingAdapter:
|
||||
'close': df['close'].tolist(),
|
||||
'volume': df['volume'].tolist()
|
||||
}
|
||||
|
||||
logger.debug(f"Prepared {len(df)} candles for {symbol} {tf}")
|
||||
else:
|
||||
logger.warning(f"Still insufficient data for {symbol} {tf} after API fetch: {len(df) if df is not None else 0} < {required_limit}")
|
||||
|
||||
# Also fetch BTC reference data for 1m timeframe
|
||||
btc_symbol = 'BTC/USDT'
|
||||
btc_tf = '1m'
|
||||
try:
|
||||
btc_df = data_provider.get_historical_data(btc_symbol, btc_tf, limit=required_limit)
|
||||
if btc_df is None or btc_df.empty or len(btc_df) < required_limit:
|
||||
logger.info(f"Fetching BTC reference data for training")
|
||||
btc_df = data_provider.get_historical_data(
|
||||
btc_symbol, btc_tf, limit=required_limit,
|
||||
refresh=True, persist=True
|
||||
)
|
||||
|
||||
if btc_df is not None and not btc_df.empty and len(btc_df) >= required_limit:
|
||||
btc_df = btc_df.tail(required_limit)
|
||||
market_state['secondary_timeframes'][btc_symbol] = {
|
||||
btc_tf: {
|
||||
'timestamps': btc_df.index.strftime('%Y-%m-%d %H:%M:%S').tolist(),
|
||||
'open': btc_df['open'].tolist(),
|
||||
'high': btc_df['high'].tolist(),
|
||||
'low': btc_df['low'].tolist(),
|
||||
'close': btc_df['close'].tolist(),
|
||||
'volume': btc_df['volume'].tolist()
|
||||
}
|
||||
}
|
||||
except Exception as btc_error:
|
||||
logger.warning(f"Failed to fetch BTC reference data: {btc_error}")
|
||||
|
||||
return market_state
|
||||
except Exception as e:
|
||||
logger.warning(f"Error fetching market state for candle: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return {}
|
||||
|
||||
def _convert_prediction_to_batch(self, prediction_sample: Dict, timeframe: str):
|
||||
@@ -3882,7 +3978,8 @@ class RealTrainingAdapter:
|
||||
else:
|
||||
prediction_data['trend_vector'] = trend_vec
|
||||
|
||||
self.orchestrator.store_transformer_prediction(symbol, prediction_data)
|
||||
if hasattr(self.orchestrator, 'store_transformer_prediction'):
|
||||
self.orchestrator.store_transformer_prediction(symbol, prediction_data)
|
||||
|
||||
# Training decision using strategy manager
|
||||
training_strategy = session.get('training_strategy')
|
||||
|
||||
Reference in New Issue
Block a user