cleanup models; beef up models to 500M
This commit is contained in:
301
model_parameter_audit.py
Normal file
301
model_parameter_audit.py
Normal file
@ -0,0 +1,301 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Model Parameter Audit Script
|
||||
Analyzes and calculates the total parameters for all model architectures in the trading system.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
# Add paths to import local modules
|
||||
sys.path.append('.')
|
||||
sys.path.append('./NN/models')
|
||||
sys.path.append('./NN')
|
||||
|
||||
def count_parameters(model):
|
||||
"""Count total parameters in a PyTorch model"""
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
return total_params, trainable_params
|
||||
|
||||
def get_model_size_mb(model):
|
||||
"""Calculate model size in MB"""
|
||||
param_size = 0
|
||||
buffer_size = 0
|
||||
|
||||
for param in model.parameters():
|
||||
param_size += param.nelement() * param.element_size()
|
||||
|
||||
for buffer in model.buffers():
|
||||
buffer_size += buffer.nelement() * buffer.element_size()
|
||||
|
||||
size_mb = (param_size + buffer_size) / 1024 / 1024
|
||||
return size_mb
|
||||
|
||||
def analyze_layer_parameters(model, model_name):
|
||||
"""Analyze parameters by layer"""
|
||||
layer_info = []
|
||||
total_params = 0
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if len(list(module.children())) == 0: # Leaf modules only
|
||||
params = sum(p.numel() for p in module.parameters())
|
||||
if params > 0:
|
||||
layer_info.append({
|
||||
'layer_name': name,
|
||||
'layer_type': type(module).__name__,
|
||||
'parameters': params,
|
||||
'trainable': sum(p.numel() for p in module.parameters() if p.requires_grad)
|
||||
})
|
||||
total_params += params
|
||||
|
||||
return layer_info, total_params
|
||||
|
||||
def audit_enhanced_cnn():
|
||||
"""Audit Enhanced CNN model - the primary model architecture"""
|
||||
try:
|
||||
from enhanced_cnn import EnhancedCNN
|
||||
|
||||
# Test with the optimal configuration based on analysis
|
||||
config = {'input_shape': (5, 100), 'n_actions': 3, 'name': 'EnhancedCNN_Optimized'}
|
||||
|
||||
try:
|
||||
model = EnhancedCNN(
|
||||
input_shape=config['input_shape'],
|
||||
n_actions=config['n_actions']
|
||||
)
|
||||
|
||||
total_params, trainable_params = count_parameters(model)
|
||||
size_mb = get_model_size_mb(model)
|
||||
layer_info, _ = analyze_layer_parameters(model, config['name'])
|
||||
|
||||
result = {
|
||||
'model_name': config['name'],
|
||||
'input_shape': config['input_shape'],
|
||||
'total_parameters': total_params,
|
||||
'trainable_parameters': trainable_params,
|
||||
'size_mb': size_mb,
|
||||
'layer_breakdown': layer_info
|
||||
}
|
||||
|
||||
print(f"✅ {config['name']}: {total_params:,} parameters ({size_mb:.2f} MB)")
|
||||
return [result]
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to analyze {config['name']}: {e}")
|
||||
return []
|
||||
|
||||
except ImportError as e:
|
||||
print(f"❌ Cannot import EnhancedCNN: {e}")
|
||||
return []
|
||||
|
||||
def audit_dqn_agent():
|
||||
"""Audit DQN Agent model - now using Enhanced CNN"""
|
||||
try:
|
||||
from dqn_agent import DQNAgent
|
||||
|
||||
# Test with optimal configuration
|
||||
config = {'state_shape': (5, 100), 'n_actions': 3, 'name': 'DQNAgent_EnhancedCNN'}
|
||||
|
||||
try:
|
||||
agent = DQNAgent(
|
||||
state_shape=config['state_shape'],
|
||||
n_actions=config['n_actions']
|
||||
)
|
||||
|
||||
# Analyze both policy and target networks
|
||||
policy_params, policy_trainable = count_parameters(agent.policy_net)
|
||||
target_params, target_trainable = count_parameters(agent.target_net)
|
||||
total_params = policy_params + target_params
|
||||
|
||||
policy_size = get_model_size_mb(agent.policy_net)
|
||||
target_size = get_model_size_mb(agent.target_net)
|
||||
total_size = policy_size + target_size
|
||||
|
||||
layer_info, _ = analyze_layer_parameters(agent.policy_net, f"{config['name']}_policy")
|
||||
|
||||
result = {
|
||||
'model_name': config['name'],
|
||||
'state_shape': config['state_shape'],
|
||||
'policy_parameters': policy_params,
|
||||
'target_parameters': target_params,
|
||||
'total_parameters': total_params,
|
||||
'size_mb': total_size,
|
||||
'layer_breakdown': layer_info
|
||||
}
|
||||
|
||||
print(f"✅ {config['name']}: {total_params:,} parameters ({total_size:.2f} MB)")
|
||||
print(f" Policy: {policy_params:,}, Target: {target_params:,}")
|
||||
return [result]
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to analyze {config['name']}: {e}")
|
||||
return []
|
||||
|
||||
except ImportError as e:
|
||||
print(f"❌ Cannot import DQNAgent: {e}")
|
||||
return []
|
||||
|
||||
def audit_saved_models():
|
||||
"""Audit saved model files"""
|
||||
print("\n🔍 Auditing Saved Model Files...")
|
||||
|
||||
model_dirs = ['models/', 'NN/models/saved/']
|
||||
saved_models = []
|
||||
|
||||
for model_dir in model_dirs:
|
||||
if os.path.exists(model_dir):
|
||||
for file in os.listdir(model_dir):
|
||||
if file.endswith('.pt'):
|
||||
file_path = os.path.join(model_dir, file)
|
||||
try:
|
||||
file_size = os.path.getsize(file_path) / (1024 * 1024) # MB
|
||||
|
||||
# Try to load and inspect the model
|
||||
try:
|
||||
checkpoint = torch.load(file_path, map_location='cpu')
|
||||
|
||||
# Count parameters if it's a state dict
|
||||
if isinstance(checkpoint, dict):
|
||||
total_params = 0
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
elif 'model_state_dict' in checkpoint:
|
||||
state_dict = checkpoint['model_state_dict']
|
||||
elif 'policy_net' in checkpoint:
|
||||
# DQN agent format
|
||||
policy_params = sum(p.numel() for p in checkpoint['policy_net'].values() if isinstance(p, torch.Tensor))
|
||||
target_params = sum(p.numel() for p in checkpoint['target_net'].values() if isinstance(p, torch.Tensor)) if 'target_net' in checkpoint else 0
|
||||
total_params = policy_params + target_params
|
||||
state_dict = None
|
||||
else:
|
||||
# Direct state dict
|
||||
state_dict = checkpoint
|
||||
|
||||
if state_dict and isinstance(state_dict, dict):
|
||||
total_params = sum(p.numel() for p in state_dict.values() if isinstance(p, torch.Tensor))
|
||||
|
||||
saved_models.append({
|
||||
'filename': file,
|
||||
'path': file_path,
|
||||
'size_mb': file_size,
|
||||
'estimated_parameters': total_params,
|
||||
'checkpoint_keys': list(checkpoint.keys()) if isinstance(checkpoint, dict) else 'N/A'
|
||||
})
|
||||
|
||||
print(f"📁 {file}: {file_size:.1f} MB, ~{total_params:,} parameters")
|
||||
else:
|
||||
saved_models.append({
|
||||
'filename': file,
|
||||
'path': file_path,
|
||||
'size_mb': file_size,
|
||||
'estimated_parameters': 'Unknown',
|
||||
'checkpoint_keys': 'N/A'
|
||||
})
|
||||
print(f"📁 {file}: {file_size:.1f} MB, Unknown parameters")
|
||||
|
||||
except Exception as e:
|
||||
saved_models.append({
|
||||
'filename': file,
|
||||
'path': file_path,
|
||||
'size_mb': file_size,
|
||||
'estimated_parameters': 'Error loading',
|
||||
'error': str(e)
|
||||
})
|
||||
print(f"📁 {file}: {file_size:.1f} MB, Error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing {file}: {e}")
|
||||
|
||||
return saved_models
|
||||
|
||||
def generate_report(enhanced_cnn_results, dqn_results, saved_models):
|
||||
"""Generate comprehensive audit report"""
|
||||
|
||||
report = {
|
||||
'timestamp': str(torch.datetime.now()) if hasattr(torch, 'datetime') else 'N/A',
|
||||
'pytorch_version': torch.__version__,
|
||||
'cuda_available': torch.cuda.is_available(),
|
||||
'device_info': {
|
||||
'cuda_device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
||||
'current_device': str(torch.cuda.current_device()) if torch.cuda.is_available() else 'CPU'
|
||||
},
|
||||
'model_architectures': {
|
||||
'enhanced_cnn': enhanced_cnn_results,
|
||||
'dqn_agent': dqn_results
|
||||
},
|
||||
'saved_models': saved_models,
|
||||
'summary': {}
|
||||
}
|
||||
|
||||
# Calculate summary statistics
|
||||
all_results = enhanced_cnn_results + dqn_results
|
||||
|
||||
if all_results:
|
||||
total_params = sum(r.get('total_parameters', 0) for r in all_results)
|
||||
total_size = sum(r.get('size_mb', 0) for r in all_results)
|
||||
max_params = max(r.get('total_parameters', 0) for r in all_results)
|
||||
min_params = min(r.get('total_parameters', 0) for r in all_results)
|
||||
|
||||
report['summary'] = {
|
||||
'total_model_architectures': len(all_results),
|
||||
'total_parameters_across_all': total_params,
|
||||
'total_size_mb': total_size,
|
||||
'largest_model_parameters': max_params,
|
||||
'smallest_model_parameters': min_params,
|
||||
'saved_models_count': len(saved_models),
|
||||
'saved_models_total_size_mb': sum(m.get('size_mb', 0) for m in saved_models)
|
||||
}
|
||||
|
||||
return report
|
||||
|
||||
def main():
|
||||
"""Main audit function"""
|
||||
print("🔍 STREAMLINED MODEL PARAMETER AUDIT")
|
||||
print("=" * 50)
|
||||
|
||||
print("\n📊 Analyzing Enhanced CNN Model (Primary Architecture)...")
|
||||
enhanced_cnn_results = audit_enhanced_cnn()
|
||||
|
||||
print("\n🤖 Analyzing DQN Agent with Enhanced CNN...")
|
||||
dqn_results = audit_dqn_agent()
|
||||
|
||||
print("\n💾 Auditing Saved Models...")
|
||||
saved_models = audit_saved_models()
|
||||
|
||||
print("\n📋 Generating Report...")
|
||||
report = generate_report(enhanced_cnn_results, dqn_results, saved_models)
|
||||
|
||||
# Save detailed report
|
||||
with open('model_parameter_audit_report.json', 'w') as f:
|
||||
json.dump(report, f, indent=2, default=str)
|
||||
|
||||
# Print summary
|
||||
print("\n📊 STREAMLINED AUDIT SUMMARY")
|
||||
print("=" * 50)
|
||||
if report['summary']:
|
||||
summary = report['summary']
|
||||
print(f"Streamlined Model Architectures: {summary['total_model_architectures']}")
|
||||
print(f"Total Parameters: {summary['total_parameters_across_all']:,}")
|
||||
print(f"Total Memory Usage: {summary['total_size_mb']:.1f} MB")
|
||||
print(f"Largest Model: {summary['largest_model_parameters']:,} parameters")
|
||||
print(f"Smallest Model: {summary['smallest_model_parameters']:,} parameters")
|
||||
print(f"Saved Models: {summary['saved_models_count']} files")
|
||||
print(f"Saved Models Total Size: {summary['saved_models_total_size_mb']:.1f} MB")
|
||||
|
||||
print(f"\n📄 Detailed report saved to: model_parameter_audit_report.json")
|
||||
print("\n🎯 STREAMLINING COMPLETE:")
|
||||
print(" ✅ Enhanced CNN: Primary high-performance model")
|
||||
print(" ✅ DQN Agent: Now uses Enhanced CNN for better performance")
|
||||
print(" ❌ Simple models: Removed for streamlined architecture")
|
||||
|
||||
return report
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user