fix DQN RL inference, rebuild model
This commit is contained in:
@ -283,9 +283,22 @@ class TradingOrchestrator:
|
||||
# Initialize DQN Agent
|
||||
try:
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
state_size = self.config.rl.get('state_size', 13800) # Enhanced with COB features
|
||||
|
||||
# Determine actual state size from BaseDataInput
|
||||
try:
|
||||
base_data = self.data_provider.build_base_data_input(self.symbol)
|
||||
if base_data:
|
||||
actual_state_size = len(base_data.get_feature_vector())
|
||||
logger.info(f"Detected actual state size: {actual_state_size}")
|
||||
else:
|
||||
actual_state_size = 7850 # Fallback based on error message
|
||||
logger.warning(f"Could not determine state size, using fallback: {actual_state_size}")
|
||||
except Exception as e:
|
||||
actual_state_size = 7850 # Fallback based on error message
|
||||
logger.warning(f"Error determining state size: {e}, using fallback: {actual_state_size}")
|
||||
|
||||
action_size = self.config.rl.get('action_space', 3)
|
||||
self.rl_agent = DQNAgent(state_shape=state_size, n_actions=action_size)
|
||||
self.rl_agent = DQNAgent(state_shape=actual_state_size, n_actions=action_size)
|
||||
self.rl_agent.to(self.device) # Move DQN agent to the determined device
|
||||
|
||||
# Load best checkpoint and capture initial state (using database metadata)
|
||||
@ -306,7 +319,10 @@ class TradingOrchestrator:
|
||||
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
||||
logger.info(f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading DQN checkpoint: {e}")
|
||||
logger.warning(f"Error loading DQN checkpoint (likely dimension mismatch): {e}")
|
||||
logger.info("DQN will start fresh due to checkpoint incompatibility")
|
||||
# Reset the agent to handle dimension mismatch
|
||||
checkpoint_loaded = False
|
||||
|
||||
if not checkpoint_loaded:
|
||||
# New model - no synthetic data, start fresh
|
||||
@ -1144,7 +1160,10 @@ class TradingOrchestrator:
|
||||
|
||||
# Collect input data for all models
|
||||
input_data = await self._collect_model_input_data(symbol)
|
||||
|
||||
|
||||
# log all registered models
|
||||
logger.debug(f"inferencing registered models: {self.model_registry.models}")
|
||||
|
||||
for model_name, model in self.model_registry.models.items():
|
||||
try:
|
||||
prediction = None
|
||||
@ -2058,8 +2077,17 @@ class TradingOrchestrator:
|
||||
)
|
||||
predictions.append(prediction)
|
||||
|
||||
# Store prediction in queue for future use
|
||||
self.update_data_queue('model_predictions', symbol, result)
|
||||
# Store prediction in SQLite database for training
|
||||
logger.debug(f"Added CNN prediction to database: {prediction}")
|
||||
|
||||
# Store CNN prediction as inference record
|
||||
await self._store_inference_data_async(
|
||||
model_name="enhanced_cnn",
|
||||
model_input=base_data,
|
||||
prediction=prediction,
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error using CNN adapter: {e}")
|
||||
@ -2111,6 +2139,15 @@ class TradingOrchestrator:
|
||||
)
|
||||
predictions.append(pred)
|
||||
|
||||
# Store CNN fallback prediction as inference record
|
||||
await self._store_inference_data_async(
|
||||
model_name=model.name,
|
||||
model_input=base_data,
|
||||
prediction=pred,
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol
|
||||
)
|
||||
|
||||
# Capture for dashboard
|
||||
current_price = self._get_current_price(symbol)
|
||||
if current_price is not None:
|
||||
@ -2188,12 +2225,22 @@ class TradingOrchestrator:
|
||||
elif raw_q_values is not None and isinstance(raw_q_values, list):
|
||||
q_values_for_capture = raw_q_values
|
||||
|
||||
# Create prediction object
|
||||
# Create prediction object with safe probability calculation
|
||||
probabilities = {}
|
||||
if q_values_for_capture and len(q_values_for_capture) == len(action_names):
|
||||
# Use actual q_values if they match the expected length
|
||||
probabilities = {action_names[i]: float(q_values_for_capture[i]) for i in range(len(action_names))}
|
||||
else:
|
||||
# Use default uniform probabilities if q_values are unavailable or mismatched
|
||||
default_prob = 1.0 / len(action_names)
|
||||
probabilities = {name: default_prob for name in action_names}
|
||||
if q_values_for_capture:
|
||||
logger.warning(f"Q-values length mismatch: expected {len(action_names)}, got {len(q_values_for_capture)}. Using default probabilities.")
|
||||
|
||||
prediction = Prediction(
|
||||
action=action,
|
||||
confidence=float(confidence),
|
||||
# Use actual q_values if available, otherwise default probabilities
|
||||
probabilities={action_names[i]: float(q_values_for_capture[i]) if q_values_for_capture else (1.0 / len(action_names)) for i in range(len(action_names))},
|
||||
probabilities=probabilities,
|
||||
timeframe='mixed', # RL uses mixed timeframes
|
||||
timestamp=datetime.now(),
|
||||
model_name=model.name,
|
||||
@ -2279,7 +2326,7 @@ class TradingOrchestrator:
|
||||
return None
|
||||
|
||||
def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Get current state for RL agent - ensure compatibility with saved model"""
|
||||
"""Get current state for RL agent using unified BaseDataInput"""
|
||||
try:
|
||||
# Use unified BaseDataInput approach
|
||||
base_data = self.build_base_data_input(symbol)
|
||||
@ -2287,21 +2334,12 @@ class TradingOrchestrator:
|
||||
logger.warning(f"Cannot build BaseDataInput for RL state: {symbol}")
|
||||
return None
|
||||
|
||||
# Get unified feature vector
|
||||
# Get unified feature vector (7850 features including all timeframes and COB data)
|
||||
feature_vector = base_data.get_feature_vector()
|
||||
|
||||
# Ensure compatibility with saved model (expects 403 features)
|
||||
target_size = 403 # Match the saved model's expected input size
|
||||
if len(feature_vector) < target_size:
|
||||
# Pad with zeros
|
||||
padded_state = np.zeros(target_size)
|
||||
padded_state[:len(feature_vector)] = feature_vector
|
||||
return padded_state
|
||||
elif len(feature_vector) > target_size:
|
||||
# Truncate to target size
|
||||
return feature_vector[:target_size]
|
||||
else:
|
||||
return feature_vector
|
||||
# Return the full unified feature vector for RL agent
|
||||
# The DQN agent is now initialized with the correct size to match this
|
||||
return feature_vector
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating RL state for {symbol}: {e}")
|
||||
|
Reference in New Issue
Block a user