Files
gogo2/test_massive_dqn.py
Dobromir Popov 9e1684f9f8 cb ws
2025-07-27 20:56:37 +03:00

232 lines
7.2 KiB
Python

#!/usr/bin/env python3
"""
Test script for the massive 50M parameter DQN agent
Tests:
1. Model initialization and parameter count
2. Forward pass functionality
3. Gradient flow verification
4. Training step simulation
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import torch
import numpy as np
from NN.models.dqn_agent import DQNAgent, DQNNetwork
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def test_dqn_architecture():
"""Test the new massive DQN architecture"""
print("🔥 Testing Massive DQN Architecture (Target: 50M parameters)")
# Test the network directly first
input_dim = 7850 # BaseDataInput feature size
n_actions = 3 # BUY, SELL, HOLD
print(f"\n1. Creating DQN Network with input_dim={input_dim}, n_actions={n_actions}")
network = DQNNetwork(input_dim, n_actions)
# Count parameters
total_params = sum(p.numel() for p in network.parameters())
print(f" ✅ Total parameters: {total_params:,}")
print(f" 🎯 Target achieved: {total_params >= 50_000_000}")
# Test forward pass
print(f"\n2. Testing forward pass...")
batch_size = 4
test_input = torch.randn(batch_size, input_dim)
with torch.no_grad():
output = network(test_input)
if isinstance(output, tuple):
q_values, regime_pred, price_pred, volatility_pred, features = output
print(f" ✅ Q-values shape: {q_values.shape}")
print(f" ✅ Regime prediction shape: {regime_pred.shape}")
print(f" ✅ Price prediction shape: {price_pred.shape}")
print(f" ✅ Volatility prediction shape: {volatility_pred.shape}")
print(f" ✅ Features shape: {features.shape}")
else:
print(f" ✅ Output shape: {output.shape}")
return network
def test_gradient_flow():
"""Test that gradients flow properly through the network"""
print(f"\n🧪 Testing Gradient Flow...")
# Create agent
state_shape = (7850,)
agent = DQNAgent(
state_shape=state_shape,
n_actions=3,
learning_rate=0.001,
batch_size=16,
buffer_size=1000
)
# Force disable mixed precision
agent.use_mixed_precision = False
print(f" ✅ Mixed precision disabled: {not agent.use_mixed_precision}")
# Ensure model is in training mode
agent.policy_net.train()
print(f" ✅ Model in training mode: {agent.policy_net.training}")
# Create test batch
batch_size = 8
state_dim = 7850
states = torch.randn(batch_size, state_dim, requires_grad=True)
actions = torch.randint(0, 3, (batch_size,))
rewards = torch.randn(batch_size)
next_states = torch.randn(batch_size, state_dim)
dones = torch.zeros(batch_size)
print(f" 📊 Test batch created - states: {states.shape}, actions: {actions.shape}")
# Test forward pass and check gradients
agent.optimizer.zero_grad()
# Forward pass
output = agent.policy_net(states)
if isinstance(output, tuple):
q_values = output[0]
else:
q_values = output
print(f" ✅ Forward pass successful - Q-values: {q_values.shape}")
print(f" ✅ Q-values require grad: {q_values.requires_grad}")
# Gather Q-values for actions
current_q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
print(f" ✅ Gathered Q-values require grad: {current_q_values.requires_grad}")
# Compute simple loss
target_q_values = rewards # Simplified target
loss = torch.nn.MSELoss()(current_q_values, target_q_values)
print(f" ✅ Loss computed: {loss.item():.6f}")
print(f" ✅ Loss requires grad: {loss.requires_grad}")
# Backward pass
loss.backward()
# Check if gradients exist and are finite
grad_norms = []
params_with_grad = 0
total_params = 0
for name, param in agent.policy_net.named_parameters():
total_params += 1
if param.grad is not None:
params_with_grad += 1
grad_norm = param.grad.norm().item()
grad_norms.append(grad_norm)
if not torch.isfinite(param.grad).all():
print(f" ❌ Non-finite gradients in {name}")
return False
print(f" ✅ Parameters with gradients: {params_with_grad}/{total_params}")
print(f" ✅ Average gradient norm: {np.mean(grad_norms):.6f}")
print(f" ✅ Max gradient norm: {max(grad_norms):.6f}")
# Test optimizer step
agent.optimizer.step()
print(f" ✅ Optimizer step completed successfully")
return True
def test_training_step():
"""Test a complete training step"""
print(f"\n🏋️ Testing Complete Training Step...")
# Create agent
state_shape = (7850,)
agent = DQNAgent(
state_shape=state_shape,
n_actions=3,
learning_rate=0.001,
batch_size=8,
buffer_size=1000
)
# Force disable mixed precision
agent.use_mixed_precision = False
# Add some experiences
for i in range(20):
state = np.random.randn(7850).astype(np.float32)
action = np.random.randint(0, 3)
reward = np.random.randn() * 0.1
next_state = np.random.randn(7850).astype(np.float32)
done = np.random.random() < 0.1
agent.remember(state, action, reward, next_state, done)
print(f" ✅ Added {len(agent.memory)} experiences to memory")
# Test replay training
if len(agent.memory) >= agent.batch_size:
loss = agent.replay()
print(f" ✅ Training completed with loss: {loss:.6f}")
if loss > 0:
print(f" ✅ Training successful - non-zero loss indicates learning")
return True
else:
print(f" ❌ Training failed - zero loss indicates gradient issues")
return False
else:
print(f" ⚠️ Not enough experiences for training")
return True
def main():
"""Run all tests"""
print("🚀 MASSIVE DQN AGENT TESTING SUITE")
print("=" * 50)
# Test 1: Architecture
try:
network = test_dqn_architecture()
print(" ✅ Architecture test PASSED")
except Exception as e:
print(f" ❌ Architecture test FAILED: {e}")
return False
# Test 2: Gradient flow
try:
gradient_success = test_gradient_flow()
if gradient_success:
print(" ✅ Gradient flow test PASSED")
else:
print(" ❌ Gradient flow test FAILED")
return False
except Exception as e:
print(f" ❌ Gradient flow test FAILED: {e}")
return False
# Test 3: Training step
try:
training_success = test_training_step()
if training_success:
print(" ✅ Training step test PASSED")
else:
print(" ❌ Training step test FAILED")
return False
except Exception as e:
print(f" ❌ Training step test FAILED: {e}")
return False
print("\n🎉 ALL TESTS PASSED!")
print("✅ Massive DQN agent is ready for 50M parameter learning!")
return True
if __name__ == "__main__":
success = main()
exit(0 if success else 1)