fix DQN RL inference, rebuild model

This commit is contained in:
Dobromir Popov
2025-07-26 23:57:03 +03:00
parent 87942d3807
commit 36a8e256a8
2 changed files with 201 additions and 98 deletions

View File

@ -283,9 +283,22 @@ class TradingOrchestrator:
# Initialize DQN Agent
try:
from NN.models.dqn_agent import DQNAgent
state_size = self.config.rl.get('state_size', 13800) # Enhanced with COB features
# Determine actual state size from BaseDataInput
try:
base_data = self.data_provider.build_base_data_input(self.symbol)
if base_data:
actual_state_size = len(base_data.get_feature_vector())
logger.info(f"Detected actual state size: {actual_state_size}")
else:
actual_state_size = 7850 # Fallback based on error message
logger.warning(f"Could not determine state size, using fallback: {actual_state_size}")
except Exception as e:
actual_state_size = 7850 # Fallback based on error message
logger.warning(f"Error determining state size: {e}, using fallback: {actual_state_size}")
action_size = self.config.rl.get('action_space', 3)
self.rl_agent = DQNAgent(state_shape=state_size, n_actions=action_size)
self.rl_agent = DQNAgent(state_shape=actual_state_size, n_actions=action_size)
self.rl_agent.to(self.device) # Move DQN agent to the determined device
# Load best checkpoint and capture initial state (using database metadata)
@ -306,7 +319,10 @@ class TradingOrchestrator:
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
logger.info(f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
except Exception as e:
logger.warning(f"Error loading DQN checkpoint: {e}")
logger.warning(f"Error loading DQN checkpoint (likely dimension mismatch): {e}")
logger.info("DQN will start fresh due to checkpoint incompatibility")
# Reset the agent to handle dimension mismatch
checkpoint_loaded = False
if not checkpoint_loaded:
# New model - no synthetic data, start fresh
@ -1144,7 +1160,10 @@ class TradingOrchestrator:
# Collect input data for all models
input_data = await self._collect_model_input_data(symbol)
# log all registered models
logger.debug(f"inferencing registered models: {self.model_registry.models}")
for model_name, model in self.model_registry.models.items():
try:
prediction = None
@ -2058,8 +2077,17 @@ class TradingOrchestrator:
)
predictions.append(prediction)
# Store prediction in queue for future use
self.update_data_queue('model_predictions', symbol, result)
# Store prediction in SQLite database for training
logger.debug(f"Added CNN prediction to database: {prediction}")
# Store CNN prediction as inference record
await self._store_inference_data_async(
model_name="enhanced_cnn",
model_input=base_data,
prediction=prediction,
timestamp=datetime.now(),
symbol=symbol
)
except Exception as e:
logger.error(f"Error using CNN adapter: {e}")
@ -2111,6 +2139,15 @@ class TradingOrchestrator:
)
predictions.append(pred)
# Store CNN fallback prediction as inference record
await self._store_inference_data_async(
model_name=model.name,
model_input=base_data,
prediction=pred,
timestamp=datetime.now(),
symbol=symbol
)
# Capture for dashboard
current_price = self._get_current_price(symbol)
if current_price is not None:
@ -2188,12 +2225,22 @@ class TradingOrchestrator:
elif raw_q_values is not None and isinstance(raw_q_values, list):
q_values_for_capture = raw_q_values
# Create prediction object
# Create prediction object with safe probability calculation
probabilities = {}
if q_values_for_capture and len(q_values_for_capture) == len(action_names):
# Use actual q_values if they match the expected length
probabilities = {action_names[i]: float(q_values_for_capture[i]) for i in range(len(action_names))}
else:
# Use default uniform probabilities if q_values are unavailable or mismatched
default_prob = 1.0 / len(action_names)
probabilities = {name: default_prob for name in action_names}
if q_values_for_capture:
logger.warning(f"Q-values length mismatch: expected {len(action_names)}, got {len(q_values_for_capture)}. Using default probabilities.")
prediction = Prediction(
action=action,
confidence=float(confidence),
# Use actual q_values if available, otherwise default probabilities
probabilities={action_names[i]: float(q_values_for_capture[i]) if q_values_for_capture else (1.0 / len(action_names)) for i in range(len(action_names))},
probabilities=probabilities,
timeframe='mixed', # RL uses mixed timeframes
timestamp=datetime.now(),
model_name=model.name,
@ -2279,7 +2326,7 @@ class TradingOrchestrator:
return None
def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]:
"""Get current state for RL agent - ensure compatibility with saved model"""
"""Get current state for RL agent using unified BaseDataInput"""
try:
# Use unified BaseDataInput approach
base_data = self.build_base_data_input(symbol)
@ -2287,21 +2334,12 @@ class TradingOrchestrator:
logger.warning(f"Cannot build BaseDataInput for RL state: {symbol}")
return None
# Get unified feature vector
# Get unified feature vector (7850 features including all timeframes and COB data)
feature_vector = base_data.get_feature_vector()
# Ensure compatibility with saved model (expects 403 features)
target_size = 403 # Match the saved model's expected input size
if len(feature_vector) < target_size:
# Pad with zeros
padded_state = np.zeros(target_size)
padded_state[:len(feature_vector)] = feature_vector
return padded_state
elif len(feature_vector) > target_size:
# Truncate to target size
return feature_vector[:target_size]
else:
return feature_vector
# Return the full unified feature vector for RL agent
# The DQN agent is now initialized with the correct size to match this
return feature_vector
except Exception as e:
logger.error(f"Error creating RL state for {symbol}: {e}")