enable training
This commit is contained in:
@@ -1330,110 +1330,6 @@ class TradingOrchestrator:
|
||||
logger.debug(f"Error building COB state for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _combine_predictions(self, symbol: str, price: float, predictions: List[Prediction],
|
||||
timestamp: datetime) -> TradingDecision:
|
||||
# Call act_with_confidence and handle different return formats
|
||||
result = model.model.act_with_confidence(state)
|
||||
|
||||
if len(result) == 3:
|
||||
# EnhancedCNN format: (action, confidence, q_values)
|
||||
action_idx, confidence, raw_q_values = result
|
||||
elif len(result) == 2:
|
||||
# DQN format: (action, confidence)
|
||||
action_idx, confidence = result
|
||||
raw_q_values = None
|
||||
else:
|
||||
logger.warning(f"Unexpected result format from RL model: {result}")
|
||||
return None
|
||||
else:
|
||||
# Fallback to standard act method
|
||||
action_idx = model.model.act(state)
|
||||
confidence = 0.6 # Default confidence
|
||||
raw_q_values = None
|
||||
|
||||
# Convert action index to action name
|
||||
action_names = ['BUY', 'SELL', 'HOLD']
|
||||
if 0 <= action_idx < len(action_names):
|
||||
action = action_names[action_idx]
|
||||
else:
|
||||
logger.warning(f"Invalid action index from RL model: {action_idx}")
|
||||
return None
|
||||
|
||||
# Store prediction in database for tracking
|
||||
if (hasattr(self, 'enhanced_training_system') and
|
||||
self.enhanced_training_system and
|
||||
hasattr(self.enhanced_training_system, 'store_model_prediction')):
|
||||
|
||||
current_price = self._get_current_price_safe(symbol)
|
||||
if current_price > 0:
|
||||
prediction_id = self.enhanced_training_system.store_model_prediction(
|
||||
model_name=f"DQN_{model.model_name}" if hasattr(model, 'model_name') else "DQN",
|
||||
symbol=symbol,
|
||||
prediction_type=action,
|
||||
confidence=confidence,
|
||||
current_price=current_price
|
||||
)
|
||||
logger.debug(f"Stored DQN prediction {prediction_id} for {symbol}")
|
||||
|
||||
# Create prediction object
|
||||
prediction = Prediction(
|
||||
model_name=f"DQN_{model.model_name}" if hasattr(model, 'model_name') else "DQN",
|
||||
symbol=symbol,
|
||||
signal=action,
|
||||
confidence=confidence,
|
||||
reasoning=f"DQN agent prediction with Q-values: {raw_q_values}",
|
||||
features=state.tolist() if isinstance(state, np.ndarray) else [],
|
||||
metadata={
|
||||
'action_idx': action_idx,
|
||||
'q_values': raw_q_values.tolist() if raw_q_values is not None else None,
|
||||
'state_size': len(state) if state is not None else 0
|
||||
}
|
||||
)
|
||||
|
||||
return prediction
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting RL prediction for {symbol}: {e}")
|
||||
return None
|
||||
raw_q_values = None # No raw q_values from simple act
|
||||
else:
|
||||
logger.error(f"RL model {model.name} has no act method")
|
||||
return None
|
||||
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
action = action_names[action_idx]
|
||||
|
||||
# Convert raw_q_values to list if they are a tensor
|
||||
q_values_for_capture = None
|
||||
if raw_q_values is not None and hasattr(raw_q_values, 'tolist'):
|
||||
q_values_for_capture = raw_q_values.tolist()
|
||||
elif raw_q_values is not None and isinstance(raw_q_values, list):
|
||||
q_values_for_capture = raw_q_values
|
||||
|
||||
# Create prediction object
|
||||
prediction = Prediction(
|
||||
action=action,
|
||||
confidence=float(confidence),
|
||||
# Use actual q_values if available, otherwise default probabilities
|
||||
probabilities={action_names[i]: float(q_values_for_capture[i]) if q_values_for_capture else (1.0 / len(action_names)) for i in range(len(action_names))},
|
||||
timeframe='mixed', # RL uses mixed timeframes
|
||||
timestamp=datetime.now(),
|
||||
model_name=model.name,
|
||||
metadata={'state_size': len(state)}
|
||||
)
|
||||
|
||||
# Capture DQN prediction for dashboard visualization
|
||||
current_price = self._get_current_price(symbol)
|
||||
if current_price:
|
||||
# Only pass q_values if they exist, otherwise pass empty list
|
||||
q_values_to_pass = q_values_for_capture if q_values_for_capture is not None else []
|
||||
self.capture_dqn_prediction(symbol, action_idx, float(confidence), current_price, q_values_to_pass)
|
||||
|
||||
return prediction
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting RL prediction: {e}")
|
||||
return None
|
||||
|
||||
async def _get_generic_prediction(self, model: ModelInterface, symbol: str) -> Optional[Prediction]:
|
||||
"""Get prediction from generic model"""
|
||||
|
@@ -33,7 +33,7 @@ def create_safe_orchestrator() -> Optional[TradingOrchestrator]:
|
||||
try:
|
||||
# Create orchestrator with basic configuration (uses correct constructor parameters)
|
||||
orchestrator = TradingOrchestrator(
|
||||
enhanced_rl_training=False # Disable problematic training initially
|
||||
enhanced_rl_training=True # Enable RL training for model improvement
|
||||
)
|
||||
|
||||
logger.info("Trading orchestrator created successfully")
|
||||
|
Reference in New Issue
Block a user