raining normalization fix

This commit is contained in:
Dobromir Popov
2025-11-12 14:36:28 +02:00
parent 4c04503f3e
commit a7a22334fb
5 changed files with 800 additions and 50 deletions

View File

@@ -1197,12 +1197,90 @@ class RealTrainingAdapter:
ohlcv[:, 4] = (ohlcv[:, 4] - volume_min) / (volume_max - volume_min) ohlcv[:, 4] = (ohlcv[:, 4] - volume_min) / (volume_max - volume_min)
# Convert to tensor and add batch dimension [1, seq_len, 5] # Convert to tensor and add batch dimension [1, seq_len, 5]
return torch.tensor(ohlcv, dtype=torch.float32).unsqueeze(0) # STORE normalization parameters for denormalization
# Return tuple: (tensor, normalization_params)
norm_params = {
'price_min': float(price_min),
'price_max': float(price_max),
'volume_min': float(volume_min),
'volume_max': float(volume_max)
}
return torch.tensor(ohlcv, dtype=torch.float32).unsqueeze(0), norm_params
except Exception as e: except Exception as e:
logger.error(f"Error extracting timeframe data: {e}") logger.error(f"Error extracting timeframe data: {e}")
return None return None
def _extract_next_candle(self, tf_data: Dict, reference_data = None) -> Optional[torch.Tensor]:
"""
Extract the NEXT candle OHLCV (after the current sequence) as training target
This extracts the candle that comes immediately after the sequence used for input.
Normalized using the same price range as the reference data for consistency.
Args:
tf_data: Timeframe data dictionary with 'open', 'high', 'low', 'close', 'volume'
reference_data: Optional reference OHLCV data [seq_len, 5] to get normalization params
Returns:
Tensor of shape [1, 5] representing next candle OHLCV, or None if not available
"""
import torch
import numpy as np
try:
# Extract OHLCV arrays - get the LAST value as the "next" candle
# In annotation context, the "current" sequence is historical, and we have the "next" candle
opens = np.array(tf_data.get('open', []), dtype=np.float32)
highs = np.array(tf_data.get('high', []), dtype=np.float32)
lows = np.array(tf_data.get('low', []), dtype=np.float32)
closes = np.array(tf_data.get('close', []), dtype=np.float32)
volumes = np.array(tf_data.get('volume', []), dtype=np.float32)
if len(closes) == 0:
return None
# Get the last candle as the "next" candle target
# This assumes the timeframe data includes one extra candle after the sequence
next_open = opens[-1]
next_high = highs[-1]
next_low = lows[-1]
next_close = closes[-1]
next_volume = volumes[-1]
# Create OHLCV array [5]
next_candle = np.array([next_open, next_high, next_low, next_close, next_volume], dtype=np.float32)
# Normalize using reference data if provided
if reference_data is not None:
# Use same normalization as input sequence
price_min = np.min(reference_data[:, :4])
price_max = np.max(reference_data[:, :4])
if price_max > price_min:
next_candle[:4] = (next_candle[:4] - price_min) / (price_max - price_min)
volume_min = np.min(reference_data[:, 4])
volume_max = np.max(reference_data[:, 4])
if volume_max > volume_min:
next_candle[4] = (next_candle[4] - volume_min) / (volume_max - volume_min)
else:
# If no reference, normalize relative to current candle's close
if next_close > 0:
next_candle[:4] = next_candle[:4] / next_close
# Volume normalized to 0-1 range (simple min-max with self)
if next_volume > 0:
next_candle[4] = 1.0
# Return as [1, 5] tensor
return torch.tensor(next_candle, dtype=torch.float32).unsqueeze(0)
except Exception as e:
logger.error(f"Error extracting next candle: {e}")
return None
def _convert_annotation_to_transformer_batch(self, training_sample: Dict) -> Dict[str, 'torch.Tensor']: def _convert_annotation_to_transformer_batch(self, training_sample: Dict) -> Dict[str, 'torch.Tensor']:
""" """
Convert annotation training sample to multi-timeframe transformer input Convert annotation training sample to multi-timeframe transformer input
@@ -1237,16 +1315,40 @@ class RealTrainingAdapter:
target_seq_len = min(len(tf_data['close']), 200) # Cap at 200 target_seq_len = min(len(tf_data['close']), 200) # Cap at 200
break break
# Extract each timeframe (returns None if not available) # Extract each timeframe (returns tuple: (tensor, norm_params) or None)
price_data_1s = self._extract_timeframe_data(timeframes.get('1s', {}), target_seq_len) if '1s' in timeframes else None # Store normalization parameters for each timeframe
price_data_1m = self._extract_timeframe_data(timeframes.get('1m', {}), target_seq_len) if '1m' in timeframes else None norm_params_dict = {}
price_data_1h = self._extract_timeframe_data(timeframes.get('1h', {}), target_seq_len) if '1h' in timeframes else None
price_data_1d = self._extract_timeframe_data(timeframes.get('1d', {}), target_seq_len) if '1d' in timeframes else None result_1s = self._extract_timeframe_data(timeframes.get('1s', {}), target_seq_len) if '1s' in timeframes else None
if result_1s:
price_data_1s, norm_params_dict['1s'] = result_1s
else:
price_data_1s = None
result_1m = self._extract_timeframe_data(timeframes.get('1m', {}), target_seq_len) if '1m' in timeframes else None
if result_1m:
price_data_1m, norm_params_dict['1m'] = result_1m
else:
price_data_1m = None
result_1h = self._extract_timeframe_data(timeframes.get('1h', {}), target_seq_len) if '1h' in timeframes else None
if result_1h:
price_data_1h, norm_params_dict['1h'] = result_1h
else:
price_data_1h = None
result_1d = self._extract_timeframe_data(timeframes.get('1d', {}), target_seq_len) if '1d' in timeframes else None
if result_1d:
price_data_1d, norm_params_dict['1d'] = result_1d
else:
price_data_1d = None
# Extract BTC reference data # Extract BTC reference data
btc_data_1m = None btc_data_1m = None
if 'BTC/USDT' in secondary_timeframes and '1m' in secondary_timeframes['BTC/USDT']: if 'BTC/USDT' in secondary_timeframes and '1m' in secondary_timeframes['BTC/USDT']:
btc_data_1m = self._extract_timeframe_data(secondary_timeframes['BTC/USDT']['1m'], target_seq_len) result_btc = self._extract_timeframe_data(secondary_timeframes['BTC/USDT']['1m'], target_seq_len)
if result_btc:
btc_data_1m, norm_params_dict['btc'] = result_btc
# Ensure at least one timeframe is available # Ensure at least one timeframe is available
# Check if all are None (can't use any() with tensors) # Check if all are None (can't use any() with tensors)
@@ -1488,9 +1590,34 @@ class RealTrainingAdapter:
# Create trend target tensor [batch, 3]: [angle, steepness, direction] # Create trend target tensor [batch, 3]: [angle, steepness, direction]
trend_target = torch.tensor([[trend_angle, trend_steepness, trend_direction]], dtype=torch.float32) # [1, 3] trend_target = torch.tensor([[trend_angle, trend_steepness, trend_direction]], dtype=torch.float32) # [1, 3]
# Extract NEXT candle OHLCV targets for each available timeframe
# These are the ground truth candles that the model should learn to predict
future_candle_1s = None
future_candle_1m = None
future_candle_1h = None
future_candle_1d = None
# For each timeframe, extract the next candle if data is available
# Use the input sequence data as reference for normalization
if price_data_1s is not None and '1s' in timeframes:
ref_data_1s = price_data_1s.squeeze(0).numpy() # [seq_len, 5]
future_candle_1s = self._extract_next_candle(timeframes['1s'], ref_data_1s)
if price_data_1m is not None and '1m' in timeframes:
ref_data_1m = price_data_1m.squeeze(0).numpy() # [seq_len, 5]
future_candle_1m = self._extract_next_candle(timeframes['1m'], ref_data_1m)
if price_data_1h is not None and '1h' in timeframes:
ref_data_1h = price_data_1h.squeeze(0).numpy() # [seq_len, 5]
future_candle_1h = self._extract_next_candle(timeframes['1h'], ref_data_1h)
if price_data_1d is not None and '1d' in timeframes:
ref_data_1d = price_data_1d.squeeze(0).numpy() # [seq_len, 5]
future_candle_1d = self._extract_next_candle(timeframes['1d'], ref_data_1d)
# Return batch dictionary with ALL timeframes # Return batch dictionary with ALL timeframes
batch = { batch = {
# Multi-timeframe price data # Multi-timeframe price data (INPUT)
'price_data_1s': price_data_1s, # [1, 600, 5] or None 'price_data_1s': price_data_1s, # [1, 600, 5] or None
'price_data_1m': price_data_1m, # [1, 600, 5] or None 'price_data_1m': price_data_1m, # [1, 600, 5] or None
'price_data_1h': price_data_1h, # [1, 600, 5] or None 'price_data_1h': price_data_1h, # [1, 600, 5] or None
@@ -1503,11 +1630,20 @@ class RealTrainingAdapter:
'market_data': market_data, # [1, 30] 'market_data': market_data, # [1, 30]
'position_state': position_state, # [1, 5] 'position_state': position_state, # [1, 5]
# Training targets # Training targets - Actions and prices
'actions': actions, # [1] 'actions': actions, # [1]
'future_prices': future_prices, # [1, 1] 'future_prices': future_prices, # [1, 1]
'trade_success': trade_success, # [1, 1] 'trade_success': trade_success, # [1, 1]
'trend_target': trend_target, # [1, 3] - NEW: [angle, steepness, direction] 'trend_target': trend_target, # [1, 3] - [angle, steepness, direction]
# Training targets - Next candle OHLCV for each timeframe
'future_candle_1s': future_candle_1s, # [1, 5] or None
'future_candle_1m': future_candle_1m, # [1, 5] or None
'future_candle_1h': future_candle_1h, # [1, 5] or None
'future_candle_1d': future_candle_1d, # [1, 5] or None
# CRITICAL: Normalization parameters for denormalization
'norm_params': norm_params_dict, # Dict with keys: '1s', '1m', '1h', '1d', 'btc'
# Legacy support (use 1m as default) # Legacy support (use 1m as default)
'price_data': price_data_1m if price_data_1m is not None else ref_data 'price_data': price_data_1m if price_data_1m is not None else ref_data
@@ -1713,25 +1849,27 @@ class RealTrainingAdapter:
batch_accuracy = result.get('accuracy', 0.0) batch_accuracy = result.get('accuracy', 0.0)
batch_candle_accuracy = result.get('candle_accuracy', 0.0) batch_candle_accuracy = result.get('candle_accuracy', 0.0)
batch_trend_loss = result.get('trend_loss', 0.0) batch_trend_loss = result.get('trend_loss', 0.0)
batch_candle_loss = result.get('candle_loss', 0.0)
batch_candle_loss_denorm = result.get('candle_loss_denorm', {})
epoch_loss += batch_loss epoch_loss += batch_loss
epoch_accuracy += batch_accuracy epoch_accuracy += batch_accuracy
num_batches += 1 num_batches += 1
# Log first batch and every 5th batch for debugging # Log first batch and every 5th batch for debugging
if (i + 1) == 1 or (i + 1) % 5 == 0: if (i + 1) == 1 or (i + 1) % 5 == 0:
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}") # Format denormalized losses if available
denorm_str = ""
if batch_candle_loss_denorm:
denorm_values = [f"{tf}=${loss:.2f}" for tf, loss in batch_candle_loss_denorm.items()]
denorm_str = f", Real Price Error: {', '.join(denorm_values)}"
logger.info(f" Batch {i + 1}/{len(grouped_batches)}, Loss: {batch_loss:.6f}, Action Acc: {batch_accuracy:.2%}, Candle Acc: {batch_candle_accuracy:.2%}, Trend Loss: {batch_trend_loss:.6f}, Candle Loss (norm): {batch_candle_loss:.6f}{denorm_str}")
else: else:
logger.warning(f" Batch {i + 1} returned None result - skipping") logger.warning(f" Batch {i + 1} returned None result - skipping")
# CRITICAL FIX: Delete batch tensors immediately to free GPU memory
# This prevents memory accumulation during gradient accumulation
for key in list(batch.keys()):
if isinstance(batch[key], torch.Tensor):
del batch[key]
del batch
# CRITICAL: Clear CUDA cache after EVERY batch to prevent memory accumulation # CRITICAL: Clear CUDA cache after EVERY batch to prevent memory accumulation
# This is essential with large models and limited GPU memory # NOTE: We do NOT delete batch tensors here because they are reused across epochs
# Deleting them would cause "At least one timeframe must be provided" error on epoch 2+
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
@@ -1801,6 +1939,22 @@ class RealTrainingAdapter:
session.final_loss = session.current_loss session.final_loss = session.current_loss
session.accuracy = avg_accuracy session.accuracy = avg_accuracy
# Cleanup: Delete batch tensors after all epochs are complete
logger.info(" Cleaning up batch data...")
for batch in grouped_batches:
for key in list(batch.keys()):
if isinstance(batch[key], torch.Tensor):
del batch[key]
batch.clear()
grouped_batches.clear()
converted_batches.clear()
# Final memory cleanup
if torch.cuda.is_available():
torch.cuda.empty_cache()
import gc
gc.collect()
# Log best checkpoint info # Log best checkpoint info
try: try:
checkpoint_dir = "models/checkpoints/transformer" checkpoint_dir = "models/checkpoints/transformer"

View File

@@ -69,29 +69,6 @@
"exit_state": {} "exit_state": {}
} }
}, },
{
"annotation_id": "c9849a6b-e430-4305-9009-dd7471553c2f",
"symbol": "ETH/USDT",
"timeframe": "1m",
"entry": {
"timestamp": "2025-10-30 19:59",
"price": 3680.1,
"index": 272
},
"exit": {
"timestamp": "2025-10-30 21:59",
"price": 3767.82,
"index": 312
},
"direction": "LONG",
"profit_loss_pct": 2.38363087959567,
"notes": "",
"created_at": "2025-10-31T00:31:10.201966",
"market_context": {
"entry_state": {},
"exit_state": {}
}
},
{ {
"annotation_id": "479eb310-c963-4837-b712-70e5a42afb53", "annotation_id": "479eb310-c963-4837-b712-70e5a42afb53",
"symbol": "ETH/USDT", "symbol": "ETH/USDT",
@@ -114,10 +91,56 @@
"entry_state": {}, "entry_state": {},
"exit_state": {} "exit_state": {}
} }
},
{
"annotation_id": "6b529132-8a3e-488d-b354-db8785ddaa71",
"symbol": "ETH/USDT",
"timeframe": "1m",
"entry": {
"timestamp": "2025-11-11 12:07",
"price": 3594.33,
"index": 144
},
"exit": {
"timestamp": "2025-11-11 20:46",
"price": 3429.24,
"index": 329
},
"direction": "SHORT",
"profit_loss_pct": 4.593067414511193,
"notes": "",
"created_at": "2025-11-11T23:23:00.643510",
"market_context": {
"entry_state": {},
"exit_state": {}
}
},
{
"annotation_id": "bbafc50c-f885-4dbc-b0cb-fdfb48223b5c",
"symbol": "ETH/USDT",
"timeframe": "1m",
"entry": {
"timestamp": "2025-11-12 07:58",
"price": 3424.58,
"index": 284
},
"exit": {
"timestamp": "2025-11-12 11:08",
"price": 3546.35,
"index": 329
},
"direction": "LONG",
"profit_loss_pct": 3.5557645025083366,
"notes": "",
"created_at": "2025-11-12T13:11:31.267142",
"market_context": {
"entry_state": {},
"exit_state": {}
}
} }
], ],
"metadata": { "metadata": {
"total_annotations": 5, "total_annotations": 6,
"last_updated": "2025-10-31T00:35:00.549074" "last_updated": "2025-11-12T13:11:31.267456"
} }
} }

View File

@@ -443,6 +443,8 @@ class AdvancedTradingTransformer(nn.Module):
self.uncertainty_estimator = UncertaintyEstimation(config.d_model) self.uncertainty_estimator = UncertaintyEstimation(config.d_model)
# Enhanced price prediction head (auxiliary task) # Enhanced price prediction head (auxiliary task)
# Predicts price change ratio (future_price - current_price) / current_price
# Use Tanh to constrain to [-1, 1] range (max 100% change up/down)
self.price_head = nn.Sequential( self.price_head = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2), nn.Linear(config.d_model, config.d_model // 2),
nn.GELU(), nn.GELU(),
@@ -450,7 +452,8 @@ class AdvancedTradingTransformer(nn.Module):
nn.Linear(config.d_model // 2, config.d_model // 4), nn.Linear(config.d_model // 2, config.d_model // 4),
nn.GELU(), nn.GELU(),
nn.Dropout(config.dropout), nn.Dropout(config.dropout),
nn.Linear(config.d_model // 4, 1) nn.Linear(config.d_model // 4, 1),
nn.Tanh() # Constrain to [-1, 1] range for price change ratio
) )
# Additional specialized heads for 46M model # Additional specialized heads for 46M model
@@ -473,6 +476,7 @@ class AdvancedTradingTransformer(nn.Module):
# NEW: Next candle OHLCV prediction heads for each timeframe (1s, 1m, 1h, 1d) # NEW: Next candle OHLCV prediction heads for each timeframe (1s, 1m, 1h, 1d)
# Each timeframe predicts: [open, high, low, close, volume] = 5 values # Each timeframe predicts: [open, high, low, close, volume] = 5 values
# Note: self.timeframes already defined above in input projections # Note: self.timeframes already defined above in input projections
# CRITICAL: Outputs are constrained to [0, 1] range using Sigmoid since inputs are normalized
self.next_candle_heads = nn.ModuleDict({ self.next_candle_heads = nn.ModuleDict({
tf: nn.Sequential( tf: nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2), nn.Linear(config.d_model, config.d_model // 2),
@@ -481,11 +485,13 @@ class AdvancedTradingTransformer(nn.Module):
nn.Linear(config.d_model // 2, config.d_model // 4), nn.Linear(config.d_model // 2, config.d_model // 4),
nn.GELU(), nn.GELU(),
nn.Dropout(config.dropout), nn.Dropout(config.dropout),
nn.Linear(config.d_model // 4, 5) # OHLCV: [open, high, low, close, volume] nn.Linear(config.d_model // 4, 5), # OHLCV: [open, high, low, close, volume]
nn.Sigmoid() # Constrain to [0, 1] to match normalized input range
) for tf in self.timeframes ) for tf in self.timeframes
}) })
# BTC next candle prediction head # BTC next candle prediction head
# CRITICAL: Outputs are constrained to [0, 1] range using Sigmoid since inputs are normalized
self.btc_next_candle_head = nn.Sequential( self.btc_next_candle_head = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2), nn.Linear(config.d_model, config.d_model // 2),
nn.GELU(), nn.GELU(),
@@ -493,7 +499,8 @@ class AdvancedTradingTransformer(nn.Module):
nn.Linear(config.d_model // 2, config.d_model // 4), nn.Linear(config.d_model // 2, config.d_model // 4),
nn.GELU(), nn.GELU(),
nn.Dropout(config.dropout), nn.Dropout(config.dropout),
nn.Linear(config.d_model // 4, 5) # OHLCV for BTC nn.Linear(config.d_model // 4, 5), # OHLCV for BTC
nn.Sigmoid() # Constrain to [0, 1] to match normalized input range
) )
# NEW: Next pivot point prediction heads for L1-L5 levels # NEW: Next pivot point prediction heads for L1-L5 levels
@@ -1153,6 +1160,54 @@ class TradingTransformerTrainer:
'learning_rates': [] 'learning_rates': []
} }
@staticmethod
def denormalize_prices(normalized_values: torch.Tensor, norm_params: Dict[str, float]) -> torch.Tensor:
"""
Denormalize price predictions back to real price space
Args:
normalized_values: Tensor of normalized values in [0, 1] range
norm_params: Dict with 'price_min' and 'price_max' keys
Returns:
Denormalized tensor in original price space
"""
price_min = norm_params.get('price_min', 0.0)
price_max = norm_params.get('price_max', 1.0)
if price_max > price_min:
return normalized_values * (price_max - price_min) + price_min
else:
return normalized_values
@staticmethod
def denormalize_candle(normalized_candle: torch.Tensor, norm_params: Dict[str, float]) -> torch.Tensor:
"""
Denormalize a full OHLCV candle back to real values
Args:
normalized_candle: Tensor of shape [..., 5] with normalized OHLCV
norm_params: Dict with normalization parameters
Returns:
Denormalized OHLCV tensor
"""
denorm = normalized_candle.clone()
# Denormalize OHLC (first 4 values)
price_min = norm_params.get('price_min', 0.0)
price_max = norm_params.get('price_max', 1.0)
if price_max > price_min:
denorm[..., :4] = denorm[..., :4] * (price_max - price_min) + price_min
# Denormalize volume (5th value)
volume_min = norm_params.get('volume_min', 0.0)
volume_max = norm_params.get('volume_max', 1.0)
if volume_max > volume_min:
denorm[..., 4] = denorm[..., 4] * (volume_max - volume_min) + volume_min
return denorm
def train_step(self, batch: Dict[str, torch.Tensor], accumulate_gradients: bool = False) -> Dict[str, float]: def train_step(self, batch: Dict[str, torch.Tensor], accumulate_gradients: bool = False) -> Dict[str, float]:
"""Single training step with optional gradient accumulation """Single training step with optional gradient accumulation
@@ -1217,8 +1272,50 @@ class TradingTransformerTrainer:
trend_loss = self.price_criterion(trend_pred, trend_target) trend_loss = self.price_criterion(trend_pred, trend_target)
logger.debug(f"Trend loss: {trend_loss.item():.6f} (pred={trend_pred[0].tolist()}, target={trend_target[0].tolist()})") logger.debug(f"Trend loss: {trend_loss.item():.6f} (pred={trend_pred[0].tolist()}, target={trend_target[0].tolist()})")
# NEW: Next candle prediction loss for each timeframe
# This trains the model to predict full OHLCV for the next candle on each timeframe
candle_loss = torch.tensor(0.0, device=self.device)
candle_losses_detail = {} # Track per-timeframe losses (normalized space)
candle_losses_denorm = {} # Track per-timeframe losses (denormalized/real space)
if 'next_candles' in outputs:
timeframe_losses = []
# Get normalization parameters if available
norm_params = batch.get('norm_params', {})
# Calculate loss for each timeframe that has target data
for tf in ['1s', '1m', '1h', '1d']:
future_key = f'future_candle_{tf}'
if tf in outputs['next_candles'] and future_key in batch:
pred_candle = outputs['next_candles'][tf] # [batch, 5] - predicted OHLCV (normalized)
target_candle = batch[future_key] # [batch, 5] - actual OHLCV (normalized)
if target_candle is not None and pred_candle.shape == target_candle.shape:
# MSE loss on normalized values (used for backprop)
tf_loss = self.price_criterion(pred_candle, target_candle)
timeframe_losses.append(tf_loss)
candle_losses_detail[tf] = tf_loss.item()
# ALSO calculate denormalized loss for better interpretability
if tf in norm_params:
with torch.no_grad():
pred_denorm = self.denormalize_candle(pred_candle, norm_params[tf])
target_denorm = self.denormalize_candle(target_candle, norm_params[tf])
denorm_loss = self.price_criterion(pred_denorm, target_denorm)
candle_losses_denorm[tf] = denorm_loss.item()
# Average loss across available timeframes
if timeframe_losses:
candle_loss = torch.stack(timeframe_losses).mean()
if candle_losses_denorm:
logger.debug(f"Candle losses (normalized): {candle_losses_detail}")
logger.debug(f"Candle losses (real prices): {candle_losses_denorm}")
# Start with base losses - avoid inplace operations on computation graph # Start with base losses - avoid inplace operations on computation graph
total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss # Weight auxiliary tasks # Weight: action=1.0, price=0.1, trend=0.05, candle=0.15
total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss + 0.15 * candle_loss
# CRITICAL FIX: Scale loss for gradient accumulation # CRITICAL FIX: Scale loss for gradient accumulation
# This prevents gradient explosion when accumulating over multiple batches # This prevents gradient explosion when accumulating over multiple batches
@@ -1322,7 +1419,9 @@ class TradingTransformerTrainer:
'total_loss': total_loss.item(), 'total_loss': total_loss.item(),
'action_loss': action_loss.item(), 'action_loss': action_loss.item(),
'price_loss': price_loss.item(), 'price_loss': price_loss.item(),
'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0, # NEW 'trend_loss': trend_loss.item() if isinstance(trend_loss, torch.Tensor) else 0.0,
'candle_loss': candle_loss.item() if isinstance(candle_loss, torch.Tensor) else 0.0,
'candle_loss_denorm': candle_losses_denorm, # Dict of denormalized losses per timeframe
'accuracy': accuracy.item(), 'accuracy': accuracy.item(),
'candle_accuracy': candle_accuracy, 'candle_accuracy': candle_accuracy,
'learning_rate': self.scheduler.get_last_lr()[0] 'learning_rate': self.scheduler.get_last_lr()[0]
@@ -1330,7 +1429,7 @@ class TradingTransformerTrainer:
# CRITICAL: Delete large tensors to free memory immediately # CRITICAL: Delete large tensors to free memory immediately
# This prevents memory accumulation across batches # This prevents memory accumulation across batches
del outputs, total_loss, action_loss, price_loss, predictions, accuracy del outputs, total_loss, action_loss, price_loss, trend_loss, candle_loss, predictions, accuracy
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@@ -0,0 +1,186 @@
# Transformer Model Input/Output Structure
## FIXED ISSUE: Batch Data Deletion Bug
**Problem**: Training was failing after epoch 1 with "At least one timeframe must be provided"
**Root Cause**: Batch tensors were being deleted after each use in the training loop, but the same batch dictionaries were being reused across all epochs.
**Solution**: Removed batch deletion from inside the epoch loop and moved cleanup to after all epochs complete.
## Current Model Architecture
### INPUT Structure (Multi-Timeframe)
The model accepts the following inputs in the `forward()` method:
```python
forward(
# Price data for different timeframes - [batch, seq_len, 5] OHLCV
price_data_1s=None, # 1-second timeframe
price_data_1m=None, # 1-minute timeframe
price_data_1h=None, # 1-hour timeframe
price_data_1d=None, # 1-day timeframe
# Reference data
btc_data_1m=None, # BTC reference - [batch, seq_len, 5]
# Additional features
cob_data=None, # COB orderbook data - [batch, seq_len, 100]
tech_data=None, # Technical indicators - [batch, 40]
market_data=None, # Market context (pivots, volume) - [batch, 30]
position_state=None, # Current position state - [batch, 5]
# Legacy support
price_data=None # Fallback to single timeframe
)
```
**At least one timeframe** (price_data_1s, 1m, 1h, or 1d) must be provided, otherwise the model raises:
```
ValueError: At least one timeframe must be provided
```
### OUTPUT Structure
The model returns a dictionary with the following predictions:
```python
outputs = {
# PRIMARY OUTPUTS (trained with loss):
'action_logits': tensor, # [batch, 3] - BUY/SELL/HOLD logits
'action_probs': tensor, # [batch, 3] - softmax probabilities
'price_prediction': tensor, # [batch, 1] - next price change ratio
'confidence': tensor, # [batch, 1] - prediction confidence
# TREND ANALYSIS (trained with loss):
'trend_analysis': {
'angle_radians': tensor, # [batch, 1] - trend angle in radians
'steepness': tensor, # [batch, 1] - trend steepness (0-1)
'direction': tensor # [batch, 1] - direction (-1/0/+1)
},
# NEXT CANDLE PREDICTIONS (evaluated but NOT trained):
'next_candles': {
'1s': tensor, # [batch, 5] - predicted OHLCV for 1s
'1m': tensor, # [batch, 5] - predicted OHLCV for 1m
'1h': tensor, # [batch, 5] - predicted OHLCV for 1h
'1d': tensor, # [batch, 5] - predicted OHLCV for 1d
},
'btc_next_candle': tensor, # [batch, 5] - predicted BTC OHLCV
# PIVOT PREDICTIONS:
'next_pivots': {
'L1': {
'price': tensor, # [batch, 1] - pivot price
'type_prob_high': tensor, # [batch, 1] - probability of high
'type_prob_low': tensor, # [batch, 1] - probability of low
'pivot_type': tensor, # [batch, 1] - 0=high, 1=low
'confidence': tensor # [batch, 1] - confidence
},
# Same structure for L2, L3, L4, L5
},
# AUXILIARY OUTPUTS:
'volatility_prediction': tensor, # [batch, 1]
'trend_strength_prediction': tensor, # [batch, 1]
'uncertainty_mean': tensor, # [batch, 1]
'uncertainty_std': tensor # [batch, 1]
}
```
### TRAINING TARGETS (in batch)
```python
batch = {
# Input features (see INPUT Structure above)
'price_data_1s': tensor,
'price_data_1m': tensor,
'price_data_1h': tensor,
'price_data_1d': tensor,
'btc_data_1m': tensor,
'cob_data': tensor,
'tech_data': tensor,
'market_data': tensor,
'position_state': tensor,
# Training targets:
'actions': tensor, # [batch] - target action (0/1/2)
'future_prices': tensor, # [batch, 1] - actual price change ratio
'trade_success': tensor, # [batch, 1] - 1.0 if profitable
'trend_target': tensor, # [batch, 3] - [angle, steepness, direction]
}
```
### LOSS CALCULATION
Current loss function in `train_step()`:
```python
total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss
where:
- action_loss: CrossEntropyLoss(action_logits, actions)
- price_loss: MSELoss(price_prediction, future_prices)
- trend_loss: MSELoss(trend_pred, trend_target)
```
**NOTE**: Next candle predictions are currently only used for accuracy evaluation, NOT trained directly.
## CURRENT ISSUES AND RECOMMENDATIONS
### Issue 1: Next Candle Predictions Not Trained
**Status**: The model outputs next candle predictions for each timeframe, but these are NOT included in the loss function.
**Impact**: The model is not explicitly learning to predict next candle OHLCV values.
**Recommendation**: Add next candle loss to training:
```python
# Calculate next candle loss for each available timeframe
candle_loss = 0.0
if 'next_candles' in outputs:
for tf in ['1s', '1m', '1h', '1d']:
if tf in outputs['next_candles'] and f'future_candle_{tf}' in batch:
pred_candle = outputs['next_candles'][tf] # [batch, 5]
target_candle = batch[f'future_candle_{tf}'] # [batch, 5]
candle_loss += MSELoss(pred_candle, target_candle)
total_loss = action_loss + 0.1 * price_loss + 0.05 * trend_loss + 0.1 * candle_loss
```
### Issue 2: Annotation Timeframe vs Prediction Timeframe
**Current Behavior**:
- Annotations are created at a specific point in time
- The model receives multiple timeframes (1s, 1m, 1h, 1d) as input
- Predictions are made for ALL timeframes simultaneously
- Only the 1m timeframe prediction is currently evaluated for accuracy
**Question**: Should predictions be specific to the annotation's timeframe?
**Options**:
1. **Multi-timeframe predictions (current)**: Keep predicting all timeframes, add loss for each
2. **Annotation-specific predictions**: Only predict/train on the timeframe that matches the annotation
3. **Weighted predictions**: Weight the loss by the annotation's timeframe (e.g., if annotated on 1m, weight 1m prediction higher)
### Issue 3: Missing Target Data for Next Candles
**Current**: The batch only contains `future_prices` (next close price change)
**Needed**: To train next candle predictions, we need full OHLCV targets:
- `future_candle_1s`: [batch, 5] - next 1s candle OHLCV
- `future_candle_1m`: [batch, 5] - next 1m candle OHLCV
- `future_candle_1h`: [batch, 5] - next 1h candle OHLCV
- `future_candle_1d`: [batch, 5] - next 1d candle OHLCV
**Location to add**: `ANNOTATE/core/real_training_adapter.py` in `_convert_annotation_to_transformer_batch()`
## SUMMARY
**Fixed**: Batch deletion bug causing epoch 2+ failures
**Working**: Model can predict next candles for all timeframes
**Working**: Model can predict trend vector (angle, steepness, direction)
**Missing**: Loss calculation for next candle predictions
**Missing**: Target data (future OHLCV) for next candle training
⚠️ **Unclear**: Should predictions be timeframe-specific or multi-timeframe?
## NEXT STEPS
1. **Add future OHLCV target data** to training batches
2. **Add next candle loss** to the training loop
3. **Clarify prediction strategy**: Single timeframe vs multi-timeframe
4. **Test training** with enhanced loss function

288
test_normalization_fix.py Normal file
View File

@@ -0,0 +1,288 @@
#!/usr/bin/env python3
"""
Test script to verify normalization fix is working correctly
This creates a simple test batch and verifies:
1. Model outputs are in expected ranges (thanks to Sigmoid/Tanh constraints)
2. Normalization parameters are stored and can be retrieved
3. Denormalization works correctly
4. Losses are in reasonable ranges (not billions!)
"""
import torch
import numpy as np
from NN.models.advanced_transformer_trading import (
AdvancedTradingTransformer,
TradingTransformerConfig,
TradingTransformerTrainer
)
def create_test_batch():
"""Create a simple test batch with known normalization parameters"""
batch_size = 1
seq_len = 200
# Create synthetic price data in [0, 1] range (normalized)
price_data_1m = torch.rand(batch_size, seq_len, 5) * 0.2 + 0.4 # Range [0.4, 0.6]
# Create normalization parameters (simulate ETHUSDT around $2500)
norm_params = {
'1m': {
'price_min': 2480.0,
'price_max': 2520.0,
'volume_min': 100.0,
'volume_max': 10000.0
}
}
# Create future candle target (slightly higher close)
future_candle_1m = torch.rand(batch_size, 5) * 0.2 + 0.45 # Range [0.45, 0.65]
# Create other required inputs
cob_data = torch.zeros(batch_size, seq_len, 100)
tech_data = torch.zeros(batch_size, 40)
market_data = torch.zeros(batch_size, 30)
position_state = torch.zeros(batch_size, 5)
actions = torch.tensor([1], dtype=torch.long) # BUY action
future_prices = torch.tensor([[0.01]], dtype=torch.float32) # 1% price increase expected
trade_success = torch.tensor([[1.0]], dtype=torch.float32)
trend_target = torch.tensor([[0.785, 0.5, 1.0]], dtype=torch.float32)
batch = {
'price_data_1m': price_data_1m,
'price_data_1s': None,
'price_data_1h': None,
'price_data_1d': None,
'btc_data_1m': None,
'cob_data': cob_data,
'tech_data': tech_data,
'market_data': market_data,
'position_state': position_state,
'actions': actions,
'future_prices': future_prices,
'trade_success': trade_success,
'trend_target': trend_target,
'future_candle_1m': future_candle_1m,
'future_candle_1s': None,
'future_candle_1h': None,
'future_candle_1d': None,
'norm_params': norm_params
}
return batch
def test_model_outputs():
"""Test that model outputs are in expected ranges"""
print("=" * 80)
print("TESTING: Model Output Constraints")
print("=" * 80)
# Create small model for testing
config = TradingTransformerConfig(
d_model=128,
n_heads=4,
n_layers=2,
seq_len=200
)
model = AdvancedTradingTransformer(config)
model.eval()
batch = create_test_batch()
with torch.no_grad():
outputs = model(
price_data_1m=batch['price_data_1m'],
cob_data=batch['cob_data'],
tech_data=batch['tech_data'],
market_data=batch['market_data'],
position_state=batch['position_state']
)
# Check candle predictions are in [0, 1] range (thanks to Sigmoid)
if 'next_candles' in outputs and '1m' in outputs['next_candles']:
candle_pred = outputs['next_candles']['1m']
print(f"\nCandle Prediction (1m):")
print(f" Shape: {candle_pred.shape}")
print(f" Min value: {candle_pred.min().item():.6f}")
print(f" Max value: {candle_pred.max().item():.6f}")
print(f" Mean value: {candle_pred.mean().item():.6f}")
if candle_pred.min() >= 0.0 and candle_pred.max() <= 1.0:
print(" ✅ PASS: Values in [0, 1] range (Sigmoid working!)")
else:
print(" ❌ FAIL: Values outside [0, 1] range!")
# Check price prediction is in [-1, 1] range (thanks to Tanh)
if 'price_prediction' in outputs:
price_pred = outputs['price_prediction']
print(f"\nPrice Prediction (change ratio):")
print(f" Shape: {price_pred.shape}")
print(f" Value: {price_pred.item():.6f}")
if price_pred.min() >= -1.0 and price_pred.max() <= 1.0:
print(" ✅ PASS: Values in [-1, 1] range (Tanh working!)")
else:
print(" ❌ FAIL: Values outside [-1, 1] range!")
# Check action probabilities sum to 1
if 'action_probs' in outputs:
action_probs = outputs['action_probs']
print(f"\nAction Probabilities:")
print(f" BUY: {action_probs[0, 0].item():.4f}")
print(f" SELL: {action_probs[0, 1].item():.4f}")
print(f" HOLD: {action_probs[0, 2].item():.4f}")
print(f" Sum: {action_probs[0].sum().item():.6f}")
if abs(action_probs[0].sum().item() - 1.0) < 0.001:
print(" ✅ PASS: Probabilities sum to 1.0")
else:
print(" ❌ FAIL: Probabilities don't sum to 1.0!")
return outputs
def test_denormalization():
"""Test denormalization functions"""
print("\n" + "=" * 80)
print("TESTING: Denormalization Functions")
print("=" * 80)
# Create test normalized candle
normalized_candle = torch.tensor([[0.5, 0.6, 0.4, 0.55, 0.3]]) # OHLCV
# Normalization params (ETHUSDT $2480-$2520)
norm_params = {
'price_min': 2480.0,
'price_max': 2520.0,
'volume_min': 100.0,
'volume_max': 10000.0
}
print(f"\nNormalized Candle: {normalized_candle[0].tolist()}")
print(f"Normalization Params: price [{norm_params['price_min']}, {norm_params['price_max']}], "
f"volume [{norm_params['volume_min']}, {norm_params['volume_max']}]")
# Denormalize
denorm_candle = TradingTransformerTrainer.denormalize_candle(normalized_candle, norm_params)
print(f"\nDenormalized Candle:")
print(f" Open: ${denorm_candle[0, 0].item():.2f}")
print(f" High: ${denorm_candle[0, 1].item():.2f}")
print(f" Low: ${denorm_candle[0, 2].item():.2f}")
print(f" Close: ${denorm_candle[0, 3].item():.2f}")
print(f" Volume: {denorm_candle[0, 4].item():.2f}")
# Verify values are in expected range
expected_min_price = norm_params['price_min']
expected_max_price = norm_params['price_max']
prices_ok = True
for i, name in enumerate(['Open', 'High', 'Low', 'Close']):
value = denorm_candle[0, i].item()
if value < expected_min_price or value > expected_max_price:
print(f" ❌ FAIL: {name} price ${value:.2f} outside expected range!")
prices_ok = False
if prices_ok:
print(f" ✅ PASS: All prices in expected range [${expected_min_price}, ${expected_max_price}]")
# Verify volume
volume = denorm_candle[0, 4].item()
if norm_params['volume_min'] <= volume <= norm_params['volume_max']:
print(f" ✅ PASS: Volume {volume:.2f} in expected range [{norm_params['volume_min']}, {norm_params['volume_max']}]")
else:
print(f" ❌ FAIL: Volume {volume:.2f} outside expected range!")
def test_loss_magnitude():
"""Test that losses are in reasonable ranges"""
print("\n" + "=" * 80)
print("TESTING: Loss Magnitudes")
print("=" * 80)
config = TradingTransformerConfig(
d_model=128,
n_heads=4,
n_layers=2,
seq_len=200
)
model = AdvancedTradingTransformer(config)
trainer = TradingTransformerTrainer(model, config)
batch = create_test_batch()
# Run one training step
result = trainer.train_step(batch, accumulate_gradients=False)
print(f"\nTraining Step Results:")
print(f" Total Loss: {result['total_loss']:.6f}")
print(f" Action Loss: {result['action_loss']:.6f}")
print(f" Price Loss: {result['price_loss']:.6f}")
print(f" Trend Loss: {result['trend_loss']:.6f}")
print(f" Candle Loss: {result['candle_loss']:.6f}")
print(f" Action Accuracy: {result['accuracy']:.2%}")
print(f" Candle Accuracy: {result['candle_accuracy']:.2%}")
# Check losses are reasonable (not billions!)
all_ok = True
if result['total_loss'] < 100.0:
print(f" ✅ PASS: Total loss < 100 (was {result['total_loss']:.6f})")
else:
print(f" ❌ FAIL: Total loss too high! ({result['total_loss']:.6f})")
all_ok = False
if result['candle_loss'] < 10.0:
print(f" ✅ PASS: Candle loss < 10 (was {result['candle_loss']:.6f})")
else:
print(f" ❌ FAIL: Candle loss too high! ({result['candle_loss']:.6f})")
all_ok = False
# Check denormalized losses if available
if 'candle_loss_denorm' in result and result['candle_loss_denorm']:
print(f"\n Denormalized Candle Losses (Real Price Errors):")
for tf, loss in result['candle_loss_denorm'].items():
print(f" {tf}: ${loss:.2f}")
if loss < 1000.0:
print(f" ✅ PASS: Real price error < $1000")
else:
print(f" ❌ FAIL: Real price error too high!")
all_ok = False
if all_ok:
print("\n ✅ ALL TESTS PASSED: Losses in reasonable ranges!")
else:
print("\n ❌ SOME TESTS FAILED: Check model/normalization!")
return result
def main():
print("\n" + "=" * 80)
print("NORMALIZATION FIX VERIFICATION TEST")
print("=" * 80)
print("\nThis test verifies that:")
print("1. Model outputs are properly constrained (Sigmoid/Tanh)")
print("2. Normalization parameters are stored and accessible")
print("3. Denormalization functions work correctly")
print("4. Losses are in reasonable ranges (not billions!)")
print("\n" + "=" * 80)
# Run tests
outputs = test_model_outputs()
test_denormalization()
test_loss_magnitude()
print("\n" + "=" * 80)
print("TEST SUMMARY")
print("=" * 80)
print("\nIf all tests passed (✅), the normalization fix is working correctly!")
print("You should now see reasonable losses in training logs:")
print(" - Total loss: ~0.5-1.0 (not billions!)")
print(" - Candle loss: ~0.1-0.3")
print(" - Real price errors: $2-20 (not $147,000!)")
print("\nYou can now resume training and monitor these metrics.")
print("=" * 80 + "\n")
if __name__ == "__main__":
main()