remove dummy data, improve training , follow architecture
This commit is contained in:
@ -451,7 +451,13 @@ class DQNAgent:
|
||||
state_tensor = state.unsqueeze(0).to(self.device)
|
||||
|
||||
# Get Q-values
|
||||
q_values = self.policy_net(state_tensor)
|
||||
policy_output = self.policy_net(state_tensor)
|
||||
if isinstance(policy_output, dict):
|
||||
q_values = policy_output.get('q_values', policy_output.get('Q_values', list(policy_output.values())[0]))
|
||||
elif isinstance(policy_output, tuple):
|
||||
q_values = policy_output[0] # Assume first element is Q-values
|
||||
else:
|
||||
q_values = policy_output
|
||||
action_values = q_values.cpu().data.numpy()[0]
|
||||
|
||||
# Calculate confidence scores
|
||||
|
@ -161,7 +161,7 @@ class OrderBookEncoder(nn.Module):
|
||||
attended_features, attention_weights = self.cross_attention(combined_seq)
|
||||
|
||||
# Flatten attended features
|
||||
attended_flat = attended_features.view(attended_features.size(0), -1) # [batch, 512]
|
||||
attended_flat = attended_features.reshape(attended_features.size(0), -1) # [batch, 512]
|
||||
|
||||
# Combine with microstructure features
|
||||
combined_features = torch.cat([attended_flat, micro_encoded], dim=1) # [batch, 640]
|
||||
@ -210,8 +210,7 @@ class VolumeProfileEncoder(nn.Module):
|
||||
if isinstance(volume_profile_data, list):
|
||||
if not volume_profile_data:
|
||||
# Return zero features if no data
|
||||
batch_size = 1
|
||||
return torch.zeros(batch_size, self.aggregator[-1].out_features)
|
||||
return torch.zeros(1, 256, device=torch.device('cpu')) # Hardcoded output dim as per hidden_dim in class init
|
||||
|
||||
# Convert to tensor
|
||||
features = []
|
||||
@ -239,7 +238,7 @@ class VolumeProfileEncoder(nn.Module):
|
||||
|
||||
# Encode each level
|
||||
level_features = self.level_encoder(volume_tensor.view(-1, feature_dim))
|
||||
level_features = level_features.view(batch_size, num_levels, -1)
|
||||
level_features = level_features.reshape(batch_size, num_levels, -1)
|
||||
|
||||
# Apply attention across levels
|
||||
attended_levels, _ = self.level_attention(level_features)
|
||||
@ -423,14 +422,14 @@ class EnhancedCNNWithOrderBook(nn.Module):
|
||||
Returns:
|
||||
Dictionary with Q-values, confidence, regime, and auxiliary predictions
|
||||
"""
|
||||
batch_size = market_data.size(0)
|
||||
|
||||
# Process market data
|
||||
# Process market data - ensure batch dimension first
|
||||
if len(market_data.shape) == 2:
|
||||
market_data = market_data.unsqueeze(0)
|
||||
|
||||
# Reshape for convolutional processing
|
||||
market_reshaped = market_data.view(batch_size, -1, market_data.size(-1))
|
||||
batch_size = market_data.size(0) # Get correct batch size after shape adjustment
|
||||
|
||||
# Reshape for convolutional processing with safe dimensions
|
||||
market_reshaped = market_data.reshape(batch_size, -1, market_data.size(-1))
|
||||
market_features = self.market_encoder(market_reshaped.transpose(1, 2))
|
||||
|
||||
# Process order book data
|
||||
@ -440,7 +439,7 @@ class EnhancedCNNWithOrderBook(nn.Module):
|
||||
if volume_profile_data is not None:
|
||||
volume_features = self.volume_encoder(volume_profile_data)
|
||||
else:
|
||||
volume_features = torch.zeros(batch_size, 256, device=self.device)
|
||||
volume_features = torch.zeros(batch_size, 256, device=market_data.device)
|
||||
|
||||
# Fuse all features
|
||||
combined_features = torch.cat([
|
||||
|
@ -81,15 +81,15 @@
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250704_082452",
|
||||
"checkpoint_id": "decision_20250704_214714",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_082452.pt",
|
||||
"created_at": "2025-07-04T08:24:52.949705",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_214714.pt",
|
||||
"created_at": "2025-07-04T21:47:14.427187",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 102.79965677530546,
|
||||
"performance_score": 102.79966325731509,
|
||||
"accuracy": null,
|
||||
"loss": 3.432258725613987e-06,
|
||||
"loss": 3.3674381887394134e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
|
Reference in New Issue
Block a user