fix DQN RL inference, rebuild model

This commit is contained in:
Dobromir Popov
2025-07-26 23:57:03 +03:00
parent 87942d3807
commit 36a8e256a8
2 changed files with 201 additions and 98 deletions

View File

@ -21,6 +21,112 @@ from utils.training_integration import get_training_integration
# Configure logger # Configure logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DQNNetwork(nn.Module):
"""
Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
Handles 7850 input features from multi-timeframe, multi-asset data
"""
def __init__(self, input_dim: int, n_actions: int):
super(DQNNetwork, self).__init__()
# Handle different input dimension formats
if isinstance(input_dim, (tuple, list)):
if len(input_dim) == 1:
self.input_size = input_dim[0]
else:
self.input_size = np.prod(input_dim) # Flatten multi-dimensional input
else:
self.input_size = input_dim
self.n_actions = n_actions
# Deep network architecture optimized for trading features
self.network = nn.Sequential(
# Input layer
nn.Linear(self.input_size, 2048),
nn.ReLU(),
nn.Dropout(0.3),
# Hidden layers with residual-like connections
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
# Output layer for Q-values
nn.Linear(128, n_actions)
)
# Initialize weights
self._initialize_weights()
def _initialize_weights(self):
"""Initialize network weights using Xavier initialization"""
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
def forward(self, x):
"""Forward pass through the network"""
# Ensure input is properly shaped
if x.dim() > 2:
x = x.view(x.size(0), -1) # Flatten if needed
elif x.dim() == 1:
x = x.unsqueeze(0) # Add batch dimension if needed
return self.network(x)
def act(self, state, explore=True):
"""
Select action using epsilon-greedy policy
Args:
state: Current state (numpy array or tensor)
explore: Whether to use epsilon-greedy exploration
Returns:
action_idx: Selected action index
confidence: Confidence score
action_probs: Action probabilities
"""
# Convert state to tensor if needed
if isinstance(state, np.ndarray):
state = torch.FloatTensor(state).to(next(self.parameters()).device)
# Ensure proper shape
if state.dim() == 1:
state = state.unsqueeze(0)
with torch.no_grad():
q_values = self.forward(state)
# Get action probabilities using softmax
action_probs = F.softmax(q_values, dim=1)
# Select action (greedy for inference)
action_idx = torch.argmax(q_values, dim=1).item()
# Calculate confidence as max probability
confidence = float(action_probs[0, action_idx].item())
# Convert probabilities to list
probs_list = action_probs.squeeze(0).cpu().numpy().tolist()
return action_idx, confidence, probs_list
class DQNAgent: class DQNAgent:
""" """
Deep Q-Network agent for trading Deep Q-Network agent for trading
@ -80,12 +186,9 @@ class DQNAgent:
else: else:
self.device = device self.device = device
# Initialize models with Enhanced CNN architecture for better performance # Initialize models with RL-specific network architecture
from NN.models.enhanced_cnn import EnhancedCNN self.policy_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device)
self.target_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device)
# Use Enhanced CNN for both policy and target networks
self.policy_net = EnhancedCNN(self.state_dim, self.n_actions)
self.target_net = EnhancedCNN(self.state_dim, self.n_actions)
# Initialize the target network with the same weights as the policy network # Initialize the target network with the same weights as the policy network
self.target_net.load_state_dict(self.policy_net.state_dict()) self.target_net.load_state_dict(self.policy_net.state_dict())
@ -578,83 +681,45 @@ class DQNAgent:
market_context: Additional market context for decision making market_context: Additional market context for decision making
Returns: Returns:
int: Action (0=BUY, 1=SELL, 2=HOLD) or None if should hold position int: Action (0=BUY, 1=SELL)
""" """
try:
# Convert state to tensor # Use the DQNNetwork's act method for consistent behavior
if isinstance(state, np.ndarray): action_idx, confidence, action_probs = self.policy_net.act(state, explore=explore)
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
else: # Apply epsilon-greedy exploration if requested
state_tensor = state.unsqueeze(0).to(self.device) if explore and np.random.random() <= self.epsilon:
action_idx = np.random.choice(self.n_actions)
# Get Q-values
policy_output = self.policy_net(state_tensor) # Update tracking
if isinstance(policy_output, dict): if current_price:
q_values = policy_output.get('q_values', policy_output.get('Q_values', list(policy_output.values())[0])) self.recent_prices.append(current_price)
elif isinstance(policy_output, tuple):
q_values = policy_output[0] # Assume first element is Q-values self.recent_actions.append(action_idx)
else: return action_idx
q_values = policy_output
action_values = q_values.cpu().data.numpy()[0] except Exception as e:
logger.error(f"Error in act method: {e}")
# Calculate confidence scores # Return default action (HOLD/SELL)
# Ensure q_values has correct shape for softmax
if q_values.dim() == 1:
q_values = q_values.unsqueeze(0)
# FIXED ACTION MAPPING: 0=BUY, 1=SELL, 2=HOLD
buy_confidence = torch.softmax(q_values, dim=1)[0, 0].item()
sell_confidence = torch.softmax(q_values, dim=1)[0, 1].item()
# Determine action based on current position and confidence thresholds
action = self._determine_action_with_position_management(
sell_confidence, buy_confidence, current_price, market_context, explore
)
# Update tracking
if current_price:
self.recent_prices.append(current_price)
if action is not None:
self.recent_actions.append(action)
return action
else:
# Return 1 (HOLD) as a safe default if action is None
return 1 return 1
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]: def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float, List[float]]:
"""Choose action with confidence score adapted to market regime (from Enhanced DQN)""" """Choose action with confidence score adapted to market regime"""
with torch.no_grad(): try:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) # Use the DQNNetwork's act method which handles the state properly
q_values = self.policy_net(state_tensor) action_idx, base_confidence, action_probs = self.policy_net.act(state, explore=False)
# Handle case where network might return a tuple instead of tensor
if isinstance(q_values, tuple):
# If it's a tuple, take the first element (usually the main output)
q_values = q_values[0]
# Ensure q_values is a tensor and has correct shape for softmax
if not hasattr(q_values, 'dim'):
logger.error(f"DQN: q_values is not a tensor: {type(q_values)}")
# Return default action with low confidence
return 1, 0.1 # Default to HOLD action
if q_values.dim() == 1:
q_values = q_values.unsqueeze(0)
# Convert Q-values to probabilities
action_probs = torch.softmax(q_values, dim=1)
action = q_values.argmax().item()
base_confidence = action_probs[0, action].item()
# Adapt confidence based on market regime # Adapt confidence based on market regime
regime_weight = self.market_regime_weights.get(market_regime, 1.0) regime_weight = self.market_regime_weights.get(market_regime, 1.0)
adapted_confidence = min(base_confidence * regime_weight, 1.0) adapted_confidence = min(base_confidence * regime_weight, 1.0)
# Always return int, float # Return action, confidence, and probabilities (for orchestrator compatibility)
if action is None: return int(action_idx), float(adapted_confidence), action_probs
return 1, 0.1
return int(action), float(adapted_confidence) except Exception as e:
logger.error(f"Error in act_with_confidence: {e}")
# Return default action with low confidence
return 1, 0.1, [0.45, 0.55] # Default to HOLD action
def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore): def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore):
""" """

View File

@ -283,9 +283,22 @@ class TradingOrchestrator:
# Initialize DQN Agent # Initialize DQN Agent
try: try:
from NN.models.dqn_agent import DQNAgent from NN.models.dqn_agent import DQNAgent
state_size = self.config.rl.get('state_size', 13800) # Enhanced with COB features
# Determine actual state size from BaseDataInput
try:
base_data = self.data_provider.build_base_data_input(self.symbol)
if base_data:
actual_state_size = len(base_data.get_feature_vector())
logger.info(f"Detected actual state size: {actual_state_size}")
else:
actual_state_size = 7850 # Fallback based on error message
logger.warning(f"Could not determine state size, using fallback: {actual_state_size}")
except Exception as e:
actual_state_size = 7850 # Fallback based on error message
logger.warning(f"Error determining state size: {e}, using fallback: {actual_state_size}")
action_size = self.config.rl.get('action_space', 3) action_size = self.config.rl.get('action_space', 3)
self.rl_agent = DQNAgent(state_shape=state_size, n_actions=action_size) self.rl_agent = DQNAgent(state_shape=actual_state_size, n_actions=action_size)
self.rl_agent.to(self.device) # Move DQN agent to the determined device self.rl_agent.to(self.device) # Move DQN agent to the determined device
# Load best checkpoint and capture initial state (using database metadata) # Load best checkpoint and capture initial state (using database metadata)
@ -306,7 +319,10 @@ class TradingOrchestrator:
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}" loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
logger.info(f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})") logger.info(f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
except Exception as e: except Exception as e:
logger.warning(f"Error loading DQN checkpoint: {e}") logger.warning(f"Error loading DQN checkpoint (likely dimension mismatch): {e}")
logger.info("DQN will start fresh due to checkpoint incompatibility")
# Reset the agent to handle dimension mismatch
checkpoint_loaded = False
if not checkpoint_loaded: if not checkpoint_loaded:
# New model - no synthetic data, start fresh # New model - no synthetic data, start fresh
@ -1144,7 +1160,10 @@ class TradingOrchestrator:
# Collect input data for all models # Collect input data for all models
input_data = await self._collect_model_input_data(symbol) input_data = await self._collect_model_input_data(symbol)
# log all registered models
logger.debug(f"inferencing registered models: {self.model_registry.models}")
for model_name, model in self.model_registry.models.items(): for model_name, model in self.model_registry.models.items():
try: try:
prediction = None prediction = None
@ -2058,8 +2077,17 @@ class TradingOrchestrator:
) )
predictions.append(prediction) predictions.append(prediction)
# Store prediction in queue for future use # Store prediction in SQLite database for training
self.update_data_queue('model_predictions', symbol, result) logger.debug(f"Added CNN prediction to database: {prediction}")
# Store CNN prediction as inference record
await self._store_inference_data_async(
model_name="enhanced_cnn",
model_input=base_data,
prediction=prediction,
timestamp=datetime.now(),
symbol=symbol
)
except Exception as e: except Exception as e:
logger.error(f"Error using CNN adapter: {e}") logger.error(f"Error using CNN adapter: {e}")
@ -2111,6 +2139,15 @@ class TradingOrchestrator:
) )
predictions.append(pred) predictions.append(pred)
# Store CNN fallback prediction as inference record
await self._store_inference_data_async(
model_name=model.name,
model_input=base_data,
prediction=pred,
timestamp=datetime.now(),
symbol=symbol
)
# Capture for dashboard # Capture for dashboard
current_price = self._get_current_price(symbol) current_price = self._get_current_price(symbol)
if current_price is not None: if current_price is not None:
@ -2188,12 +2225,22 @@ class TradingOrchestrator:
elif raw_q_values is not None and isinstance(raw_q_values, list): elif raw_q_values is not None and isinstance(raw_q_values, list):
q_values_for_capture = raw_q_values q_values_for_capture = raw_q_values
# Create prediction object # Create prediction object with safe probability calculation
probabilities = {}
if q_values_for_capture and len(q_values_for_capture) == len(action_names):
# Use actual q_values if they match the expected length
probabilities = {action_names[i]: float(q_values_for_capture[i]) for i in range(len(action_names))}
else:
# Use default uniform probabilities if q_values are unavailable or mismatched
default_prob = 1.0 / len(action_names)
probabilities = {name: default_prob for name in action_names}
if q_values_for_capture:
logger.warning(f"Q-values length mismatch: expected {len(action_names)}, got {len(q_values_for_capture)}. Using default probabilities.")
prediction = Prediction( prediction = Prediction(
action=action, action=action,
confidence=float(confidence), confidence=float(confidence),
# Use actual q_values if available, otherwise default probabilities probabilities=probabilities,
probabilities={action_names[i]: float(q_values_for_capture[i]) if q_values_for_capture else (1.0 / len(action_names)) for i in range(len(action_names))},
timeframe='mixed', # RL uses mixed timeframes timeframe='mixed', # RL uses mixed timeframes
timestamp=datetime.now(), timestamp=datetime.now(),
model_name=model.name, model_name=model.name,
@ -2279,7 +2326,7 @@ class TradingOrchestrator:
return None return None
def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]: def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]:
"""Get current state for RL agent - ensure compatibility with saved model""" """Get current state for RL agent using unified BaseDataInput"""
try: try:
# Use unified BaseDataInput approach # Use unified BaseDataInput approach
base_data = self.build_base_data_input(symbol) base_data = self.build_base_data_input(symbol)
@ -2287,21 +2334,12 @@ class TradingOrchestrator:
logger.warning(f"Cannot build BaseDataInput for RL state: {symbol}") logger.warning(f"Cannot build BaseDataInput for RL state: {symbol}")
return None return None
# Get unified feature vector # Get unified feature vector (7850 features including all timeframes and COB data)
feature_vector = base_data.get_feature_vector() feature_vector = base_data.get_feature_vector()
# Ensure compatibility with saved model (expects 403 features) # Return the full unified feature vector for RL agent
target_size = 403 # Match the saved model's expected input size # The DQN agent is now initialized with the correct size to match this
if len(feature_vector) < target_size: return feature_vector
# Pad with zeros
padded_state = np.zeros(target_size)
padded_state[:len(feature_vector)] = feature_vector
return padded_state
elif len(feature_vector) > target_size:
# Truncate to target size
return feature_vector[:target_size]
else:
return feature_vector
except Exception as e: except Exception as e:
logger.error(f"Error creating RL state for {symbol}: {e}") logger.error(f"Error creating RL state for {symbol}: {e}")