training wip
This commit is contained in:
@ -200,7 +200,11 @@ class DQNNetwork(nn.Module):
|
||||
"""
|
||||
# Convert state to tensor if needed
|
||||
if isinstance(state, np.ndarray):
|
||||
state = torch.FloatTensor(state).to(next(self.parameters()).device)
|
||||
state = torch.FloatTensor(state)
|
||||
|
||||
# Move to device
|
||||
device = next(self.parameters()).device
|
||||
state = state.to(device)
|
||||
|
||||
# Ensure proper shape
|
||||
if state.dim() == 1:
|
||||
@ -209,9 +213,8 @@ class DQNNetwork(nn.Module):
|
||||
with torch.no_grad():
|
||||
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.forward(state)
|
||||
|
||||
# Process price direction predictions
|
||||
if price_direction_pred is not None:
|
||||
self.process_price_direction_predictions(price_direction_pred)
|
||||
# Price direction predictions are processed in the agent's act method
|
||||
# This is just the network forward pass
|
||||
|
||||
# Get action probabilities using softmax
|
||||
action_probs = F.softmax(q_values, dim=1)
|
||||
@ -234,7 +237,7 @@ class DQNAgent:
|
||||
"""
|
||||
def __init__(self,
|
||||
state_shape: Tuple[int, ...],
|
||||
n_actions: int = 2,
|
||||
n_actions: int = 3, # BUY=0, SELL=1, HOLD=2
|
||||
learning_rate: float = 0.001,
|
||||
epsilon: float = 1.0,
|
||||
epsilon_min: float = 0.01,
|
||||
@ -761,6 +764,13 @@ class DQNAgent:
|
||||
# Use the DQNNetwork's act method for consistent behavior
|
||||
action_idx, confidence, action_probs = self.policy_net.act(state, explore=explore)
|
||||
|
||||
# Process price direction predictions from the network
|
||||
# Get the raw predictions from the network's forward pass
|
||||
with torch.no_grad():
|
||||
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state)
|
||||
if price_direction_pred is not None:
|
||||
self.process_price_direction_predictions(price_direction_pred)
|
||||
|
||||
# Apply epsilon-greedy exploration if requested
|
||||
if explore and np.random.random() <= self.epsilon:
|
||||
action_idx = np.random.choice(self.n_actions)
|
||||
@ -780,15 +790,44 @@ class DQNAgent:
|
||||
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float, List[float]]:
|
||||
"""Choose action with confidence score adapted to market regime"""
|
||||
try:
|
||||
# Use the DQNNetwork's act method which handles the state properly
|
||||
action_idx, base_confidence, action_probs = self.policy_net.act(state, explore=False)
|
||||
# Convert state to tensor if needed
|
||||
if isinstance(state, np.ndarray):
|
||||
state_tensor = torch.FloatTensor(state)
|
||||
device = next(self.policy_net.parameters()).device
|
||||
state_tensor = state_tensor.to(device)
|
||||
|
||||
# Ensure proper shape
|
||||
if state_tensor.dim() == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0)
|
||||
else:
|
||||
state_tensor = state
|
||||
|
||||
# Get network outputs
|
||||
with torch.no_grad():
|
||||
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state_tensor)
|
||||
|
||||
# Process price direction predictions
|
||||
if price_direction_pred is not None:
|
||||
self.process_price_direction_predictions(price_direction_pred)
|
||||
|
||||
# Get action probabilities using softmax
|
||||
action_probs = F.softmax(q_values, dim=1)
|
||||
|
||||
# Select action (greedy for inference)
|
||||
action_idx = torch.argmax(q_values, dim=1).item()
|
||||
|
||||
# Calculate confidence as max probability
|
||||
base_confidence = float(action_probs[0, action_idx].item())
|
||||
|
||||
# Adapt confidence based on market regime
|
||||
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
||||
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
||||
|
||||
# Convert probabilities to list
|
||||
probs_list = action_probs.squeeze(0).cpu().numpy().tolist()
|
||||
|
||||
# Return action, confidence, and probabilities (for orchestrator compatibility)
|
||||
return int(action_idx), float(adapted_confidence), action_probs
|
||||
return int(action_idx), float(adapted_confidence), probs_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in act_with_confidence: {e}")
|
||||
|
@ -80,6 +80,9 @@ class EnhancedCNN(nn.Module):
|
||||
self.n_actions = n_actions
|
||||
self.confidence_threshold = confidence_threshold
|
||||
|
||||
# Training data storage
|
||||
self.training_data = []
|
||||
|
||||
# Calculate input dimensions
|
||||
if isinstance(input_shape, (list, tuple)):
|
||||
if len(input_shape) == 3: # [channels, height, width]
|
||||
@ -648,6 +651,30 @@ class EnhancedCNN(nn.Module):
|
||||
'strength': 0.0,
|
||||
'weighted_strength': 0.0
|
||||
}
|
||||
|
||||
def add_training_data(self, state, action, reward):
|
||||
"""
|
||||
Add training data to the model's training buffer
|
||||
|
||||
Args:
|
||||
state: Input state
|
||||
action: Action taken
|
||||
reward: Reward received
|
||||
"""
|
||||
try:
|
||||
self.training_data.append({
|
||||
'state': state,
|
||||
'action': action,
|
||||
'reward': reward,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
|
||||
# Keep only the last 1000 training samples
|
||||
if len(self.training_data) > 1000:
|
||||
self.training_data = self.training_data[-1000:]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding training data: {e}")
|
||||
|
||||
def save(self, path):
|
||||
"""Save model weights and architecture"""
|
||||
|
Reference in New Issue
Block a user