raining normalization fix
This commit is contained in:
@@ -1197,12 +1197,90 @@ class RealTrainingAdapter:
|
||||
ohlcv[:, 4] = (ohlcv[:, 4] - volume_min) / (volume_max - volume_min)
|
||||
|
||||
# Convert to tensor and add batch dimension [1, seq_len, 5]
|
||||
return torch.tensor(ohlcv, dtype=torch.float32).unsqueeze(0)
|
||||
# STORE normalization parameters for denormalization
|
||||
# Return tuple: (tensor, normalization_params)
|
||||
norm_params = {
|
||||
'price_min': float(price_min),
|
||||
'price_max': float(price_max),
|
||||
'volume_min': float(volume_min),
|
||||
'volume_max': float(volume_max)
|
||||
}
|
||||
return torch.tensor(ohlcv, dtype=torch.float32).unsqueeze(0), norm_params
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting timeframe data: {e}")
|
||||
return None
|
||||
|
||||
def _extract_next_candle(self, tf_data: Dict, reference_data = None) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Extract the NEXT candle OHLCV (after the current sequence) as training target
|
||||
|
||||
This extracts the candle that comes immediately after the sequence used for input.
|
||||
Normalized using the same price range as the reference data for consistency.
|
||||
|
||||
Args:
|
||||
tf_data: Timeframe data dictionary with 'open', 'high', 'low', 'close', 'volume'
|
||||
reference_data: Optional reference OHLCV data [seq_len, 5] to get normalization params
|
||||
|
||||
Returns:
|
||||
Tensor of shape [1, 5] representing next candle OHLCV, or None if not available
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
# Extract OHLCV arrays - get the LAST value as the "next" candle
|
||||
# In annotation context, the "current" sequence is historical, and we have the "next" candle
|
||||
opens = np.array(tf_data.get('open', []), dtype=np.float32)
|
||||
highs = np.array(tf_data.get('high', []), dtype=np.float32)
|
||||
lows = np.array(tf_data.get('low', []), dtype=np.float32)
|
||||
closes = np.array(tf_data.get('close', []), dtype=np.float32)
|
||||
volumes = np.array(tf_data.get('volume', []), dtype=np.float32)
|
||||
|
||||
if len(closes) == 0:
|
||||
return None
|
||||
|
||||
# Get the last candle as the "next" candle target
|
||||
# This assumes the timeframe data includes one extra candle after the sequence
|
||||
next_open = opens[-1]
|
||||
next_high = highs[-1]
|
||||
next_low = lows[-1]
|
||||
next_close = closes[-1]
|
||||
next_volume = volumes[-1]
|
||||
|
||||
# Create OHLCV array [5]
|
||||
next_candle = np.array([next_open, next_high, next_low, next_close, next_volume], dtype=np.float32)
|
||||
|
||||
# Normalize using reference data if provided
|
||||
if reference_data is not None:
|
||||
# Use same normalization as input sequence
|
||||
price_min = np.min(reference_data[:, :4])
|
||||
price_max = np.max(reference_data[:, :4])
|
||||
|
||||
if price_max > price_min:
|
||||
next_candle[:4] = (next_candle[:4] - price_min) / (price_max - price_min)
|
||||
|
||||
volume_min = np.min(reference_data[:, 4])
|
||||
volume_max = np.max(reference_data[:, 4])
|
||||
|
||||
if volume_max > volume_min:
|
||||
next_candle[4] = (next_candle[4] - volume_min) / (volume_max - volume_min)
|
||||
else:
|
||||
# If no reference, normalize relative to current candle's close
|
||||
if next_close > 0:
|
||||
next_candle[:4] = next_candle[:4] / next_close
|
||||
|
||||
# Volume normalized to 0-1 range (simple min-max with self)
|
||||
if next_volume > 0:
|
||||
next_candle[4] = 1.0
|
||||
|
||||
# Return as [1, 5] tensor
|
||||
return torch.tensor(next_candle, dtype=torch.float32).unsqueeze(0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting next candle: {e}")
|
||||
return None
|
||||
|
||||
def _convert_annotation_to_transformer_batch(self, training_sample: Dict) -> Dict[str, 'torch.Tensor']:
|
||||
"""
|
||||
Convert annotation training sample to multi-timeframe transformer input
|
||||
@@ -1237,16 +1315,40 @@ class RealTrainingAdapter:
|
||||
target_seq_len = min(len(tf_data['close']), 200) # Cap at 200
|
||||
break
|
||||
|
||||
# Extract each timeframe (returns None if not available)
|
||||
price_data_1s = self._extract_timeframe_data(timeframes.get('1s', {}), target_seq_len) if '1s' in timeframes else None
|
||||
price_data_1m = self._extract_timeframe_data(timeframes.get('1m', {}), target_seq_len) if '1m' in timeframes else None
|
||||
price_data_1h = self._extract_timeframe_data(timeframes.get('1h', {}), target_seq_len) if '1h' in timeframes else None
|
||||
price_data_1d = self._extract_timeframe_data(timeframes.get('1d', {}), target_seq_len) if '1d' in timeframes else None
|
||||
# Extract each timeframe (returns tuple: (tensor, norm_params) or None)
|
||||
# Store normalization parameters for each timeframe
|
||||
norm_params_dict = {}
|
||||
|
||||
result_1s = self._extract_timeframe_data(timeframes.get('1s', {}), target_seq_len) if '1s' in timeframes else None
|
||||
if result_1s:
|
||||
price_data_1s, norm_params_dict['1s'] = result_1s
|
||||
else:
|
||||
price_data_1s = None
|
||||
|
||||
result_1m = self._extract_timeframe_data(timeframes.get('1m', {}), target_seq_len) if '1m' in timeframes else None
|
||||
if result_1m:
|
||||
price_data_1m, norm_params_dict['1m'] = result_1m
|
||||
else:
|
||||
price_data_1m = None
|
||||
|
||||
result_1h = self._extract_timeframe_data(timeframes.get('1h', {}), target_seq_len) if '1h' in timeframes else None
|
||||
if result_1h:
|
||||
price_data_1h, norm_params_dict['1h'] = result_1h
|
||||
else:
|
||||
price_data_1h = None
|
||||
|
||||
result_1d = self._extract_timeframe_data(timeframes.get('1d', {}), target_seq_len) if '1d' in timeframes else None
|
||||
if result_1d:
|
||||
price_data_1d, norm_params_dict['1d'] = result_1d
|
||||
else:
|
||||
price_data_1d = None
|
||||
|
||||
# Extract BTC reference data
|
||||
btc_data_1m = None
|
||||
if 'BTC/USDT' in secondary_timeframes and '1m' in secondary_timeframes['BTC/USDT']:
|
||||
btc_data_1m = self._extract_timeframe_data(secondary_timeframes['BTC/USDT']['1m'], target_seq_len)
|
||||
result_btc = self._extract_timeframe_data(secondary_timeframes['BTC/USDT']['1m'], target_seq_len)
|
||||
if result_btc:
|
||||
btc_data_1m, norm_params_dict['btc'] = result_btc
|
||||
|
||||
# Ensure at least one timeframe is available
|
||||
# Check if all are None (can't use any() with tensors)
|
||||
@@ -1488,9 +1590,34 @@ class RealTrainingAdapter:
|
||||
# Create trend target tensor [batch, 3]: [angle, steepness, direction]
|
||||
trend_target = torch.tensor([[trend_angle, trend_steepness, trend_direction]], dtype=torch.float32) # [1, 3]
|
||||
|
||||
# Extract NEXT candle OHLCV targets for each available timeframe
|
||||
# These are the ground truth candles that the model should learn to predict
|
||||
future_candle_1s = None
|
||||
future_candle_1m = None
|
||||
future_candle_1h = None
|
||||
future_candle_1d = None
|
||||
|
||||
# For each timeframe, extract the next candle if data is available
|
||||
# Use the input sequence data as reference for normalization
|
||||
if price_data_1s is not None and '1s' in timeframes:
|
||||
ref_data_1s = price_data_1s.squeeze(0).numpy() # [seq_len, 5]
|
||||
future_candle_1s = self._extract_next_candle(timeframes['1s'], ref_data_1s)
|
||||
|
||||
if price_data_1m is not None and '1m' in timeframes:
|
||||
ref_data_1m = price_data_1m.squeeze(0).numpy() # [seq_len, 5]
|
||||
future_candle_1m = self._extract_next_candle(timeframes['1m'], ref_data_1m)
|
||||
|
||||
if price_data_1h is not None and '1h' in timeframes:
|
||||
ref_data_1h = price_data_1h.squeeze(0).numpy() # [seq_len, 5]
|
||||
future_candle_1h = self._extract_next_candle(timeframes['1h'], ref_data_1h)
|
||||
|
||||
if price_data_1d is not None and '1d' in timeframes:
|
||||
ref_data_1d = price_data_1d.squeeze(0).numpy() # [seq_len, 5]
|
||||
future_candle_1d = self._extract_next_candle(timeframes['1d'], ref_data_1d)
|
||||
|
||||
# Return batch dictionary with ALL timeframes
|
||||
batch = {
|
||||
# Multi-timeframe price data
|
||||
# Multi-timeframe price data (INPUT)
|
||||
'price_data_1s': price_data_1s, # [1, 600, 5] or None
|
||||
'price_data_1m': price_data_1m, # [1, 600, 5] or None
|
||||
'price_data_1h': price_data_1h, # [1, 600, 5] or None
|
||||
@@ -1503,11 +1630,20 @@ class RealTrainingAdapter:
|
||||
'market_data': market_data, # [1, 30]
|
||||
'position_state': position_state, # [1, 5]
|
||||
|
||||
# Training targets
|
||||
# Training targets - Actions and prices
|
||||
'actions': actions, # [1]
|
||||
'future_prices': future_prices, # [1, 1]
|
||||
'trade_success': trade_success, # [1, 1]
|
||||
'trend_target': trend_target, # [1, 3] - NEW: [angle, steepness, direction]
|
||||
'trend_target': trend_target, # [1, 3] - [angle, steepness, direction]
|
||||
|
||||
# Training targets - Next candle OHLCV for each timeframe
|
||||
'future_candle_1s': future_candle_1s, # [1, 5] or None
|
||||
'future_candle_1m': future_candle_1m, # [1, 5] or None
|
||||
'future_candle_1h': future_candle_1h, # [1, 5] or None
|
||||
'future_candle_1d': future_candle_1d, # [1, 5] or None
|
||||
|
||||
# CRITICAL: Normalization parameters for denormalization
|
||||
'norm_params': norm_params_dict, # Dict with keys: '1s', '1m', '1h', '1d', 'btc'
|
||||
|
||||
# Legacy support (use 1m as default)
|
||||
'price_data': price_data_1m if price_data_1m is not None else ref_data
|
||||
@@ -1713,25 +1849,27 @@ class RealTrainingAdapter:
|
||||
batch_accuracy = result.get('accuracy', 0.0)
|
||||
batch_candle_accuracy = result.get('candle_accuracy', 0.0)
|
||||
batch_trend_loss = result.get('trend_loss', 0.0)
|
||||
batch_candle_loss = result.get('candle_loss', 0.0)
|
||||
batch_candle_loss_denorm = result.get('candle_loss_denorm', {})
|
||||
epoch_loss += batch_loss
|
||||
epoch_accuracy += batch_accuracy
|
||||
num_batches += 1
|
||||
|
||||
# Log first batch and every 5th batch for debugging
|
||||
if (i + 1) == 1 or (i + 1) % 5 == 0:
|
||||
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}")
|
||||
# Format denormalized losses if available
|
||||
denorm_str = ""
|
||||
if batch_candle_loss_denorm:
|
||||
denorm_values = [f"{tf}=${loss:.2f}" for tf, loss in batch_candle_loss_denorm.items()]
|
||||
denorm_str = f", Real Price Error: {', '.join(denorm_values)}"
|
||||
|
||||
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}")
|
||||
else:
|
||||
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
||||
|
||||
# CRITICAL FIX: Delete batch tensors immediately to free GPU memory
|
||||
# This prevents memory accumulation during gradient accumulation
|
||||
for key in list(batch.keys()):
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
del batch[key]
|
||||
del batch
|
||||
|
||||
# CRITICAL: Clear CUDA cache after EVERY batch to prevent memory accumulation
|
||||
# This is essential with large models and limited GPU memory
|
||||
# NOTE: We do NOT delete batch tensors here because they are reused across epochs
|
||||
# Deleting them would cause "At least one timeframe must be provided" error on epoch 2+
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -1801,6 +1939,22 @@ class RealTrainingAdapter:
|
||||
session.final_loss = session.current_loss
|
||||
session.accuracy = avg_accuracy
|
||||
|
||||
# Cleanup: Delete batch tensors after all epochs are complete
|
||||
logger.info(" Cleaning up batch data...")
|
||||
for batch in grouped_batches:
|
||||
for key in list(batch.keys()):
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
del batch[key]
|
||||
batch.clear()
|
||||
grouped_batches.clear()
|
||||
converted_batches.clear()
|
||||
|
||||
# Final memory cleanup
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
# Log best checkpoint info
|
||||
try:
|
||||
checkpoint_dir = "models/checkpoints/transformer"
|
||||
|
||||
@@ -69,29 +69,6 @@
|
||||
"exit_state": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"annotation_id": "c9849a6b-e430-4305-9009-dd7471553c2f",
|
||||
"symbol": "ETH/USDT",
|
||||
"timeframe": "1m",
|
||||
"entry": {
|
||||
"timestamp": "2025-10-30 19:59",
|
||||
"price": 3680.1,
|
||||
"index": 272
|
||||
},
|
||||
"exit": {
|
||||
"timestamp": "2025-10-30 21:59",
|
||||
"price": 3767.82,
|
||||
"index": 312
|
||||
},
|
||||
"direction": "LONG",
|
||||
"profit_loss_pct": 2.38363087959567,
|
||||
"notes": "",
|
||||
"created_at": "2025-10-31T00:31:10.201966",
|
||||
"market_context": {
|
||||
"entry_state": {},
|
||||
"exit_state": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"annotation_id": "479eb310-c963-4837-b712-70e5a42afb53",
|
||||
"symbol": "ETH/USDT",
|
||||
@@ -114,10 +91,56 @@
|
||||
"entry_state": {},
|
||||
"exit_state": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"annotation_id": "6b529132-8a3e-488d-b354-db8785ddaa71",
|
||||
"symbol": "ETH/USDT",
|
||||
"timeframe": "1m",
|
||||
"entry": {
|
||||
"timestamp": "2025-11-11 12:07",
|
||||
"price": 3594.33,
|
||||
"index": 144
|
||||
},
|
||||
"exit": {
|
||||
"timestamp": "2025-11-11 20:46",
|
||||
"price": 3429.24,
|
||||
"index": 329
|
||||
},
|
||||
"direction": "SHORT",
|
||||
"profit_loss_pct": 4.593067414511193,
|
||||
"notes": "",
|
||||
"created_at": "2025-11-11T23:23:00.643510",
|
||||
"market_context": {
|
||||
"entry_state": {},
|
||||
"exit_state": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"annotation_id": "bbafc50c-f885-4dbc-b0cb-fdfb48223b5c",
|
||||
"symbol": "ETH/USDT",
|
||||
"timeframe": "1m",
|
||||
"entry": {
|
||||
"timestamp": "2025-11-12 07:58",
|
||||
"price": 3424.58,
|
||||
"index": 284
|
||||
},
|
||||
"exit": {
|
||||
"timestamp": "2025-11-12 11:08",
|
||||
"price": 3546.35,
|
||||
"index": 329
|
||||
},
|
||||
"direction": "LONG",
|
||||
"profit_loss_pct": 3.5557645025083366,
|
||||
"notes": "",
|
||||
"created_at": "2025-11-12T13:11:31.267142",
|
||||
"market_context": {
|
||||
"entry_state": {},
|
||||
"exit_state": {}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"total_annotations": 5,
|
||||
"last_updated": "2025-10-31T00:35:00.549074"
|
||||
"total_annotations": 6,
|
||||
"last_updated": "2025-11-12T13:11:31.267456"
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user