Optional numeric return head (predicts percent change for 1s,1m,1h,1d)
This commit is contained in:
@@ -1268,15 +1268,12 @@ class DataProvider:
|
||||
logger.debug(f"No valid candles generated for {symbol}")
|
||||
return None
|
||||
|
||||
# Convert to DataFrame (timestamps remain UTC tz-aware)
|
||||
# Convert to DataFrame and normalize timestamps to UTC tz-aware
|
||||
df = pd.DataFrame(candles)
|
||||
# Ensure timestamps are timezone-aware (UTC to match COB WebSocket data)
|
||||
if not df.empty and 'timestamp' in df.columns:
|
||||
# Normalize to UTC tz-aware using pandas idioms
|
||||
if df['timestamp'].dt.tz is None:
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)
|
||||
else:
|
||||
df['timestamp'] = df['timestamp'].dt.tz_convert('UTC')
|
||||
# Coerce to datetime with UTC; avoid .dt on non-datetimelike
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True, errors='coerce')
|
||||
df = df.dropna(subset=['timestamp'])
|
||||
|
||||
df = df.sort_values('timestamp').reset_index(drop=True)
|
||||
|
||||
|
@@ -140,6 +140,7 @@ class EnhancedRewardCalculator:
|
||||
symbol: str,
|
||||
timeframe: TimeFrame,
|
||||
predicted_price: float,
|
||||
predicted_return: Optional[float] = None,
|
||||
predicted_direction: int,
|
||||
confidence: float,
|
||||
current_price: float,
|
||||
@@ -171,6 +172,14 @@ class EnhancedRewardCalculator:
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
# If predicted_return provided, prefer computing implied predicted_price
|
||||
# to avoid synthetic price fabrication
|
||||
try:
|
||||
if predicted_return is not None and current_price > 0:
|
||||
prediction.predicted_price = current_price * (1.0 + float(predicted_return))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Store prediction
|
||||
if symbol not in self.predictions:
|
||||
self._initialize_data_structures()
|
||||
|
@@ -136,8 +136,8 @@ class EnhancedRewardSystemIntegration:
|
||||
"""Get current price for a symbol"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'data_provider'):
|
||||
current_prices = self.orchestrator.data_provider.current_prices
|
||||
return current_prices.get(symbol, 0.0)
|
||||
price = self.orchestrator.data_provider.get_current_price(symbol)
|
||||
return float(price) if price is not None else 0.0
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
|
||||
|
@@ -135,20 +135,24 @@ class EnhancedRLTrainingAdapter:
|
||||
# Run DQN prediction
|
||||
if hasattr(self.orchestrator.rl_agent, 'act'):
|
||||
action_idx = self.orchestrator.rl_agent.act(state)
|
||||
confidence = 0.7 # Default confidence for DQN
|
||||
# Try to extract confidence from agent if available
|
||||
confidence = getattr(self.orchestrator.rl_agent, 'last_confidence', None)
|
||||
if confidence is None:
|
||||
confidence = 0.5
|
||||
|
||||
# Convert action to prediction format
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
direction = action_idx - 1 # Convert 0,1,2 to -1,0,1
|
||||
|
||||
current_price = base_data.get('current_price', 0.0)
|
||||
predicted_price = current_price * (1 + (direction * 0.001)) # Small price prediction
|
||||
# Use real current price
|
||||
current_price = self._safe_get_current_price(context.symbol)
|
||||
|
||||
# Do not fabricate price; set predicted_price only if model provides numeric target later
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'predicted_price': current_price, # same as current when no numeric target available
|
||||
'current_price': current_price,
|
||||
'direction': direction,
|
||||
'confidence': confidence,
|
||||
'confidence': float(confidence),
|
||||
'action': action_names[action_idx],
|
||||
'model_state': state,
|
||||
'context': context
|
||||
@@ -174,8 +178,10 @@ class EnhancedRLTrainingAdapter:
|
||||
prediction = self.orchestrator.realtime_rl_trader._predict(context.symbol, features)
|
||||
|
||||
if prediction:
|
||||
current_price = await self._get_current_price(context.symbol)
|
||||
predicted_price = current_price * (1 + prediction.get('change', 0))
|
||||
current_price = self._safe_get_current_price(context.symbol)
|
||||
# If 'change' is available assume it is a fractional return
|
||||
change = prediction.get('change', None)
|
||||
predicted_price = current_price * (1 + change) if (change is not None and current_price) else current_price
|
||||
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
@@ -207,22 +213,37 @@ class EnhancedRLTrainingAdapter:
|
||||
if hasattr(model, 'predict_from_base_input'):
|
||||
model_output = model.predict_from_base_input(base_data)
|
||||
|
||||
current_price = base_data.get('current_price', 0.0)
|
||||
# Extract current price from data provider
|
||||
current_price = self._safe_get_current_price(context.symbol)
|
||||
|
||||
# Extract prediction data
|
||||
predictions = model_output.predictions
|
||||
action = predictions.get('action', 'HOLD')
|
||||
confidence = predictions.get('confidence', 0.0)
|
||||
|
||||
# Convert action to direction
|
||||
# Convert action to direction only for classification signal
|
||||
direction = {'BUY': 1, 'SELL': -1, 'HOLD': 0}.get(action, 0)
|
||||
predicted_price = current_price * (1 + (direction * 0.002))
|
||||
|
||||
# Use numeric predicted return if provided (no synthetic fabrication)
|
||||
pr_map = {
|
||||
TimeFrame.SECONDS_1: 'predicted_return_1s',
|
||||
TimeFrame.MINUTES_1: 'predicted_return_1m',
|
||||
TimeFrame.HOURS_1: 'predicted_return_1h',
|
||||
TimeFrame.DAYS_1: 'predicted_return_1d',
|
||||
}
|
||||
ret_key = pr_map.get(context.target_timeframe)
|
||||
predicted_return = None
|
||||
if ret_key and ret_key in predictions:
|
||||
predicted_return = float(predictions.get(ret_key))
|
||||
|
||||
predicted_price = current_price * (1 + predicted_return) if (predicted_return is not None and current_price) else current_price
|
||||
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'current_price': current_price,
|
||||
'direction': direction,
|
||||
'confidence': confidence,
|
||||
'predicted_return': predicted_return,
|
||||
'action': action,
|
||||
'model_output': model_output,
|
||||
'context': context
|
||||
@@ -260,15 +281,14 @@ class EnhancedRLTrainingAdapter:
|
||||
|
||||
return None
|
||||
|
||||
async def _get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price for a symbol"""
|
||||
def _safe_get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price for a symbol via DataProvider API"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
|
||||
current_prices = self.orchestrator.data_provider.current_prices
|
||||
return current_prices.get(symbol, 0.0)
|
||||
price = self.orchestrator.data_provider.get_current_price(symbol)
|
||||
return float(price) if price is not None else 0.0
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
|
||||
return 0.0
|
||||
|
||||
def _convert_to_dqn_state(self, base_data: Any, context: InferenceContext) -> np.ndarray:
|
||||
@@ -433,7 +453,7 @@ class EnhancedRLTrainingAdapter:
|
||||
for prediction_record, reward in training_data:
|
||||
# Extract state information
|
||||
# This would need to be adapted based on how states are stored
|
||||
state = np.zeros(100) # Placeholder - you'll need to extract actual state
|
||||
state = np.zeros(100)
|
||||
next_state = state.copy() # Simplified next state
|
||||
|
||||
# Convert direction to action
|
||||
|
@@ -6899,6 +6899,16 @@ class TradingOrchestrator:
|
||||
if hasattr(self.enhanced_training_system, "start_training"):
|
||||
self.enhanced_training_system.start_training()
|
||||
logger.info("Enhanced real-time training started")
|
||||
|
||||
# Start Enhanced Reward System integration
|
||||
try:
|
||||
from core.enhanced_reward_system_integration import start_enhanced_rewards_for_orchestrator
|
||||
# Fire and forget task to start integration
|
||||
import asyncio as _asyncio
|
||||
_asyncio.create_task(start_enhanced_rewards_for_orchestrator(self, symbols=[self.symbol] + self.ref_symbols))
|
||||
logger.info("Enhanced reward system started")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting enhanced reward system: {e}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
|
@@ -215,7 +215,7 @@ class TimeframeInferenceCoordinator:
|
||||
await asyncio.sleep(1.0) # Wait longer on error
|
||||
|
||||
async def _hourly_inference_scheduler(self):
|
||||
"""Scheduler for hourly multi-timeframe inference"""
|
||||
"""Scheduler for hourly multi-timeframe inference and timeframe-boundary triggers"""
|
||||
logger.info("Starting hourly inference scheduler")
|
||||
|
||||
while self.running:
|
||||
@@ -231,6 +231,17 @@ class TimeframeInferenceCoordinator:
|
||||
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
|
||||
self.next_hourly_inference[symbol] = next_hour
|
||||
self.last_hourly_inference[symbol] = current_time
|
||||
|
||||
# Trigger at each new timeframe boundary: 1m, 1h, 1d
|
||||
if current_time.second == 0:
|
||||
# New minute
|
||||
await self._execute_boundary_inference(symbol, current_time, TimeFrame.MINUTES_1)
|
||||
if current_time.minute == 0 and current_time.second == 0:
|
||||
# New hour
|
||||
await self._execute_boundary_inference(symbol, current_time, TimeFrame.HOURS_1)
|
||||
if current_time.hour == 0 and current_time.minute == 0 and current_time.second == 0:
|
||||
# New day
|
||||
await self._execute_boundary_inference(symbol, current_time, TimeFrame.DAYS_1)
|
||||
|
||||
# Sleep for 30 seconds between checks
|
||||
await asyncio.sleep(30)
|
||||
@@ -238,6 +249,21 @@ class TimeframeInferenceCoordinator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in hourly inference scheduler: {e}")
|
||||
await asyncio.sleep(60) # Wait longer on error
|
||||
|
||||
async def _execute_boundary_inference(self, symbol: str, timestamp: datetime, timeframe: TimeFrame):
|
||||
"""Execute an inference exactly at timeframe boundary"""
|
||||
try:
|
||||
context = InferenceContext(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
timestamp=timestamp,
|
||||
target_timeframe=timeframe,
|
||||
is_hourly_inference=False,
|
||||
inference_type="boundary"
|
||||
)
|
||||
await self._execute_inference(context)
|
||||
except Exception as e:
|
||||
logger.debug(f"Boundary inference error for {symbol} {timeframe.value}: {e}")
|
||||
|
||||
async def _execute_hourly_inference(self, symbol: str, timestamp: datetime):
|
||||
"""
|
||||
@@ -327,6 +353,7 @@ class TimeframeInferenceCoordinator:
|
||||
try:
|
||||
# Update price cache if data provider available
|
||||
if self.data_provider:
|
||||
# DataProvider.get_current_price is synchronous; do not await
|
||||
await self._update_price_cache()
|
||||
|
||||
# Evaluate predictions and get training data
|
||||
@@ -352,7 +379,7 @@ class TimeframeInferenceCoordinator:
|
||||
for symbol in self.symbols:
|
||||
# Get current price from data provider
|
||||
if hasattr(self.data_provider, 'get_current_price'):
|
||||
current_price = await self.data_provider.get_current_price(symbol)
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
if current_price:
|
||||
self.reward_calculator.update_price(symbol, current_price)
|
||||
|
||||
|
Reference in New Issue
Block a user