cleanup, cob ladder still broken
This commit is contained in:
@ -113,6 +113,15 @@ class DQNAgent:
|
||||
# Initialize avg_reward for dashboard compatibility
|
||||
self.avg_reward = 0.0 # Average reward tracking for dashboard
|
||||
|
||||
# Market regime adaptation weights
|
||||
self.market_regime_weights = {
|
||||
'trending': 1.0,
|
||||
'sideways': 0.8,
|
||||
'volatile': 1.2,
|
||||
'bullish': 1.1,
|
||||
'bearish': 1.1
|
||||
}
|
||||
|
||||
# Load best checkpoint if available
|
||||
if self.enable_checkpoints:
|
||||
self.load_best_checkpoint()
|
||||
@ -490,7 +499,17 @@ class DQNAgent:
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
q_values = self.policy_net(state_tensor)
|
||||
|
||||
# Ensure q_values has correct shape for softmax
|
||||
# Handle case where network might return a tuple instead of tensor
|
||||
if isinstance(q_values, tuple):
|
||||
# If it's a tuple, take the first element (usually the main output)
|
||||
q_values = q_values[0]
|
||||
|
||||
# Ensure q_values is a tensor and has correct shape for softmax
|
||||
if not hasattr(q_values, 'dim'):
|
||||
logger.error(f"DQN: q_values is not a tensor: {type(q_values)}")
|
||||
# Return default action with low confidence
|
||||
return 1, 0.1 # Default to HOLD action
|
||||
|
||||
if q_values.dim() == 1:
|
||||
q_values = q_values.unsqueeze(0)
|
||||
|
||||
|
Reference in New Issue
Block a user