raining normalization fix
This commit is contained in:
288
test_normalization_fix.py
Normal file
288
test_normalization_fix.py
Normal file
@@ -0,0 +1,288 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify normalization fix is working correctly
|
||||
|
||||
This creates a simple test batch and verifies:
|
||||
1. Model outputs are in expected ranges (thanks to Sigmoid/Tanh constraints)
|
||||
2. Normalization parameters are stored and can be retrieved
|
||||
3. Denormalization works correctly
|
||||
4. Losses are in reasonable ranges (not billions!)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from NN.models.advanced_transformer_trading import (
|
||||
AdvancedTradingTransformer,
|
||||
TradingTransformerConfig,
|
||||
TradingTransformerTrainer
|
||||
)
|
||||
|
||||
def create_test_batch():
|
||||
"""Create a simple test batch with known normalization parameters"""
|
||||
batch_size = 1
|
||||
seq_len = 200
|
||||
|
||||
# Create synthetic price data in [0, 1] range (normalized)
|
||||
price_data_1m = torch.rand(batch_size, seq_len, 5) * 0.2 + 0.4 # Range [0.4, 0.6]
|
||||
|
||||
# Create normalization parameters (simulate ETHUSDT around $2500)
|
||||
norm_params = {
|
||||
'1m': {
|
||||
'price_min': 2480.0,
|
||||
'price_max': 2520.0,
|
||||
'volume_min': 100.0,
|
||||
'volume_max': 10000.0
|
||||
}
|
||||
}
|
||||
|
||||
# Create future candle target (slightly higher close)
|
||||
future_candle_1m = torch.rand(batch_size, 5) * 0.2 + 0.45 # Range [0.45, 0.65]
|
||||
|
||||
# Create other required inputs
|
||||
cob_data = torch.zeros(batch_size, seq_len, 100)
|
||||
tech_data = torch.zeros(batch_size, 40)
|
||||
market_data = torch.zeros(batch_size, 30)
|
||||
position_state = torch.zeros(batch_size, 5)
|
||||
actions = torch.tensor([1], dtype=torch.long) # BUY action
|
||||
future_prices = torch.tensor([[0.01]], dtype=torch.float32) # 1% price increase expected
|
||||
trade_success = torch.tensor([[1.0]], dtype=torch.float32)
|
||||
trend_target = torch.tensor([[0.785, 0.5, 1.0]], dtype=torch.float32)
|
||||
|
||||
batch = {
|
||||
'price_data_1m': price_data_1m,
|
||||
'price_data_1s': None,
|
||||
'price_data_1h': None,
|
||||
'price_data_1d': None,
|
||||
'btc_data_1m': None,
|
||||
'cob_data': cob_data,
|
||||
'tech_data': tech_data,
|
||||
'market_data': market_data,
|
||||
'position_state': position_state,
|
||||
'actions': actions,
|
||||
'future_prices': future_prices,
|
||||
'trade_success': trade_success,
|
||||
'trend_target': trend_target,
|
||||
'future_candle_1m': future_candle_1m,
|
||||
'future_candle_1s': None,
|
||||
'future_candle_1h': None,
|
||||
'future_candle_1d': None,
|
||||
'norm_params': norm_params
|
||||
}
|
||||
|
||||
return batch
|
||||
|
||||
def test_model_outputs():
|
||||
"""Test that model outputs are in expected ranges"""
|
||||
print("=" * 80)
|
||||
print("TESTING: Model Output Constraints")
|
||||
print("=" * 80)
|
||||
|
||||
# Create small model for testing
|
||||
config = TradingTransformerConfig(
|
||||
d_model=128,
|
||||
n_heads=4,
|
||||
n_layers=2,
|
||||
seq_len=200
|
||||
)
|
||||
|
||||
model = AdvancedTradingTransformer(config)
|
||||
model.eval()
|
||||
|
||||
batch = create_test_batch()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
price_data_1m=batch['price_data_1m'],
|
||||
cob_data=batch['cob_data'],
|
||||
tech_data=batch['tech_data'],
|
||||
market_data=batch['market_data'],
|
||||
position_state=batch['position_state']
|
||||
)
|
||||
|
||||
# Check candle predictions are in [0, 1] range (thanks to Sigmoid)
|
||||
if 'next_candles' in outputs and '1m' in outputs['next_candles']:
|
||||
candle_pred = outputs['next_candles']['1m']
|
||||
print(f"\nCandle Prediction (1m):")
|
||||
print(f" Shape: {candle_pred.shape}")
|
||||
print(f" Min value: {candle_pred.min().item():.6f}")
|
||||
print(f" Max value: {candle_pred.max().item():.6f}")
|
||||
print(f" Mean value: {candle_pred.mean().item():.6f}")
|
||||
|
||||
if candle_pred.min() >= 0.0 and candle_pred.max() <= 1.0:
|
||||
print(" ✅ PASS: Values in [0, 1] range (Sigmoid working!)")
|
||||
else:
|
||||
print(" ❌ FAIL: Values outside [0, 1] range!")
|
||||
|
||||
# Check price prediction is in [-1, 1] range (thanks to Tanh)
|
||||
if 'price_prediction' in outputs:
|
||||
price_pred = outputs['price_prediction']
|
||||
print(f"\nPrice Prediction (change ratio):")
|
||||
print(f" Shape: {price_pred.shape}")
|
||||
print(f" Value: {price_pred.item():.6f}")
|
||||
|
||||
if price_pred.min() >= -1.0 and price_pred.max() <= 1.0:
|
||||
print(" ✅ PASS: Values in [-1, 1] range (Tanh working!)")
|
||||
else:
|
||||
print(" ❌ FAIL: Values outside [-1, 1] range!")
|
||||
|
||||
# Check action probabilities sum to 1
|
||||
if 'action_probs' in outputs:
|
||||
action_probs = outputs['action_probs']
|
||||
print(f"\nAction Probabilities:")
|
||||
print(f" BUY: {action_probs[0, 0].item():.4f}")
|
||||
print(f" SELL: {action_probs[0, 1].item():.4f}")
|
||||
print(f" HOLD: {action_probs[0, 2].item():.4f}")
|
||||
print(f" Sum: {action_probs[0].sum().item():.6f}")
|
||||
|
||||
if abs(action_probs[0].sum().item() - 1.0) < 0.001:
|
||||
print(" ✅ PASS: Probabilities sum to 1.0")
|
||||
else:
|
||||
print(" ❌ FAIL: Probabilities don't sum to 1.0!")
|
||||
|
||||
return outputs
|
||||
|
||||
def test_denormalization():
|
||||
"""Test denormalization functions"""
|
||||
print("\n" + "=" * 80)
|
||||
print("TESTING: Denormalization Functions")
|
||||
print("=" * 80)
|
||||
|
||||
# Create test normalized candle
|
||||
normalized_candle = torch.tensor([[0.5, 0.6, 0.4, 0.55, 0.3]]) # OHLCV
|
||||
|
||||
# Normalization params (ETHUSDT $2480-$2520)
|
||||
norm_params = {
|
||||
'price_min': 2480.0,
|
||||
'price_max': 2520.0,
|
||||
'volume_min': 100.0,
|
||||
'volume_max': 10000.0
|
||||
}
|
||||
|
||||
print(f"\nNormalized Candle: {normalized_candle[0].tolist()}")
|
||||
print(f"Normalization Params: price [{norm_params['price_min']}, {norm_params['price_max']}], "
|
||||
f"volume [{norm_params['volume_min']}, {norm_params['volume_max']}]")
|
||||
|
||||
# Denormalize
|
||||
denorm_candle = TradingTransformerTrainer.denormalize_candle(normalized_candle, norm_params)
|
||||
|
||||
print(f"\nDenormalized Candle:")
|
||||
print(f" Open: ${denorm_candle[0, 0].item():.2f}")
|
||||
print(f" High: ${denorm_candle[0, 1].item():.2f}")
|
||||
print(f" Low: ${denorm_candle[0, 2].item():.2f}")
|
||||
print(f" Close: ${denorm_candle[0, 3].item():.2f}")
|
||||
print(f" Volume: {denorm_candle[0, 4].item():.2f}")
|
||||
|
||||
# Verify values are in expected range
|
||||
expected_min_price = norm_params['price_min']
|
||||
expected_max_price = norm_params['price_max']
|
||||
|
||||
prices_ok = True
|
||||
for i, name in enumerate(['Open', 'High', 'Low', 'Close']):
|
||||
value = denorm_candle[0, i].item()
|
||||
if value < expected_min_price or value > expected_max_price:
|
||||
print(f" ❌ FAIL: {name} price ${value:.2f} outside expected range!")
|
||||
prices_ok = False
|
||||
|
||||
if prices_ok:
|
||||
print(f" ✅ PASS: All prices in expected range [${expected_min_price}, ${expected_max_price}]")
|
||||
|
||||
# Verify volume
|
||||
volume = denorm_candle[0, 4].item()
|
||||
if norm_params['volume_min'] <= volume <= norm_params['volume_max']:
|
||||
print(f" ✅ PASS: Volume {volume:.2f} in expected range [{norm_params['volume_min']}, {norm_params['volume_max']}]")
|
||||
else:
|
||||
print(f" ❌ FAIL: Volume {volume:.2f} outside expected range!")
|
||||
|
||||
def test_loss_magnitude():
|
||||
"""Test that losses are in reasonable ranges"""
|
||||
print("\n" + "=" * 80)
|
||||
print("TESTING: Loss Magnitudes")
|
||||
print("=" * 80)
|
||||
|
||||
config = TradingTransformerConfig(
|
||||
d_model=128,
|
||||
n_heads=4,
|
||||
n_layers=2,
|
||||
seq_len=200
|
||||
)
|
||||
|
||||
model = AdvancedTradingTransformer(config)
|
||||
trainer = TradingTransformerTrainer(model, config)
|
||||
|
||||
batch = create_test_batch()
|
||||
|
||||
# Run one training step
|
||||
result = trainer.train_step(batch, accumulate_gradients=False)
|
||||
|
||||
print(f"\nTraining Step Results:")
|
||||
print(f" Total Loss: {result['total_loss']:.6f}")
|
||||
print(f" Action Loss: {result['action_loss']:.6f}")
|
||||
print(f" Price Loss: {result['price_loss']:.6f}")
|
||||
print(f" Trend Loss: {result['trend_loss']:.6f}")
|
||||
print(f" Candle Loss: {result['candle_loss']:.6f}")
|
||||
print(f" Action Accuracy: {result['accuracy']:.2%}")
|
||||
print(f" Candle Accuracy: {result['candle_accuracy']:.2%}")
|
||||
|
||||
# Check losses are reasonable (not billions!)
|
||||
all_ok = True
|
||||
|
||||
if result['total_loss'] < 100.0:
|
||||
print(f" ✅ PASS: Total loss < 100 (was {result['total_loss']:.6f})")
|
||||
else:
|
||||
print(f" ❌ FAIL: Total loss too high! ({result['total_loss']:.6f})")
|
||||
all_ok = False
|
||||
|
||||
if result['candle_loss'] < 10.0:
|
||||
print(f" ✅ PASS: Candle loss < 10 (was {result['candle_loss']:.6f})")
|
||||
else:
|
||||
print(f" ❌ FAIL: Candle loss too high! ({result['candle_loss']:.6f})")
|
||||
all_ok = False
|
||||
|
||||
# Check denormalized losses if available
|
||||
if 'candle_loss_denorm' in result and result['candle_loss_denorm']:
|
||||
print(f"\n Denormalized Candle Losses (Real Price Errors):")
|
||||
for tf, loss in result['candle_loss_denorm'].items():
|
||||
print(f" {tf}: ${loss:.2f}")
|
||||
if loss < 1000.0:
|
||||
print(f" ✅ PASS: Real price error < $1000")
|
||||
else:
|
||||
print(f" ❌ FAIL: Real price error too high!")
|
||||
all_ok = False
|
||||
|
||||
if all_ok:
|
||||
print("\n ✅ ALL TESTS PASSED: Losses in reasonable ranges!")
|
||||
else:
|
||||
print("\n ❌ SOME TESTS FAILED: Check model/normalization!")
|
||||
|
||||
return result
|
||||
|
||||
def main():
|
||||
print("\n" + "=" * 80)
|
||||
print("NORMALIZATION FIX VERIFICATION TEST")
|
||||
print("=" * 80)
|
||||
print("\nThis test verifies that:")
|
||||
print("1. Model outputs are properly constrained (Sigmoid/Tanh)")
|
||||
print("2. Normalization parameters are stored and accessible")
|
||||
print("3. Denormalization functions work correctly")
|
||||
print("4. Losses are in reasonable ranges (not billions!)")
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
# Run tests
|
||||
outputs = test_model_outputs()
|
||||
test_denormalization()
|
||||
test_loss_magnitude()
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST SUMMARY")
|
||||
print("=" * 80)
|
||||
print("\nIf all tests passed (✅), the normalization fix is working correctly!")
|
||||
print("You should now see reasonable losses in training logs:")
|
||||
print(" - Total loss: ~0.5-1.0 (not billions!)")
|
||||
print(" - Candle loss: ~0.1-0.3")
|
||||
print(" - Real price errors: $2-20 (not $147,000!)")
|
||||
print("\nYou can now resume training and monitor these metrics.")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user