remove dummy data, improve training , follow architecture
This commit is contained in:
@ -161,7 +161,7 @@ class OrderBookEncoder(nn.Module):
|
||||
attended_features, attention_weights = self.cross_attention(combined_seq)
|
||||
|
||||
# Flatten attended features
|
||||
attended_flat = attended_features.view(attended_features.size(0), -1) # [batch, 512]
|
||||
attended_flat = attended_features.reshape(attended_features.size(0), -1) # [batch, 512]
|
||||
|
||||
# Combine with microstructure features
|
||||
combined_features = torch.cat([attended_flat, micro_encoded], dim=1) # [batch, 640]
|
||||
@ -210,8 +210,7 @@ class VolumeProfileEncoder(nn.Module):
|
||||
if isinstance(volume_profile_data, list):
|
||||
if not volume_profile_data:
|
||||
# Return zero features if no data
|
||||
batch_size = 1
|
||||
return torch.zeros(batch_size, self.aggregator[-1].out_features)
|
||||
return torch.zeros(1, 256, device=torch.device('cpu')) # Hardcoded output dim as per hidden_dim in class init
|
||||
|
||||
# Convert to tensor
|
||||
features = []
|
||||
@ -239,7 +238,7 @@ class VolumeProfileEncoder(nn.Module):
|
||||
|
||||
# Encode each level
|
||||
level_features = self.level_encoder(volume_tensor.view(-1, feature_dim))
|
||||
level_features = level_features.view(batch_size, num_levels, -1)
|
||||
level_features = level_features.reshape(batch_size, num_levels, -1)
|
||||
|
||||
# Apply attention across levels
|
||||
attended_levels, _ = self.level_attention(level_features)
|
||||
@ -423,14 +422,14 @@ class EnhancedCNNWithOrderBook(nn.Module):
|
||||
Returns:
|
||||
Dictionary with Q-values, confidence, regime, and auxiliary predictions
|
||||
"""
|
||||
batch_size = market_data.size(0)
|
||||
|
||||
# Process market data
|
||||
# Process market data - ensure batch dimension first
|
||||
if len(market_data.shape) == 2:
|
||||
market_data = market_data.unsqueeze(0)
|
||||
|
||||
# Reshape for convolutional processing
|
||||
market_reshaped = market_data.view(batch_size, -1, market_data.size(-1))
|
||||
batch_size = market_data.size(0) # Get correct batch size after shape adjustment
|
||||
|
||||
# Reshape for convolutional processing with safe dimensions
|
||||
market_reshaped = market_data.reshape(batch_size, -1, market_data.size(-1))
|
||||
market_features = self.market_encoder(market_reshaped.transpose(1, 2))
|
||||
|
||||
# Process order book data
|
||||
@ -440,7 +439,7 @@ class EnhancedCNNWithOrderBook(nn.Module):
|
||||
if volume_profile_data is not None:
|
||||
volume_features = self.volume_encoder(volume_profile_data)
|
||||
else:
|
||||
volume_features = torch.zeros(batch_size, 256, device=self.device)
|
||||
volume_features = torch.zeros(batch_size, 256, device=market_data.device)
|
||||
|
||||
# Fuse all features
|
||||
combined_features = torch.cat([
|
||||
|
Reference in New Issue
Block a user