raining normalization fix

This commit is contained in:
Dobromir Popov
2025-11-12 14:36:28 +02:00
parent 4c04503f3e
commit a7a22334fb
5 changed files with 800 additions and 50 deletions

288
test_normalization_fix.py Normal file
View File

@@ -0,0 +1,288 @@
#!/usr/bin/env python3
"""
Test script to verify normalization fix is working correctly
This creates a simple test batch and verifies:
1. Model outputs are in expected ranges (thanks to Sigmoid/Tanh constraints)
2. Normalization parameters are stored and can be retrieved
3. Denormalization works correctly
4. Losses are in reasonable ranges (not billions!)
"""
import torch
import numpy as np
from NN.models.advanced_transformer_trading import (
AdvancedTradingTransformer,
TradingTransformerConfig,
TradingTransformerTrainer
)
def create_test_batch():
"""Create a simple test batch with known normalization parameters"""
batch_size = 1
seq_len = 200
# Create synthetic price data in [0, 1] range (normalized)
price_data_1m = torch.rand(batch_size, seq_len, 5) * 0.2 + 0.4 # Range [0.4, 0.6]
# Create normalization parameters (simulate ETHUSDT around $2500)
norm_params = {
'1m': {
'price_min': 2480.0,
'price_max': 2520.0,
'volume_min': 100.0,
'volume_max': 10000.0
}
}
# Create future candle target (slightly higher close)
future_candle_1m = torch.rand(batch_size, 5) * 0.2 + 0.45 # Range [0.45, 0.65]
# Create other required inputs
cob_data = torch.zeros(batch_size, seq_len, 100)
tech_data = torch.zeros(batch_size, 40)
market_data = torch.zeros(batch_size, 30)
position_state = torch.zeros(batch_size, 5)
actions = torch.tensor([1], dtype=torch.long) # BUY action
future_prices = torch.tensor([[0.01]], dtype=torch.float32) # 1% price increase expected
trade_success = torch.tensor([[1.0]], dtype=torch.float32)
trend_target = torch.tensor([[0.785, 0.5, 1.0]], dtype=torch.float32)
batch = {
'price_data_1m': price_data_1m,
'price_data_1s': None,
'price_data_1h': None,
'price_data_1d': None,
'btc_data_1m': None,
'cob_data': cob_data,
'tech_data': tech_data,
'market_data': market_data,
'position_state': position_state,
'actions': actions,
'future_prices': future_prices,
'trade_success': trade_success,
'trend_target': trend_target,
'future_candle_1m': future_candle_1m,
'future_candle_1s': None,
'future_candle_1h': None,
'future_candle_1d': None,
'norm_params': norm_params
}
return batch
def test_model_outputs():
"""Test that model outputs are in expected ranges"""
print("=" * 80)
print("TESTING: Model Output Constraints")
print("=" * 80)
# Create small model for testing
config = TradingTransformerConfig(
d_model=128,
n_heads=4,
n_layers=2,
seq_len=200
)
model = AdvancedTradingTransformer(config)
model.eval()
batch = create_test_batch()
with torch.no_grad():
outputs = model(
price_data_1m=batch['price_data_1m'],
cob_data=batch['cob_data'],
tech_data=batch['tech_data'],
market_data=batch['market_data'],
position_state=batch['position_state']
)
# Check candle predictions are in [0, 1] range (thanks to Sigmoid)
if 'next_candles' in outputs and '1m' in outputs['next_candles']:
candle_pred = outputs['next_candles']['1m']
print(f"\nCandle Prediction (1m):")
print(f" Shape: {candle_pred.shape}")
print(f" Min value: {candle_pred.min().item():.6f}")
print(f" Max value: {candle_pred.max().item():.6f}")
print(f" Mean value: {candle_pred.mean().item():.6f}")
if candle_pred.min() >= 0.0 and candle_pred.max() <= 1.0:
print(" ✅ PASS: Values in [0, 1] range (Sigmoid working!)")
else:
print(" ❌ FAIL: Values outside [0, 1] range!")
# Check price prediction is in [-1, 1] range (thanks to Tanh)
if 'price_prediction' in outputs:
price_pred = outputs['price_prediction']
print(f"\nPrice Prediction (change ratio):")
print(f" Shape: {price_pred.shape}")
print(f" Value: {price_pred.item():.6f}")
if price_pred.min() >= -1.0 and price_pred.max() <= 1.0:
print(" ✅ PASS: Values in [-1, 1] range (Tanh working!)")
else:
print(" ❌ FAIL: Values outside [-1, 1] range!")
# Check action probabilities sum to 1
if 'action_probs' in outputs:
action_probs = outputs['action_probs']
print(f"\nAction Probabilities:")
print(f" BUY: {action_probs[0, 0].item():.4f}")
print(f" SELL: {action_probs[0, 1].item():.4f}")
print(f" HOLD: {action_probs[0, 2].item():.4f}")
print(f" Sum: {action_probs[0].sum().item():.6f}")
if abs(action_probs[0].sum().item() - 1.0) < 0.001:
print(" ✅ PASS: Probabilities sum to 1.0")
else:
print(" ❌ FAIL: Probabilities don't sum to 1.0!")
return outputs
def test_denormalization():
"""Test denormalization functions"""
print("\n" + "=" * 80)
print("TESTING: Denormalization Functions")
print("=" * 80)
# Create test normalized candle
normalized_candle = torch.tensor([[0.5, 0.6, 0.4, 0.55, 0.3]]) # OHLCV
# Normalization params (ETHUSDT $2480-$2520)
norm_params = {
'price_min': 2480.0,
'price_max': 2520.0,
'volume_min': 100.0,
'volume_max': 10000.0
}
print(f"\nNormalized Candle: {normalized_candle[0].tolist()}")
print(f"Normalization Params: price [{norm_params['price_min']}, {norm_params['price_max']}], "
f"volume [{norm_params['volume_min']}, {norm_params['volume_max']}]")
# Denormalize
denorm_candle = TradingTransformerTrainer.denormalize_candle(normalized_candle, norm_params)
print(f"\nDenormalized Candle:")
print(f" Open: ${denorm_candle[0, 0].item():.2f}")
print(f" High: ${denorm_candle[0, 1].item():.2f}")
print(f" Low: ${denorm_candle[0, 2].item():.2f}")
print(f" Close: ${denorm_candle[0, 3].item():.2f}")
print(f" Volume: {denorm_candle[0, 4].item():.2f}")
# Verify values are in expected range
expected_min_price = norm_params['price_min']
expected_max_price = norm_params['price_max']
prices_ok = True
for i, name in enumerate(['Open', 'High', 'Low', 'Close']):
value = denorm_candle[0, i].item()
if value < expected_min_price or value > expected_max_price:
print(f" ❌ FAIL: {name} price ${value:.2f} outside expected range!")
prices_ok = False
if prices_ok:
print(f" ✅ PASS: All prices in expected range [${expected_min_price}, ${expected_max_price}]")
# Verify volume
volume = denorm_candle[0, 4].item()
if norm_params['volume_min'] <= volume <= norm_params['volume_max']:
print(f" ✅ PASS: Volume {volume:.2f} in expected range [{norm_params['volume_min']}, {norm_params['volume_max']}]")
else:
print(f" ❌ FAIL: Volume {volume:.2f} outside expected range!")
def test_loss_magnitude():
"""Test that losses are in reasonable ranges"""
print("\n" + "=" * 80)
print("TESTING: Loss Magnitudes")
print("=" * 80)
config = TradingTransformerConfig(
d_model=128,
n_heads=4,
n_layers=2,
seq_len=200
)
model = AdvancedTradingTransformer(config)
trainer = TradingTransformerTrainer(model, config)
batch = create_test_batch()
# Run one training step
result = trainer.train_step(batch, accumulate_gradients=False)
print(f"\nTraining Step Results:")
print(f" Total Loss: {result['total_loss']:.6f}")
print(f" Action Loss: {result['action_loss']:.6f}")
print(f" Price Loss: {result['price_loss']:.6f}")
print(f" Trend Loss: {result['trend_loss']:.6f}")
print(f" Candle Loss: {result['candle_loss']:.6f}")
print(f" Action Accuracy: {result['accuracy']:.2%}")
print(f" Candle Accuracy: {result['candle_accuracy']:.2%}")
# Check losses are reasonable (not billions!)
all_ok = True
if result['total_loss'] < 100.0:
print(f" ✅ PASS: Total loss < 100 (was {result['total_loss']:.6f})")
else:
print(f" ❌ FAIL: Total loss too high! ({result['total_loss']:.6f})")
all_ok = False
if result['candle_loss'] < 10.0:
print(f" ✅ PASS: Candle loss < 10 (was {result['candle_loss']:.6f})")
else:
print(f" ❌ FAIL: Candle loss too high! ({result['candle_loss']:.6f})")
all_ok = False
# Check denormalized losses if available
if 'candle_loss_denorm' in result and result['candle_loss_denorm']:
print(f"\n Denormalized Candle Losses (Real Price Errors):")
for tf, loss in result['candle_loss_denorm'].items():
print(f" {tf}: ${loss:.2f}")
if loss < 1000.0:
print(f" ✅ PASS: Real price error < $1000")
else:
print(f" ❌ FAIL: Real price error too high!")
all_ok = False
if all_ok:
print("\n ✅ ALL TESTS PASSED: Losses in reasonable ranges!")
else:
print("\n ❌ SOME TESTS FAILED: Check model/normalization!")
return result
def main():
print("\n" + "=" * 80)
print("NORMALIZATION FIX VERIFICATION TEST")
print("=" * 80)
print("\nThis test verifies that:")
print("1. Model outputs are properly constrained (Sigmoid/Tanh)")
print("2. Normalization parameters are stored and accessible")
print("3. Denormalization functions work correctly")
print("4. Losses are in reasonable ranges (not billions!)")
print("\n" + "=" * 80)
# Run tests
outputs = test_model_outputs()
test_denormalization()
test_loss_magnitude()
print("\n" + "=" * 80)
print("TEST SUMMARY")
print("=" * 80)
print("\nIf all tests passed (✅), the normalization fix is working correctly!")
print("You should now see reasonable losses in training logs:")
print(" - Total loss: ~0.5-1.0 (not billions!)")
print(" - Candle loss: ~0.1-0.3")
print(" - Real price errors: $2-20 (not $147,000!)")
print("\nYou can now resume training and monitor these metrics.")
print("=" * 80 + "\n")
if __name__ == "__main__":
main()