order flow WIP, chart broken

This commit is contained in:
Dobromir Popov
2025-06-18 13:51:08 +03:00
parent 5bce17a21a
commit e238ce374b
16 changed files with 1768 additions and 1333 deletions

725
NN/models/cnn_model.py Normal file
View File

@ -0,0 +1,725 @@
#!/usr/bin/env python3
"""
Enhanced CNN Model for Trading - PyTorch Implementation
Much larger and more sophisticated architecture for better learning
"""
import os
import logging
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import torch.nn.functional as F
from typing import Dict, Any, Optional, Tuple
# Configure logging
logger = logging.getLogger(__name__)
class MultiHeadAttention(nn.Module):
"""Multi-head attention mechanism for sequence data"""
def __init__(self, d_model: int, num_heads: int = 8, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.d_k)
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = x.size()
# Compute Q, K, V
Q = self.w_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.w_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.w_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Attention weights
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# Apply attention
attention_output = torch.matmul(attention_weights, V)
attention_output = attention_output.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.d_model
)
return self.w_o(attention_output)
class ResidualBlock(nn.Module):
"""Residual block with normalization and dropout"""
def __init__(self, channels: int, dropout: float = 0.1):
super().__init__()
self.conv1 = nn.Conv1d(channels, channels, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(channels, channels, kernel_size=3, padding=1)
self.norm1 = nn.BatchNorm1d(channels)
self.norm2 = nn.BatchNorm1d(channels)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
out = F.relu(self.norm1(self.conv1(x)))
out = self.dropout(out)
out = self.norm2(self.conv2(out))
# Add residual connection (avoid in-place operation)
out = out + residual
return F.relu(out)
class SpatialAttentionBlock(nn.Module):
"""Spatial attention for feature maps"""
def __init__(self, channels: int):
super().__init__()
self.conv = nn.Conv1d(channels, 1, kernel_size=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Compute attention weights
attention = torch.sigmoid(self.conv(x))
# Avoid in-place operation by creating new tensor
return torch.mul(x, attention)
class EnhancedCNNModel(nn.Module):
"""
Much larger and more sophisticated CNN architecture for trading
Features:
- Deep convolutional layers with residual connections
- Multi-head attention mechanisms
- Spatial attention blocks
- Multiple feature extraction paths
- Large capacity for complex pattern learning
"""
def __init__(self,
input_size: int = 60,
feature_dim: int = 50,
output_size: int = 2, # BUY/SELL for 2-action system
base_channels: int = 256, # Increased from 128 to 256
num_blocks: int = 12, # Increased from 6 to 12
num_attention_heads: int = 16, # Increased from 8 to 16
dropout_rate: float = 0.2):
super().__init__()
self.input_size = input_size
self.feature_dim = feature_dim
self.output_size = output_size
self.base_channels = base_channels
# Much larger input embedding - project features to higher dimension
self.input_embedding = nn.Sequential(
nn.Linear(feature_dim, base_channels // 2),
nn.LayerNorm(base_channels // 2), # Changed from BatchNorm1d for batch_size=1 compatibility
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(base_channels // 2, base_channels),
nn.LayerNorm(base_channels), # Changed from BatchNorm1d for batch_size=1 compatibility
nn.ReLU(),
nn.Dropout(dropout_rate)
)
# Multi-scale convolutional feature extraction with more channels
self.conv_path1 = self._build_conv_path(base_channels, base_channels, 3)
self.conv_path2 = self._build_conv_path(base_channels, base_channels, 5)
self.conv_path3 = self._build_conv_path(base_channels, base_channels, 7)
self.conv_path4 = self._build_conv_path(base_channels, base_channels, 9) # Additional path
# Feature fusion with more capacity
self.feature_fusion = nn.Sequential(
nn.Conv1d(base_channels * 4, base_channels * 3, kernel_size=1), # 4 paths now
nn.BatchNorm1d(base_channels * 3),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Conv1d(base_channels * 3, base_channels * 2, kernel_size=1),
nn.BatchNorm1d(base_channels * 2),
nn.ReLU(),
nn.Dropout(dropout_rate)
)
# Much deeper residual blocks for complex pattern learning
self.residual_blocks = nn.ModuleList([
ResidualBlock(base_channels * 2, dropout_rate) for _ in range(num_blocks)
])
# More spatial attention blocks
self.spatial_attention = nn.ModuleList([
SpatialAttentionBlock(base_channels * 2) for _ in range(6) # Increased from 3 to 6
])
# Multiple temporal attention layers
self.temporal_attention1 = MultiHeadAttention(
d_model=base_channels * 2,
num_heads=num_attention_heads,
dropout=dropout_rate
)
self.temporal_attention2 = MultiHeadAttention(
d_model=base_channels * 2,
num_heads=num_attention_heads // 2,
dropout=dropout_rate
)
# Global feature aggregation
self.global_pool = nn.AdaptiveAvgPool1d(1)
self.global_max_pool = nn.AdaptiveMaxPool1d(1)
# Much larger advanced feature processing (using LayerNorm for batch_size=1 compatibility)
self.advanced_features = nn.Sequential(
nn.Linear(base_channels * 4, base_channels * 6), # Increased capacity
nn.LayerNorm(base_channels * 6), # Changed from BatchNorm1d
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(base_channels * 6, base_channels * 4),
nn.LayerNorm(base_channels * 4), # Changed from BatchNorm1d
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(base_channels * 4, base_channels * 3),
nn.LayerNorm(base_channels * 3), # Changed from BatchNorm1d
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(base_channels * 3, base_channels * 2),
nn.LayerNorm(base_channels * 2), # Changed from BatchNorm1d
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(base_channels * 2, base_channels),
nn.LayerNorm(base_channels), # Changed from BatchNorm1d
nn.ReLU(),
nn.Dropout(dropout_rate)
)
# Enhanced market regime detection branch (using LayerNorm for batch_size=1 compatibility)
self.regime_detector = nn.Sequential(
nn.Linear(base_channels, base_channels // 2),
nn.LayerNorm(base_channels // 2), # Changed from BatchNorm1d
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(base_channels // 2, base_channels // 4),
nn.LayerNorm(base_channels // 4), # Changed from BatchNorm1d
nn.ReLU(),
nn.Linear(base_channels // 4, 8), # 8 market regimes instead of 4
nn.Softmax(dim=1)
)
# Enhanced volatility prediction branch (using LayerNorm for batch_size=1 compatibility)
self.volatility_predictor = nn.Sequential(
nn.Linear(base_channels, base_channels // 2),
nn.LayerNorm(base_channels // 2), # Changed from BatchNorm1d
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(base_channels // 2, base_channels // 4),
nn.LayerNorm(base_channels // 4), # Changed from BatchNorm1d
nn.ReLU(),
nn.Linear(base_channels // 4, 1),
nn.Sigmoid()
)
# Main trading decision head (using LayerNorm for batch_size=1 compatibility)
self.decision_head = nn.Sequential(
nn.Linear(base_channels + 8 + 1, base_channels), # 8 regime classes + 1 volatility
nn.LayerNorm(base_channels), # Changed from BatchNorm1d
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(base_channels, base_channels // 2),
nn.LayerNorm(base_channels // 2), # Changed from BatchNorm1d
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(base_channels // 2, output_size)
)
# Confidence estimation head
self.confidence_head = nn.Sequential(
nn.Linear(base_channels, base_channels // 2),
nn.ReLU(),
nn.Linear(base_channels // 2, 1),
nn.Sigmoid()
)
# Initialize weights
self._initialize_weights()
def _build_conv_path(self, in_channels: int, out_channels: int, kernel_size: int) -> nn.Module:
"""Build a convolutional path with multiple layers"""
return nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size//2),
nn.BatchNorm1d(out_channels),
nn.ReLU(),
nn.Dropout(0.1),
nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2),
nn.BatchNorm1d(out_channels),
nn.ReLU(),
nn.Dropout(0.1),
nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size//2),
nn.BatchNorm1d(out_channels),
nn.ReLU()
)
def _initialize_weights(self):
"""Initialize model weights"""
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Forward pass with multiple outputs
Args:
x: Input tensor of shape [batch_size, sequence_length, features]
Returns:
Dictionary with predictions, confidence, regime, and volatility
"""
# Handle input shapes flexibly
if len(x.shape) == 2:
# Input is [seq_len, features] - add batch dimension
x = x.unsqueeze(0)
elif len(x.shape) > 3:
# Input has extra dimensions - flatten to [batch, seq, features]
x = x.view(x.shape[0], -1, x.shape[-1])
batch_size, seq_len, features = x.shape
# Reshape for processing: [batch, seq, features] -> [batch*seq, features]
x_reshaped = x.view(-1, features)
# Input embedding
embedded = self.input_embedding(x_reshaped) # [batch*seq, base_channels]
# Reshape back for conv1d: [batch*seq, channels] -> [batch, channels, seq]
embedded = embedded.view(batch_size, seq_len, -1).transpose(1, 2)
# Multi-scale feature extraction
path1 = self.conv_path1(embedded)
path2 = self.conv_path2(embedded)
path3 = self.conv_path3(embedded)
path4 = self.conv_path4(embedded)
# Feature fusion
fused_features = torch.cat([path1, path2, path3, path4], dim=1)
fused_features = self.feature_fusion(fused_features)
# Apply residual blocks with spatial attention
current_features = fused_features
for i, (res_block, attention) in enumerate(zip(self.residual_blocks, self.spatial_attention)):
current_features = res_block(current_features)
if i % 2 == 0: # Apply attention every other block
current_features = attention(current_features)
# Apply remaining residual blocks
for res_block in self.residual_blocks[len(self.spatial_attention):]:
current_features = res_block(current_features)
# Temporal attention - apply both attention layers
# Reshape for attention: [batch, channels, seq] -> [batch, seq, channels]
attention_input = current_features.transpose(1, 2)
attended_features = self.temporal_attention1(attention_input)
attended_features = self.temporal_attention2(attended_features)
# Back to conv format: [batch, seq, channels] -> [batch, channels, seq]
attended_features = attended_features.transpose(1, 2)
# Global aggregation
avg_pooled = self.global_pool(attended_features).squeeze(-1) # [batch, channels]
max_pooled = self.global_max_pool(attended_features).squeeze(-1) # [batch, channels]
# Combine global features
global_features = torch.cat([avg_pooled, max_pooled], dim=1)
# Advanced feature processing
processed_features = self.advanced_features(global_features)
# Multi-task predictions
regime_probs = self.regime_detector(processed_features)
volatility_pred = self.volatility_predictor(processed_features)
confidence = self.confidence_head(processed_features)
# Combine all features for final decision (8 regime classes + 1 volatility)
combined_features = torch.cat([processed_features, regime_probs, volatility_pred], dim=1)
trading_logits = self.decision_head(combined_features)
# Apply temperature scaling for better calibration
temperature = 1.5
trading_probs = F.softmax(trading_logits / temperature, dim=1)
return {
'logits': trading_logits,
'probabilities': trading_probs,
'confidence': confidence.squeeze(-1),
'regime': regime_probs,
'volatility': volatility_pred.squeeze(-1),
'features': processed_features
}
def predict(self, feature_matrix: np.ndarray) -> Dict[str, Any]:
"""
Make predictions on feature matrix
Args:
feature_matrix: numpy array of shape [sequence_length, features]
Returns:
Dictionary with prediction results
"""
self.eval()
with torch.no_grad():
# Convert to tensor and add batch dimension
if isinstance(feature_matrix, np.ndarray):
x = torch.FloatTensor(feature_matrix).unsqueeze(0) # Add batch dim
else:
x = feature_matrix.unsqueeze(0)
# Move to device
device = next(self.parameters()).device
x = x.to(device)
# Forward pass
outputs = self.forward(x)
# Extract results
probs = outputs['probabilities'].cpu().numpy()[0]
confidence = outputs['confidence'].cpu().numpy()[0]
regime = outputs['regime'].cpu().numpy()[0]
volatility = outputs['volatility'].cpu().numpy()[0]
# Determine action (0=BUY, 1=SELL for 2-action system)
action = int(np.argmax(probs))
action_confidence = float(probs[action])
return {
'action': action,
'action_name': 'BUY' if action == 0 else 'SELL',
'confidence': float(confidence),
'action_confidence': action_confidence,
'probabilities': probs.tolist(),
'regime_probabilities': regime.tolist(),
'volatility_prediction': float(volatility),
'raw_logits': outputs['logits'].cpu().numpy()[0].tolist()
}
def get_memory_usage(self) -> Dict[str, Any]:
"""Get model memory usage statistics"""
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
param_size = sum(p.numel() * p.element_size() for p in self.parameters())
buffer_size = sum(b.numel() * b.element_size() for b in self.buffers())
return {
'total_parameters': total_params,
'trainable_parameters': trainable_params,
'parameter_size_mb': param_size / (1024 * 1024),
'buffer_size_mb': buffer_size / (1024 * 1024),
'total_size_mb': (param_size + buffer_size) / (1024 * 1024)
}
def to_device(self, device: str):
"""Move model to specified device"""
return self.to(torch.device(device))
class CNNModelTrainer:
"""Enhanced trainer for the beefed-up CNN model"""
def __init__(self, model: EnhancedCNNModel, learning_rate: float = 0.0001, device: str = 'cuda'):
self.model = model.to(device)
self.device = device
self.learning_rate = learning_rate
# Use AdamW optimizer with weight decay
self.optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=0.01,
betas=(0.9, 0.999)
)
# Learning rate scheduler
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
self.optimizer,
max_lr=learning_rate * 10,
total_steps=10000, # Will be updated based on actual training
pct_start=0.1,
anneal_strategy='cos'
)
# Multi-task loss functions
self.main_criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
self.confidence_criterion = nn.BCELoss()
self.regime_criterion = nn.CrossEntropyLoss()
self.volatility_criterion = nn.MSELoss()
self.training_history = []
def train_step(self, x: torch.Tensor, y: torch.Tensor,
confidence_targets: Optional[torch.Tensor] = None,
regime_targets: Optional[torch.Tensor] = None,
volatility_targets: Optional[torch.Tensor] = None) -> Dict[str, float]:
"""Single training step with multi-task learning"""
self.model.train()
self.optimizer.zero_grad()
# Forward pass
outputs = self.model(x)
# Main trading loss
main_loss = self.main_criterion(outputs['logits'], y)
total_loss = main_loss
losses = {'main_loss': main_loss.item()}
# Confidence loss (if targets provided)
if confidence_targets is not None:
conf_loss = self.confidence_criterion(outputs['confidence'], confidence_targets)
total_loss += 0.1 * conf_loss
losses['confidence_loss'] = conf_loss.item()
# Regime classification loss (if targets provided)
if regime_targets is not None:
regime_loss = self.regime_criterion(outputs['regime'], regime_targets)
total_loss += 0.05 * regime_loss
losses['regime_loss'] = regime_loss.item()
# Volatility prediction loss (if targets provided)
if volatility_targets is not None:
vol_loss = self.volatility_criterion(outputs['volatility'], volatility_targets)
total_loss += 0.05 * vol_loss
losses['volatility_loss'] = vol_loss.item()
losses['total_loss'] = total_loss.item()
# Backward pass
total_loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
self.scheduler.step()
# Calculate accuracy
with torch.no_grad():
predictions = torch.argmax(outputs['probabilities'], dim=1)
accuracy = (predictions == y).float().mean().item()
losses['accuracy'] = accuracy
return losses
def save_model(self, filepath: str, metadata: Optional[Dict] = None):
"""Save model with metadata"""
save_dict = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'training_history': self.training_history,
'model_config': {
'input_size': self.model.input_size,
'feature_dim': self.model.feature_dim,
'output_size': self.model.output_size,
'base_channels': self.model.base_channels
}
}
if metadata:
save_dict['metadata'] = metadata
torch.save(save_dict, filepath)
logger.info(f"Enhanced CNN model saved to {filepath}")
def load_model(self, filepath: str) -> Dict:
"""Load model from file"""
checkpoint = torch.load(filepath, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
if 'training_history' in checkpoint:
self.training_history = checkpoint['training_history']
logger.info(f"Enhanced CNN model loaded from {filepath}")
return checkpoint.get('metadata', {})
def create_enhanced_cnn_model(input_size: int = 60,
feature_dim: int = 50,
output_size: int = 2,
base_channels: int = 256,
device: str = 'cuda') -> Tuple[EnhancedCNNModel, CNNModelTrainer]:
"""Create enhanced CNN model and trainer"""
model = EnhancedCNNModel(
input_size=input_size,
feature_dim=feature_dim,
output_size=output_size,
base_channels=base_channels,
num_blocks=12,
num_attention_heads=16,
dropout_rate=0.2
)
trainer = CNNModelTrainer(model, learning_rate=0.0001, device=device)
logger.info(f"Created enhanced CNN model with {model.get_memory_usage()['total_parameters']:,} parameters")
return model, trainer
# Compatibility wrapper for williams_market_structure.py
class CNNModel:
"""
Compatibility wrapper for the enhanced CNN model
"""
def __init__(self, input_shape=(900, 50), output_size=10, model_path=None):
self.input_shape = input_shape
self.output_size = output_size
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create the enhanced model
self.model = EnhancedCNNModel(
input_size=input_shape[0],
feature_dim=input_shape[1],
output_size=output_size
)
self.trainer = CNNModelTrainer(self.model, device=self.device)
logger.info(f"CNN Model wrapper initialized: input_shape={input_shape}, output_size={output_size}")
if model_path and os.path.exists(model_path):
self.load(model_path)
def build_model(self, **kwargs):
"""Build/configure the model"""
logger.info("CNN Model build_model called")
return self
def predict(self, X):
"""Make predictions on input data"""
try:
if isinstance(X, np.ndarray):
result = self.model.predict(X)
pred_class = np.array([result['action']])
pred_proba = np.array([result['probabilities']])
else:
# Handle tensor input
result = self.model.predict(X.cpu().numpy() if hasattr(X, 'cpu') else X)
pred_class = np.array([result['action']])
pred_proba = np.array([result['probabilities']])
logger.debug(f"CNN prediction: class={pred_class}, proba_shape={pred_proba.shape}")
return pred_class, pred_proba
except Exception as e:
logger.error(f"Error in CNN prediction: {e}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
# Return dummy prediction
pred_class = np.array([0])
pred_proba = np.array([[0.1] * self.output_size])
return pred_class, pred_proba
def fit(self, X, y, **kwargs):
"""Train the model on input data"""
try:
# Convert to tensors if needed (create new tensors to avoid in-place modifications)
if isinstance(X, np.ndarray):
X = torch.FloatTensor(X.copy()) # Use copy to avoid in-place modifications
elif isinstance(X, torch.Tensor):
X = X.clone().detach() # Clone to avoid in-place modifications
if isinstance(y, np.ndarray):
y = torch.LongTensor(y.copy()) # Use copy to avoid in-place modifications
elif isinstance(y, torch.Tensor):
y = y.clone().detach().long() # Clone to avoid in-place modifications
# Ensure proper shapes and consistent batch sizes
if len(X.shape) == 2:
X = X.unsqueeze(0) # [seq, features] -> [1, seq, features]
# Handle target tensor - ensure it matches batch size (avoid in-place operations)
if len(y.shape) == 0:
y = y.unsqueeze(0) # scalar -> [1]
elif len(y.shape) == 2 and y.shape[0] == 1:
# Already correct shape [1, num_classes] -> get class index
y = torch.argmax(y, dim=1) # [1, num_classes] -> [1]
elif len(y.shape) == 1 and len(y) > 1:
# Multi-class probabilities -> get class index, ensure batch size 1
y = torch.argmax(y).unsqueeze(0) # [num_classes] -> [1]
elif len(y.shape) == 1 and len(y) == 1:
pass # Already correct [1]
else:
# Fallback: take first element and ensure batch size 1
y = y.view(-1)[:1] # Take only first element
# Move to device (create new tensors on device, don't modify in-place)
X = X.to(self.device, non_blocking=True)
y = y.to(self.device, non_blocking=True)
# Use trainer's train_step
loss_dict = self.trainer.train_step(X, y)
logger.info(f"CNN training: X_shape={X.shape}, y_shape={y.shape}, loss={loss_dict.get('total_loss', 0):.4f}")
return self
except Exception as e:
logger.error(f"Error in CNN training: {e}")
return self
def save(self, filepath: str):
"""Save the model"""
try:
self.trainer.save_model(filepath)
logger.info(f"CNN model saved to {filepath}")
except Exception as e:
logger.error(f"Error saving CNN model: {e}")
def load(self, filepath: str):
"""Load the model"""
try:
self.trainer.load_model(filepath)
logger.info(f"CNN model loaded from {filepath}")
except Exception as e:
logger.error(f"Error loading CNN model: {e}")
def to_device(self, device):
"""Move model to device"""
self.device = device
self.model.to(device)
return self
def get_memory_usage(self):
"""Get model memory usage"""
try:
return self.model.get_memory_usage()
except Exception as e:
logger.error(f"Error getting memory usage: {e}")
return {'total_parameters': 0, 'memory_mb': 0}