training wip
This commit is contained in:
@ -200,7 +200,11 @@ class DQNNetwork(nn.Module):
|
||||
"""
|
||||
# Convert state to tensor if needed
|
||||
if isinstance(state, np.ndarray):
|
||||
state = torch.FloatTensor(state).to(next(self.parameters()).device)
|
||||
state = torch.FloatTensor(state)
|
||||
|
||||
# Move to device
|
||||
device = next(self.parameters()).device
|
||||
state = state.to(device)
|
||||
|
||||
# Ensure proper shape
|
||||
if state.dim() == 1:
|
||||
@ -209,9 +213,8 @@ class DQNNetwork(nn.Module):
|
||||
with torch.no_grad():
|
||||
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.forward(state)
|
||||
|
||||
# Process price direction predictions
|
||||
if price_direction_pred is not None:
|
||||
self.process_price_direction_predictions(price_direction_pred)
|
||||
# Price direction predictions are processed in the agent's act method
|
||||
# This is just the network forward pass
|
||||
|
||||
# Get action probabilities using softmax
|
||||
action_probs = F.softmax(q_values, dim=1)
|
||||
@ -234,7 +237,7 @@ class DQNAgent:
|
||||
"""
|
||||
def __init__(self,
|
||||
state_shape: Tuple[int, ...],
|
||||
n_actions: int = 2,
|
||||
n_actions: int = 3, # BUY=0, SELL=1, HOLD=2
|
||||
learning_rate: float = 0.001,
|
||||
epsilon: float = 1.0,
|
||||
epsilon_min: float = 0.01,
|
||||
@ -761,6 +764,13 @@ class DQNAgent:
|
||||
# Use the DQNNetwork's act method for consistent behavior
|
||||
action_idx, confidence, action_probs = self.policy_net.act(state, explore=explore)
|
||||
|
||||
# Process price direction predictions from the network
|
||||
# Get the raw predictions from the network's forward pass
|
||||
with torch.no_grad():
|
||||
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state)
|
||||
if price_direction_pred is not None:
|
||||
self.process_price_direction_predictions(price_direction_pred)
|
||||
|
||||
# Apply epsilon-greedy exploration if requested
|
||||
if explore and np.random.random() <= self.epsilon:
|
||||
action_idx = np.random.choice(self.n_actions)
|
||||
@ -780,15 +790,44 @@ class DQNAgent:
|
||||
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float, List[float]]:
|
||||
"""Choose action with confidence score adapted to market regime"""
|
||||
try:
|
||||
# Use the DQNNetwork's act method which handles the state properly
|
||||
action_idx, base_confidence, action_probs = self.policy_net.act(state, explore=False)
|
||||
# Convert state to tensor if needed
|
||||
if isinstance(state, np.ndarray):
|
||||
state_tensor = torch.FloatTensor(state)
|
||||
device = next(self.policy_net.parameters()).device
|
||||
state_tensor = state_tensor.to(device)
|
||||
|
||||
# Ensure proper shape
|
||||
if state_tensor.dim() == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0)
|
||||
else:
|
||||
state_tensor = state
|
||||
|
||||
# Get network outputs
|
||||
with torch.no_grad():
|
||||
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state_tensor)
|
||||
|
||||
# Process price direction predictions
|
||||
if price_direction_pred is not None:
|
||||
self.process_price_direction_predictions(price_direction_pred)
|
||||
|
||||
# Get action probabilities using softmax
|
||||
action_probs = F.softmax(q_values, dim=1)
|
||||
|
||||
# Select action (greedy for inference)
|
||||
action_idx = torch.argmax(q_values, dim=1).item()
|
||||
|
||||
# Calculate confidence as max probability
|
||||
base_confidence = float(action_probs[0, action_idx].item())
|
||||
|
||||
# Adapt confidence based on market regime
|
||||
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
||||
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
||||
|
||||
# Convert probabilities to list
|
||||
probs_list = action_probs.squeeze(0).cpu().numpy().tolist()
|
||||
|
||||
# Return action, confidence, and probabilities (for orchestrator compatibility)
|
||||
return int(action_idx), float(adapted_confidence), action_probs
|
||||
return int(action_idx), float(adapted_confidence), probs_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in act_with_confidence: {e}")
|
||||
|
Reference in New Issue
Block a user