order flow WIP, chart broken
This commit is contained in:
@ -80,8 +80,8 @@ class ResidualBlock(nn.Module):
|
||||
out = self.dropout(out)
|
||||
out = self.norm2(self.conv2(out))
|
||||
|
||||
# Add residual connection
|
||||
out += residual
|
||||
# Add residual connection (avoid in-place operation)
|
||||
out = out + residual
|
||||
return F.relu(out)
|
||||
|
||||
class SpatialAttentionBlock(nn.Module):
|
||||
@ -94,7 +94,8 @@ class SpatialAttentionBlock(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Compute attention weights
|
||||
attention = torch.sigmoid(self.conv(x))
|
||||
return x * attention
|
||||
# Avoid in-place operation by creating new tensor
|
||||
return torch.mul(x, attention)
|
||||
|
||||
class EnhancedCNNModel(nn.Module):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user