Optional numeric return head (predicts percent change for 1s,1m,1h,1d)
This commit is contained in:
@@ -135,20 +135,24 @@ class EnhancedRLTrainingAdapter:
|
||||
# Run DQN prediction
|
||||
if hasattr(self.orchestrator.rl_agent, 'act'):
|
||||
action_idx = self.orchestrator.rl_agent.act(state)
|
||||
confidence = 0.7 # Default confidence for DQN
|
||||
# Try to extract confidence from agent if available
|
||||
confidence = getattr(self.orchestrator.rl_agent, 'last_confidence', None)
|
||||
if confidence is None:
|
||||
confidence = 0.5
|
||||
|
||||
# Convert action to prediction format
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
direction = action_idx - 1 # Convert 0,1,2 to -1,0,1
|
||||
|
||||
current_price = base_data.get('current_price', 0.0)
|
||||
predicted_price = current_price * (1 + (direction * 0.001)) # Small price prediction
|
||||
# Use real current price
|
||||
current_price = self._safe_get_current_price(context.symbol)
|
||||
|
||||
# Do not fabricate price; set predicted_price only if model provides numeric target later
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'predicted_price': current_price, # same as current when no numeric target available
|
||||
'current_price': current_price,
|
||||
'direction': direction,
|
||||
'confidence': confidence,
|
||||
'confidence': float(confidence),
|
||||
'action': action_names[action_idx],
|
||||
'model_state': state,
|
||||
'context': context
|
||||
@@ -174,8 +178,10 @@ class EnhancedRLTrainingAdapter:
|
||||
prediction = self.orchestrator.realtime_rl_trader._predict(context.symbol, features)
|
||||
|
||||
if prediction:
|
||||
current_price = await self._get_current_price(context.symbol)
|
||||
predicted_price = current_price * (1 + prediction.get('change', 0))
|
||||
current_price = self._safe_get_current_price(context.symbol)
|
||||
# If 'change' is available assume it is a fractional return
|
||||
change = prediction.get('change', None)
|
||||
predicted_price = current_price * (1 + change) if (change is not None and current_price) else current_price
|
||||
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
@@ -207,22 +213,37 @@ class EnhancedRLTrainingAdapter:
|
||||
if hasattr(model, 'predict_from_base_input'):
|
||||
model_output = model.predict_from_base_input(base_data)
|
||||
|
||||
current_price = base_data.get('current_price', 0.0)
|
||||
# Extract current price from data provider
|
||||
current_price = self._safe_get_current_price(context.symbol)
|
||||
|
||||
# Extract prediction data
|
||||
predictions = model_output.predictions
|
||||
action = predictions.get('action', 'HOLD')
|
||||
confidence = predictions.get('confidence', 0.0)
|
||||
|
||||
# Convert action to direction
|
||||
# Convert action to direction only for classification signal
|
||||
direction = {'BUY': 1, 'SELL': -1, 'HOLD': 0}.get(action, 0)
|
||||
predicted_price = current_price * (1 + (direction * 0.002))
|
||||
|
||||
# Use numeric predicted return if provided (no synthetic fabrication)
|
||||
pr_map = {
|
||||
TimeFrame.SECONDS_1: 'predicted_return_1s',
|
||||
TimeFrame.MINUTES_1: 'predicted_return_1m',
|
||||
TimeFrame.HOURS_1: 'predicted_return_1h',
|
||||
TimeFrame.DAYS_1: 'predicted_return_1d',
|
||||
}
|
||||
ret_key = pr_map.get(context.target_timeframe)
|
||||
predicted_return = None
|
||||
if ret_key and ret_key in predictions:
|
||||
predicted_return = float(predictions.get(ret_key))
|
||||
|
||||
predicted_price = current_price * (1 + predicted_return) if (predicted_return is not None and current_price) else current_price
|
||||
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'current_price': current_price,
|
||||
'direction': direction,
|
||||
'confidence': confidence,
|
||||
'predicted_return': predicted_return,
|
||||
'action': action,
|
||||
'model_output': model_output,
|
||||
'context': context
|
||||
@@ -260,15 +281,14 @@ class EnhancedRLTrainingAdapter:
|
||||
|
||||
return None
|
||||
|
||||
async def _get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price for a symbol"""
|
||||
def _safe_get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price for a symbol via DataProvider API"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
|
||||
current_prices = self.orchestrator.data_provider.current_prices
|
||||
return current_prices.get(symbol, 0.0)
|
||||
price = self.orchestrator.data_provider.get_current_price(symbol)
|
||||
return float(price) if price is not None else 0.0
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
|
||||
return 0.0
|
||||
|
||||
def _convert_to_dqn_state(self, base_data: Any, context: InferenceContext) -> np.ndarray:
|
||||
@@ -433,7 +453,7 @@ class EnhancedRLTrainingAdapter:
|
||||
for prediction_record, reward in training_data:
|
||||
# Extract state information
|
||||
# This would need to be adapted based on how states are stored
|
||||
state = np.zeros(100) # Placeholder - you'll need to extract actual state
|
||||
state = np.zeros(100)
|
||||
next_state = state.copy() # Simplified next state
|
||||
|
||||
# Convert direction to action
|
||||
|
||||
Reference in New Issue
Block a user