Optional numeric return head (predicts percent change for 1s,1m,1h,1d)
This commit is contained in:
@@ -66,6 +66,15 @@ class StandardizedCNN(nn.Module):
|
||||
# Output processing layers
|
||||
self.output_processor = self._build_output_processor()
|
||||
|
||||
# Optional numeric return head (predicts percent change for 1s,1m,1h,1d)
|
||||
# Uses cnn_features (1024) to regress predicted returns per timeframe
|
||||
self.return_head = nn.Sequential(
|
||||
nn.Linear(1024, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(256, 4) # [return_1s, return_1m, return_1h, return_1d]
|
||||
)
|
||||
|
||||
# Device management
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.to(self.device)
|
||||
@@ -175,6 +184,9 @@ class StandardizedCNN(nn.Module):
|
||||
# Process outputs for standardized format
|
||||
action_probs = self.output_processor(cnn_features) # [batch, 3]
|
||||
|
||||
# Predict numeric returns per timeframe from cnn_features
|
||||
predicted_returns = self.return_head(cnn_features) # [batch, 4]
|
||||
|
||||
# Prepare hidden states for cross-model feeding
|
||||
hidden_states = {
|
||||
'processed_features': processed_features.detach(),
|
||||
@@ -186,7 +198,7 @@ class StandardizedCNN(nn.Module):
|
||||
'attention_weights': torch.ones(batch_size, 1, device=x.device) # Placeholder
|
||||
}
|
||||
|
||||
return action_probs, hidden_states
|
||||
return action_probs, hidden_states, predicted_returns.detach()
|
||||
|
||||
def predict_from_base_input(self, base_input: BaseDataInput) -> ModelOutput:
|
||||
"""
|
||||
@@ -210,7 +222,7 @@ class StandardizedCNN(nn.Module):
|
||||
|
||||
with torch.no_grad():
|
||||
# Forward pass
|
||||
action_probs, hidden_states = self.forward(input_tensor)
|
||||
action_probs, hidden_states, predicted_returns = self.forward(input_tensor)
|
||||
|
||||
# Get action and confidence
|
||||
action_probs_np = action_probs.squeeze(0).cpu().numpy()
|
||||
@@ -233,6 +245,19 @@ class StandardizedCNN(nn.Module):
|
||||
'market_conditions': self._interpret_advanced_predictions(hidden_states.get('advanced_predictions'))
|
||||
}
|
||||
|
||||
# Add numeric predicted returns per timeframe if available
|
||||
try:
|
||||
pr = predicted_returns.squeeze(0).cpu().numpy().tolist()
|
||||
# Ensure length 4; if not, safely handle
|
||||
if isinstance(pr, list) and len(pr) >= 4:
|
||||
predictions['predicted_returns'] = pr[:4]
|
||||
predictions['predicted_return_1s'] = float(pr[0])
|
||||
predictions['predicted_return_1m'] = float(pr[1])
|
||||
predictions['predicted_return_1h'] = float(pr[2])
|
||||
predictions['predicted_return_1d'] = float(pr[3])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Prepare hidden states for cross-model feeding (convert tensors to numpy)
|
||||
cross_model_states = {}
|
||||
for key, tensor in hidden_states.items():
|
||||
|
@@ -111,3 +111,9 @@ THINK REALY HARD
|
||||
|
||||
|
||||
do we evaluate and reward/punish each model at each reference?
|
||||
|
||||
|
||||
|
||||
|
||||
in our realtime Reinforcement learning training how do we calculate the score (reward/penalty?)
|
||||
Let's use the mean squared difference between the prediction and the empirical outcome. We should do a training run at each inference which will use the last inference's prediction and the current price as outcome. do that up to 6 last predictions and calculating accuracity separately to have a better picture of the ability to predict couple of timeframes in the future. additionally to the frequent inference every 1 or 5s (i forgot the curent CNN rate) do an inference at each new timeframe interval. model should get the full data (multi timeframe - ETH (main) 1s 1m 1h 1d and 1m for BTC, SPX and one more) but should also know on what timeframe it is predicting. we predict only on the main symbol - so in 4 timeframes. bur on every hour we will do 4 inferences - one for each timeframe
|
@@ -1268,15 +1268,12 @@ class DataProvider:
|
||||
logger.debug(f"No valid candles generated for {symbol}")
|
||||
return None
|
||||
|
||||
# Convert to DataFrame (timestamps remain UTC tz-aware)
|
||||
# Convert to DataFrame and normalize timestamps to UTC tz-aware
|
||||
df = pd.DataFrame(candles)
|
||||
# Ensure timestamps are timezone-aware (UTC to match COB WebSocket data)
|
||||
if not df.empty and 'timestamp' in df.columns:
|
||||
# Normalize to UTC tz-aware using pandas idioms
|
||||
if df['timestamp'].dt.tz is None:
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)
|
||||
else:
|
||||
df['timestamp'] = df['timestamp'].dt.tz_convert('UTC')
|
||||
# Coerce to datetime with UTC; avoid .dt on non-datetimelike
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True, errors='coerce')
|
||||
df = df.dropna(subset=['timestamp'])
|
||||
|
||||
df = df.sort_values('timestamp').reset_index(drop=True)
|
||||
|
||||
|
@@ -140,6 +140,7 @@ class EnhancedRewardCalculator:
|
||||
symbol: str,
|
||||
timeframe: TimeFrame,
|
||||
predicted_price: float,
|
||||
predicted_return: Optional[float] = None,
|
||||
predicted_direction: int,
|
||||
confidence: float,
|
||||
current_price: float,
|
||||
@@ -171,6 +172,14 @@ class EnhancedRewardCalculator:
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
# If predicted_return provided, prefer computing implied predicted_price
|
||||
# to avoid synthetic price fabrication
|
||||
try:
|
||||
if predicted_return is not None and current_price > 0:
|
||||
prediction.predicted_price = current_price * (1.0 + float(predicted_return))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Store prediction
|
||||
if symbol not in self.predictions:
|
||||
self._initialize_data_structures()
|
||||
|
@@ -136,8 +136,8 @@ class EnhancedRewardSystemIntegration:
|
||||
"""Get current price for a symbol"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'data_provider'):
|
||||
current_prices = self.orchestrator.data_provider.current_prices
|
||||
return current_prices.get(symbol, 0.0)
|
||||
price = self.orchestrator.data_provider.get_current_price(symbol)
|
||||
return float(price) if price is not None else 0.0
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
|
||||
|
@@ -135,20 +135,24 @@ class EnhancedRLTrainingAdapter:
|
||||
# Run DQN prediction
|
||||
if hasattr(self.orchestrator.rl_agent, 'act'):
|
||||
action_idx = self.orchestrator.rl_agent.act(state)
|
||||
confidence = 0.7 # Default confidence for DQN
|
||||
# Try to extract confidence from agent if available
|
||||
confidence = getattr(self.orchestrator.rl_agent, 'last_confidence', None)
|
||||
if confidence is None:
|
||||
confidence = 0.5
|
||||
|
||||
# Convert action to prediction format
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
direction = action_idx - 1 # Convert 0,1,2 to -1,0,1
|
||||
|
||||
current_price = base_data.get('current_price', 0.0)
|
||||
predicted_price = current_price * (1 + (direction * 0.001)) # Small price prediction
|
||||
# Use real current price
|
||||
current_price = self._safe_get_current_price(context.symbol)
|
||||
|
||||
# Do not fabricate price; set predicted_price only if model provides numeric target later
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'predicted_price': current_price, # same as current when no numeric target available
|
||||
'current_price': current_price,
|
||||
'direction': direction,
|
||||
'confidence': confidence,
|
||||
'confidence': float(confidence),
|
||||
'action': action_names[action_idx],
|
||||
'model_state': state,
|
||||
'context': context
|
||||
@@ -174,8 +178,10 @@ class EnhancedRLTrainingAdapter:
|
||||
prediction = self.orchestrator.realtime_rl_trader._predict(context.symbol, features)
|
||||
|
||||
if prediction:
|
||||
current_price = await self._get_current_price(context.symbol)
|
||||
predicted_price = current_price * (1 + prediction.get('change', 0))
|
||||
current_price = self._safe_get_current_price(context.symbol)
|
||||
# If 'change' is available assume it is a fractional return
|
||||
change = prediction.get('change', None)
|
||||
predicted_price = current_price * (1 + change) if (change is not None and current_price) else current_price
|
||||
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
@@ -207,22 +213,37 @@ class EnhancedRLTrainingAdapter:
|
||||
if hasattr(model, 'predict_from_base_input'):
|
||||
model_output = model.predict_from_base_input(base_data)
|
||||
|
||||
current_price = base_data.get('current_price', 0.0)
|
||||
# Extract current price from data provider
|
||||
current_price = self._safe_get_current_price(context.symbol)
|
||||
|
||||
# Extract prediction data
|
||||
predictions = model_output.predictions
|
||||
action = predictions.get('action', 'HOLD')
|
||||
confidence = predictions.get('confidence', 0.0)
|
||||
|
||||
# Convert action to direction
|
||||
# Convert action to direction only for classification signal
|
||||
direction = {'BUY': 1, 'SELL': -1, 'HOLD': 0}.get(action, 0)
|
||||
predicted_price = current_price * (1 + (direction * 0.002))
|
||||
|
||||
# Use numeric predicted return if provided (no synthetic fabrication)
|
||||
pr_map = {
|
||||
TimeFrame.SECONDS_1: 'predicted_return_1s',
|
||||
TimeFrame.MINUTES_1: 'predicted_return_1m',
|
||||
TimeFrame.HOURS_1: 'predicted_return_1h',
|
||||
TimeFrame.DAYS_1: 'predicted_return_1d',
|
||||
}
|
||||
ret_key = pr_map.get(context.target_timeframe)
|
||||
predicted_return = None
|
||||
if ret_key and ret_key in predictions:
|
||||
predicted_return = float(predictions.get(ret_key))
|
||||
|
||||
predicted_price = current_price * (1 + predicted_return) if (predicted_return is not None and current_price) else current_price
|
||||
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'current_price': current_price,
|
||||
'direction': direction,
|
||||
'confidence': confidence,
|
||||
'predicted_return': predicted_return,
|
||||
'action': action,
|
||||
'model_output': model_output,
|
||||
'context': context
|
||||
@@ -260,15 +281,14 @@ class EnhancedRLTrainingAdapter:
|
||||
|
||||
return None
|
||||
|
||||
async def _get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price for a symbol"""
|
||||
def _safe_get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price for a symbol via DataProvider API"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
|
||||
current_prices = self.orchestrator.data_provider.current_prices
|
||||
return current_prices.get(symbol, 0.0)
|
||||
price = self.orchestrator.data_provider.get_current_price(symbol)
|
||||
return float(price) if price is not None else 0.0
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
|
||||
return 0.0
|
||||
|
||||
def _convert_to_dqn_state(self, base_data: Any, context: InferenceContext) -> np.ndarray:
|
||||
@@ -433,7 +453,7 @@ class EnhancedRLTrainingAdapter:
|
||||
for prediction_record, reward in training_data:
|
||||
# Extract state information
|
||||
# This would need to be adapted based on how states are stored
|
||||
state = np.zeros(100) # Placeholder - you'll need to extract actual state
|
||||
state = np.zeros(100)
|
||||
next_state = state.copy() # Simplified next state
|
||||
|
||||
# Convert direction to action
|
||||
|
@@ -6899,6 +6899,16 @@ class TradingOrchestrator:
|
||||
if hasattr(self.enhanced_training_system, "start_training"):
|
||||
self.enhanced_training_system.start_training()
|
||||
logger.info("Enhanced real-time training started")
|
||||
|
||||
# Start Enhanced Reward System integration
|
||||
try:
|
||||
from core.enhanced_reward_system_integration import start_enhanced_rewards_for_orchestrator
|
||||
# Fire and forget task to start integration
|
||||
import asyncio as _asyncio
|
||||
_asyncio.create_task(start_enhanced_rewards_for_orchestrator(self, symbols=[self.symbol] + self.ref_symbols))
|
||||
logger.info("Enhanced reward system started")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting enhanced reward system: {e}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
|
@@ -215,7 +215,7 @@ class TimeframeInferenceCoordinator:
|
||||
await asyncio.sleep(1.0) # Wait longer on error
|
||||
|
||||
async def _hourly_inference_scheduler(self):
|
||||
"""Scheduler for hourly multi-timeframe inference"""
|
||||
"""Scheduler for hourly multi-timeframe inference and timeframe-boundary triggers"""
|
||||
logger.info("Starting hourly inference scheduler")
|
||||
|
||||
while self.running:
|
||||
@@ -232,6 +232,17 @@ class TimeframeInferenceCoordinator:
|
||||
self.next_hourly_inference[symbol] = next_hour
|
||||
self.last_hourly_inference[symbol] = current_time
|
||||
|
||||
# Trigger at each new timeframe boundary: 1m, 1h, 1d
|
||||
if current_time.second == 0:
|
||||
# New minute
|
||||
await self._execute_boundary_inference(symbol, current_time, TimeFrame.MINUTES_1)
|
||||
if current_time.minute == 0 and current_time.second == 0:
|
||||
# New hour
|
||||
await self._execute_boundary_inference(symbol, current_time, TimeFrame.HOURS_1)
|
||||
if current_time.hour == 0 and current_time.minute == 0 and current_time.second == 0:
|
||||
# New day
|
||||
await self._execute_boundary_inference(symbol, current_time, TimeFrame.DAYS_1)
|
||||
|
||||
# Sleep for 30 seconds between checks
|
||||
await asyncio.sleep(30)
|
||||
|
||||
@@ -239,6 +250,21 @@ class TimeframeInferenceCoordinator:
|
||||
logger.error(f"Error in hourly inference scheduler: {e}")
|
||||
await asyncio.sleep(60) # Wait longer on error
|
||||
|
||||
async def _execute_boundary_inference(self, symbol: str, timestamp: datetime, timeframe: TimeFrame):
|
||||
"""Execute an inference exactly at timeframe boundary"""
|
||||
try:
|
||||
context = InferenceContext(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
timestamp=timestamp,
|
||||
target_timeframe=timeframe,
|
||||
is_hourly_inference=False,
|
||||
inference_type="boundary"
|
||||
)
|
||||
await self._execute_inference(context)
|
||||
except Exception as e:
|
||||
logger.debug(f"Boundary inference error for {symbol} {timeframe.value}: {e}")
|
||||
|
||||
async def _execute_hourly_inference(self, symbol: str, timestamp: datetime):
|
||||
"""
|
||||
Execute hourly multi-timeframe inference for a symbol
|
||||
@@ -327,6 +353,7 @@ class TimeframeInferenceCoordinator:
|
||||
try:
|
||||
# Update price cache if data provider available
|
||||
if self.data_provider:
|
||||
# DataProvider.get_current_price is synchronous; do not await
|
||||
await self._update_price_cache()
|
||||
|
||||
# Evaluate predictions and get training data
|
||||
@@ -352,7 +379,7 @@ class TimeframeInferenceCoordinator:
|
||||
for symbol in self.symbols:
|
||||
# Get current price from data provider
|
||||
if hasattr(self.data_provider, 'get_current_price'):
|
||||
current_price = await self.data_provider.get_current_price(symbol)
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
if current_price:
|
||||
self.reward_calculator.update_price(symbol, current_price)
|
||||
|
||||
|
Reference in New Issue
Block a user