training wip

This commit is contained in:
Dobromir Popov
2025-07-27 23:45:57 +03:00
parent 39267697f3
commit b4076241c9
4 changed files with 283 additions and 66 deletions

View File

@ -200,7 +200,11 @@ class DQNNetwork(nn.Module):
"""
# Convert state to tensor if needed
if isinstance(state, np.ndarray):
state = torch.FloatTensor(state).to(next(self.parameters()).device)
state = torch.FloatTensor(state)
# Move to device
device = next(self.parameters()).device
state = state.to(device)
# Ensure proper shape
if state.dim() == 1:
@ -209,9 +213,8 @@ class DQNNetwork(nn.Module):
with torch.no_grad():
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.forward(state)
# Process price direction predictions
if price_direction_pred is not None:
self.process_price_direction_predictions(price_direction_pred)
# Price direction predictions are processed in the agent's act method
# This is just the network forward pass
# Get action probabilities using softmax
action_probs = F.softmax(q_values, dim=1)
@ -234,7 +237,7 @@ class DQNAgent:
"""
def __init__(self,
state_shape: Tuple[int, ...],
n_actions: int = 2,
n_actions: int = 3, # BUY=0, SELL=1, HOLD=2
learning_rate: float = 0.001,
epsilon: float = 1.0,
epsilon_min: float = 0.01,
@ -761,6 +764,13 @@ class DQNAgent:
# Use the DQNNetwork's act method for consistent behavior
action_idx, confidence, action_probs = self.policy_net.act(state, explore=explore)
# Process price direction predictions from the network
# Get the raw predictions from the network's forward pass
with torch.no_grad():
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state)
if price_direction_pred is not None:
self.process_price_direction_predictions(price_direction_pred)
# Apply epsilon-greedy exploration if requested
if explore and np.random.random() <= self.epsilon:
action_idx = np.random.choice(self.n_actions)
@ -780,15 +790,44 @@ class DQNAgent:
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float, List[float]]:
"""Choose action with confidence score adapted to market regime"""
try:
# Use the DQNNetwork's act method which handles the state properly
action_idx, base_confidence, action_probs = self.policy_net.act(state, explore=False)
# Convert state to tensor if needed
if isinstance(state, np.ndarray):
state_tensor = torch.FloatTensor(state)
device = next(self.policy_net.parameters()).device
state_tensor = state_tensor.to(device)
# Ensure proper shape
if state_tensor.dim() == 1:
state_tensor = state_tensor.unsqueeze(0)
else:
state_tensor = state
# Get network outputs
with torch.no_grad():
q_values, regime_pred, price_direction_pred, volatility_pred, features = self.policy_net.forward(state_tensor)
# Process price direction predictions
if price_direction_pred is not None:
self.process_price_direction_predictions(price_direction_pred)
# Get action probabilities using softmax
action_probs = F.softmax(q_values, dim=1)
# Select action (greedy for inference)
action_idx = torch.argmax(q_values, dim=1).item()
# Calculate confidence as max probability
base_confidence = float(action_probs[0, action_idx].item())
# Adapt confidence based on market regime
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
adapted_confidence = min(base_confidence * regime_weight, 1.0)
# Convert probabilities to list
probs_list = action_probs.squeeze(0).cpu().numpy().tolist()
# Return action, confidence, and probabilities (for orchestrator compatibility)
return int(action_idx), float(adapted_confidence), action_probs
return int(action_idx), float(adapted_confidence), probs_list
except Exception as e:
logger.error(f"Error in act_with_confidence: {e}")

View File

@ -80,6 +80,9 @@ class EnhancedCNN(nn.Module):
self.n_actions = n_actions
self.confidence_threshold = confidence_threshold
# Training data storage
self.training_data = []
# Calculate input dimensions
if isinstance(input_shape, (list, tuple)):
if len(input_shape) == 3: # [channels, height, width]
@ -648,6 +651,30 @@ class EnhancedCNN(nn.Module):
'strength': 0.0,
'weighted_strength': 0.0
}
def add_training_data(self, state, action, reward):
"""
Add training data to the model's training buffer
Args:
state: Input state
action: Action taken
reward: Reward received
"""
try:
self.training_data.append({
'state': state,
'action': action,
'reward': reward,
'timestamp': time.time()
})
# Keep only the last 1000 training samples
if len(self.training_data) > 1000:
self.training_data = self.training_data[-1000:]
except Exception as e:
logger.error(f"Error adding training data: {e}")
def save(self, path):
"""Save model weights and architecture"""