remove dummy data, improve training , follow architecture

This commit is contained in:
Dobromir Popov
2025-07-04 23:51:35 +03:00
parent e8b9c05148
commit ce8c00a9d1
13 changed files with 435 additions and 838 deletions

View File

@ -161,7 +161,7 @@ class OrderBookEncoder(nn.Module):
attended_features, attention_weights = self.cross_attention(combined_seq)
# Flatten attended features
attended_flat = attended_features.view(attended_features.size(0), -1) # [batch, 512]
attended_flat = attended_features.reshape(attended_features.size(0), -1) # [batch, 512]
# Combine with microstructure features
combined_features = torch.cat([attended_flat, micro_encoded], dim=1) # [batch, 640]
@ -210,8 +210,7 @@ class VolumeProfileEncoder(nn.Module):
if isinstance(volume_profile_data, list):
if not volume_profile_data:
# Return zero features if no data
batch_size = 1
return torch.zeros(batch_size, self.aggregator[-1].out_features)
return torch.zeros(1, 256, device=torch.device('cpu')) # Hardcoded output dim as per hidden_dim in class init
# Convert to tensor
features = []
@ -239,7 +238,7 @@ class VolumeProfileEncoder(nn.Module):
# Encode each level
level_features = self.level_encoder(volume_tensor.view(-1, feature_dim))
level_features = level_features.view(batch_size, num_levels, -1)
level_features = level_features.reshape(batch_size, num_levels, -1)
# Apply attention across levels
attended_levels, _ = self.level_attention(level_features)
@ -423,14 +422,14 @@ class EnhancedCNNWithOrderBook(nn.Module):
Returns:
Dictionary with Q-values, confidence, regime, and auxiliary predictions
"""
batch_size = market_data.size(0)
# Process market data
# Process market data - ensure batch dimension first
if len(market_data.shape) == 2:
market_data = market_data.unsqueeze(0)
# Reshape for convolutional processing
market_reshaped = market_data.view(batch_size, -1, market_data.size(-1))
batch_size = market_data.size(0) # Get correct batch size after shape adjustment
# Reshape for convolutional processing with safe dimensions
market_reshaped = market_data.reshape(batch_size, -1, market_data.size(-1))
market_features = self.market_encoder(market_reshaped.transpose(1, 2))
# Process order book data
@ -440,7 +439,7 @@ class EnhancedCNNWithOrderBook(nn.Module):
if volume_profile_data is not None:
volume_features = self.volume_encoder(volume_profile_data)
else:
volume_features = torch.zeros(batch_size, 256, device=self.device)
volume_features = torch.zeros(batch_size, 256, device=market_data.device)
# Fuse all features
combined_features = torch.cat([