Optional numeric return head (predicts percent change for 1s,1m,1h,1d)

This commit is contained in:
Dobromir Popov
2025-08-23 15:17:04 +03:00
parent 9992b226ea
commit 81749ee18e
8 changed files with 124 additions and 30 deletions

View File

@@ -1268,15 +1268,12 @@ class DataProvider:
logger.debug(f"No valid candles generated for {symbol}")
return None
# Convert to DataFrame (timestamps remain UTC tz-aware)
# Convert to DataFrame and normalize timestamps to UTC tz-aware
df = pd.DataFrame(candles)
# Ensure timestamps are timezone-aware (UTC to match COB WebSocket data)
if not df.empty and 'timestamp' in df.columns:
# Normalize to UTC tz-aware using pandas idioms
if df['timestamp'].dt.tz is None:
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)
else:
df['timestamp'] = df['timestamp'].dt.tz_convert('UTC')
# Coerce to datetime with UTC; avoid .dt on non-datetimelike
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True, errors='coerce')
df = df.dropna(subset=['timestamp'])
df = df.sort_values('timestamp').reset_index(drop=True)

View File

@@ -140,6 +140,7 @@ class EnhancedRewardCalculator:
symbol: str,
timeframe: TimeFrame,
predicted_price: float,
predicted_return: Optional[float] = None,
predicted_direction: int,
confidence: float,
current_price: float,
@@ -171,6 +172,14 @@ class EnhancedRewardCalculator:
model_name=model_name
)
# If predicted_return provided, prefer computing implied predicted_price
# to avoid synthetic price fabrication
try:
if predicted_return is not None and current_price > 0:
prediction.predicted_price = current_price * (1.0 + float(predicted_return))
except Exception:
pass
# Store prediction
if symbol not in self.predictions:
self._initialize_data_structures()

View File

@@ -136,8 +136,8 @@ class EnhancedRewardSystemIntegration:
"""Get current price for a symbol"""
try:
if hasattr(self.orchestrator, 'data_provider'):
current_prices = self.orchestrator.data_provider.current_prices
return current_prices.get(symbol, 0.0)
price = self.orchestrator.data_provider.get_current_price(symbol)
return float(price) if price is not None else 0.0
except Exception as e:
logger.debug(f"Error getting current price for {symbol}: {e}")

View File

@@ -135,20 +135,24 @@ class EnhancedRLTrainingAdapter:
# Run DQN prediction
if hasattr(self.orchestrator.rl_agent, 'act'):
action_idx = self.orchestrator.rl_agent.act(state)
confidence = 0.7 # Default confidence for DQN
# Try to extract confidence from agent if available
confidence = getattr(self.orchestrator.rl_agent, 'last_confidence', None)
if confidence is None:
confidence = 0.5
# Convert action to prediction format
action_names = ['SELL', 'HOLD', 'BUY']
direction = action_idx - 1 # Convert 0,1,2 to -1,0,1
current_price = base_data.get('current_price', 0.0)
predicted_price = current_price * (1 + (direction * 0.001)) # Small price prediction
# Use real current price
current_price = self._safe_get_current_price(context.symbol)
# Do not fabricate price; set predicted_price only if model provides numeric target later
return {
'predicted_price': predicted_price,
'predicted_price': current_price, # same as current when no numeric target available
'current_price': current_price,
'direction': direction,
'confidence': confidence,
'confidence': float(confidence),
'action': action_names[action_idx],
'model_state': state,
'context': context
@@ -174,8 +178,10 @@ class EnhancedRLTrainingAdapter:
prediction = self.orchestrator.realtime_rl_trader._predict(context.symbol, features)
if prediction:
current_price = await self._get_current_price(context.symbol)
predicted_price = current_price * (1 + prediction.get('change', 0))
current_price = self._safe_get_current_price(context.symbol)
# If 'change' is available assume it is a fractional return
change = prediction.get('change', None)
predicted_price = current_price * (1 + change) if (change is not None and current_price) else current_price
return {
'predicted_price': predicted_price,
@@ -207,22 +213,37 @@ class EnhancedRLTrainingAdapter:
if hasattr(model, 'predict_from_base_input'):
model_output = model.predict_from_base_input(base_data)
current_price = base_data.get('current_price', 0.0)
# Extract current price from data provider
current_price = self._safe_get_current_price(context.symbol)
# Extract prediction data
predictions = model_output.predictions
action = predictions.get('action', 'HOLD')
confidence = predictions.get('confidence', 0.0)
# Convert action to direction
# Convert action to direction only for classification signal
direction = {'BUY': 1, 'SELL': -1, 'HOLD': 0}.get(action, 0)
predicted_price = current_price * (1 + (direction * 0.002))
# Use numeric predicted return if provided (no synthetic fabrication)
pr_map = {
TimeFrame.SECONDS_1: 'predicted_return_1s',
TimeFrame.MINUTES_1: 'predicted_return_1m',
TimeFrame.HOURS_1: 'predicted_return_1h',
TimeFrame.DAYS_1: 'predicted_return_1d',
}
ret_key = pr_map.get(context.target_timeframe)
predicted_return = None
if ret_key and ret_key in predictions:
predicted_return = float(predictions.get(ret_key))
predicted_price = current_price * (1 + predicted_return) if (predicted_return is not None and current_price) else current_price
return {
'predicted_price': predicted_price,
'current_price': current_price,
'direction': direction,
'confidence': confidence,
'predicted_return': predicted_return,
'action': action,
'model_output': model_output,
'context': context
@@ -260,15 +281,14 @@ class EnhancedRLTrainingAdapter:
return None
async def _get_current_price(self, symbol: str) -> float:
"""Get current price for a symbol"""
def _safe_get_current_price(self, symbol: str) -> float:
"""Get current price for a symbol via DataProvider API"""
try:
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
current_prices = self.orchestrator.data_provider.current_prices
return current_prices.get(symbol, 0.0)
price = self.orchestrator.data_provider.get_current_price(symbol)
return float(price) if price is not None else 0.0
except Exception as e:
logger.debug(f"Error getting current price for {symbol}: {e}")
return 0.0
def _convert_to_dqn_state(self, base_data: Any, context: InferenceContext) -> np.ndarray:
@@ -433,7 +453,7 @@ class EnhancedRLTrainingAdapter:
for prediction_record, reward in training_data:
# Extract state information
# This would need to be adapted based on how states are stored
state = np.zeros(100) # Placeholder - you'll need to extract actual state
state = np.zeros(100)
next_state = state.copy() # Simplified next state
# Convert direction to action

View File

@@ -6899,6 +6899,16 @@ class TradingOrchestrator:
if hasattr(self.enhanced_training_system, "start_training"):
self.enhanced_training_system.start_training()
logger.info("Enhanced real-time training started")
# Start Enhanced Reward System integration
try:
from core.enhanced_reward_system_integration import start_enhanced_rewards_for_orchestrator
# Fire and forget task to start integration
import asyncio as _asyncio
_asyncio.create_task(start_enhanced_rewards_for_orchestrator(self, symbols=[self.symbol] + self.ref_symbols))
logger.info("Enhanced reward system started")
except Exception as e:
logger.error(f"Error starting enhanced reward system: {e}")
return True
else:
logger.warning(

View File

@@ -215,7 +215,7 @@ class TimeframeInferenceCoordinator:
await asyncio.sleep(1.0) # Wait longer on error
async def _hourly_inference_scheduler(self):
"""Scheduler for hourly multi-timeframe inference"""
"""Scheduler for hourly multi-timeframe inference and timeframe-boundary triggers"""
logger.info("Starting hourly inference scheduler")
while self.running:
@@ -231,6 +231,17 @@ class TimeframeInferenceCoordinator:
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
self.next_hourly_inference[symbol] = next_hour
self.last_hourly_inference[symbol] = current_time
# Trigger at each new timeframe boundary: 1m, 1h, 1d
if current_time.second == 0:
# New minute
await self._execute_boundary_inference(symbol, current_time, TimeFrame.MINUTES_1)
if current_time.minute == 0 and current_time.second == 0:
# New hour
await self._execute_boundary_inference(symbol, current_time, TimeFrame.HOURS_1)
if current_time.hour == 0 and current_time.minute == 0 and current_time.second == 0:
# New day
await self._execute_boundary_inference(symbol, current_time, TimeFrame.DAYS_1)
# Sleep for 30 seconds between checks
await asyncio.sleep(30)
@@ -238,6 +249,21 @@ class TimeframeInferenceCoordinator:
except Exception as e:
logger.error(f"Error in hourly inference scheduler: {e}")
await asyncio.sleep(60) # Wait longer on error
async def _execute_boundary_inference(self, symbol: str, timestamp: datetime, timeframe: TimeFrame):
"""Execute an inference exactly at timeframe boundary"""
try:
context = InferenceContext(
symbol=symbol,
timeframe=timeframe,
timestamp=timestamp,
target_timeframe=timeframe,
is_hourly_inference=False,
inference_type="boundary"
)
await self._execute_inference(context)
except Exception as e:
logger.debug(f"Boundary inference error for {symbol} {timeframe.value}: {e}")
async def _execute_hourly_inference(self, symbol: str, timestamp: datetime):
"""
@@ -327,6 +353,7 @@ class TimeframeInferenceCoordinator:
try:
# Update price cache if data provider available
if self.data_provider:
# DataProvider.get_current_price is synchronous; do not await
await self._update_price_cache()
# Evaluate predictions and get training data
@@ -352,7 +379,7 @@ class TimeframeInferenceCoordinator:
for symbol in self.symbols:
# Get current price from data provider
if hasattr(self.data_provider, 'get_current_price'):
current_price = await self.data_provider.get_current_price(symbol)
current_price = self.data_provider.get_current_price(symbol)
if current_price:
self.reward_calculator.update_price(symbol, current_price)