301 lines
12 KiB
Python
301 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Model Parameter Audit Script
|
|
Analyzes and calculates the total parameters for all model architectures in the trading system.
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import sys
|
|
import os
|
|
import json
|
|
from pathlib import Path
|
|
from collections import defaultdict
|
|
import numpy as np
|
|
|
|
# Add paths to import local modules
|
|
sys.path.append('.')
|
|
sys.path.append('./NN/models')
|
|
sys.path.append('./NN')
|
|
|
|
def count_parameters(model):
|
|
"""Count total parameters in a PyTorch model"""
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
return total_params, trainable_params
|
|
|
|
def get_model_size_mb(model):
|
|
"""Calculate model size in MB"""
|
|
param_size = 0
|
|
buffer_size = 0
|
|
|
|
for param in model.parameters():
|
|
param_size += param.nelement() * param.element_size()
|
|
|
|
for buffer in model.buffers():
|
|
buffer_size += buffer.nelement() * buffer.element_size()
|
|
|
|
size_mb = (param_size + buffer_size) / 1024 / 1024
|
|
return size_mb
|
|
|
|
def analyze_layer_parameters(model, model_name):
|
|
"""Analyze parameters by layer"""
|
|
layer_info = []
|
|
total_params = 0
|
|
|
|
for name, module in model.named_modules():
|
|
if len(list(module.children())) == 0: # Leaf modules only
|
|
params = sum(p.numel() for p in module.parameters())
|
|
if params > 0:
|
|
layer_info.append({
|
|
'layer_name': name,
|
|
'layer_type': type(module).__name__,
|
|
'parameters': params,
|
|
'trainable': sum(p.numel() for p in module.parameters() if p.requires_grad)
|
|
})
|
|
total_params += params
|
|
|
|
return layer_info, total_params
|
|
|
|
def audit_enhanced_cnn():
|
|
"""Audit Enhanced CNN model - the primary model architecture"""
|
|
try:
|
|
from enhanced_cnn import EnhancedCNN
|
|
|
|
# Test with the optimal configuration based on analysis
|
|
config = {'input_shape': (5, 100), 'n_actions': 3, 'name': 'EnhancedCNN_Optimized'}
|
|
|
|
try:
|
|
model = EnhancedCNN(
|
|
input_shape=config['input_shape'],
|
|
n_actions=config['n_actions']
|
|
)
|
|
|
|
total_params, trainable_params = count_parameters(model)
|
|
size_mb = get_model_size_mb(model)
|
|
layer_info, _ = analyze_layer_parameters(model, config['name'])
|
|
|
|
result = {
|
|
'model_name': config['name'],
|
|
'input_shape': config['input_shape'],
|
|
'total_parameters': total_params,
|
|
'trainable_parameters': trainable_params,
|
|
'size_mb': size_mb,
|
|
'layer_breakdown': layer_info
|
|
}
|
|
|
|
print(f"✅ {config['name']}: {total_params:,} parameters ({size_mb:.2f} MB)")
|
|
return [result]
|
|
|
|
except Exception as e:
|
|
print(f"❌ Failed to analyze {config['name']}: {e}")
|
|
return []
|
|
|
|
except ImportError as e:
|
|
print(f"❌ Cannot import EnhancedCNN: {e}")
|
|
return []
|
|
|
|
def audit_dqn_agent():
|
|
"""Audit DQN Agent model - now using Enhanced CNN"""
|
|
try:
|
|
from dqn_agent import DQNAgent
|
|
|
|
# Test with optimal configuration
|
|
config = {'state_shape': (5, 100), 'n_actions': 3, 'name': 'DQNAgent_EnhancedCNN'}
|
|
|
|
try:
|
|
agent = DQNAgent(
|
|
state_shape=config['state_shape'],
|
|
n_actions=config['n_actions']
|
|
)
|
|
|
|
# Analyze both policy and target networks
|
|
policy_params, policy_trainable = count_parameters(agent.policy_net)
|
|
target_params, target_trainable = count_parameters(agent.target_net)
|
|
total_params = policy_params + target_params
|
|
|
|
policy_size = get_model_size_mb(agent.policy_net)
|
|
target_size = get_model_size_mb(agent.target_net)
|
|
total_size = policy_size + target_size
|
|
|
|
layer_info, _ = analyze_layer_parameters(agent.policy_net, f"{config['name']}_policy")
|
|
|
|
result = {
|
|
'model_name': config['name'],
|
|
'state_shape': config['state_shape'],
|
|
'policy_parameters': policy_params,
|
|
'target_parameters': target_params,
|
|
'total_parameters': total_params,
|
|
'size_mb': total_size,
|
|
'layer_breakdown': layer_info
|
|
}
|
|
|
|
print(f"✅ {config['name']}: {total_params:,} parameters ({total_size:.2f} MB)")
|
|
print(f" Policy: {policy_params:,}, Target: {target_params:,}")
|
|
return [result]
|
|
|
|
except Exception as e:
|
|
print(f"❌ Failed to analyze {config['name']}: {e}")
|
|
return []
|
|
|
|
except ImportError as e:
|
|
print(f"❌ Cannot import DQNAgent: {e}")
|
|
return []
|
|
|
|
def audit_saved_models():
|
|
"""Audit saved model files"""
|
|
print("\n🔍 Auditing Saved Model Files...")
|
|
|
|
model_dirs = ['models/', 'NN/models/saved/']
|
|
saved_models = []
|
|
|
|
for model_dir in model_dirs:
|
|
if os.path.exists(model_dir):
|
|
for file in os.listdir(model_dir):
|
|
if file.endswith('.pt'):
|
|
file_path = os.path.join(model_dir, file)
|
|
try:
|
|
file_size = os.path.getsize(file_path) / (1024 * 1024) # MB
|
|
|
|
# Try to load and inspect the model
|
|
try:
|
|
checkpoint = torch.load(file_path, map_location='cpu')
|
|
|
|
# Count parameters if it's a state dict
|
|
if isinstance(checkpoint, dict):
|
|
total_params = 0
|
|
if 'state_dict' in checkpoint:
|
|
state_dict = checkpoint['state_dict']
|
|
elif 'model_state_dict' in checkpoint:
|
|
state_dict = checkpoint['model_state_dict']
|
|
elif 'policy_net' in checkpoint:
|
|
# DQN agent format
|
|
policy_params = sum(p.numel() for p in checkpoint['policy_net'].values() if isinstance(p, torch.Tensor))
|
|
target_params = sum(p.numel() for p in checkpoint['target_net'].values() if isinstance(p, torch.Tensor)) if 'target_net' in checkpoint else 0
|
|
total_params = policy_params + target_params
|
|
state_dict = None
|
|
else:
|
|
# Direct state dict
|
|
state_dict = checkpoint
|
|
|
|
if state_dict and isinstance(state_dict, dict):
|
|
total_params = sum(p.numel() for p in state_dict.values() if isinstance(p, torch.Tensor))
|
|
|
|
saved_models.append({
|
|
'filename': file,
|
|
'path': file_path,
|
|
'size_mb': file_size,
|
|
'estimated_parameters': total_params,
|
|
'checkpoint_keys': list(checkpoint.keys()) if isinstance(checkpoint, dict) else 'N/A'
|
|
})
|
|
|
|
print(f"📁 {file}: {file_size:.1f} MB, ~{total_params:,} parameters")
|
|
else:
|
|
saved_models.append({
|
|
'filename': file,
|
|
'path': file_path,
|
|
'size_mb': file_size,
|
|
'estimated_parameters': 'Unknown',
|
|
'checkpoint_keys': 'N/A'
|
|
})
|
|
print(f"📁 {file}: {file_size:.1f} MB, Unknown parameters")
|
|
|
|
except Exception as e:
|
|
saved_models.append({
|
|
'filename': file,
|
|
'path': file_path,
|
|
'size_mb': file_size,
|
|
'estimated_parameters': 'Error loading',
|
|
'error': str(e)
|
|
})
|
|
print(f"📁 {file}: {file_size:.1f} MB, Error: {e}")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error processing {file}: {e}")
|
|
|
|
return saved_models
|
|
|
|
def generate_report(enhanced_cnn_results, dqn_results, saved_models):
|
|
"""Generate comprehensive audit report"""
|
|
|
|
report = {
|
|
'timestamp': str(torch.datetime.now()) if hasattr(torch, 'datetime') else 'N/A',
|
|
'pytorch_version': torch.__version__,
|
|
'cuda_available': torch.cuda.is_available(),
|
|
'device_info': {
|
|
'cuda_device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
|
'current_device': str(torch.cuda.current_device()) if torch.cuda.is_available() else 'CPU'
|
|
},
|
|
'model_architectures': {
|
|
'enhanced_cnn': enhanced_cnn_results,
|
|
'dqn_agent': dqn_results
|
|
},
|
|
'saved_models': saved_models,
|
|
'summary': {}
|
|
}
|
|
|
|
# Calculate summary statistics
|
|
all_results = enhanced_cnn_results + dqn_results
|
|
|
|
if all_results:
|
|
total_params = sum(r.get('total_parameters', 0) for r in all_results)
|
|
total_size = sum(r.get('size_mb', 0) for r in all_results)
|
|
max_params = max(r.get('total_parameters', 0) for r in all_results)
|
|
min_params = min(r.get('total_parameters', 0) for r in all_results)
|
|
|
|
report['summary'] = {
|
|
'total_model_architectures': len(all_results),
|
|
'total_parameters_across_all': total_params,
|
|
'total_size_mb': total_size,
|
|
'largest_model_parameters': max_params,
|
|
'smallest_model_parameters': min_params,
|
|
'saved_models_count': len(saved_models),
|
|
'saved_models_total_size_mb': sum(m.get('size_mb', 0) for m in saved_models)
|
|
}
|
|
|
|
return report
|
|
|
|
def main():
|
|
"""Main audit function"""
|
|
print("🔍 STREAMLINED MODEL PARAMETER AUDIT")
|
|
print("=" * 50)
|
|
|
|
print("\n📊 Analyzing Enhanced CNN Model (Primary Architecture)...")
|
|
enhanced_cnn_results = audit_enhanced_cnn()
|
|
|
|
print("\n🤖 Analyzing DQN Agent with Enhanced CNN...")
|
|
dqn_results = audit_dqn_agent()
|
|
|
|
print("\n💾 Auditing Saved Models...")
|
|
saved_models = audit_saved_models()
|
|
|
|
print("\n📋 Generating Report...")
|
|
report = generate_report(enhanced_cnn_results, dqn_results, saved_models)
|
|
|
|
# Save detailed report
|
|
with open('model_parameter_audit_report.json', 'w') as f:
|
|
json.dump(report, f, indent=2, default=str)
|
|
|
|
# Print summary
|
|
print("\n📊 STREAMLINED AUDIT SUMMARY")
|
|
print("=" * 50)
|
|
if report['summary']:
|
|
summary = report['summary']
|
|
print(f"Streamlined Model Architectures: {summary['total_model_architectures']}")
|
|
print(f"Total Parameters: {summary['total_parameters_across_all']:,}")
|
|
print(f"Total Memory Usage: {summary['total_size_mb']:.1f} MB")
|
|
print(f"Largest Model: {summary['largest_model_parameters']:,} parameters")
|
|
print(f"Smallest Model: {summary['smallest_model_parameters']:,} parameters")
|
|
print(f"Saved Models: {summary['saved_models_count']} files")
|
|
print(f"Saved Models Total Size: {summary['saved_models_total_size_mb']:.1f} MB")
|
|
|
|
print(f"\n📄 Detailed report saved to: model_parameter_audit_report.json")
|
|
print("\n🎯 STREAMLINING COMPLETE:")
|
|
print(" ✅ Enhanced CNN: Primary high-performance model")
|
|
print(" ✅ DQN Agent: Now uses Enhanced CNN for better performance")
|
|
print(" ❌ Simple models: Removed for streamlined architecture")
|
|
|
|
return report
|
|
|
|
if __name__ == "__main__":
|
|
main() |