fix DQN RL inference, rebuild model
This commit is contained in:
@ -21,6 +21,112 @@ from utils.training_integration import get_training_integration
|
|||||||
# Configure logger
|
# Configure logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class DQNNetwork(nn.Module):
|
||||||
|
"""
|
||||||
|
Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
|
||||||
|
Handles 7850 input features from multi-timeframe, multi-asset data
|
||||||
|
"""
|
||||||
|
def __init__(self, input_dim: int, n_actions: int):
|
||||||
|
super(DQNNetwork, self).__init__()
|
||||||
|
|
||||||
|
# Handle different input dimension formats
|
||||||
|
if isinstance(input_dim, (tuple, list)):
|
||||||
|
if len(input_dim) == 1:
|
||||||
|
self.input_size = input_dim[0]
|
||||||
|
else:
|
||||||
|
self.input_size = np.prod(input_dim) # Flatten multi-dimensional input
|
||||||
|
else:
|
||||||
|
self.input_size = input_dim
|
||||||
|
|
||||||
|
self.n_actions = n_actions
|
||||||
|
|
||||||
|
# Deep network architecture optimized for trading features
|
||||||
|
self.network = nn.Sequential(
|
||||||
|
# Input layer
|
||||||
|
nn.Linear(self.input_size, 2048),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
|
||||||
|
# Hidden layers with residual-like connections
|
||||||
|
nn.Linear(2048, 1024),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
|
||||||
|
nn.Linear(1024, 512),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.3),
|
||||||
|
|
||||||
|
nn.Linear(512, 256),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
|
||||||
|
nn.Linear(256, 128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
|
||||||
|
# Output layer for Q-values
|
||||||
|
nn.Linear(128, n_actions)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize weights
|
||||||
|
self._initialize_weights()
|
||||||
|
|
||||||
|
def _initialize_weights(self):
|
||||||
|
"""Initialize network weights using Xavier initialization"""
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
nn.init.xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.constant_(module.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward pass through the network"""
|
||||||
|
# Ensure input is properly shaped
|
||||||
|
if x.dim() > 2:
|
||||||
|
x = x.view(x.size(0), -1) # Flatten if needed
|
||||||
|
elif x.dim() == 1:
|
||||||
|
x = x.unsqueeze(0) # Add batch dimension if needed
|
||||||
|
|
||||||
|
return self.network(x)
|
||||||
|
|
||||||
|
def act(self, state, explore=True):
|
||||||
|
"""
|
||||||
|
Select action using epsilon-greedy policy
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current state (numpy array or tensor)
|
||||||
|
explore: Whether to use epsilon-greedy exploration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
action_idx: Selected action index
|
||||||
|
confidence: Confidence score
|
||||||
|
action_probs: Action probabilities
|
||||||
|
"""
|
||||||
|
# Convert state to tensor if needed
|
||||||
|
if isinstance(state, np.ndarray):
|
||||||
|
state = torch.FloatTensor(state).to(next(self.parameters()).device)
|
||||||
|
|
||||||
|
# Ensure proper shape
|
||||||
|
if state.dim() == 1:
|
||||||
|
state = state.unsqueeze(0)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
q_values = self.forward(state)
|
||||||
|
|
||||||
|
# Get action probabilities using softmax
|
||||||
|
action_probs = F.softmax(q_values, dim=1)
|
||||||
|
|
||||||
|
# Select action (greedy for inference)
|
||||||
|
action_idx = torch.argmax(q_values, dim=1).item()
|
||||||
|
|
||||||
|
# Calculate confidence as max probability
|
||||||
|
confidence = float(action_probs[0, action_idx].item())
|
||||||
|
|
||||||
|
# Convert probabilities to list
|
||||||
|
probs_list = action_probs.squeeze(0).cpu().numpy().tolist()
|
||||||
|
|
||||||
|
return action_idx, confidence, probs_list
|
||||||
|
|
||||||
class DQNAgent:
|
class DQNAgent:
|
||||||
"""
|
"""
|
||||||
Deep Q-Network agent for trading
|
Deep Q-Network agent for trading
|
||||||
@ -80,12 +186,9 @@ class DQNAgent:
|
|||||||
else:
|
else:
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
# Initialize models with Enhanced CNN architecture for better performance
|
# Initialize models with RL-specific network architecture
|
||||||
from NN.models.enhanced_cnn import EnhancedCNN
|
self.policy_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device)
|
||||||
|
self.target_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device)
|
||||||
# Use Enhanced CNN for both policy and target networks
|
|
||||||
self.policy_net = EnhancedCNN(self.state_dim, self.n_actions)
|
|
||||||
self.target_net = EnhancedCNN(self.state_dim, self.n_actions)
|
|
||||||
|
|
||||||
# Initialize the target network with the same weights as the policy network
|
# Initialize the target network with the same weights as the policy network
|
||||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||||
@ -578,83 +681,45 @@ class DQNAgent:
|
|||||||
market_context: Additional market context for decision making
|
market_context: Additional market context for decision making
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: Action (0=BUY, 1=SELL, 2=HOLD) or None if should hold position
|
int: Action (0=BUY, 1=SELL)
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
# Use the DQNNetwork's act method for consistent behavior
|
||||||
|
action_idx, confidence, action_probs = self.policy_net.act(state, explore=explore)
|
||||||
|
|
||||||
# Convert state to tensor
|
# Apply epsilon-greedy exploration if requested
|
||||||
if isinstance(state, np.ndarray):
|
if explore and np.random.random() <= self.epsilon:
|
||||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
action_idx = np.random.choice(self.n_actions)
|
||||||
else:
|
|
||||||
state_tensor = state.unsqueeze(0).to(self.device)
|
|
||||||
|
|
||||||
# Get Q-values
|
# Update tracking
|
||||||
policy_output = self.policy_net(state_tensor)
|
if current_price:
|
||||||
if isinstance(policy_output, dict):
|
self.recent_prices.append(current_price)
|
||||||
q_values = policy_output.get('q_values', policy_output.get('Q_values', list(policy_output.values())[0]))
|
|
||||||
elif isinstance(policy_output, tuple):
|
|
||||||
q_values = policy_output[0] # Assume first element is Q-values
|
|
||||||
else:
|
|
||||||
q_values = policy_output
|
|
||||||
action_values = q_values.cpu().data.numpy()[0]
|
|
||||||
|
|
||||||
# Calculate confidence scores
|
self.recent_actions.append(action_idx)
|
||||||
# Ensure q_values has correct shape for softmax
|
return action_idx
|
||||||
if q_values.dim() == 1:
|
|
||||||
q_values = q_values.unsqueeze(0)
|
|
||||||
|
|
||||||
# FIXED ACTION MAPPING: 0=BUY, 1=SELL, 2=HOLD
|
except Exception as e:
|
||||||
buy_confidence = torch.softmax(q_values, dim=1)[0, 0].item()
|
logger.error(f"Error in act method: {e}")
|
||||||
sell_confidence = torch.softmax(q_values, dim=1)[0, 1].item()
|
# Return default action (HOLD/SELL)
|
||||||
|
|
||||||
# Determine action based on current position and confidence thresholds
|
|
||||||
action = self._determine_action_with_position_management(
|
|
||||||
sell_confidence, buy_confidence, current_price, market_context, explore
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update tracking
|
|
||||||
if current_price:
|
|
||||||
self.recent_prices.append(current_price)
|
|
||||||
|
|
||||||
if action is not None:
|
|
||||||
self.recent_actions.append(action)
|
|
||||||
return action
|
|
||||||
else:
|
|
||||||
# Return 1 (HOLD) as a safe default if action is None
|
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]:
|
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float, List[float]]:
|
||||||
"""Choose action with confidence score adapted to market regime (from Enhanced DQN)"""
|
"""Choose action with confidence score adapted to market regime"""
|
||||||
with torch.no_grad():
|
try:
|
||||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
# Use the DQNNetwork's act method which handles the state properly
|
||||||
q_values = self.policy_net(state_tensor)
|
action_idx, base_confidence, action_probs = self.policy_net.act(state, explore=False)
|
||||||
|
|
||||||
# Handle case where network might return a tuple instead of tensor
|
|
||||||
if isinstance(q_values, tuple):
|
|
||||||
# If it's a tuple, take the first element (usually the main output)
|
|
||||||
q_values = q_values[0]
|
|
||||||
|
|
||||||
# Ensure q_values is a tensor and has correct shape for softmax
|
|
||||||
if not hasattr(q_values, 'dim'):
|
|
||||||
logger.error(f"DQN: q_values is not a tensor: {type(q_values)}")
|
|
||||||
# Return default action with low confidence
|
|
||||||
return 1, 0.1 # Default to HOLD action
|
|
||||||
|
|
||||||
if q_values.dim() == 1:
|
|
||||||
q_values = q_values.unsqueeze(0)
|
|
||||||
|
|
||||||
# Convert Q-values to probabilities
|
|
||||||
action_probs = torch.softmax(q_values, dim=1)
|
|
||||||
action = q_values.argmax().item()
|
|
||||||
base_confidence = action_probs[0, action].item()
|
|
||||||
|
|
||||||
# Adapt confidence based on market regime
|
# Adapt confidence based on market regime
|
||||||
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
||||||
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
||||||
|
|
||||||
# Always return int, float
|
# Return action, confidence, and probabilities (for orchestrator compatibility)
|
||||||
if action is None:
|
return int(action_idx), float(adapted_confidence), action_probs
|
||||||
return 1, 0.1
|
|
||||||
return int(action), float(adapted_confidence)
|
except Exception as e:
|
||||||
|
logger.error(f"Error in act_with_confidence: {e}")
|
||||||
|
# Return default action with low confidence
|
||||||
|
return 1, 0.1, [0.45, 0.55] # Default to HOLD action
|
||||||
|
|
||||||
def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore):
|
def _determine_action_with_position_management(self, sell_conf, buy_conf, current_price, market_context, explore):
|
||||||
"""
|
"""
|
||||||
|
@ -283,9 +283,22 @@ class TradingOrchestrator:
|
|||||||
# Initialize DQN Agent
|
# Initialize DQN Agent
|
||||||
try:
|
try:
|
||||||
from NN.models.dqn_agent import DQNAgent
|
from NN.models.dqn_agent import DQNAgent
|
||||||
state_size = self.config.rl.get('state_size', 13800) # Enhanced with COB features
|
|
||||||
|
# Determine actual state size from BaseDataInput
|
||||||
|
try:
|
||||||
|
base_data = self.data_provider.build_base_data_input(self.symbol)
|
||||||
|
if base_data:
|
||||||
|
actual_state_size = len(base_data.get_feature_vector())
|
||||||
|
logger.info(f"Detected actual state size: {actual_state_size}")
|
||||||
|
else:
|
||||||
|
actual_state_size = 7850 # Fallback based on error message
|
||||||
|
logger.warning(f"Could not determine state size, using fallback: {actual_state_size}")
|
||||||
|
except Exception as e:
|
||||||
|
actual_state_size = 7850 # Fallback based on error message
|
||||||
|
logger.warning(f"Error determining state size: {e}, using fallback: {actual_state_size}")
|
||||||
|
|
||||||
action_size = self.config.rl.get('action_space', 3)
|
action_size = self.config.rl.get('action_space', 3)
|
||||||
self.rl_agent = DQNAgent(state_shape=state_size, n_actions=action_size)
|
self.rl_agent = DQNAgent(state_shape=actual_state_size, n_actions=action_size)
|
||||||
self.rl_agent.to(self.device) # Move DQN agent to the determined device
|
self.rl_agent.to(self.device) # Move DQN agent to the determined device
|
||||||
|
|
||||||
# Load best checkpoint and capture initial state (using database metadata)
|
# Load best checkpoint and capture initial state (using database metadata)
|
||||||
@ -306,7 +319,10 @@ class TradingOrchestrator:
|
|||||||
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
||||||
logger.info(f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
|
logger.info(f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error loading DQN checkpoint: {e}")
|
logger.warning(f"Error loading DQN checkpoint (likely dimension mismatch): {e}")
|
||||||
|
logger.info("DQN will start fresh due to checkpoint incompatibility")
|
||||||
|
# Reset the agent to handle dimension mismatch
|
||||||
|
checkpoint_loaded = False
|
||||||
|
|
||||||
if not checkpoint_loaded:
|
if not checkpoint_loaded:
|
||||||
# New model - no synthetic data, start fresh
|
# New model - no synthetic data, start fresh
|
||||||
@ -1145,6 +1161,9 @@ class TradingOrchestrator:
|
|||||||
# Collect input data for all models
|
# Collect input data for all models
|
||||||
input_data = await self._collect_model_input_data(symbol)
|
input_data = await self._collect_model_input_data(symbol)
|
||||||
|
|
||||||
|
# log all registered models
|
||||||
|
logger.debug(f"inferencing registered models: {self.model_registry.models}")
|
||||||
|
|
||||||
for model_name, model in self.model_registry.models.items():
|
for model_name, model in self.model_registry.models.items():
|
||||||
try:
|
try:
|
||||||
prediction = None
|
prediction = None
|
||||||
@ -2058,8 +2077,17 @@ class TradingOrchestrator:
|
|||||||
)
|
)
|
||||||
predictions.append(prediction)
|
predictions.append(prediction)
|
||||||
|
|
||||||
# Store prediction in queue for future use
|
# Store prediction in SQLite database for training
|
||||||
self.update_data_queue('model_predictions', symbol, result)
|
logger.debug(f"Added CNN prediction to database: {prediction}")
|
||||||
|
|
||||||
|
# Store CNN prediction as inference record
|
||||||
|
await self._store_inference_data_async(
|
||||||
|
model_name="enhanced_cnn",
|
||||||
|
model_input=base_data,
|
||||||
|
prediction=prediction,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
symbol=symbol
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error using CNN adapter: {e}")
|
logger.error(f"Error using CNN adapter: {e}")
|
||||||
@ -2111,6 +2139,15 @@ class TradingOrchestrator:
|
|||||||
)
|
)
|
||||||
predictions.append(pred)
|
predictions.append(pred)
|
||||||
|
|
||||||
|
# Store CNN fallback prediction as inference record
|
||||||
|
await self._store_inference_data_async(
|
||||||
|
model_name=model.name,
|
||||||
|
model_input=base_data,
|
||||||
|
prediction=pred,
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
symbol=symbol
|
||||||
|
)
|
||||||
|
|
||||||
# Capture for dashboard
|
# Capture for dashboard
|
||||||
current_price = self._get_current_price(symbol)
|
current_price = self._get_current_price(symbol)
|
||||||
if current_price is not None:
|
if current_price is not None:
|
||||||
@ -2188,12 +2225,22 @@ class TradingOrchestrator:
|
|||||||
elif raw_q_values is not None and isinstance(raw_q_values, list):
|
elif raw_q_values is not None and isinstance(raw_q_values, list):
|
||||||
q_values_for_capture = raw_q_values
|
q_values_for_capture = raw_q_values
|
||||||
|
|
||||||
# Create prediction object
|
# Create prediction object with safe probability calculation
|
||||||
|
probabilities = {}
|
||||||
|
if q_values_for_capture and len(q_values_for_capture) == len(action_names):
|
||||||
|
# Use actual q_values if they match the expected length
|
||||||
|
probabilities = {action_names[i]: float(q_values_for_capture[i]) for i in range(len(action_names))}
|
||||||
|
else:
|
||||||
|
# Use default uniform probabilities if q_values are unavailable or mismatched
|
||||||
|
default_prob = 1.0 / len(action_names)
|
||||||
|
probabilities = {name: default_prob for name in action_names}
|
||||||
|
if q_values_for_capture:
|
||||||
|
logger.warning(f"Q-values length mismatch: expected {len(action_names)}, got {len(q_values_for_capture)}. Using default probabilities.")
|
||||||
|
|
||||||
prediction = Prediction(
|
prediction = Prediction(
|
||||||
action=action,
|
action=action,
|
||||||
confidence=float(confidence),
|
confidence=float(confidence),
|
||||||
# Use actual q_values if available, otherwise default probabilities
|
probabilities=probabilities,
|
||||||
probabilities={action_names[i]: float(q_values_for_capture[i]) if q_values_for_capture else (1.0 / len(action_names)) for i in range(len(action_names))},
|
|
||||||
timeframe='mixed', # RL uses mixed timeframes
|
timeframe='mixed', # RL uses mixed timeframes
|
||||||
timestamp=datetime.now(),
|
timestamp=datetime.now(),
|
||||||
model_name=model.name,
|
model_name=model.name,
|
||||||
@ -2279,7 +2326,7 @@ class TradingOrchestrator:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]:
|
def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]:
|
||||||
"""Get current state for RL agent - ensure compatibility with saved model"""
|
"""Get current state for RL agent using unified BaseDataInput"""
|
||||||
try:
|
try:
|
||||||
# Use unified BaseDataInput approach
|
# Use unified BaseDataInput approach
|
||||||
base_data = self.build_base_data_input(symbol)
|
base_data = self.build_base_data_input(symbol)
|
||||||
@ -2287,21 +2334,12 @@ class TradingOrchestrator:
|
|||||||
logger.warning(f"Cannot build BaseDataInput for RL state: {symbol}")
|
logger.warning(f"Cannot build BaseDataInput for RL state: {symbol}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Get unified feature vector
|
# Get unified feature vector (7850 features including all timeframes and COB data)
|
||||||
feature_vector = base_data.get_feature_vector()
|
feature_vector = base_data.get_feature_vector()
|
||||||
|
|
||||||
# Ensure compatibility with saved model (expects 403 features)
|
# Return the full unified feature vector for RL agent
|
||||||
target_size = 403 # Match the saved model's expected input size
|
# The DQN agent is now initialized with the correct size to match this
|
||||||
if len(feature_vector) < target_size:
|
return feature_vector
|
||||||
# Pad with zeros
|
|
||||||
padded_state = np.zeros(target_size)
|
|
||||||
padded_state[:len(feature_vector)] = feature_vector
|
|
||||||
return padded_state
|
|
||||||
elif len(feature_vector) > target_size:
|
|
||||||
# Truncate to target size
|
|
||||||
return feature_vector[:target_size]
|
|
||||||
else:
|
|
||||||
return feature_vector
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating RL state for {symbol}: {e}")
|
logger.error(f"Error creating RL state for {symbol}: {e}")
|
||||||
|
Reference in New Issue
Block a user