gogo2/model_parameter_audit.py
2025-05-24 23:22:34 +03:00

301 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Model Parameter Audit Script
Analyzes and calculates the total parameters for all model architectures in the trading system.
"""
import torch
import torch.nn as nn
import sys
import os
import json
from pathlib import Path
from collections import defaultdict
import numpy as np
# Add paths to import local modules
sys.path.append('.')
sys.path.append('./NN/models')
sys.path.append('./NN')
def count_parameters(model):
"""Count total parameters in a PyTorch model"""
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total_params, trainable_params
def get_model_size_mb(model):
"""Calculate model size in MB"""
param_size = 0
buffer_size = 0
for param in model.parameters():
param_size += param.nelement() * param.element_size()
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
size_mb = (param_size + buffer_size) / 1024 / 1024
return size_mb
def analyze_layer_parameters(model, model_name):
"""Analyze parameters by layer"""
layer_info = []
total_params = 0
for name, module in model.named_modules():
if len(list(module.children())) == 0: # Leaf modules only
params = sum(p.numel() for p in module.parameters())
if params > 0:
layer_info.append({
'layer_name': name,
'layer_type': type(module).__name__,
'parameters': params,
'trainable': sum(p.numel() for p in module.parameters() if p.requires_grad)
})
total_params += params
return layer_info, total_params
def audit_enhanced_cnn():
"""Audit Enhanced CNN model - the primary model architecture"""
try:
from enhanced_cnn import EnhancedCNN
# Test with the optimal configuration based on analysis
config = {'input_shape': (5, 100), 'n_actions': 3, 'name': 'EnhancedCNN_Optimized'}
try:
model = EnhancedCNN(
input_shape=config['input_shape'],
n_actions=config['n_actions']
)
total_params, trainable_params = count_parameters(model)
size_mb = get_model_size_mb(model)
layer_info, _ = analyze_layer_parameters(model, config['name'])
result = {
'model_name': config['name'],
'input_shape': config['input_shape'],
'total_parameters': total_params,
'trainable_parameters': trainable_params,
'size_mb': size_mb,
'layer_breakdown': layer_info
}
print(f"{config['name']}: {total_params:,} parameters ({size_mb:.2f} MB)")
return [result]
except Exception as e:
print(f"❌ Failed to analyze {config['name']}: {e}")
return []
except ImportError as e:
print(f"❌ Cannot import EnhancedCNN: {e}")
return []
def audit_dqn_agent():
"""Audit DQN Agent model - now using Enhanced CNN"""
try:
from dqn_agent import DQNAgent
# Test with optimal configuration
config = {'state_shape': (5, 100), 'n_actions': 3, 'name': 'DQNAgent_EnhancedCNN'}
try:
agent = DQNAgent(
state_shape=config['state_shape'],
n_actions=config['n_actions']
)
# Analyze both policy and target networks
policy_params, policy_trainable = count_parameters(agent.policy_net)
target_params, target_trainable = count_parameters(agent.target_net)
total_params = policy_params + target_params
policy_size = get_model_size_mb(agent.policy_net)
target_size = get_model_size_mb(agent.target_net)
total_size = policy_size + target_size
layer_info, _ = analyze_layer_parameters(agent.policy_net, f"{config['name']}_policy")
result = {
'model_name': config['name'],
'state_shape': config['state_shape'],
'policy_parameters': policy_params,
'target_parameters': target_params,
'total_parameters': total_params,
'size_mb': total_size,
'layer_breakdown': layer_info
}
print(f"{config['name']}: {total_params:,} parameters ({total_size:.2f} MB)")
print(f" Policy: {policy_params:,}, Target: {target_params:,}")
return [result]
except Exception as e:
print(f"❌ Failed to analyze {config['name']}: {e}")
return []
except ImportError as e:
print(f"❌ Cannot import DQNAgent: {e}")
return []
def audit_saved_models():
"""Audit saved model files"""
print("\n🔍 Auditing Saved Model Files...")
model_dirs = ['models/', 'NN/models/saved/']
saved_models = []
for model_dir in model_dirs:
if os.path.exists(model_dir):
for file in os.listdir(model_dir):
if file.endswith('.pt'):
file_path = os.path.join(model_dir, file)
try:
file_size = os.path.getsize(file_path) / (1024 * 1024) # MB
# Try to load and inspect the model
try:
checkpoint = torch.load(file_path, map_location='cpu')
# Count parameters if it's a state dict
if isinstance(checkpoint, dict):
total_params = 0
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
elif 'policy_net' in checkpoint:
# DQN agent format
policy_params = sum(p.numel() for p in checkpoint['policy_net'].values() if isinstance(p, torch.Tensor))
target_params = sum(p.numel() for p in checkpoint['target_net'].values() if isinstance(p, torch.Tensor)) if 'target_net' in checkpoint else 0
total_params = policy_params + target_params
state_dict = None
else:
# Direct state dict
state_dict = checkpoint
if state_dict and isinstance(state_dict, dict):
total_params = sum(p.numel() for p in state_dict.values() if isinstance(p, torch.Tensor))
saved_models.append({
'filename': file,
'path': file_path,
'size_mb': file_size,
'estimated_parameters': total_params,
'checkpoint_keys': list(checkpoint.keys()) if isinstance(checkpoint, dict) else 'N/A'
})
print(f"📁 {file}: {file_size:.1f} MB, ~{total_params:,} parameters")
else:
saved_models.append({
'filename': file,
'path': file_path,
'size_mb': file_size,
'estimated_parameters': 'Unknown',
'checkpoint_keys': 'N/A'
})
print(f"📁 {file}: {file_size:.1f} MB, Unknown parameters")
except Exception as e:
saved_models.append({
'filename': file,
'path': file_path,
'size_mb': file_size,
'estimated_parameters': 'Error loading',
'error': str(e)
})
print(f"📁 {file}: {file_size:.1f} MB, Error: {e}")
except Exception as e:
print(f"❌ Error processing {file}: {e}")
return saved_models
def generate_report(enhanced_cnn_results, dqn_results, saved_models):
"""Generate comprehensive audit report"""
report = {
'timestamp': str(torch.datetime.now()) if hasattr(torch, 'datetime') else 'N/A',
'pytorch_version': torch.__version__,
'cuda_available': torch.cuda.is_available(),
'device_info': {
'cuda_device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
'current_device': str(torch.cuda.current_device()) if torch.cuda.is_available() else 'CPU'
},
'model_architectures': {
'enhanced_cnn': enhanced_cnn_results,
'dqn_agent': dqn_results
},
'saved_models': saved_models,
'summary': {}
}
# Calculate summary statistics
all_results = enhanced_cnn_results + dqn_results
if all_results:
total_params = sum(r.get('total_parameters', 0) for r in all_results)
total_size = sum(r.get('size_mb', 0) for r in all_results)
max_params = max(r.get('total_parameters', 0) for r in all_results)
min_params = min(r.get('total_parameters', 0) for r in all_results)
report['summary'] = {
'total_model_architectures': len(all_results),
'total_parameters_across_all': total_params,
'total_size_mb': total_size,
'largest_model_parameters': max_params,
'smallest_model_parameters': min_params,
'saved_models_count': len(saved_models),
'saved_models_total_size_mb': sum(m.get('size_mb', 0) for m in saved_models)
}
return report
def main():
"""Main audit function"""
print("🔍 STREAMLINED MODEL PARAMETER AUDIT")
print("=" * 50)
print("\n📊 Analyzing Enhanced CNN Model (Primary Architecture)...")
enhanced_cnn_results = audit_enhanced_cnn()
print("\n🤖 Analyzing DQN Agent with Enhanced CNN...")
dqn_results = audit_dqn_agent()
print("\n💾 Auditing Saved Models...")
saved_models = audit_saved_models()
print("\n📋 Generating Report...")
report = generate_report(enhanced_cnn_results, dqn_results, saved_models)
# Save detailed report
with open('model_parameter_audit_report.json', 'w') as f:
json.dump(report, f, indent=2, default=str)
# Print summary
print("\n📊 STREAMLINED AUDIT SUMMARY")
print("=" * 50)
if report['summary']:
summary = report['summary']
print(f"Streamlined Model Architectures: {summary['total_model_architectures']}")
print(f"Total Parameters: {summary['total_parameters_across_all']:,}")
print(f"Total Memory Usage: {summary['total_size_mb']:.1f} MB")
print(f"Largest Model: {summary['largest_model_parameters']:,} parameters")
print(f"Smallest Model: {summary['smallest_model_parameters']:,} parameters")
print(f"Saved Models: {summary['saved_models_count']} files")
print(f"Saved Models Total Size: {summary['saved_models_total_size_mb']:.1f} MB")
print(f"\n📄 Detailed report saved to: model_parameter_audit_report.json")
print("\n🎯 STREAMLINING COMPLETE:")
print(" ✅ Enhanced CNN: Primary high-performance model")
print(" ✅ DQN Agent: Now uses Enhanced CNN for better performance")
print(" ❌ Simple models: Removed for streamlined architecture")
return report
if __name__ == "__main__":
main()