fix DQN RL inference, rebuild model
This commit is contained in:
@ -21,6 +21,112 @@ from utils.training_integration import get_training_integration
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DQNNetwork(nn.Module):
|
||||
"""
|
||||
Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
|
||||
Handles 7850 input features from multi-timeframe, multi-asset data
|
||||
"""
|
||||
def __init__(self, input_dim: int, n_actions: int):
|
||||
super(DQNNetwork, self).__init__()
|
||||
|
||||
# Handle different input dimension formats
|
||||
if isinstance(input_dim, (tuple, list)):
|
||||
if len(input_dim) == 1:
|
||||
self.input_size = input_dim[0]
|
||||
else:
|
||||
self.input_size = np.prod(input_dim) # Flatten multi-dimensional input
|
||||
else:
|
||||
self.input_size = input_dim
|
||||
|
||||
self.n_actions = n_actions
|
||||
|
||||
# Deep network architecture optimized for trading features
|
||||
self.network = nn.Sequential(
|
||||
# Input layer
|
||||
nn.Linear(self.input_size, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
|
||||
# Hidden layers with residual-like connections
|
||||
nn.Linear(2048, 1024),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Output layer for Q-values
|
||||
nn.Linear(128, n_actions)
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""Initialize network weights using Xavier initialization"""
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass through the network"""
|
||||
# Ensure input is properly shaped
|
||||
if x.dim() > 2:
|
||||
x = x.view(x.size(0), -1) # Flatten if needed
|
||||
elif x.dim() == 1:
|
||||
x = x.unsqueeze(0) # Add batch dimension if needed
|
||||
|
||||
return self.network(x)
|
||||
|
||||
def act(self, state, explore=True):
|
||||
"""
|
||||
Select action using epsilon-greedy policy
|
||||
|
||||
Args:
|
||||
state: Current state (numpy array or tensor)
|
||||
explore: Whether to use epsilon-greedy exploration
|
||||
|
||||
Returns:
|
||||
action_idx: Selected action index
|
||||
confidence: Confidence score
|
||||
action_probs: Action probabilities
|
||||
"""
|
||||
# Convert state to tensor if needed
|
||||
if isinstance(state, np.ndarray):
|
||||
state = torch.FloatTensor(state).to(next(self.parameters()).device)
|
||||
|
||||
# Ensure proper shape
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
q_values = self.forward(state)
|
||||
|
||||
# Get action probabilities using softmax
|
||||
action_probs = F.softmax(q_values, dim=1)
|
||||
|
||||
# Select action (greedy for inference)
|
||||
action_idx = torch.argmax(q_values, dim=1).item()
|
||||
|
||||
# Calculate confidence as max probability
|
||||
confidence = float(action_probs[0, action_idx].item())
|
||||
|
||||
# Convert probabilities to list
|
||||
probs_list = action_probs.squeeze(0).cpu().numpy().tolist()
|
||||
|
||||
return action_idx, confidence, probs_list
|
||||
|
||||
class DQNAgent:
|
||||
"""
|
||||
Deep Q-Network agent for trading
|
||||
@ -80,12 +186,9 @@ class DQNAgent:
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
# Initialize models with Enhanced CNN architecture for better performance
|
||||
from NN.models.enhanced_cnn import EnhancedCNN
|
||||
|
||||
# Use Enhanced CNN for both policy and target networks
|
||||
self.policy_net = EnhancedCNN(self.state_dim, self.n_actions)
|
||||
self.target_net = EnhancedCNN(self.state_dim, self.n_actions)
|
||||
# Initialize models with RL-specific network architecture
|
||||
self.policy_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device)
|
||||
self.target_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device)
|
||||
|
||||
# Initialize the target network with the same weights as the policy network
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
@ -578,83 +681,45 @@ class DQNAgent:
|
||||
market_context: Additional market context for decision making
|
||||
|
||||
Returns:
|
||||
int: Action (0=BUY, 1=SELL, 2=HOLD) or None if should hold position
|
||||
int: Action (0=BUY, 1=SELL)
|
||||
"""
|
||||
try:
|
||||
# Use the DQNNetwork's act method for consistent behavior
|
||||
action_idx, confidence, action_probs = self.policy_net.act(state, explore=explore)
|
||||
|
||||
# Convert state to tensor
|
||||
if isinstance(state, np.ndarray):
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
else:
|
||||
state_tensor = state.unsqueeze(0).to(self.device)
|
||||
|
||||
# Get Q-values
|
||||
policy_output = self.policy_net(state_tensor)
|
||||
if isinstance(policy_output, dict):
|
||||
q_values = policy_output.get('q_values', policy_output.get('Q_values', list(policy_output.values())[0]))
|
||||
elif isinstance(policy_output, tuple):
|
||||
q_values = policy_output[0] # Assume first element is Q-values
|
||||
else:
|
||||
q_values = policy_output
|
||||
action_values = q_values.cpu().data.numpy()[0]
|
||||
|
||||
# Calculate confidence scores
|
||||
# Ensure q_values has correct shape for softmax
|
||||
if q_values.dim() == 1:
|
||||
q_values = q_values.unsqueeze(0)
|
||||
|
||||
# FIXED ACTION MAPPING: 0=BUY, 1=SELL, 2=HOLD
|
||||
buy_confidence = torch.softmax(q_values, dim=1)[0, 0].item()
|
||||
sell_confidence = torch.softmax(q_values, dim=1)[0, 1].item()
|
||||
|
||||
# Determine action based on current position and confidence thresholds
|
||||
action = self._determine_action_with_position_management(
|
||||
sell_confidence, buy_confidence, current_price, market_context, explore
|
||||
)
|
||||
# Apply epsilon-greedy exploration if requested
|
||||
if explore and np.random.random() <= self.epsilon:
|
||||
action_idx = np.random.choice(self.n_actions)
|
||||
|
||||
# Update tracking
|
||||
if current_price:
|
||||
self.recent_prices.append(current_price)
|
||||
|
||||
if action is not None:
|
||||
self.recent_actions.append(action)
|
||||
return action
|
||||
else:
|
||||
# Return 1 (HOLD) as a safe default if action is None
|
||||
self.recent_actions.append(action_idx)
|
||||
return action_idx
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in act method: {e}")
|
||||
# Return default action (HOLD/SELL)
|
||||
return 1
|
||||
|
||||
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]:
|
||||
"""Choose action with confidence score adapted to market regime (from Enhanced DQN)"""
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
q_values = self.policy_net(state_tensor)
|
||||
|
||||
# Handle case where network might return a tuple instead of tensor
|
||||
if isinstance(q_values, tuple):
|
||||
# If it's a tuple, take the first element (usually the main output)
|
||||
q_values = q_values[0]
|
||||
|
||||
# Ensure q_values is a tensor and has correct shape for softmax
|
||||
if not hasattr(q_values, 'dim'):
|
||||
logger.error(f"DQN: q_values is not a tensor: {type(q_values)}")
|
||||
# Return default action with low confidence
|
||||
return 1, 0.1 # Default to HOLD action
|
||||
|
||||
if q_values.dim() == 1:
|
||||
q_values = q_values.unsqueeze(0)
|
||||
|
||||
# Convert Q-values to probabilities
|
||||
action_probs = torch.softmax(q_values, dim=1)
|
||||
action = q_values.argmax().item()
|
||||
base_confidence = action_probs[0, action].item()
|
||||
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float, List[float]]:
|
||||
"""Choose action with confidence score adapted to market regime"""
|
||||
try:
|
||||
# Use the DQNNetwork's act method which handles the state properly
|
||||
action_idx, base_confidence, action_probs = self.policy_net.act(state, explore=False)
|
||||
|
||||
# Adapt confidence based on market regime
|
||||
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
||||
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
||||
|
||||
# Always return int, float
|
||||
if action is None:
|
||||
return 1, 0.1
|
||||
return int(action), float(adapted_confidence)
|
||||
# Return action, confidence, and probabilities (for orchestrator compatibility)
|
||||
return int(action_idx), float(adapted_confidence), action_probs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in act_with_confidence: {e}")
|
||||
# Return default action with low confidence
|
||||
return 1, 0.1, [0.45, 0.55] # Default to HOLD action
|
||||
|
||||
def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore):
|
||||
"""
|
||||
|
@ -283,9 +283,22 @@ class TradingOrchestrator:
|
||||
# Initialize DQN Agent
|
||||
try:
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
state_size = self.config.rl.get('state_size', 13800) # Enhanced with COB features
|
||||
|
||||
# Determine actual state size from BaseDataInput
|
||||
try:
|
||||
base_data = self.data_provider.build_base_data_input(self.symbol)
|
||||
if base_data:
|
||||
actual_state_size = len(base_data.get_feature_vector())
|
||||
logger.info(f"Detected actual state size: {actual_state_size}")
|
||||
else:
|
||||
actual_state_size = 7850 # Fallback based on error message
|
||||
logger.warning(f"Could not determine state size, using fallback: {actual_state_size}")
|
||||
except Exception as e:
|
||||
actual_state_size = 7850 # Fallback based on error message
|
||||
logger.warning(f"Error determining state size: {e}, using fallback: {actual_state_size}")
|
||||
|
||||
action_size = self.config.rl.get('action_space', 3)
|
||||
self.rl_agent = DQNAgent(state_shape=state_size, n_actions=action_size)
|
||||
self.rl_agent = DQNAgent(state_shape=actual_state_size, n_actions=action_size)
|
||||
self.rl_agent.to(self.device) # Move DQN agent to the determined device
|
||||
|
||||
# Load best checkpoint and capture initial state (using database metadata)
|
||||
@ -306,7 +319,10 @@ class TradingOrchestrator:
|
||||
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
||||
logger.info(f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading DQN checkpoint: {e}")
|
||||
logger.warning(f"Error loading DQN checkpoint (likely dimension mismatch): {e}")
|
||||
logger.info("DQN will start fresh due to checkpoint incompatibility")
|
||||
# Reset the agent to handle dimension mismatch
|
||||
checkpoint_loaded = False
|
||||
|
||||
if not checkpoint_loaded:
|
||||
# New model - no synthetic data, start fresh
|
||||
@ -1145,6 +1161,9 @@ class TradingOrchestrator:
|
||||
# Collect input data for all models
|
||||
input_data = await self._collect_model_input_data(symbol)
|
||||
|
||||
# log all registered models
|
||||
logger.debug(f"inferencing registered models: {self.model_registry.models}")
|
||||
|
||||
for model_name, model in self.model_registry.models.items():
|
||||
try:
|
||||
prediction = None
|
||||
@ -2058,8 +2077,17 @@ class TradingOrchestrator:
|
||||
)
|
||||
predictions.append(prediction)
|
||||
|
||||
# Store prediction in queue for future use
|
||||
self.update_data_queue('model_predictions', symbol, result)
|
||||
# Store prediction in SQLite database for training
|
||||
logger.debug(f"Added CNN prediction to database: {prediction}")
|
||||
|
||||
# Store CNN prediction as inference record
|
||||
await self._store_inference_data_async(
|
||||
model_name="enhanced_cnn",
|
||||
model_input=base_data,
|
||||
prediction=prediction,
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error using CNN adapter: {e}")
|
||||
@ -2111,6 +2139,15 @@ class TradingOrchestrator:
|
||||
)
|
||||
predictions.append(pred)
|
||||
|
||||
# Store CNN fallback prediction as inference record
|
||||
await self._store_inference_data_async(
|
||||
model_name=model.name,
|
||||
model_input=base_data,
|
||||
prediction=pred,
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol
|
||||
)
|
||||
|
||||
# Capture for dashboard
|
||||
current_price = self._get_current_price(symbol)
|
||||
if current_price is not None:
|
||||
@ -2188,12 +2225,22 @@ class TradingOrchestrator:
|
||||
elif raw_q_values is not None and isinstance(raw_q_values, list):
|
||||
q_values_for_capture = raw_q_values
|
||||
|
||||
# Create prediction object
|
||||
# Create prediction object with safe probability calculation
|
||||
probabilities = {}
|
||||
if q_values_for_capture and len(q_values_for_capture) == len(action_names):
|
||||
# Use actual q_values if they match the expected length
|
||||
probabilities = {action_names[i]: float(q_values_for_capture[i]) for i in range(len(action_names))}
|
||||
else:
|
||||
# Use default uniform probabilities if q_values are unavailable or mismatched
|
||||
default_prob = 1.0 / len(action_names)
|
||||
probabilities = {name: default_prob for name in action_names}
|
||||
if q_values_for_capture:
|
||||
logger.warning(f"Q-values length mismatch: expected {len(action_names)}, got {len(q_values_for_capture)}. Using default probabilities.")
|
||||
|
||||
prediction = Prediction(
|
||||
action=action,
|
||||
confidence=float(confidence),
|
||||
# Use actual q_values if available, otherwise default probabilities
|
||||
probabilities={action_names[i]: float(q_values_for_capture[i]) if q_values_for_capture else (1.0 / len(action_names)) for i in range(len(action_names))},
|
||||
probabilities=probabilities,
|
||||
timeframe='mixed', # RL uses mixed timeframes
|
||||
timestamp=datetime.now(),
|
||||
model_name=model.name,
|
||||
@ -2279,7 +2326,7 @@ class TradingOrchestrator:
|
||||
return None
|
||||
|
||||
def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Get current state for RL agent - ensure compatibility with saved model"""
|
||||
"""Get current state for RL agent using unified BaseDataInput"""
|
||||
try:
|
||||
# Use unified BaseDataInput approach
|
||||
base_data = self.build_base_data_input(symbol)
|
||||
@ -2287,20 +2334,11 @@ class TradingOrchestrator:
|
||||
logger.warning(f"Cannot build BaseDataInput for RL state: {symbol}")
|
||||
return None
|
||||
|
||||
# Get unified feature vector
|
||||
# Get unified feature vector (7850 features including all timeframes and COB data)
|
||||
feature_vector = base_data.get_feature_vector()
|
||||
|
||||
# Ensure compatibility with saved model (expects 403 features)
|
||||
target_size = 403 # Match the saved model's expected input size
|
||||
if len(feature_vector) < target_size:
|
||||
# Pad with zeros
|
||||
padded_state = np.zeros(target_size)
|
||||
padded_state[:len(feature_vector)] = feature_vector
|
||||
return padded_state
|
||||
elif len(feature_vector) > target_size:
|
||||
# Truncate to target size
|
||||
return feature_vector[:target_size]
|
||||
else:
|
||||
# Return the full unified feature vector for RL agent
|
||||
# The DQN agent is now initialized with the correct size to match this
|
||||
return feature_vector
|
||||
|
||||
except Exception as e:
|
||||
|
Reference in New Issue
Block a user