predict price direction
This commit is contained in:
@ -719,6 +719,95 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing ML models: {e}")
|
||||
|
||||
def _calculate_cnn_price_direction_loss(self, price_direction_pred: torch.Tensor, rewards: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate price direction loss for CNN model
|
||||
|
||||
Args:
|
||||
price_direction_pred: Tensor of shape [batch, 2] containing [direction, confidence]
|
||||
rewards: Tensor of shape [batch] containing rewards
|
||||
actions: Tensor of shape [batch] containing actions
|
||||
|
||||
Returns:
|
||||
Price direction loss tensor
|
||||
"""
|
||||
try:
|
||||
if price_direction_pred.size(1) != 2:
|
||||
return None
|
||||
|
||||
batch_size = price_direction_pred.size(0)
|
||||
|
||||
# Extract direction and confidence predictions
|
||||
direction_pred = price_direction_pred[:, 0] # -1 to 1
|
||||
confidence_pred = price_direction_pred[:, 1] # 0 to 1
|
||||
|
||||
# Create targets based on rewards and actions
|
||||
with torch.no_grad():
|
||||
# Direction targets: 1 if reward > 0 and action is BUY, -1 if reward > 0 and action is SELL, 0 otherwise
|
||||
direction_targets = torch.zeros(batch_size, device=price_direction_pred.device)
|
||||
for i in range(batch_size):
|
||||
if rewards[i] > 0.01: # Positive reward threshold
|
||||
if actions[i] == 0: # BUY action
|
||||
direction_targets[i] = 1.0 # UP
|
||||
elif actions[i] == 1: # SELL action
|
||||
direction_targets[i] = -1.0 # DOWN
|
||||
# else: targets remain 0 (sideways)
|
||||
|
||||
# Confidence targets: based on reward magnitude (higher reward = higher confidence)
|
||||
confidence_targets = torch.abs(rewards).clamp(0, 1)
|
||||
|
||||
# Calculate losses for each component
|
||||
direction_loss = nn.MSELoss()(direction_pred, direction_targets)
|
||||
confidence_loss = nn.MSELoss()(confidence_pred, confidence_targets)
|
||||
|
||||
# Combined loss (direction is more important than confidence)
|
||||
total_loss = direction_loss + 0.3 * confidence_loss
|
||||
|
||||
return total_loss
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error calculating CNN price direction loss: {e}")
|
||||
return None
|
||||
|
||||
def _calculate_cnn_extrema_loss(self, extrema_pred: torch.Tensor, rewards: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate extrema loss for CNN model
|
||||
|
||||
Args:
|
||||
extrema_pred: Extrema predictions
|
||||
rewards: Tensor containing rewards
|
||||
actions: Tensor containing actions
|
||||
|
||||
Returns:
|
||||
Extrema loss tensor
|
||||
"""
|
||||
try:
|
||||
batch_size = extrema_pred.size(0)
|
||||
|
||||
# Create targets based on reward patterns
|
||||
with torch.no_grad():
|
||||
extrema_targets = torch.ones(batch_size, dtype=torch.long, device=extrema_pred.device) * 2 # Default to "neither"
|
||||
|
||||
for i in range(batch_size):
|
||||
# High positive reward suggests we're at a good entry point
|
||||
if rewards[i] > 0.05:
|
||||
if actions[i] == 0: # BUY action
|
||||
extrema_targets[i] = 0 # Bottom
|
||||
elif actions[i] == 1: # SELL action
|
||||
extrema_targets[i] = 1 # Top
|
||||
|
||||
# Calculate cross-entropy loss
|
||||
if extrema_pred.size(1) >= 3:
|
||||
extrema_loss = nn.CrossEntropyLoss()(extrema_pred[:, :3], extrema_targets)
|
||||
else:
|
||||
extrema_loss = nn.CrossEntropyLoss()(extrema_pred, extrema_targets)
|
||||
|
||||
return extrema_loss
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error calculating CNN extrema loss: {e}")
|
||||
return None
|
||||
|
||||
def update_model_loss(self, model_name: str, current_loss: float, best_loss: Optional[float] = None):
|
||||
"""Update model loss and potentially best loss"""
|
||||
if model_name in self.model_states:
|
||||
@ -1938,7 +2027,71 @@ class TradingOrchestrator:
|
||||
# Evaluate the previous prediction and train the model immediately
|
||||
await self._evaluate_and_train_on_record(inference_record, current_price)
|
||||
|
||||
logger.info(f"Completed immediate training for {model_name}")
|
||||
# Log predicted vs actual outcome
|
||||
prediction = inference_record.get('prediction', {})
|
||||
predicted_action = prediction.get('action', 'UNKNOWN')
|
||||
predicted_confidence = prediction.get('confidence', 0.0)
|
||||
|
||||
# Calculate actual outcome
|
||||
symbol = inference_record.get('symbol', 'ETH/USDT')
|
||||
predicted_price = None
|
||||
actual_price_change_pct = 0.0
|
||||
|
||||
# Try to get price direction vectors from metadata (new format)
|
||||
if 'price_direction' in prediction and prediction['price_direction']:
|
||||
try:
|
||||
price_direction_data = prediction['price_direction']
|
||||
# Process price direction data
|
||||
if isinstance(price_direction_data, dict) and 'direction' in price_direction_data:
|
||||
direction = price_direction_data['direction']
|
||||
confidence = price_direction_data.get('confidence', 1.0)
|
||||
|
||||
# Convert direction to price change percentage
|
||||
# Scale by confidence and direction strength
|
||||
predicted_price_change_pct = direction * confidence * 0.02 # 2% max change
|
||||
predicted_price = current_price * (1 + predicted_price_change_pct)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error processing price direction data: {e}")
|
||||
|
||||
# Fallback to old price prediction format
|
||||
elif 'price_prediction' in prediction and prediction['price_prediction']:
|
||||
try:
|
||||
price_prediction_data = prediction['price_prediction']
|
||||
if isinstance(price_prediction_data, list) and len(price_prediction_data) > 0:
|
||||
predicted_price_change_pct = float(price_prediction_data[0]) * 0.01
|
||||
predicted_price = current_price * (1 + predicted_price_change_pct)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Calculate price change
|
||||
if predicted_price is not None:
|
||||
actual_price_change_pct = (current_price - predicted_price) / predicted_price * 100
|
||||
price_outcome = f"Predicted: ${predicted_price:.2f} -> Actual: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
|
||||
else:
|
||||
# Fall back to historical price comparison
|
||||
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
|
||||
if historical_data is not None and not historical_data.empty:
|
||||
historical_price = historical_data['close'].iloc[-1]
|
||||
actual_price_change_pct = (current_price - historical_price) / historical_price * 100
|
||||
price_outcome = f"Historical: ${historical_price:.2f} -> Actual: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
|
||||
else:
|
||||
price_outcome = f"Actual: ${current_price:.2f}"
|
||||
|
||||
# Determine if prediction was correct based on action and price movement
|
||||
was_correct = False
|
||||
if predicted_action == 'BUY' and actual_price_change_pct > 0.1: # Price went up
|
||||
was_correct = True
|
||||
elif predicted_action == 'SELL' and actual_price_change_pct < -0.1: # Price went down
|
||||
was_correct = True
|
||||
elif predicted_action == 'HOLD' and abs(actual_price_change_pct) < 0.5: # Price stayed stable
|
||||
was_correct = True
|
||||
|
||||
outcome_status = "✅ CORRECT" if was_correct else "❌ INCORRECT"
|
||||
|
||||
logger.info(f"Completed immediate training for {model_name} - {outcome_status}")
|
||||
logger.info(f" Prediction: {predicted_action} ({predicted_confidence:.3f})")
|
||||
logger.info(f" {price_outcome}")
|
||||
logger.info(f" Outcome: {outcome_status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in immediate training for {model_name}: {e}")
|
||||
@ -2412,12 +2565,33 @@ class TradingOrchestrator:
|
||||
self.cnn_optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.cnn_model(features_tensor)
|
||||
q_values, extrema_pred, price_direction_pred, features_refined, advanced_pred = self.cnn_model(features_tensor)
|
||||
|
||||
# Calculate loss
|
||||
# Calculate primary Q-value loss
|
||||
q_values_selected = q_values.gather(1, action_tensor.unsqueeze(1)).squeeze(1)
|
||||
target_q = reward_tensor # Simplified target
|
||||
loss = nn.MSELoss()(q_values_selected, target_q)
|
||||
q_loss = nn.MSELoss()(q_values_selected, target_q)
|
||||
|
||||
# Calculate auxiliary losses for price direction and extrema
|
||||
total_loss = q_loss
|
||||
|
||||
# Price direction loss
|
||||
if price_direction_pred is not None and price_direction_pred.shape[0] > 0:
|
||||
price_direction_loss = self._calculate_cnn_price_direction_loss(
|
||||
price_direction_pred, reward_tensor, action_tensor
|
||||
)
|
||||
if price_direction_loss is not None:
|
||||
total_loss = total_loss + 0.2 * price_direction_loss
|
||||
|
||||
# Extrema loss
|
||||
if extrema_pred is not None and extrema_pred.shape[0] > 0:
|
||||
extrema_loss = self._calculate_cnn_extrema_loss(
|
||||
extrema_pred, reward_tensor, action_tensor
|
||||
)
|
||||
if extrema_loss is not None:
|
||||
total_loss = total_loss + 0.1 * extrema_loss
|
||||
|
||||
loss = total_loss
|
||||
|
||||
# Backward pass
|
||||
training_start_time = time.time()
|
||||
@ -2640,9 +2814,17 @@ class TradingOrchestrator:
|
||||
'HOLD': float(action_probs[0, 2].item())
|
||||
}
|
||||
|
||||
# Extract price predictions if available
|
||||
price_prediction = None
|
||||
# Extract price direction predictions if available
|
||||
price_direction_data = None
|
||||
if price_pred is not None:
|
||||
# Process price direction predictions
|
||||
if hasattr(model.model, 'process_price_direction_predictions'):
|
||||
try:
|
||||
price_direction_data = model.model.process_price_direction_predictions(price_pred)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error processing CNN price direction: {e}")
|
||||
|
||||
# Fallback to old format for compatibility
|
||||
price_prediction = price_pred.squeeze(0).cpu().numpy().tolist()
|
||||
|
||||
prediction = Prediction(
|
||||
@ -2656,6 +2838,7 @@ class TradingOrchestrator:
|
||||
'feature_size': len(base_data.get_feature_vector()),
|
||||
'data_sources': ['ohlcv_1s', 'ohlcv_1m', 'ohlcv_1h', 'ohlcv_1d', 'btc', 'cob', 'indicators'],
|
||||
'price_prediction': price_prediction,
|
||||
'price_direction': price_direction_data,
|
||||
'extrema_prediction': extrema_pred.squeeze(0).cpu().numpy().tolist() if extrema_pred is not None else None
|
||||
}
|
||||
)
|
||||
@ -2694,6 +2877,14 @@ class TradingOrchestrator:
|
||||
action_names = ['BUY', 'SELL', 'HOLD'] # Note: enhanced_cnn uses this order
|
||||
best_action = action_names[action_idx]
|
||||
|
||||
# Get price direction vectors from CNN model if available
|
||||
price_direction_data = None
|
||||
if hasattr(model.model, 'get_price_direction_vector'):
|
||||
try:
|
||||
price_direction_data = model.model.get_price_direction_vector()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting price direction from CNN: {e}")
|
||||
|
||||
pred = Prediction(
|
||||
action=best_action,
|
||||
confidence=float(confidence),
|
||||
@ -2708,7 +2899,8 @@ class TradingOrchestrator:
|
||||
metadata={
|
||||
'feature_vector_size': len(feature_vector),
|
||||
'unified_input': True,
|
||||
'fallback_method': 'direct_model_inference'
|
||||
'fallback_method': 'direct_model_inference',
|
||||
'price_direction': price_direction_data
|
||||
}
|
||||
)
|
||||
predictions.append(pred)
|
||||
@ -2811,6 +3003,14 @@ class TradingOrchestrator:
|
||||
if q_values_for_capture:
|
||||
logger.warning(f"Q-values length mismatch: expected {len(action_names)}, got {len(q_values_for_capture)}. Using default probabilities.")
|
||||
|
||||
# Get price direction vectors from DQN model if available
|
||||
price_direction_data = None
|
||||
if hasattr(model.model, 'get_price_direction_vector'):
|
||||
try:
|
||||
price_direction_data = model.model.get_price_direction_vector()
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting price direction from DQN: {e}")
|
||||
|
||||
prediction = Prediction(
|
||||
action=action,
|
||||
confidence=float(confidence),
|
||||
@ -2818,7 +3018,10 @@ class TradingOrchestrator:
|
||||
timeframe='mixed', # RL uses mixed timeframes
|
||||
timestamp=datetime.now(),
|
||||
model_name=model.name,
|
||||
metadata={'state_size': len(state)}
|
||||
metadata={
|
||||
'state_size': len(state),
|
||||
'price_direction': price_direction_data
|
||||
}
|
||||
)
|
||||
|
||||
# Capture DQN prediction for dashboard visualization
|
||||
|
Reference in New Issue
Block a user