raining normalization fix
This commit is contained in:
@@ -1197,12 +1197,90 @@ class RealTrainingAdapter:
|
|||||||
ohlcv[:, 4] = (ohlcv[:, 4] - volume_min) / (volume_max - volume_min)
|
ohlcv[:, 4] = (ohlcv[:, 4] - volume_min) / (volume_max - volume_min)
|
||||||
|
|
||||||
# Convert to tensor and add batch dimension [1, seq_len, 5]
|
# Convert to tensor and add batch dimension [1, seq_len, 5]
|
||||||
return torch.tensor(ohlcv, dtype=torch.float32).unsqueeze(0)
|
# STORE normalization parameters for denormalization
|
||||||
|
# Return tuple: (tensor, normalization_params)
|
||||||
|
norm_params = {
|
||||||
|
'price_min': float(price_min),
|
||||||
|
'price_max': float(price_max),
|
||||||
|
'volume_min': float(volume_min),
|
||||||
|
'volume_max': float(volume_max)
|
||||||
|
}
|
||||||
|
return torch.tensor(ohlcv, dtype=torch.float32).unsqueeze(0), norm_params
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error extracting timeframe data: {e}")
|
logger.error(f"Error extracting timeframe data: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _extract_next_candle(self, tf_data: Dict, reference_data = None) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Extract the NEXT candle OHLCV (after the current sequence) as training target
|
||||||
|
|
||||||
|
This extracts the candle that comes immediately after the sequence used for input.
|
||||||
|
Normalized using the same price range as the reference data for consistency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tf_data: Timeframe data dictionary with 'open', 'high', 'low', 'close', 'volume'
|
||||||
|
reference_data: Optional reference OHLCV data [seq_len, 5] to get normalization params
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of shape [1, 5] representing next candle OHLCV, or None if not available
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract OHLCV arrays - get the LAST value as the "next" candle
|
||||||
|
# In annotation context, the "current" sequence is historical, and we have the "next" candle
|
||||||
|
opens = np.array(tf_data.get('open', []), dtype=np.float32)
|
||||||
|
highs = np.array(tf_data.get('high', []), dtype=np.float32)
|
||||||
|
lows = np.array(tf_data.get('low', []), dtype=np.float32)
|
||||||
|
closes = np.array(tf_data.get('close', []), dtype=np.float32)
|
||||||
|
volumes = np.array(tf_data.get('volume', []), dtype=np.float32)
|
||||||
|
|
||||||
|
if len(closes) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get the last candle as the "next" candle target
|
||||||
|
# This assumes the timeframe data includes one extra candle after the sequence
|
||||||
|
next_open = opens[-1]
|
||||||
|
next_high = highs[-1]
|
||||||
|
next_low = lows[-1]
|
||||||
|
next_close = closes[-1]
|
||||||
|
next_volume = volumes[-1]
|
||||||
|
|
||||||
|
# Create OHLCV array [5]
|
||||||
|
next_candle = np.array([next_open, next_high, next_low, next_close, next_volume], dtype=np.float32)
|
||||||
|
|
||||||
|
# Normalize using reference data if provided
|
||||||
|
if reference_data is not None:
|
||||||
|
# Use same normalization as input sequence
|
||||||
|
price_min = np.min(reference_data[:, :4])
|
||||||
|
price_max = np.max(reference_data[:, :4])
|
||||||
|
|
||||||
|
if price_max > price_min:
|
||||||
|
next_candle[:4] = (next_candle[:4] - price_min) / (price_max - price_min)
|
||||||
|
|
||||||
|
volume_min = np.min(reference_data[:, 4])
|
||||||
|
volume_max = np.max(reference_data[:, 4])
|
||||||
|
|
||||||
|
if volume_max > volume_min:
|
||||||
|
next_candle[4] = (next_candle[4] - volume_min) / (volume_max - volume_min)
|
||||||
|
else:
|
||||||
|
# If no reference, normalize relative to current candle's close
|
||||||
|
if next_close > 0:
|
||||||
|
next_candle[:4] = next_candle[:4] / next_close
|
||||||
|
|
||||||
|
# Volume normalized to 0-1 range (simple min-max with self)
|
||||||
|
if next_volume > 0:
|
||||||
|
next_candle[4] = 1.0
|
||||||
|
|
||||||
|
# Return as [1, 5] tensor
|
||||||
|
return torch.tensor(next_candle, dtype=torch.float32).unsqueeze(0)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting next candle: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def _convert_annotation_to_transformer_batch(self, training_sample: Dict) -> Dict[str, 'torch.Tensor']:
|
def _convert_annotation_to_transformer_batch(self, training_sample: Dict) -> Dict[str, 'torch.Tensor']:
|
||||||
"""
|
"""
|
||||||
Convert annotation training sample to multi-timeframe transformer input
|
Convert annotation training sample to multi-timeframe transformer input
|
||||||
@@ -1237,16 +1315,40 @@ class RealTrainingAdapter:
|
|||||||
target_seq_len = min(len(tf_data['close']), 200) # Cap at 200
|
target_seq_len = min(len(tf_data['close']), 200) # Cap at 200
|
||||||
break
|
break
|
||||||
|
|
||||||
# Extract each timeframe (returns None if not available)
|
# Extract each timeframe (returns tuple: (tensor, norm_params) or None)
|
||||||
price_data_1s = self._extract_timeframe_data(timeframes.get('1s', {}), target_seq_len) if '1s' in timeframes else None
|
# Store normalization parameters for each timeframe
|
||||||
price_data_1m = self._extract_timeframe_data(timeframes.get('1m', {}), target_seq_len) if '1m' in timeframes else None
|
norm_params_dict = {}
|
||||||
price_data_1h = self._extract_timeframe_data(timeframes.get('1h', {}), target_seq_len) if '1h' in timeframes else None
|
|
||||||
price_data_1d = self._extract_timeframe_data(timeframes.get('1d', {}), target_seq_len) if '1d' in timeframes else None
|
result_1s = self._extract_timeframe_data(timeframes.get('1s', {}), target_seq_len) if '1s' in timeframes else None
|
||||||
|
if result_1s:
|
||||||
|
price_data_1s, norm_params_dict['1s'] = result_1s
|
||||||
|
else:
|
||||||
|
price_data_1s = None
|
||||||
|
|
||||||
|
result_1m = self._extract_timeframe_data(timeframes.get('1m', {}), target_seq_len) if '1m' in timeframes else None
|
||||||
|
if result_1m:
|
||||||
|
price_data_1m, norm_params_dict['1m'] = result_1m
|
||||||
|
else:
|
||||||
|
price_data_1m = None
|
||||||
|
|
||||||
|
result_1h = self._extract_timeframe_data(timeframes.get('1h', {}), target_seq_len) if '1h' in timeframes else None
|
||||||
|
if result_1h:
|
||||||
|
price_data_1h, norm_params_dict['1h'] = result_1h
|
||||||
|
else:
|
||||||
|
price_data_1h = None
|
||||||
|
|
||||||
|
result_1d = self._extract_timeframe_data(timeframes.get('1d', {}), target_seq_len) if '1d' in timeframes else None
|
||||||
|
if result_1d:
|
||||||
|
price_data_1d, norm_params_dict['1d'] = result_1d
|
||||||
|
else:
|
||||||
|
price_data_1d = None
|
||||||
|
|
||||||
# Extract BTC reference data
|
# Extract BTC reference data
|
||||||
btc_data_1m = None
|
btc_data_1m = None
|
||||||
if 'BTC/USDT' in secondary_timeframes and '1m' in secondary_timeframes['BTC/USDT']:
|
if 'BTC/USDT' in secondary_timeframes and '1m' in secondary_timeframes['BTC/USDT']:
|
||||||
btc_data_1m = self._extract_timeframe_data(secondary_timeframes['BTC/USDT']['1m'], target_seq_len)
|
result_btc = self._extract_timeframe_data(secondary_timeframes['BTC/USDT']['1m'], target_seq_len)
|
||||||
|
if result_btc:
|
||||||
|
btc_data_1m, norm_params_dict['btc'] = result_btc
|
||||||
|
|
||||||
# Ensure at least one timeframe is available
|
# Ensure at least one timeframe is available
|
||||||
# Check if all are None (can't use any() with tensors)
|
# Check if all are None (can't use any() with tensors)
|
||||||
@@ -1488,9 +1590,34 @@ class RealTrainingAdapter:
|
|||||||
# Create trend target tensor [batch, 3]: [angle, steepness, direction]
|
# Create trend target tensor [batch, 3]: [angle, steepness, direction]
|
||||||
trend_target = torch.tensor([[trend_angle, trend_steepness, trend_direction]], dtype=torch.float32) # [1, 3]
|
trend_target = torch.tensor([[trend_angle, trend_steepness, trend_direction]], dtype=torch.float32) # [1, 3]
|
||||||
|
|
||||||
|
# Extract NEXT candle OHLCV targets for each available timeframe
|
||||||
|
# These are the ground truth candles that the model should learn to predict
|
||||||
|
future_candle_1s = None
|
||||||
|
future_candle_1m = None
|
||||||
|
future_candle_1h = None
|
||||||
|
future_candle_1d = None
|
||||||
|
|
||||||
|
# For each timeframe, extract the next candle if data is available
|
||||||
|
# Use the input sequence data as reference for normalization
|
||||||
|
if price_data_1s is not None and '1s' in timeframes:
|
||||||
|
ref_data_1s = price_data_1s.squeeze(0).numpy() # [seq_len, 5]
|
||||||
|
future_candle_1s = self._extract_next_candle(timeframes['1s'], ref_data_1s)
|
||||||
|
|
||||||
|
if price_data_1m is not None and '1m' in timeframes:
|
||||||
|
ref_data_1m = price_data_1m.squeeze(0).numpy() # [seq_len, 5]
|
||||||
|
future_candle_1m = self._extract_next_candle(timeframes['1m'], ref_data_1m)
|
||||||
|
|
||||||
|
if price_data_1h is not None and '1h' in timeframes:
|
||||||
|
ref_data_1h = price_data_1h.squeeze(0).numpy() # [seq_len, 5]
|
||||||
|
future_candle_1h = self._extract_next_candle(timeframes['1h'], ref_data_1h)
|
||||||
|
|
||||||
|
if price_data_1d is not None and '1d' in timeframes:
|
||||||
|
ref_data_1d = price_data_1d.squeeze(0).numpy() # [seq_len, 5]
|
||||||
|
future_candle_1d = self._extract_next_candle(timeframes['1d'], ref_data_1d)
|
||||||
|
|
||||||
# Return batch dictionary with ALL timeframes
|
# Return batch dictionary with ALL timeframes
|
||||||
batch = {
|
batch = {
|
||||||
# Multi-timeframe price data
|
# Multi-timeframe price data (INPUT)
|
||||||
'price_data_1s': price_data_1s, # [1, 600, 5] or None
|
'price_data_1s': price_data_1s, # [1, 600, 5] or None
|
||||||
'price_data_1m': price_data_1m, # [1, 600, 5] or None
|
'price_data_1m': price_data_1m, # [1, 600, 5] or None
|
||||||
'price_data_1h': price_data_1h, # [1, 600, 5] or None
|
'price_data_1h': price_data_1h, # [1, 600, 5] or None
|
||||||
@@ -1503,11 +1630,20 @@ class RealTrainingAdapter:
|
|||||||
'market_data': market_data, # [1, 30]
|
'market_data': market_data, # [1, 30]
|
||||||
'position_state': position_state, # [1, 5]
|
'position_state': position_state, # [1, 5]
|
||||||
|
|
||||||
# Training targets
|
# Training targets - Actions and prices
|
||||||
'actions': actions, # [1]
|
'actions': actions, # [1]
|
||||||
'future_prices': future_prices, # [1, 1]
|
'future_prices': future_prices, # [1, 1]
|
||||||
'trade_success': trade_success, # [1, 1]
|
'trade_success': trade_success, # [1, 1]
|
||||||
'trend_target': trend_target, # [1, 3] - NEW: [angle, steepness, direction]
|
'trend_target': trend_target, # [1, 3] - [angle, steepness, direction]
|
||||||
|
|
||||||
|
# Training targets - Next candle OHLCV for each timeframe
|
||||||
|
'future_candle_1s': future_candle_1s, # [1, 5] or None
|
||||||
|
'future_candle_1m': future_candle_1m, # [1, 5] or None
|
||||||
|
'future_candle_1h': future_candle_1h, # [1, 5] or None
|
||||||
|
'future_candle_1d': future_candle_1d, # [1, 5] or None
|
||||||
|
|
||||||
|
# CRITICAL: Normalization parameters for denormalization
|
||||||
|
'norm_params': norm_params_dict, # Dict with keys: '1s', '1m', '1h', '1d', 'btc'
|
||||||
|
|
||||||
# Legacy support (use 1m as default)
|
# Legacy support (use 1m as default)
|
||||||
'price_data': price_data_1m if price_data_1m is not None else ref_data
|
'price_data': price_data_1m if price_data_1m is not None else ref_data
|
||||||
@@ -1713,25 +1849,27 @@ class RealTrainingAdapter:
|
|||||||
batch_accuracy = result.get('accuracy', 0.0)
|
batch_accuracy = result.get('accuracy', 0.0)
|
||||||
batch_candle_accuracy = result.get('candle_accuracy', 0.0)
|
batch_candle_accuracy = result.get('candle_accuracy', 0.0)
|
||||||
batch_trend_loss = result.get('trend_loss', 0.0)
|
batch_trend_loss = result.get('trend_loss', 0.0)
|
||||||
|
batch_candle_loss = result.get('candle_loss', 0.0)
|
||||||
|
batch_candle_loss_denorm = result.get('candle_loss_denorm', {})
|
||||||
epoch_loss += batch_loss
|
epoch_loss += batch_loss
|
||||||
epoch_accuracy += batch_accuracy
|
epoch_accuracy += batch_accuracy
|
||||||
num_batches += 1
|
num_batches += 1
|
||||||
|
|
||||||
# Log first batch and every 5th batch for debugging
|
# Log first batch and every 5th batch for debugging
|
||||||
if (i + 1) == 1 or (i + 1) % 5 == 0:
|
if (i + 1) == 1 or (i + 1) % 5 == 0:
|
||||||
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}")
|
# Format denormalized losses if available
|
||||||
|
denorm_str = ""
|
||||||
|
if batch_candle_loss_denorm:
|
||||||
|
denorm_values = [f"{tf}=${loss:.2f}" for tf, loss in batch_candle_loss_denorm.items()]
|
||||||
|
denorm_str = f", Real Price Error: {', '.join(denorm_values)}"
|
||||||
|
|
||||||
|
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
||||||
|
|
||||||
# CRITICAL FIX: Delete batch tensors immediately to free GPU memory
|
|
||||||
# This prevents memory accumulation during gradient accumulation
|
|
||||||
for key in list(batch.keys()):
|
|
||||||
if isinstance(batch[key], torch.Tensor):
|
|
||||||
del batch[key]
|
|
||||||
del batch
|
|
||||||
|
|
||||||
# CRITICAL: Clear CUDA cache after EVERY batch to prevent memory accumulation
|
# CRITICAL: Clear CUDA cache after EVERY batch to prevent memory accumulation
|
||||||
# This is essential with large models and limited GPU memory
|
# NOTE: We do NOT delete batch tensors here because they are reused across epochs
|
||||||
|
# Deleting them would cause "At least one timeframe must be provided" error on epoch 2+
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@@ -1801,6 +1939,22 @@ class RealTrainingAdapter:
|
|||||||
session.final_loss = session.current_loss
|
session.final_loss = session.current_loss
|
||||||
session.accuracy = avg_accuracy
|
session.accuracy = avg_accuracy
|
||||||
|
|
||||||
|
# Cleanup: Delete batch tensors after all epochs are complete
|
||||||
|
logger.info(" Cleaning up batch data...")
|
||||||
|
for batch in grouped_batches:
|
||||||
|
for key in list(batch.keys()):
|
||||||
|
if isinstance(batch[key], torch.Tensor):
|
||||||
|
del batch[key]
|
||||||
|
batch.clear()
|
||||||
|
grouped_batches.clear()
|
||||||
|
converted_batches.clear()
|
||||||
|
|
||||||
|
# Final memory cleanup
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
# Log best checkpoint info
|
# Log best checkpoint info
|
||||||
try:
|
try:
|
||||||
checkpoint_dir = "models/checkpoints/transformer"
|
checkpoint_dir = "models/checkpoints/transformer"
|
||||||
|
|||||||
@@ -69,29 +69,6 @@
|
|||||||
"exit_state": {}
|
"exit_state": {}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"annotation_id": "c9849a6b-e430-4305-9009-dd7471553c2f",
|
|
||||||
"symbol": "ETH/USDT",
|
|
||||||
"timeframe": "1m",
|
|
||||||
"entry": {
|
|
||||||
"timestamp": "2025-10-30 19:59",
|
|
||||||
"price": 3680.1,
|
|
||||||
"index": 272
|
|
||||||
},
|
|
||||||
"exit": {
|
|
||||||
"timestamp": "2025-10-30 21:59",
|
|
||||||
"price": 3767.82,
|
|
||||||
"index": 312
|
|
||||||
},
|
|
||||||
"direction": "LONG",
|
|
||||||
"profit_loss_pct": 2.38363087959567,
|
|
||||||
"notes": "",
|
|
||||||
"created_at": "2025-10-31T00:31:10.201966",
|
|
||||||
"market_context": {
|
|
||||||
"entry_state": {},
|
|
||||||
"exit_state": {}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"annotation_id": "479eb310-c963-4837-b712-70e5a42afb53",
|
"annotation_id": "479eb310-c963-4837-b712-70e5a42afb53",
|
||||||
"symbol": "ETH/USDT",
|
"symbol": "ETH/USDT",
|
||||||
@@ -114,10 +91,56 @@
|
|||||||
"entry_state": {},
|
"entry_state": {},
|
||||||
"exit_state": {}
|
"exit_state": {}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"annotation_id": "6b529132-8a3e-488d-b354-db8785ddaa71",
|
||||||
|
"symbol": "ETH/USDT",
|
||||||
|
"timeframe": "1m",
|
||||||
|
"entry": {
|
||||||
|
"timestamp": "2025-11-11 12:07",
|
||||||
|
"price": 3594.33,
|
||||||
|
"index": 144
|
||||||
|
},
|
||||||
|
"exit": {
|
||||||
|
"timestamp": "2025-11-11 20:46",
|
||||||
|
"price": 3429.24,
|
||||||
|
"index": 329
|
||||||
|
},
|
||||||
|
"direction": "SHORT",
|
||||||
|
"profit_loss_pct": 4.593067414511193,
|
||||||
|
"notes": "",
|
||||||
|
"created_at": "2025-11-11T23:23:00.643510",
|
||||||
|
"market_context": {
|
||||||
|
"entry_state": {},
|
||||||
|
"exit_state": {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"annotation_id": "bbafc50c-f885-4dbc-b0cb-fdfb48223b5c",
|
||||||
|
"symbol": "ETH/USDT",
|
||||||
|
"timeframe": "1m",
|
||||||
|
"entry": {
|
||||||
|
"timestamp": "2025-11-12 07:58",
|
||||||
|
"price": 3424.58,
|
||||||
|
"index": 284
|
||||||
|
},
|
||||||
|
"exit": {
|
||||||
|
"timestamp": "2025-11-12 11:08",
|
||||||
|
"price": 3546.35,
|
||||||
|
"index": 329
|
||||||
|
},
|
||||||
|
"direction": "LONG",
|
||||||
|
"profit_loss_pct": 3.5557645025083366,
|
||||||
|
"notes": "",
|
||||||
|
"created_at": "2025-11-12T13:11:31.267142",
|
||||||
|
"market_context": {
|
||||||
|
"entry_state": {},
|
||||||
|
"exit_state": {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"total_annotations": 5,
|
"total_annotations": 6,
|
||||||
"last_updated": "2025-10-31T00:35:00.549074"
|
"last_updated": "2025-11-12T13:11:31.267456"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -443,6 +443,8 @@ class AdvancedTradingTransformer(nn.Module):
|
|||||||
self.uncertainty_estimator = UncertaintyEstimation(config.d_model)
|
self.uncertainty_estimator = UncertaintyEstimation(config.d_model)
|
||||||
|
|
||||||
# Enhanced price prediction head (auxiliary task)
|
# Enhanced price prediction head (auxiliary task)
|
||||||
|
# Predicts price change ratio (future_price - current_price) / current_price
|
||||||
|
# Use Tanh to constrain to [-1, 1] range (max 100% change up/down)
|
||||||
self.price_head = nn.Sequential(
|
self.price_head = nn.Sequential(
|
||||||
nn.Linear(config.d_model, config.d_model // 2),
|
nn.Linear(config.d_model, config.d_model // 2),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
@@ -450,7 +452,8 @@ class AdvancedTradingTransformer(nn.Module):
|
|||||||
nn.Linear(config.d_model // 2, config.d_model // 4),
|
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Dropout(config.dropout),
|
nn.Dropout(config.dropout),
|
||||||
nn.Linear(config.d_model // 4, 1)
|
nn.Linear(config.d_model // 4, 1),
|
||||||
|
nn.Tanh() # Constrain to [-1, 1] range for price change ratio
|
||||||
)
|
)
|
||||||
|
|
||||||
# Additional specialized heads for 46M model
|
# Additional specialized heads for 46M model
|
||||||
@@ -473,6 +476,7 @@ class AdvancedTradingTransformer(nn.Module):
|
|||||||
# NEW: Next candle OHLCV prediction heads for each timeframe (1s, 1m, 1h, 1d)
|
# NEW: Next candle OHLCV prediction heads for each timeframe (1s, 1m, 1h, 1d)
|
||||||
# Each timeframe predicts: [open, high, low, close, volume] = 5 values
|
# Each timeframe predicts: [open, high, low, close, volume] = 5 values
|
||||||
# Note: self.timeframes already defined above in input projections
|
# Note: self.timeframes already defined above in input projections
|
||||||
|
# CRITICAL: Outputs are constrained to [0, 1] range using Sigmoid since inputs are normalized
|
||||||
self.next_candle_heads = nn.ModuleDict({
|
self.next_candle_heads = nn.ModuleDict({
|
||||||
tf: nn.Sequential(
|
tf: nn.Sequential(
|
||||||
nn.Linear(config.d_model, config.d_model // 2),
|
nn.Linear(config.d_model, config.d_model // 2),
|
||||||
@@ -481,11 +485,13 @@ class AdvancedTradingTransformer(nn.Module):
|
|||||||
nn.Linear(config.d_model // 2, config.d_model // 4),
|
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Dropout(config.dropout),
|
nn.Dropout(config.dropout),
|
||||||
nn.Linear(config.d_model // 4, 5) # OHLCV: [open, high, low, close, volume]
|
nn.Linear(config.d_model // 4, 5), # OHLCV: [open, high, low, close, volume]
|
||||||
|
nn.Sigmoid() # Constrain to [0, 1] to match normalized input range
|
||||||
) for tf in self.timeframes
|
) for tf in self.timeframes
|
||||||
})
|
})
|
||||||
|
|
||||||
# BTC next candle prediction head
|
# BTC next candle prediction head
|
||||||
|
# CRITICAL: Outputs are constrained to [0, 1] range using Sigmoid since inputs are normalized
|
||||||
self.btc_next_candle_head = nn.Sequential(
|
self.btc_next_candle_head = nn.Sequential(
|
||||||
nn.Linear(config.d_model, config.d_model // 2),
|
nn.Linear(config.d_model, config.d_model // 2),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
@@ -493,7 +499,8 @@ class AdvancedTradingTransformer(nn.Module):
|
|||||||
nn.Linear(config.d_model // 2, config.d_model // 4),
|
nn.Linear(config.d_model // 2, config.d_model // 4),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Dropout(config.dropout),
|
nn.Dropout(config.dropout),
|
||||||
nn.Linear(config.d_model // 4, 5) # OHLCV for BTC
|
nn.Linear(config.d_model // 4, 5), # OHLCV for BTC
|
||||||
|
nn.Sigmoid() # Constrain to [0, 1] to match normalized input range
|
||||||
)
|
)
|
||||||
|
|
||||||
# NEW: Next pivot point prediction heads for L1-L5 levels
|
# NEW: Next pivot point prediction heads for L1-L5 levels
|
||||||
@@ -1153,6 +1160,54 @@ class TradingTransformerTrainer:
|
|||||||
'learning_rates': []
|
'learning_rates': []
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def denormalize_prices(normalized_values: torch.Tensor, norm_params: Dict[str, float]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Denormalize price predictions back to real price space
|
||||||
|
|
||||||
|
Args:
|
||||||
|
normalized_values: Tensor of normalized values in [0, 1] range
|
||||||
|
norm_params: Dict with 'price_min' and 'price_max' keys
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Denormalized tensor in original price space
|
||||||
|
"""
|
||||||
|
price_min = norm_params.get('price_min', 0.0)
|
||||||
|
price_max = norm_params.get('price_max', 1.0)
|
||||||
|
|
||||||
|
if price_max > price_min:
|
||||||
|
return normalized_values * (price_max - price_min) + price_min
|
||||||
|
else:
|
||||||
|
return normalized_values
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def denormalize_candle(normalized_candle: torch.Tensor, norm_params: Dict[str, float]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Denormalize a full OHLCV candle back to real values
|
||||||
|
|
||||||
|
Args:
|
||||||
|
normalized_candle: Tensor of shape [..., 5] with normalized OHLCV
|
||||||
|
norm_params: Dict with normalization parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Denormalized OHLCV tensor
|
||||||
|
"""
|
||||||
|
denorm = normalized_candle.clone()
|
||||||
|
|
||||||
|
# Denormalize OHLC (first 4 values)
|
||||||
|
price_min = norm_params.get('price_min', 0.0)
|
||||||
|
price_max = norm_params.get('price_max', 1.0)
|
||||||
|
if price_max > price_min:
|
||||||
|
denorm[..., :4] = denorm[..., :4] * (price_max - price_min) + price_min
|
||||||
|
|
||||||
|
# Denormalize volume (5th value)
|
||||||
|
volume_min = norm_params.get('volume_min', 0.0)
|
||||||
|
volume_max = norm_params.get('volume_max', 1.0)
|
||||||
|
if volume_max > volume_min:
|
||||||
|
denorm[..., 4] = denorm[..., 4] * (volume_max - volume_min) + volume_min
|
||||||
|
|
||||||
|
return denorm
|
||||||
|
|
||||||
def train_step(self, batch: Dict[str, torch.Tensor], accumulate_gradients: bool = False) -> Dict[str, float]:
|
def train_step(self, batch: Dict[str, torch.Tensor], accumulate_gradients: bool = False) -> Dict[str, float]:
|
||||||
"""Single training step with optional gradient accumulation
|
"""Single training step with optional gradient accumulation
|
||||||
|
|
||||||
@@ -1217,8 +1272,50 @@ class TradingTransformerTrainer:
|
|||||||
trend_loss = self.price_criterion(trend_pred, trend_target)
|
trend_loss = self.price_criterion(trend_pred, trend_target)
|
||||||
logger.debug(f"Trend loss: {trend_loss.item():.6f} (pred={trend_pred[0].tolist()}, target={trend_target[0].tolist()})")
|
logger.debug(f"Trend loss: {trend_loss.item():.6f} (pred={trend_pred[0].tolist()}, target={trend_target[0].tolist()})")
|
||||||
|
|
||||||
|
# NEW: Next candle prediction loss for each timeframe
|
||||||
|
# This trains the model to predict full OHLCV for the next candle on each timeframe
|
||||||
|
candle_loss = torch.tensor(0.0, device=self.device)
|
||||||
|
candle_losses_detail = {} # Track per-timeframe losses (normalized space)
|
||||||
|
candle_losses_denorm = {} # Track per-timeframe losses (denormalized/real space)
|
||||||
|
|
||||||
|
if 'next_candles' in outputs:
|
||||||
|
timeframe_losses = []
|
||||||
|
|
||||||
|
# Get normalization parameters if available
|
||||||
|
norm_params = batch.get('norm_params', {})
|
||||||
|
|
||||||
|
# Calculate loss for each timeframe that has target data
|
||||||
|
for tf in ['1s', '1m', '1h', '1d']:
|
||||||
|
future_key = f'future_candle_{tf}'
|
||||||
|
|
||||||
|
if tf in outputs['next_candles'] and future_key in batch:
|
||||||
|
pred_candle = outputs['next_candles'][tf] # [batch, 5] - predicted OHLCV (normalized)
|
||||||
|
target_candle = batch[future_key] # [batch, 5] - actual OHLCV (normalized)
|
||||||
|
|
||||||
|
if target_candle is not None and pred_candle.shape == target_candle.shape:
|
||||||
|
# MSE loss on normalized values (used for backprop)
|
||||||
|
tf_loss = self.price_criterion(pred_candle, target_candle)
|
||||||
|
timeframe_losses.append(tf_loss)
|
||||||
|
candle_losses_detail[tf] = tf_loss.item()
|
||||||
|
|
||||||
|
# ALSO calculate denormalized loss for better interpretability
|
||||||
|
if tf in norm_params:
|
||||||
|
with torch.no_grad():
|
||||||
|
pred_denorm = self.denormalize_candle(pred_candle, norm_params[tf])
|
||||||
|
target_denorm = self.denormalize_candle(target_candle, norm_params[tf])
|
||||||
|
denorm_loss = self.price_criterion(pred_denorm, target_denorm)
|
||||||
|
candle_losses_denorm[tf] = denorm_loss.item()
|
||||||
|
|
||||||
|
# Average loss across available timeframes
|
||||||
|
if timeframe_losses:
|
||||||
|
candle_loss = torch.stack(timeframe_losses).mean()
|
||||||
|
if candle_losses_denorm:
|
||||||
|
logger.debug(f"Candle losses (normalized): {candle_losses_detail}")
|
||||||
|
logger.debug(f"Candle losses (real prices): {candle_losses_denorm}")
|
||||||
|
|
||||||
# Start with base losses - avoid inplace operations on computation graph
|
# Start with base losses - avoid inplace operations on computation graph
|
||||||
total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss # Weight auxiliary tasks
|
# Weight: action=1.0, price=0.1, trend=0.05, candle=0.15
|
||||||
|
total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss + 0.15 * candle_loss
|
||||||
|
|
||||||
# CRITICAL FIX: Scale loss for gradient accumulation
|
# CRITICAL FIX: Scale loss for gradient accumulation
|
||||||
# This prevents gradient explosion when accumulating over multiple batches
|
# This prevents gradient explosion when accumulating over multiple batches
|
||||||
@@ -1322,7 +1419,9 @@ class TradingTransformerTrainer:
|
|||||||
'total_loss': total_loss.item(),
|
'total_loss': total_loss.item(),
|
||||||
'action_loss': action_loss.item(),
|
'action_loss': action_loss.item(),
|
||||||
'price_loss': price_loss.item(),
|
'price_loss': price_loss.item(),
|
||||||
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0, # NEW
|
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0,
|
||||||
|
'candle_loss': candle_loss.item() if isinstance(candle_loss, torch.Tensor) else 0.0,
|
||||||
|
'candle_loss_denorm': candle_losses_denorm, # Dict of denormalized losses per timeframe
|
||||||
'accuracy': accuracy.item(),
|
'accuracy': accuracy.item(),
|
||||||
'candle_accuracy': candle_accuracy,
|
'candle_accuracy': candle_accuracy,
|
||||||
'learning_rate': self.scheduler.get_last_lr()[0]
|
'learning_rate': self.scheduler.get_last_lr()[0]
|
||||||
@@ -1330,7 +1429,7 @@ class TradingTransformerTrainer:
|
|||||||
|
|
||||||
# CRITICAL: Delete large tensors to free memory immediately
|
# CRITICAL: Delete large tensors to free memory immediately
|
||||||
# This prevents memory accumulation across batches
|
# This prevents memory accumulation across batches
|
||||||
del outputs, total_loss, action_loss, price_loss, predictions, accuracy
|
del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, predictions, accuracy
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|||||||
186
docs/main/_MODEL_INPUT_OUTPUT_STRUCTURE.md
Normal file
186
docs/main/_MODEL_INPUT_OUTPUT_STRUCTURE.md
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
# Transformer Model Input/Output Structure
|
||||||
|
|
||||||
|
## FIXED ISSUE: Batch Data Deletion Bug
|
||||||
|
**Problem**: Training was failing after epoch 1 with "At least one timeframe must be provided"
|
||||||
|
**Root Cause**: Batch tensors were being deleted after each use in the training loop, but the same batch dictionaries were being reused across all epochs.
|
||||||
|
**Solution**: Removed batch deletion from inside the epoch loop and moved cleanup to after all epochs complete.
|
||||||
|
|
||||||
|
## Current Model Architecture
|
||||||
|
|
||||||
|
### INPUT Structure (Multi-Timeframe)
|
||||||
|
|
||||||
|
The model accepts the following inputs in the `forward()` method:
|
||||||
|
|
||||||
|
```python
|
||||||
|
forward(
|
||||||
|
# Price data for different timeframes - [batch, seq_len, 5] OHLCV
|
||||||
|
price_data_1s=None, # 1-second timeframe
|
||||||
|
price_data_1m=None, # 1-minute timeframe
|
||||||
|
price_data_1h=None, # 1-hour timeframe
|
||||||
|
price_data_1d=None, # 1-day timeframe
|
||||||
|
|
||||||
|
# Reference data
|
||||||
|
btc_data_1m=None, # BTC reference - [batch, seq_len, 5]
|
||||||
|
|
||||||
|
# Additional features
|
||||||
|
cob_data=None, # COB orderbook data - [batch, seq_len, 100]
|
||||||
|
tech_data=None, # Technical indicators - [batch, 40]
|
||||||
|
market_data=None, # Market context (pivots, volume) - [batch, 30]
|
||||||
|
position_state=None, # Current position state - [batch, 5]
|
||||||
|
|
||||||
|
# Legacy support
|
||||||
|
price_data=None # Fallback to single timeframe
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**At least one timeframe** (price_data_1s, 1m, 1h, or 1d) must be provided, otherwise the model raises:
|
||||||
|
```
|
||||||
|
ValueError: At least one timeframe must be provided
|
||||||
|
```
|
||||||
|
|
||||||
|
### OUTPUT Structure
|
||||||
|
|
||||||
|
The model returns a dictionary with the following predictions:
|
||||||
|
|
||||||
|
```python
|
||||||
|
outputs = {
|
||||||
|
# PRIMARY OUTPUTS (trained with loss):
|
||||||
|
'action_logits': tensor, # [batch, 3] - BUY/SELL/HOLD logits
|
||||||
|
'action_probs': tensor, # [batch, 3] - softmax probabilities
|
||||||
|
'price_prediction': tensor, # [batch, 1] - next price change ratio
|
||||||
|
'confidence': tensor, # [batch, 1] - prediction confidence
|
||||||
|
|
||||||
|
# TREND ANALYSIS (trained with loss):
|
||||||
|
'trend_analysis': {
|
||||||
|
'angle_radians': tensor, # [batch, 1] - trend angle in radians
|
||||||
|
'steepness': tensor, # [batch, 1] - trend steepness (0-1)
|
||||||
|
'direction': tensor # [batch, 1] - direction (-1/0/+1)
|
||||||
|
},
|
||||||
|
|
||||||
|
# NEXT CANDLE PREDICTIONS (evaluated but NOT trained):
|
||||||
|
'next_candles': {
|
||||||
|
'1s': tensor, # [batch, 5] - predicted OHLCV for 1s
|
||||||
|
'1m': tensor, # [batch, 5] - predicted OHLCV for 1m
|
||||||
|
'1h': tensor, # [batch, 5] - predicted OHLCV for 1h
|
||||||
|
'1d': tensor, # [batch, 5] - predicted OHLCV for 1d
|
||||||
|
},
|
||||||
|
'btc_next_candle': tensor, # [batch, 5] - predicted BTC OHLCV
|
||||||
|
|
||||||
|
# PIVOT PREDICTIONS:
|
||||||
|
'next_pivots': {
|
||||||
|
'L1': {
|
||||||
|
'price': tensor, # [batch, 1] - pivot price
|
||||||
|
'type_prob_high': tensor, # [batch, 1] - probability of high
|
||||||
|
'type_prob_low': tensor, # [batch, 1] - probability of low
|
||||||
|
'pivot_type': tensor, # [batch, 1] - 0=high, 1=low
|
||||||
|
'confidence': tensor # [batch, 1] - confidence
|
||||||
|
},
|
||||||
|
# Same structure for L2, L3, L4, L5
|
||||||
|
},
|
||||||
|
|
||||||
|
# AUXILIARY OUTPUTS:
|
||||||
|
'volatility_prediction': tensor, # [batch, 1]
|
||||||
|
'trend_strength_prediction': tensor, # [batch, 1]
|
||||||
|
'uncertainty_mean': tensor, # [batch, 1]
|
||||||
|
'uncertainty_std': tensor # [batch, 1]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### TRAINING TARGETS (in batch)
|
||||||
|
|
||||||
|
```python
|
||||||
|
batch = {
|
||||||
|
# Input features (see INPUT Structure above)
|
||||||
|
'price_data_1s': tensor,
|
||||||
|
'price_data_1m': tensor,
|
||||||
|
'price_data_1h': tensor,
|
||||||
|
'price_data_1d': tensor,
|
||||||
|
'btc_data_1m': tensor,
|
||||||
|
'cob_data': tensor,
|
||||||
|
'tech_data': tensor,
|
||||||
|
'market_data': tensor,
|
||||||
|
'position_state': tensor,
|
||||||
|
|
||||||
|
# Training targets:
|
||||||
|
'actions': tensor, # [batch] - target action (0/1/2)
|
||||||
|
'future_prices': tensor, # [batch, 1] - actual price change ratio
|
||||||
|
'trade_success': tensor, # [batch, 1] - 1.0 if profitable
|
||||||
|
'trend_target': tensor, # [batch, 3] - [angle, steepness, direction]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### LOSS CALCULATION
|
||||||
|
|
||||||
|
Current loss function in `train_step()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss
|
||||||
|
|
||||||
|
where:
|
||||||
|
- action_loss: CrossEntropyLoss(action_logits, actions)
|
||||||
|
- price_loss: MSELoss(price_prediction, future_prices)
|
||||||
|
- trend_loss: MSELoss(trend_pred, trend_target)
|
||||||
|
```
|
||||||
|
|
||||||
|
**NOTE**: Next candle predictions are currently only used for accuracy evaluation, NOT trained directly.
|
||||||
|
|
||||||
|
## CURRENT ISSUES AND RECOMMENDATIONS
|
||||||
|
|
||||||
|
### Issue 1: Next Candle Predictions Not Trained
|
||||||
|
**Status**: The model outputs next candle predictions for each timeframe, but these are NOT included in the loss function.
|
||||||
|
**Impact**: The model is not explicitly learning to predict next candle OHLCV values.
|
||||||
|
|
||||||
|
**Recommendation**: Add next candle loss to training:
|
||||||
|
```python
|
||||||
|
# Calculate next candle loss for each available timeframe
|
||||||
|
candle_loss = 0.0
|
||||||
|
if 'next_candles' in outputs:
|
||||||
|
for tf in ['1s', '1m', '1h', '1d']:
|
||||||
|
if tf in outputs['next_candles'] and f'future_candle_{tf}' in batch:
|
||||||
|
pred_candle = outputs['next_candles'][tf] # [batch, 5]
|
||||||
|
target_candle = batch[f'future_candle_{tf}'] # [batch, 5]
|
||||||
|
candle_loss += MSELoss(pred_candle, target_candle)
|
||||||
|
|
||||||
|
total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss + 0.1 * candle_loss
|
||||||
|
```
|
||||||
|
|
||||||
|
### Issue 2: Annotation Timeframe vs Prediction Timeframe
|
||||||
|
**Current Behavior**:
|
||||||
|
- Annotations are created at a specific point in time
|
||||||
|
- The model receives multiple timeframes (1s, 1m, 1h, 1d) as input
|
||||||
|
- Predictions are made for ALL timeframes simultaneously
|
||||||
|
- Only the 1m timeframe prediction is currently evaluated for accuracy
|
||||||
|
|
||||||
|
**Question**: Should predictions be specific to the annotation's timeframe?
|
||||||
|
|
||||||
|
**Options**:
|
||||||
|
1. **Multi-timeframe predictions (current)**: Keep predicting all timeframes, add loss for each
|
||||||
|
2. **Annotation-specific predictions**: Only predict/train on the timeframe that matches the annotation
|
||||||
|
3. **Weighted predictions**: Weight the loss by the annotation's timeframe (e.g., if annotated on 1m, weight 1m prediction higher)
|
||||||
|
|
||||||
|
### Issue 3: Missing Target Data for Next Candles
|
||||||
|
**Current**: The batch only contains `future_prices` (next close price change)
|
||||||
|
**Needed**: To train next candle predictions, we need full OHLCV targets:
|
||||||
|
- `future_candle_1s`: [batch, 5] - next 1s candle OHLCV
|
||||||
|
- `future_candle_1m`: [batch, 5] - next 1m candle OHLCV
|
||||||
|
- `future_candle_1h`: [batch, 5] - next 1h candle OHLCV
|
||||||
|
- `future_candle_1d`: [batch, 5] - next 1d candle OHLCV
|
||||||
|
|
||||||
|
**Location to add**: `ANNOTATE/core/real_training_adapter.py` in `_convert_annotation_to_transformer_batch()`
|
||||||
|
|
||||||
|
## SUMMARY
|
||||||
|
|
||||||
|
✅ **Fixed**: Batch deletion bug causing epoch 2+ failures
|
||||||
|
✅ **Working**: Model can predict next candles for all timeframes
|
||||||
|
✅ **Working**: Model can predict trend vector (angle, steepness, direction)
|
||||||
|
❌ **Missing**: Loss calculation for next candle predictions
|
||||||
|
❌ **Missing**: Target data (future OHLCV) for next candle training
|
||||||
|
⚠️ **Unclear**: Should predictions be timeframe-specific or multi-timeframe?
|
||||||
|
|
||||||
|
## NEXT STEPS
|
||||||
|
|
||||||
|
1. **Add future OHLCV target data** to training batches
|
||||||
|
2. **Add next candle loss** to the training loop
|
||||||
|
3. **Clarify prediction strategy**: Single timeframe vs multi-timeframe
|
||||||
|
4. **Test training** with enhanced loss function
|
||||||
|
|
||||||
288
test_normalization_fix.py
Normal file
288
test_normalization_fix.py
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script to verify normalization fix is working correctly
|
||||||
|
|
||||||
|
This creates a simple test batch and verifies:
|
||||||
|
1. Model outputs are in expected ranges (thanks to Sigmoid/Tanh constraints)
|
||||||
|
2. Normalization parameters are stored and can be retrieved
|
||||||
|
3. Denormalization works correctly
|
||||||
|
4. Losses are in reasonable ranges (not billions!)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from NN.models.advanced_transformer_trading import (
|
||||||
|
AdvancedTradingTransformer,
|
||||||
|
TradingTransformerConfig,
|
||||||
|
TradingTransformerTrainer
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_test_batch():
|
||||||
|
"""Create a simple test batch with known normalization parameters"""
|
||||||
|
batch_size = 1
|
||||||
|
seq_len = 200
|
||||||
|
|
||||||
|
# Create synthetic price data in [0, 1] range (normalized)
|
||||||
|
price_data_1m = torch.rand(batch_size, seq_len, 5) * 0.2 + 0.4 # Range [0.4, 0.6]
|
||||||
|
|
||||||
|
# Create normalization parameters (simulate ETHUSDT around $2500)
|
||||||
|
norm_params = {
|
||||||
|
'1m': {
|
||||||
|
'price_min': 2480.0,
|
||||||
|
'price_max': 2520.0,
|
||||||
|
'volume_min': 100.0,
|
||||||
|
'volume_max': 10000.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create future candle target (slightly higher close)
|
||||||
|
future_candle_1m = torch.rand(batch_size, 5) * 0.2 + 0.45 # Range [0.45, 0.65]
|
||||||
|
|
||||||
|
# Create other required inputs
|
||||||
|
cob_data = torch.zeros(batch_size, seq_len, 100)
|
||||||
|
tech_data = torch.zeros(batch_size, 40)
|
||||||
|
market_data = torch.zeros(batch_size, 30)
|
||||||
|
position_state = torch.zeros(batch_size, 5)
|
||||||
|
actions = torch.tensor([1], dtype=torch.long) # BUY action
|
||||||
|
future_prices = torch.tensor([[0.01]], dtype=torch.float32) # 1% price increase expected
|
||||||
|
trade_success = torch.tensor([[1.0]], dtype=torch.float32)
|
||||||
|
trend_target = torch.tensor([[0.785, 0.5, 1.0]], dtype=torch.float32)
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
'price_data_1m': price_data_1m,
|
||||||
|
'price_data_1s': None,
|
||||||
|
'price_data_1h': None,
|
||||||
|
'price_data_1d': None,
|
||||||
|
'btc_data_1m': None,
|
||||||
|
'cob_data': cob_data,
|
||||||
|
'tech_data': tech_data,
|
||||||
|
'market_data': market_data,
|
||||||
|
'position_state': position_state,
|
||||||
|
'actions': actions,
|
||||||
|
'future_prices': future_prices,
|
||||||
|
'trade_success': trade_success,
|
||||||
|
'trend_target': trend_target,
|
||||||
|
'future_candle_1m': future_candle_1m,
|
||||||
|
'future_candle_1s': None,
|
||||||
|
'future_candle_1h': None,
|
||||||
|
'future_candle_1d': None,
|
||||||
|
'norm_params': norm_params
|
||||||
|
}
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def test_model_outputs():
|
||||||
|
"""Test that model outputs are in expected ranges"""
|
||||||
|
print("=" * 80)
|
||||||
|
print("TESTING: Model Output Constraints")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Create small model for testing
|
||||||
|
config = TradingTransformerConfig(
|
||||||
|
d_model=128,
|
||||||
|
n_heads=4,
|
||||||
|
n_layers=2,
|
||||||
|
seq_len=200
|
||||||
|
)
|
||||||
|
|
||||||
|
model = AdvancedTradingTransformer(config)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch = create_test_batch()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(
|
||||||
|
price_data_1m=batch['price_data_1m'],
|
||||||
|
cob_data=batch['cob_data'],
|
||||||
|
tech_data=batch['tech_data'],
|
||||||
|
market_data=batch['market_data'],
|
||||||
|
position_state=batch['position_state']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check candle predictions are in [0, 1] range (thanks to Sigmoid)
|
||||||
|
if 'next_candles' in outputs and '1m' in outputs['next_candles']:
|
||||||
|
candle_pred = outputs['next_candles']['1m']
|
||||||
|
print(f"\nCandle Prediction (1m):")
|
||||||
|
print(f" Shape: {candle_pred.shape}")
|
||||||
|
print(f" Min value: {candle_pred.min().item():.6f}")
|
||||||
|
print(f" Max value: {candle_pred.max().item():.6f}")
|
||||||
|
print(f" Mean value: {candle_pred.mean().item():.6f}")
|
||||||
|
|
||||||
|
if candle_pred.min() >= 0.0 and candle_pred.max() <= 1.0:
|
||||||
|
print(" ✅ PASS: Values in [0, 1] range (Sigmoid working!)")
|
||||||
|
else:
|
||||||
|
print(" ❌ FAIL: Values outside [0, 1] range!")
|
||||||
|
|
||||||
|
# Check price prediction is in [-1, 1] range (thanks to Tanh)
|
||||||
|
if 'price_prediction' in outputs:
|
||||||
|
price_pred = outputs['price_prediction']
|
||||||
|
print(f"\nPrice Prediction (change ratio):")
|
||||||
|
print(f" Shape: {price_pred.shape}")
|
||||||
|
print(f" Value: {price_pred.item():.6f}")
|
||||||
|
|
||||||
|
if price_pred.min() >= -1.0 and price_pred.max() <= 1.0:
|
||||||
|
print(" ✅ PASS: Values in [-1, 1] range (Tanh working!)")
|
||||||
|
else:
|
||||||
|
print(" ❌ FAIL: Values outside [-1, 1] range!")
|
||||||
|
|
||||||
|
# Check action probabilities sum to 1
|
||||||
|
if 'action_probs' in outputs:
|
||||||
|
action_probs = outputs['action_probs']
|
||||||
|
print(f"\nAction Probabilities:")
|
||||||
|
print(f" BUY: {action_probs[0, 0].item():.4f}")
|
||||||
|
print(f" SELL: {action_probs[0, 1].item():.4f}")
|
||||||
|
print(f" HOLD: {action_probs[0, 2].item():.4f}")
|
||||||
|
print(f" Sum: {action_probs[0].sum().item():.6f}")
|
||||||
|
|
||||||
|
if abs(action_probs[0].sum().item() - 1.0) < 0.001:
|
||||||
|
print(" ✅ PASS: Probabilities sum to 1.0")
|
||||||
|
else:
|
||||||
|
print(" ❌ FAIL: Probabilities don't sum to 1.0!")
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def test_denormalization():
|
||||||
|
"""Test denormalization functions"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("TESTING: Denormalization Functions")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Create test normalized candle
|
||||||
|
normalized_candle = torch.tensor([[0.5, 0.6, 0.4, 0.55, 0.3]]) # OHLCV
|
||||||
|
|
||||||
|
# Normalization params (ETHUSDT $2480-$2520)
|
||||||
|
norm_params = {
|
||||||
|
'price_min': 2480.0,
|
||||||
|
'price_max': 2520.0,
|
||||||
|
'volume_min': 100.0,
|
||||||
|
'volume_max': 10000.0
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"\nNormalized Candle: {normalized_candle[0].tolist()}")
|
||||||
|
print(f"Normalization Params: price [{norm_params['price_min']}, {norm_params['price_max']}], "
|
||||||
|
f"volume [{norm_params['volume_min']}, {norm_params['volume_max']}]")
|
||||||
|
|
||||||
|
# Denormalize
|
||||||
|
denorm_candle = TradingTransformerTrainer.denormalize_candle(normalized_candle, norm_params)
|
||||||
|
|
||||||
|
print(f"\nDenormalized Candle:")
|
||||||
|
print(f" Open: ${denorm_candle[0, 0].item():.2f}")
|
||||||
|
print(f" High: ${denorm_candle[0, 1].item():.2f}")
|
||||||
|
print(f" Low: ${denorm_candle[0, 2].item():.2f}")
|
||||||
|
print(f" Close: ${denorm_candle[0, 3].item():.2f}")
|
||||||
|
print(f" Volume: {denorm_candle[0, 4].item():.2f}")
|
||||||
|
|
||||||
|
# Verify values are in expected range
|
||||||
|
expected_min_price = norm_params['price_min']
|
||||||
|
expected_max_price = norm_params['price_max']
|
||||||
|
|
||||||
|
prices_ok = True
|
||||||
|
for i, name in enumerate(['Open', 'High', 'Low', 'Close']):
|
||||||
|
value = denorm_candle[0, i].item()
|
||||||
|
if value < expected_min_price or value > expected_max_price:
|
||||||
|
print(f" ❌ FAIL: {name} price ${value:.2f} outside expected range!")
|
||||||
|
prices_ok = False
|
||||||
|
|
||||||
|
if prices_ok:
|
||||||
|
print(f" ✅ PASS: All prices in expected range [${expected_min_price}, ${expected_max_price}]")
|
||||||
|
|
||||||
|
# Verify volume
|
||||||
|
volume = denorm_candle[0, 4].item()
|
||||||
|
if norm_params['volume_min'] <= volume <= norm_params['volume_max']:
|
||||||
|
print(f" ✅ PASS: Volume {volume:.2f} in expected range [{norm_params['volume_min']}, {norm_params['volume_max']}]")
|
||||||
|
else:
|
||||||
|
print(f" ❌ FAIL: Volume {volume:.2f} outside expected range!")
|
||||||
|
|
||||||
|
def test_loss_magnitude():
|
||||||
|
"""Test that losses are in reasonable ranges"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("TESTING: Loss Magnitudes")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
config = TradingTransformerConfig(
|
||||||
|
d_model=128,
|
||||||
|
n_heads=4,
|
||||||
|
n_layers=2,
|
||||||
|
seq_len=200
|
||||||
|
)
|
||||||
|
|
||||||
|
model = AdvancedTradingTransformer(config)
|
||||||
|
trainer = TradingTransformerTrainer(model, config)
|
||||||
|
|
||||||
|
batch = create_test_batch()
|
||||||
|
|
||||||
|
# Run one training step
|
||||||
|
result = trainer.train_step(batch, accumulate_gradients=False)
|
||||||
|
|
||||||
|
print(f"\nTraining Step Results:")
|
||||||
|
print(f" Total Loss: {result['total_loss']:.6f}")
|
||||||
|
print(f" Action Loss: {result['action_loss']:.6f}")
|
||||||
|
print(f" Price Loss: {result['price_loss']:.6f}")
|
||||||
|
print(f" Trend Loss: {result['trend_loss']:.6f}")
|
||||||
|
print(f" Candle Loss: {result['candle_loss']:.6f}")
|
||||||
|
print(f" Action Accuracy: {result['accuracy']:.2%}")
|
||||||
|
print(f" Candle Accuracy: {result['candle_accuracy']:.2%}")
|
||||||
|
|
||||||
|
# Check losses are reasonable (not billions!)
|
||||||
|
all_ok = True
|
||||||
|
|
||||||
|
if result['total_loss'] < 100.0:
|
||||||
|
print(f" ✅ PASS: Total loss < 100 (was {result['total_loss']:.6f})")
|
||||||
|
else:
|
||||||
|
print(f" ❌ FAIL: Total loss too high! ({result['total_loss']:.6f})")
|
||||||
|
all_ok = False
|
||||||
|
|
||||||
|
if result['candle_loss'] < 10.0:
|
||||||
|
print(f" ✅ PASS: Candle loss < 10 (was {result['candle_loss']:.6f})")
|
||||||
|
else:
|
||||||
|
print(f" ❌ FAIL: Candle loss too high! ({result['candle_loss']:.6f})")
|
||||||
|
all_ok = False
|
||||||
|
|
||||||
|
# Check denormalized losses if available
|
||||||
|
if 'candle_loss_denorm' in result and result['candle_loss_denorm']:
|
||||||
|
print(f"\n Denormalized Candle Losses (Real Price Errors):")
|
||||||
|
for tf, loss in result['candle_loss_denorm'].items():
|
||||||
|
print(f" {tf}: ${loss:.2f}")
|
||||||
|
if loss < 1000.0:
|
||||||
|
print(f" ✅ PASS: Real price error < $1000")
|
||||||
|
else:
|
||||||
|
print(f" ❌ FAIL: Real price error too high!")
|
||||||
|
all_ok = False
|
||||||
|
|
||||||
|
if all_ok:
|
||||||
|
print("\n ✅ ALL TESTS PASSED: Losses in reasonable ranges!")
|
||||||
|
else:
|
||||||
|
print("\n ❌ SOME TESTS FAILED: Check model/normalization!")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("NORMALIZATION FIX VERIFICATION TEST")
|
||||||
|
print("=" * 80)
|
||||||
|
print("\nThis test verifies that:")
|
||||||
|
print("1. Model outputs are properly constrained (Sigmoid/Tanh)")
|
||||||
|
print("2. Normalization parameters are stored and accessible")
|
||||||
|
print("3. Denormalization functions work correctly")
|
||||||
|
print("4. Losses are in reasonable ranges (not billions!)")
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
outputs = test_model_outputs()
|
||||||
|
test_denormalization()
|
||||||
|
test_loss_magnitude()
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("TEST SUMMARY")
|
||||||
|
print("=" * 80)
|
||||||
|
print("\nIf all tests passed (✅), the normalization fix is working correctly!")
|
||||||
|
print("You should now see reasonable losses in training logs:")
|
||||||
|
print(" - Total loss: ~0.5-1.0 (not billions!)")
|
||||||
|
print(" - Candle loss: ~0.1-0.3")
|
||||||
|
print(" - Real price errors: $2-20 (not $147,000!)")
|
||||||
|
print("\nYou can now resume training and monitor these metrics.")
|
||||||
|
print("=" * 80 + "\n")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
Reference in New Issue
Block a user