fix netwrk rebuild

This commit is contained in:
Dobromir Popov
2025-07-25 23:59:51 +03:00
parent 43ed694917
commit a3828c708c
3 changed files with 264 additions and 35 deletions

View File

@ -376,20 +376,12 @@ class EnhancedCNN(nn.Module):
return tensor.detach().clone().requires_grad_(tensor.requires_grad)
def _check_rebuild_network(self, features):
"""Check if network needs to be rebuilt for different feature dimensions"""
# Prevent rebuilding with zero or invalid dimensions
if features <= 0:
logger.error(f"Invalid feature dimension: {features}. Cannot rebuild network with zero or negative dimensions.")
logger.error(f"Current feature_dim: {self.feature_dim}. Keeping existing network.")
return False
"""DEPRECATED: Network should have fixed architecture - no runtime rebuilding"""
if features != self.feature_dim:
logger.info(f"Rebuilding network for new feature dimension: {features} (was {self.feature_dim})")
self.feature_dim = features
self._build_network()
# Move to device after rebuilding
self.to(self.device)
return True
logger.error(f"CRITICAL: Input feature dimension mismatch! Expected {self.feature_dim}, got {features}")
logger.error("This indicates a bug in data preprocessing - input should be fixed size!")
logger.error("Network architecture should NOT change at runtime!")
raise ValueError(f"Input dimension mismatch: expected {self.feature_dim}, got {features}")
return False
def forward(self, x):
@ -429,10 +421,11 @@ class EnhancedCNN(nn.Module):
# Now x is 3D: [batch, timeframes, features]
x_reshaped = x
# Check if the feature dimension has changed and rebuild if necessary
if x_reshaped.size(1) * x_reshaped.size(2) != self.feature_dim:
total_features = x_reshaped.size(1) * x_reshaped.size(2)
self._check_rebuild_network(total_features)
# Validate input dimensions (should be fixed)
total_features = x_reshaped.size(1) * x_reshaped.size(2)
if total_features != self.feature_dim:
logger.error(f"Input dimension mismatch: expected {self.feature_dim}, got {total_features}")
raise ValueError(f"Input dimension mismatch: expected {self.feature_dim}, got {total_features}")
# Apply ultra massive convolutions
x_conv = self.conv_layers(x_reshaped)
@ -445,9 +438,10 @@ class EnhancedCNN(nn.Module):
# For 2D input [batch, features]
x_flat = x
# Check if dimensions have changed
# Validate input dimensions (should be fixed)
if x_flat.size(1) != self.feature_dim:
self._check_rebuild_network(x_flat.size(1))
logger.error(f"Input dimension mismatch: expected {self.feature_dim}, got {x_flat.size(1)}")
raise ValueError(f"Input dimension mismatch: expected {self.feature_dim}, got {x_flat.size(1)}")
# Apply ULTRA MASSIVE FC layers to get base features
features = self.fc_layers(x_flat) # [batch, 1024]