Optional numeric return head (predicts percent change for 1s,1m,1h,1d)

This commit is contained in:
Dobromir Popov
2025-08-23 15:17:04 +03:00
parent 9992b226ea
commit 81749ee18e
8 changed files with 124 additions and 30 deletions

View File

@@ -66,6 +66,15 @@ class StandardizedCNN(nn.Module):
# Output processing layers
self.output_processor = self._build_output_processor()
# Optional numeric return head (predicts percent change for 1s,1m,1h,1d)
# Uses cnn_features (1024) to regress predicted returns per timeframe
self.return_head = nn.Sequential(
nn.Linear(1024, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, 4) # [return_1s, return_1m, return_1h, return_1d]
)
# Device management
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.to(self.device)
@@ -175,6 +184,9 @@ class StandardizedCNN(nn.Module):
# Process outputs for standardized format
action_probs = self.output_processor(cnn_features) # [batch, 3]
# Predict numeric returns per timeframe from cnn_features
predicted_returns = self.return_head(cnn_features) # [batch, 4]
# Prepare hidden states for cross-model feeding
hidden_states = {
'processed_features': processed_features.detach(),
@@ -186,7 +198,7 @@ class StandardizedCNN(nn.Module):
'attention_weights': torch.ones(batch_size, 1, device=x.device) # Placeholder
}
return action_probs, hidden_states
return action_probs, hidden_states, predicted_returns.detach()
def predict_from_base_input(self, base_input: BaseDataInput) -> ModelOutput:
"""
@@ -210,7 +222,7 @@ class StandardizedCNN(nn.Module):
with torch.no_grad():
# Forward pass
action_probs, hidden_states = self.forward(input_tensor)
action_probs, hidden_states, predicted_returns = self.forward(input_tensor)
# Get action and confidence
action_probs_np = action_probs.squeeze(0).cpu().numpy()
@@ -233,6 +245,19 @@ class StandardizedCNN(nn.Module):
'market_conditions': self._interpret_advanced_predictions(hidden_states.get('advanced_predictions'))
}
# Add numeric predicted returns per timeframe if available
try:
pr = predicted_returns.squeeze(0).cpu().numpy().tolist()
# Ensure length 4; if not, safely handle
if isinstance(pr, list) and len(pr) >= 4:
predictions['predicted_returns'] = pr[:4]
predictions['predicted_return_1s'] = float(pr[0])
predictions['predicted_return_1m'] = float(pr[1])
predictions['predicted_return_1h'] = float(pr[2])
predictions['predicted_return_1d'] = float(pr[3])
except Exception:
pass
# Prepare hidden states for cross-model feeding (convert tensors to numpy)
cross_model_states = {}
for key, tensor in hidden_states.items():