train feedback - norm fix

This commit is contained in:
Dobromir Popov
2025-11-12 15:20:42 +02:00
parent 352dc9cbeb
commit 7cb4201bc0

View File

@@ -1211,16 +1211,16 @@ class RealTrainingAdapter:
logger.error(f"Error extracting timeframe data: {e}")
return None
def _extract_next_candle(self, tf_data: Dict, reference_data = None) -> Optional[torch.Tensor]:
def _extract_next_candle(self, tf_data: Dict, norm_params: Dict = None) -> Optional[torch.Tensor]:
"""
Extract the NEXT candle OHLCV (after the current sequence) as training target
This extracts the candle that comes immediately after the sequence used for input.
Normalized using the same price range as the reference data for consistency.
Normalized using the ORIGINAL price range (not the already-normalized sequence data).
Args:
tf_data: Timeframe data dictionary with 'open', 'high', 'low', 'close', 'volume'
reference_data: Optional reference OHLCV data [seq_len, 5] to get normalization params
norm_params: Normalization parameters dict with 'price_min', 'price_max', 'volume_min', 'volume_max'
Returns:
Tensor of shape [1, 5] representing next candle OHLCV, or None if not available
@@ -1251,17 +1251,20 @@ class RealTrainingAdapter:
# Create OHLCV array [5]
next_candle = np.array([next_open, next_high, next_low, next_close, next_volume], dtype=np.float32)
# Normalize using reference data if provided
if reference_data is not None:
# Use same normalization as input sequence
price_min = np.min(reference_data[:, :4])
price_max = np.max(reference_data[:, :4])
# CRITICAL FIX: Normalize using ORIGINAL price bounds, not already-normalized data
# The bug was: reference_data was already normalized to [0,1], so its min/max
# would be ~0 and ~1, which when used to normalize raw prices ($3000+) created
# astronomically large values (e.g., $3000 / 1.0 = still $3000 in "normalized" space!)
if norm_params is not None:
# Use ORIGINAL normalization bounds
price_min = norm_params.get('price_min', 0.0)
price_max = norm_params.get('price_max', 1.0)
if price_max > price_min:
next_candle[:4] = (next_candle[:4] - price_min) / (price_max - price_min)
volume_min = np.min(reference_data[:, 4])
volume_max = np.max(reference_data[:, 4])
volume_min = norm_params.get('volume_min', 0.0)
volume_max = norm_params.get('volume_max', 1.0)
if volume_max > volume_min:
next_candle[4] = (next_candle[4] - volume_min) / (volume_max - volume_min)
@@ -1598,22 +1601,18 @@ class RealTrainingAdapter:
future_candle_1d = None
# For each timeframe, extract the next candle if data is available
# Use the input sequence data as reference for normalization
if price_data_1s is not None and '1s' in timeframes:
ref_data_1s = price_data_1s.squeeze(0).numpy() # [seq_len, 5]
future_candle_1s = self._extract_next_candle(timeframes['1s'], ref_data_1s)
# CRITICAL: Pass the ORIGINAL normalization parameters, not the normalized data!
if price_data_1s is not None and '1s' in timeframes and '1s' in norm_params_dict:
future_candle_1s = self._extract_next_candle(timeframes['1s'], norm_params_dict['1s'])
if price_data_1m is not None and '1m' in timeframes:
ref_data_1m = price_data_1m.squeeze(0).numpy() # [seq_len, 5]
future_candle_1m = self._extract_next_candle(timeframes['1m'], ref_data_1m)
if price_data_1m is not None and '1m' in timeframes and '1m' in norm_params_dict:
future_candle_1m = self._extract_next_candle(timeframes['1m'], norm_params_dict['1m'])
if price_data_1h is not None and '1h' in timeframes:
ref_data_1h = price_data_1h.squeeze(0).numpy() # [seq_len, 5]
future_candle_1h = self._extract_next_candle(timeframes['1h'], ref_data_1h)
if price_data_1h is not None and '1h' in timeframes and '1h' in norm_params_dict:
future_candle_1h = self._extract_next_candle(timeframes['1h'], norm_params_dict['1h'])
if price_data_1d is not None and '1d' in timeframes:
ref_data_1d = price_data_1d.squeeze(0).numpy() # [seq_len, 5]
future_candle_1d = self._extract_next_candle(timeframes['1d'], ref_data_1d)
if price_data_1d is not None and '1d' in timeframes and '1d' in norm_params_dict:
future_candle_1d = self._extract_next_candle(timeframes['1d'], norm_params_dict['1d'])
# Return batch dictionary with ALL timeframes
batch = {