fix netwrk rebuild
This commit is contained in:
@ -376,20 +376,12 @@ class EnhancedCNN(nn.Module):
|
||||
return tensor.detach().clone().requires_grad_(tensor.requires_grad)
|
||||
|
||||
def _check_rebuild_network(self, features):
|
||||
"""Check if network needs to be rebuilt for different feature dimensions"""
|
||||
# Prevent rebuilding with zero or invalid dimensions
|
||||
if features <= 0:
|
||||
logger.error(f"Invalid feature dimension: {features}. Cannot rebuild network with zero or negative dimensions.")
|
||||
logger.error(f"Current feature_dim: {self.feature_dim}. Keeping existing network.")
|
||||
return False
|
||||
|
||||
"""DEPRECATED: Network should have fixed architecture - no runtime rebuilding"""
|
||||
if features != self.feature_dim:
|
||||
logger.info(f"Rebuilding network for new feature dimension: {features} (was {self.feature_dim})")
|
||||
self.feature_dim = features
|
||||
self._build_network()
|
||||
# Move to device after rebuilding
|
||||
self.to(self.device)
|
||||
return True
|
||||
logger.error(f"CRITICAL: Input feature dimension mismatch! Expected {self.feature_dim}, got {features}")
|
||||
logger.error("This indicates a bug in data preprocessing - input should be fixed size!")
|
||||
logger.error("Network architecture should NOT change at runtime!")
|
||||
raise ValueError(f"Input dimension mismatch: expected {self.feature_dim}, got {features}")
|
||||
return False
|
||||
|
||||
def forward(self, x):
|
||||
@ -429,10 +421,11 @@ class EnhancedCNN(nn.Module):
|
||||
# Now x is 3D: [batch, timeframes, features]
|
||||
x_reshaped = x
|
||||
|
||||
# Check if the feature dimension has changed and rebuild if necessary
|
||||
if x_reshaped.size(1) * x_reshaped.size(2) != self.feature_dim:
|
||||
total_features = x_reshaped.size(1) * x_reshaped.size(2)
|
||||
self._check_rebuild_network(total_features)
|
||||
# Validate input dimensions (should be fixed)
|
||||
total_features = x_reshaped.size(1) * x_reshaped.size(2)
|
||||
if total_features != self.feature_dim:
|
||||
logger.error(f"Input dimension mismatch: expected {self.feature_dim}, got {total_features}")
|
||||
raise ValueError(f"Input dimension mismatch: expected {self.feature_dim}, got {total_features}")
|
||||
|
||||
# Apply ultra massive convolutions
|
||||
x_conv = self.conv_layers(x_reshaped)
|
||||
@ -445,9 +438,10 @@ class EnhancedCNN(nn.Module):
|
||||
# For 2D input [batch, features]
|
||||
x_flat = x
|
||||
|
||||
# Check if dimensions have changed
|
||||
# Validate input dimensions (should be fixed)
|
||||
if x_flat.size(1) != self.feature_dim:
|
||||
self._check_rebuild_network(x_flat.size(1))
|
||||
logger.error(f"Input dimension mismatch: expected {self.feature_dim}, got {x_flat.size(1)}")
|
||||
raise ValueError(f"Input dimension mismatch: expected {self.feature_dim}, got {x_flat.size(1)}")
|
||||
|
||||
# Apply ULTRA MASSIVE FC layers to get base features
|
||||
features = self.fc_layers(x_flat) # [batch, 1024]
|
||||
|
Reference in New Issue
Block a user