Optional numeric return head (predicts percent change for 1s,1m,1h,1d)

This commit is contained in:
Dobromir Popov
2025-08-23 15:17:04 +03:00
parent 9992b226ea
commit 81749ee18e
8 changed files with 124 additions and 30 deletions

View File

@@ -135,20 +135,24 @@ class EnhancedRLTrainingAdapter:
# Run DQN prediction
if hasattr(self.orchestrator.rl_agent, 'act'):
action_idx = self.orchestrator.rl_agent.act(state)
confidence = 0.7 # Default confidence for DQN
# Try to extract confidence from agent if available
confidence = getattr(self.orchestrator.rl_agent, 'last_confidence', None)
if confidence is None:
confidence = 0.5
# Convert action to prediction format
action_names = ['SELL', 'HOLD', 'BUY']
direction = action_idx - 1 # Convert 0,1,2 to -1,0,1
current_price = base_data.get('current_price', 0.0)
predicted_price = current_price * (1 + (direction * 0.001)) # Small price prediction
# Use real current price
current_price = self._safe_get_current_price(context.symbol)
# Do not fabricate price; set predicted_price only if model provides numeric target later
return {
'predicted_price': predicted_price,
'predicted_price': current_price, # same as current when no numeric target available
'current_price': current_price,
'direction': direction,
'confidence': confidence,
'confidence': float(confidence),
'action': action_names[action_idx],
'model_state': state,
'context': context
@@ -174,8 +178,10 @@ class EnhancedRLTrainingAdapter:
prediction = self.orchestrator.realtime_rl_trader._predict(context.symbol, features)
if prediction:
current_price = await self._get_current_price(context.symbol)
predicted_price = current_price * (1 + prediction.get('change', 0))
current_price = self._safe_get_current_price(context.symbol)
# If 'change' is available assume it is a fractional return
change = prediction.get('change', None)
predicted_price = current_price * (1 + change) if (change is not None and current_price) else current_price
return {
'predicted_price': predicted_price,
@@ -207,22 +213,37 @@ class EnhancedRLTrainingAdapter:
if hasattr(model, 'predict_from_base_input'):
model_output = model.predict_from_base_input(base_data)
current_price = base_data.get('current_price', 0.0)
# Extract current price from data provider
current_price = self._safe_get_current_price(context.symbol)
# Extract prediction data
predictions = model_output.predictions
action = predictions.get('action', 'HOLD')
confidence = predictions.get('confidence', 0.0)
# Convert action to direction
# Convert action to direction only for classification signal
direction = {'BUY': 1, 'SELL': -1, 'HOLD': 0}.get(action, 0)
predicted_price = current_price * (1 + (direction * 0.002))
# Use numeric predicted return if provided (no synthetic fabrication)
pr_map = {
TimeFrame.SECONDS_1: 'predicted_return_1s',
TimeFrame.MINUTES_1: 'predicted_return_1m',
TimeFrame.HOURS_1: 'predicted_return_1h',
TimeFrame.DAYS_1: 'predicted_return_1d',
}
ret_key = pr_map.get(context.target_timeframe)
predicted_return = None
if ret_key and ret_key in predictions:
predicted_return = float(predictions.get(ret_key))
predicted_price = current_price * (1 + predicted_return) if (predicted_return is not None and current_price) else current_price
return {
'predicted_price': predicted_price,
'current_price': current_price,
'direction': direction,
'confidence': confidence,
'predicted_return': predicted_return,
'action': action,
'model_output': model_output,
'context': context
@@ -260,15 +281,14 @@ class EnhancedRLTrainingAdapter:
return None
async def _get_current_price(self, symbol: str) -> float:
"""Get current price for a symbol"""
def _safe_get_current_price(self, symbol: str) -> float:
"""Get current price for a symbol via DataProvider API"""
try:
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
current_prices = self.orchestrator.data_provider.current_prices
return current_prices.get(symbol, 0.0)
price = self.orchestrator.data_provider.get_current_price(symbol)
return float(price) if price is not None else 0.0
except Exception as e:
logger.debug(f"Error getting current price for {symbol}: {e}")
return 0.0
def _convert_to_dqn_state(self, base_data: Any, context: InferenceContext) -> np.ndarray:
@@ -433,7 +453,7 @@ class EnhancedRLTrainingAdapter:
for prediction_record, reward in training_data:
# Extract state information
# This would need to be adapted based on how states are stored
state = np.zeros(100) # Placeholder - you'll need to extract actual state
state = np.zeros(100)
next_state = state.copy() # Simplified next state
# Convert direction to action