predict price direction

This commit is contained in:
Dobromir Popov
2025-07-27 23:20:47 +03:00
parent dfa18035f1
commit 39267697f3
4 changed files with 572 additions and 101 deletions

View File

@ -719,6 +719,95 @@ class TradingOrchestrator:
except Exception as e:
logger.error(f"Error initializing ML models: {e}")
def _calculate_cnn_price_direction_loss(self, price_direction_pred: torch.Tensor, rewards: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
"""
Calculate price direction loss for CNN model
Args:
price_direction_pred: Tensor of shape [batch, 2] containing [direction, confidence]
rewards: Tensor of shape [batch] containing rewards
actions: Tensor of shape [batch] containing actions
Returns:
Price direction loss tensor
"""
try:
if price_direction_pred.size(1) != 2:
return None
batch_size = price_direction_pred.size(0)
# Extract direction and confidence predictions
direction_pred = price_direction_pred[:, 0] # -1 to 1
confidence_pred = price_direction_pred[:, 1] # 0 to 1
# Create targets based on rewards and actions
with torch.no_grad():
# Direction targets: 1 if reward > 0 and action is BUY, -1 if reward > 0 and action is SELL, 0 otherwise
direction_targets = torch.zeros(batch_size, device=price_direction_pred.device)
for i in range(batch_size):
if rewards[i] > 0.01: # Positive reward threshold
if actions[i] == 0: # BUY action
direction_targets[i] = 1.0 # UP
elif actions[i] == 1: # SELL action
direction_targets[i] = -1.0 # DOWN
# else: targets remain 0 (sideways)
# Confidence targets: based on reward magnitude (higher reward = higher confidence)
confidence_targets = torch.abs(rewards).clamp(0, 1)
# Calculate losses for each component
direction_loss = nn.MSELoss()(direction_pred, direction_targets)
confidence_loss = nn.MSELoss()(confidence_pred, confidence_targets)
# Combined loss (direction is more important than confidence)
total_loss = direction_loss + 0.3 * confidence_loss
return total_loss
except Exception as e:
logger.debug(f"Error calculating CNN price direction loss: {e}")
return None
def _calculate_cnn_extrema_loss(self, extrema_pred: torch.Tensor, rewards: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
"""
Calculate extrema loss for CNN model
Args:
extrema_pred: Extrema predictions
rewards: Tensor containing rewards
actions: Tensor containing actions
Returns:
Extrema loss tensor
"""
try:
batch_size = extrema_pred.size(0)
# Create targets based on reward patterns
with torch.no_grad():
extrema_targets = torch.ones(batch_size, dtype=torch.long, device=extrema_pred.device) * 2 # Default to "neither"
for i in range(batch_size):
# High positive reward suggests we're at a good entry point
if rewards[i] > 0.05:
if actions[i] == 0: # BUY action
extrema_targets[i] = 0 # Bottom
elif actions[i] == 1: # SELL action
extrema_targets[i] = 1 # Top
# Calculate cross-entropy loss
if extrema_pred.size(1) >= 3:
extrema_loss = nn.CrossEntropyLoss()(extrema_pred[:, :3], extrema_targets)
else:
extrema_loss = nn.CrossEntropyLoss()(extrema_pred, extrema_targets)
return extrema_loss
except Exception as e:
logger.debug(f"Error calculating CNN extrema loss: {e}")
return None
def update_model_loss(self, model_name: str, current_loss: float, best_loss: Optional[float] = None):
"""Update model loss and potentially best loss"""
if model_name in self.model_states:
@ -1938,7 +2027,71 @@ class TradingOrchestrator:
# Evaluate the previous prediction and train the model immediately
await self._evaluate_and_train_on_record(inference_record, current_price)
logger.info(f"Completed immediate training for {model_name}")
# Log predicted vs actual outcome
prediction = inference_record.get('prediction', {})
predicted_action = prediction.get('action', 'UNKNOWN')
predicted_confidence = prediction.get('confidence', 0.0)
# Calculate actual outcome
symbol = inference_record.get('symbol', 'ETH/USDT')
predicted_price = None
actual_price_change_pct = 0.0
# Try to get price direction vectors from metadata (new format)
if 'price_direction' in prediction and prediction['price_direction']:
try:
price_direction_data = prediction['price_direction']
# Process price direction data
if isinstance(price_direction_data, dict) and 'direction' in price_direction_data:
direction = price_direction_data['direction']
confidence = price_direction_data.get('confidence', 1.0)
# Convert direction to price change percentage
# Scale by confidence and direction strength
predicted_price_change_pct = direction * confidence * 0.02 # 2% max change
predicted_price = current_price * (1 + predicted_price_change_pct)
except Exception as e:
logger.debug(f"Error processing price direction data: {e}")
# Fallback to old price prediction format
elif 'price_prediction' in prediction and prediction['price_prediction']:
try:
price_prediction_data = prediction['price_prediction']
if isinstance(price_prediction_data, list) and len(price_prediction_data) > 0:
predicted_price_change_pct = float(price_prediction_data[0]) * 0.01
predicted_price = current_price * (1 + predicted_price_change_pct)
except Exception:
pass
# Calculate price change
if predicted_price is not None:
actual_price_change_pct = (current_price - predicted_price) / predicted_price * 100
price_outcome = f"Predicted: ${predicted_price:.2f} -> Actual: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
else:
# Fall back to historical price comparison
historical_data = self.data_provider.get_historical_data(symbol, '1m', limit=10)
if historical_data is not None and not historical_data.empty:
historical_price = historical_data['close'].iloc[-1]
actual_price_change_pct = (current_price - historical_price) / historical_price * 100
price_outcome = f"Historical: ${historical_price:.2f} -> Actual: ${current_price:.2f} ({actual_price_change_pct:+.2f}%)"
else:
price_outcome = f"Actual: ${current_price:.2f}"
# Determine if prediction was correct based on action and price movement
was_correct = False
if predicted_action == 'BUY' and actual_price_change_pct > 0.1: # Price went up
was_correct = True
elif predicted_action == 'SELL' and actual_price_change_pct < -0.1: # Price went down
was_correct = True
elif predicted_action == 'HOLD' and abs(actual_price_change_pct) < 0.5: # Price stayed stable
was_correct = True
outcome_status = "✅ CORRECT" if was_correct else "❌ INCORRECT"
logger.info(f"Completed immediate training for {model_name} - {outcome_status}")
logger.info(f" Prediction: {predicted_action} ({predicted_confidence:.3f})")
logger.info(f" {price_outcome}")
logger.info(f" Outcome: {outcome_status}")
except Exception as e:
logger.error(f"Error in immediate training for {model_name}: {e}")
@ -2412,12 +2565,33 @@ class TradingOrchestrator:
self.cnn_optimizer.zero_grad()
# Forward pass
q_values, extrema_pred, price_pred, features_refined, advanced_pred = self.cnn_model(features_tensor)
q_values, extrema_pred, price_direction_pred, features_refined, advanced_pred = self.cnn_model(features_tensor)
# Calculate loss
# Calculate primary Q-value loss
q_values_selected = q_values.gather(1, action_tensor.unsqueeze(1)).squeeze(1)
target_q = reward_tensor # Simplified target
loss = nn.MSELoss()(q_values_selected, target_q)
q_loss = nn.MSELoss()(q_values_selected, target_q)
# Calculate auxiliary losses for price direction and extrema
total_loss = q_loss
# Price direction loss
if price_direction_pred is not None and price_direction_pred.shape[0] > 0:
price_direction_loss = self._calculate_cnn_price_direction_loss(
price_direction_pred, reward_tensor, action_tensor
)
if price_direction_loss is not None:
total_loss = total_loss + 0.2 * price_direction_loss
# Extrema loss
if extrema_pred is not None and extrema_pred.shape[0] > 0:
extrema_loss = self._calculate_cnn_extrema_loss(
extrema_pred, reward_tensor, action_tensor
)
if extrema_loss is not None:
total_loss = total_loss + 0.1 * extrema_loss
loss = total_loss
# Backward pass
training_start_time = time.time()
@ -2640,9 +2814,17 @@ class TradingOrchestrator:
'HOLD': float(action_probs[0, 2].item())
}
# Extract price predictions if available
price_prediction = None
# Extract price direction predictions if available
price_direction_data = None
if price_pred is not None:
# Process price direction predictions
if hasattr(model.model, 'process_price_direction_predictions'):
try:
price_direction_data = model.model.process_price_direction_predictions(price_pred)
except Exception as e:
logger.debug(f"Error processing CNN price direction: {e}")
# Fallback to old format for compatibility
price_prediction = price_pred.squeeze(0).cpu().numpy().tolist()
prediction = Prediction(
@ -2656,6 +2838,7 @@ class TradingOrchestrator:
'feature_size': len(base_data.get_feature_vector()),
'data_sources': ['ohlcv_1s', 'ohlcv_1m', 'ohlcv_1h', 'ohlcv_1d', 'btc', 'cob', 'indicators'],
'price_prediction': price_prediction,
'price_direction': price_direction_data,
'extrema_prediction': extrema_pred.squeeze(0).cpu().numpy().tolist() if extrema_pred is not None else None
}
)
@ -2694,6 +2877,14 @@ class TradingOrchestrator:
action_names = ['BUY', 'SELL', 'HOLD'] # Note: enhanced_cnn uses this order
best_action = action_names[action_idx]
# Get price direction vectors from CNN model if available
price_direction_data = None
if hasattr(model.model, 'get_price_direction_vector'):
try:
price_direction_data = model.model.get_price_direction_vector()
except Exception as e:
logger.debug(f"Error getting price direction from CNN: {e}")
pred = Prediction(
action=best_action,
confidence=float(confidence),
@ -2708,7 +2899,8 @@ class TradingOrchestrator:
metadata={
'feature_vector_size': len(feature_vector),
'unified_input': True,
'fallback_method': 'direct_model_inference'
'fallback_method': 'direct_model_inference',
'price_direction': price_direction_data
}
)
predictions.append(pred)
@ -2811,6 +3003,14 @@ class TradingOrchestrator:
if q_values_for_capture:
logger.warning(f"Q-values length mismatch: expected {len(action_names)}, got {len(q_values_for_capture)}. Using default probabilities.")
# Get price direction vectors from DQN model if available
price_direction_data = None
if hasattr(model.model, 'get_price_direction_vector'):
try:
price_direction_data = model.model.get_price_direction_vector()
except Exception as e:
logger.debug(f"Error getting price direction from DQN: {e}")
prediction = Prediction(
action=action,
confidence=float(confidence),
@ -2818,7 +3018,10 @@ class TradingOrchestrator:
timeframe='mixed', # RL uses mixed timeframes
timestamp=datetime.now(),
model_name=model.name,
metadata={'state_size': len(state)}
metadata={
'state_size': len(state),
'price_direction': price_direction_data
}
)
# Capture DQN prediction for dashboard visualization