Files
gogo2/test_model_size.py
2025-11-13 17:45:42 +02:00

68 lines
2.2 KiB
Python

#!/usr/bin/env python3
"""Quick test to verify model size and GPU usage"""
import torch
from NN.models.advanced_transformer_trading import TradingTransformerConfig, AdvancedTradingTransformer
# Create config
config = TradingTransformerConfig()
# Create model
model = AdvancedTradingTransformer(config)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model Configuration:")
print(f" d_model: {config.d_model}")
print(f" n_heads: {config.n_heads}")
print(f" n_layers: {config.n_layers}")
print(f" d_ff: {config.d_ff}")
print(f" seq_len: {config.seq_len}")
print()
print(f"Model Parameters:")
print(f" Total: {total_params:,} ({total_params/1e6:.2f}M)")
print(f" Trainable: {trainable_params:,} ({trainable_params/1e6:.2f}M)")
print(f" Model size (FP32): {total_params * 4 / 1024**2:.2f} MB")
print(f" Model size (FP16): {total_params * 2 / 1024**2:.2f} MB")
print()
# Check GPU availability
if torch.cuda.is_available():
print(f"GPU Available: ✅ CUDA")
print(f" Device: {torch.cuda.get_device_name(0)}")
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
# Move model to GPU
device = torch.device('cuda')
model = model.to(device)
print(f" Model moved to GPU ✅")
# Test forward pass
batch_size = 1
seq_len = 200
# Create dummy input
price_data_1m = torch.randn(batch_size, seq_len, 5, device=device)
# Forward pass
with torch.no_grad():
outputs = model(price_data_1m=price_data_1m)
print(f" Forward pass successful ✅")
print(f" GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
print(f" GPU memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
elif hasattr(torch.version, 'hip') and torch.version.hip:
print(f"GPU Available: ✅ ROCm/HIP")
device = torch.device('cuda') # ROCm uses 'cuda' device name
model = model.to(device)
print(f" Model moved to GPU ✅")
else:
print(f"GPU Available: ❌ CPU only")
print(f" Training will use CPU (slower)")
print()
print("Model ready for training! 🚀")