train feedback - norm fix
This commit is contained in:
@@ -1211,16 +1211,16 @@ class RealTrainingAdapter:
|
||||
logger.error(f"Error extracting timeframe data: {e}")
|
||||
return None
|
||||
|
||||
def _extract_next_candle(self, tf_data: Dict, reference_data = None) -> Optional[torch.Tensor]:
|
||||
def _extract_next_candle(self, tf_data: Dict, norm_params: Dict = None) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Extract the NEXT candle OHLCV (after the current sequence) as training target
|
||||
|
||||
This extracts the candle that comes immediately after the sequence used for input.
|
||||
Normalized using the same price range as the reference data for consistency.
|
||||
Normalized using the ORIGINAL price range (not the already-normalized sequence data).
|
||||
|
||||
Args:
|
||||
tf_data: Timeframe data dictionary with 'open', 'high', 'low', 'close', 'volume'
|
||||
reference_data: Optional reference OHLCV data [seq_len, 5] to get normalization params
|
||||
norm_params: Normalization parameters dict with 'price_min', 'price_max', 'volume_min', 'volume_max'
|
||||
|
||||
Returns:
|
||||
Tensor of shape [1, 5] representing next candle OHLCV, or None if not available
|
||||
@@ -1251,17 +1251,20 @@ class RealTrainingAdapter:
|
||||
# Create OHLCV array [5]
|
||||
next_candle = np.array([next_open, next_high, next_low, next_close, next_volume], dtype=np.float32)
|
||||
|
||||
# Normalize using reference data if provided
|
||||
if reference_data is not None:
|
||||
# Use same normalization as input sequence
|
||||
price_min = np.min(reference_data[:, :4])
|
||||
price_max = np.max(reference_data[:, :4])
|
||||
# CRITICAL FIX: Normalize using ORIGINAL price bounds, not already-normalized data
|
||||
# The bug was: reference_data was already normalized to [0,1], so its min/max
|
||||
# would be ~0 and ~1, which when used to normalize raw prices ($3000+) created
|
||||
# astronomically large values (e.g., $3000 / 1.0 = still $3000 in "normalized" space!)
|
||||
if norm_params is not None:
|
||||
# Use ORIGINAL normalization bounds
|
||||
price_min = norm_params.get('price_min', 0.0)
|
||||
price_max = norm_params.get('price_max', 1.0)
|
||||
|
||||
if price_max > price_min:
|
||||
next_candle[:4] = (next_candle[:4] - price_min) / (price_max - price_min)
|
||||
|
||||
volume_min = np.min(reference_data[:, 4])
|
||||
volume_max = np.max(reference_data[:, 4])
|
||||
volume_min = norm_params.get('volume_min', 0.0)
|
||||
volume_max = norm_params.get('volume_max', 1.0)
|
||||
|
||||
if volume_max > volume_min:
|
||||
next_candle[4] = (next_candle[4] - volume_min) / (volume_max - volume_min)
|
||||
@@ -1598,22 +1601,18 @@ class RealTrainingAdapter:
|
||||
future_candle_1d = None
|
||||
|
||||
# For each timeframe, extract the next candle if data is available
|
||||
# Use the input sequence data as reference for normalization
|
||||
if price_data_1s is not None and '1s' in timeframes:
|
||||
ref_data_1s = price_data_1s.squeeze(0).numpy() # [seq_len, 5]
|
||||
future_candle_1s = self._extract_next_candle(timeframes['1s'], ref_data_1s)
|
||||
# CRITICAL: Pass the ORIGINAL normalization parameters, not the normalized data!
|
||||
if price_data_1s is not None and '1s' in timeframes and '1s' in norm_params_dict:
|
||||
future_candle_1s = self._extract_next_candle(timeframes['1s'], norm_params_dict['1s'])
|
||||
|
||||
if price_data_1m is not None and '1m' in timeframes:
|
||||
ref_data_1m = price_data_1m.squeeze(0).numpy() # [seq_len, 5]
|
||||
future_candle_1m = self._extract_next_candle(timeframes['1m'], ref_data_1m)
|
||||
if price_data_1m is not None and '1m' in timeframes and '1m' in norm_params_dict:
|
||||
future_candle_1m = self._extract_next_candle(timeframes['1m'], norm_params_dict['1m'])
|
||||
|
||||
if price_data_1h is not None and '1h' in timeframes:
|
||||
ref_data_1h = price_data_1h.squeeze(0).numpy() # [seq_len, 5]
|
||||
future_candle_1h = self._extract_next_candle(timeframes['1h'], ref_data_1h)
|
||||
if price_data_1h is not None and '1h' in timeframes and '1h' in norm_params_dict:
|
||||
future_candle_1h = self._extract_next_candle(timeframes['1h'], norm_params_dict['1h'])
|
||||
|
||||
if price_data_1d is not None and '1d' in timeframes:
|
||||
ref_data_1d = price_data_1d.squeeze(0).numpy() # [seq_len, 5]
|
||||
future_candle_1d = self._extract_next_candle(timeframes['1d'], ref_data_1d)
|
||||
if price_data_1d is not None and '1d' in timeframes and '1d' in norm_params_dict:
|
||||
future_candle_1d = self._extract_next_candle(timeframes['1d'], norm_params_dict['1d'])
|
||||
|
||||
# Return batch dictionary with ALL timeframes
|
||||
batch = {
|
||||
|
||||
Reference in New Issue
Block a user