584 lines
25 KiB
Python
584 lines
25 KiB
Python
"""
|
|
Enhanced Pivot-Based RL Trainer
|
|
|
|
Integrates Williams Market Structure pivot points with CNN predictions
|
|
for improved trading decisions and training rewards.
|
|
|
|
Key Features:
|
|
- Train RL model to buy/sell at local pivot points
|
|
- CNN predicts next pivot to avoid late signals
|
|
- Different thresholds for entry vs exit
|
|
- Rewards for staying uninvested when uncertain
|
|
- Uncertainty-based confidence adjustment
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
import numpy as np
|
|
import pandas as pd
|
|
from collections import deque, namedtuple
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Tuple, Any, Union, TYPE_CHECKING
|
|
import matplotlib.pyplot as plt
|
|
from pathlib import Path
|
|
|
|
from core.config import get_config
|
|
from core.data_provider import DataProvider
|
|
from training.williams_market_structure import WilliamsMarketStructure, SwingType, SwingPoint
|
|
|
|
# Use TYPE_CHECKING to avoid circular import
|
|
if TYPE_CHECKING:
|
|
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class PivotReward:
|
|
"""Reward structure for pivot-based trading decisions"""
|
|
|
|
def __init__(self):
|
|
# Pivot-based reward weights
|
|
self.pivot_hit_bonus = 2.0 # Bonus for trading at actual pivot points
|
|
self.pivot_anticipation_bonus = 1.5 # Bonus for trading before pivot (CNN prediction)
|
|
self.wrong_direction_penalty = -1.0 # Penalty for trading opposite to pivot direction
|
|
self.late_entry_penalty = -0.5 # Penalty for entering after pivot is confirmed
|
|
|
|
# Stay uninvested rewards
|
|
self.uninvested_reward = 0.1 # Small positive reward for staying out of poor setups
|
|
self.avoid_false_signal_bonus = 0.5 # Bonus for avoiding false signals
|
|
|
|
# Uncertainty penalties
|
|
self.overconfidence_penalty = -0.3 # Penalty for being overconfident on losses
|
|
self.underconfidence_penalty = -0.1 # Small penalty for being underconfident on wins
|
|
|
|
class EnhancedPivotRLTrainer:
|
|
"""Enhanced RL trainer focused on Williams pivot points and CNN predictions"""
|
|
|
|
def __init__(self,
|
|
data_provider: DataProvider = None,
|
|
orchestrator: Optional["EnhancedTradingOrchestrator"] = None):
|
|
|
|
self.config = get_config()
|
|
self.data_provider = data_provider or DataProvider()
|
|
self.orchestrator = orchestrator
|
|
|
|
# Initialize Williams Market Structure with CNN
|
|
self.williams = WilliamsMarketStructure(
|
|
swing_strengths=[2, 4, 6, 8, 10], # Multiple strengths for better detection
|
|
enable_cnn_feature=True,
|
|
training_data_provider=data_provider
|
|
)
|
|
|
|
# Pivot tracking
|
|
self.recent_pivots = deque(maxlen=50)
|
|
self.pivot_predictions = deque(maxlen=20)
|
|
self.trade_outcomes = deque(maxlen=100)
|
|
|
|
# Threshold management - different for entry vs exit
|
|
self.entry_threshold = 0.65 # Higher threshold for entering positions
|
|
self.exit_threshold = 0.35 # Lower threshold for exiting positions
|
|
self.max_uninvested_reward_threshold = 0.60 # Stay out if confidence below this
|
|
|
|
# Confidence learning parameters
|
|
self.confidence_history = deque(maxlen=200)
|
|
self.mistake_severity_tracker = deque(maxlen=50)
|
|
|
|
# Reward calculator
|
|
self.pivot_reward = PivotReward()
|
|
|
|
logger.info("Enhanced Pivot RL Trainer initialized")
|
|
logger.info(f"Entry threshold: {self.entry_threshold:.2%}")
|
|
logger.info(f"Exit threshold: {self.exit_threshold:.2%}")
|
|
logger.info(f"Uninvested reward threshold: {self.max_uninvested_reward_threshold:.2%}")
|
|
|
|
def calculate_pivot_based_reward(self,
|
|
trade_decision: Dict[str, Any],
|
|
market_data: pd.DataFrame,
|
|
trade_outcome: Dict[str, Any]) -> float:
|
|
"""
|
|
Calculate enhanced reward based on pivot points and CNN predictions
|
|
|
|
Args:
|
|
trade_decision: The trading decision made by the model
|
|
market_data: Market data context
|
|
trade_outcome: Actual trade outcome
|
|
|
|
Returns:
|
|
Enhanced reward score
|
|
"""
|
|
try:
|
|
base_pnl = trade_outcome.get('net_pnl', 0.0)
|
|
confidence = trade_decision.get('confidence', 0.5)
|
|
action = trade_decision.get('action', 'HOLD')
|
|
entry_price = trade_decision.get('price', 0.0)
|
|
exit_price = trade_outcome.get('exit_price', entry_price)
|
|
duration = trade_outcome.get('duration', timedelta(0))
|
|
|
|
# Base PnL reward
|
|
base_reward = base_pnl / 5.0
|
|
|
|
# 1. Pivot Point Analysis Rewards
|
|
pivot_reward = self._calculate_pivot_rewards(
|
|
trade_decision, market_data, trade_outcome
|
|
)
|
|
|
|
# 2. CNN Prediction Accuracy Rewards
|
|
cnn_reward = self._calculate_cnn_prediction_rewards(
|
|
trade_decision, market_data, trade_outcome
|
|
)
|
|
|
|
# 3. Uninvested Period Rewards
|
|
uninvested_reward = self._calculate_uninvested_rewards(
|
|
trade_decision, confidence
|
|
)
|
|
|
|
# 4. Uncertainty-based Confidence Adjustment
|
|
confidence_adjustment = self._calculate_confidence_adjustment(
|
|
trade_decision, trade_outcome
|
|
)
|
|
|
|
# 5. Time efficiency with pivot context
|
|
time_reward = self._calculate_time_efficiency_reward(
|
|
duration, base_pnl, market_data
|
|
)
|
|
|
|
# Combine all rewards
|
|
total_reward = (
|
|
base_reward +
|
|
pivot_reward +
|
|
cnn_reward +
|
|
uninvested_reward +
|
|
confidence_adjustment +
|
|
time_reward
|
|
)
|
|
|
|
# Log detailed reward breakdown
|
|
self._log_reward_breakdown(
|
|
trade_decision, trade_outcome, {
|
|
'base': base_reward,
|
|
'pivot': pivot_reward,
|
|
'cnn': cnn_reward,
|
|
'uninvested': uninvested_reward,
|
|
'confidence': confidence_adjustment,
|
|
'time': time_reward,
|
|
'total': total_reward
|
|
}
|
|
)
|
|
|
|
# Track for learning
|
|
self._track_reward_outcome(trade_decision, trade_outcome, total_reward)
|
|
|
|
return np.clip(total_reward, -15.0, 10.0)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating pivot-based reward: {e}")
|
|
return 0.0
|
|
|
|
def _calculate_pivot_rewards(self,
|
|
trade_decision: Dict[str, Any],
|
|
market_data: pd.DataFrame,
|
|
trade_outcome: Dict[str, Any]) -> float:
|
|
"""Calculate rewards based on proximity to pivot points"""
|
|
try:
|
|
entry_price = trade_decision.get('price', 0.0)
|
|
action = trade_decision.get('action', 'HOLD')
|
|
entry_time = trade_decision.get('timestamp', datetime.now())
|
|
net_pnl = trade_outcome.get('net_pnl', 0.0)
|
|
|
|
# Find recent pivot points from Williams analysis
|
|
ohlcv_array = self._convert_dataframe_to_ohlcv_array(market_data)
|
|
if ohlcv_array is None or len(ohlcv_array) < 20:
|
|
return 0.0
|
|
|
|
# Get pivot points from Williams structure
|
|
structure_levels = self.williams.calculate_recursive_pivot_points(ohlcv_array)
|
|
if not structure_levels or 'level_0' not in structure_levels:
|
|
return 0.0
|
|
|
|
level_0_pivots = structure_levels['level_0'].swing_points
|
|
if not level_0_pivots:
|
|
return 0.0
|
|
|
|
# Find closest pivot to entry
|
|
closest_pivot = self._find_closest_pivot(entry_price, entry_time, level_0_pivots)
|
|
if not closest_pivot:
|
|
return 0.0
|
|
|
|
# Calculate distance to pivot (price and time)
|
|
price_distance = abs(entry_price - closest_pivot.price) / closest_pivot.price
|
|
time_distance = abs((entry_time - closest_pivot.timestamp).total_seconds()) / 3600.0 # hours
|
|
|
|
pivot_reward = 0.0
|
|
|
|
# Reward trading at or near pivot points
|
|
if price_distance < 0.005: # Within 0.5% of pivot
|
|
if time_distance < 0.5: # Within 30 minutes
|
|
pivot_reward += self.pivot_reward.pivot_hit_bonus
|
|
logger.debug(f"PIVOT HIT BONUS: {self.pivot_reward.pivot_hit_bonus:.2f}")
|
|
|
|
# Check if trade direction aligns with pivot
|
|
if self._trade_aligns_with_pivot(action, closest_pivot, net_pnl):
|
|
pivot_reward += self.pivot_reward.pivot_anticipation_bonus
|
|
logger.debug(f"PIVOT DIRECTION BONUS: {self.pivot_reward.pivot_anticipation_bonus:.2f}")
|
|
else:
|
|
pivot_reward += self.pivot_reward.wrong_direction_penalty
|
|
logger.debug(f"WRONG DIRECTION PENALTY: {self.pivot_reward.wrong_direction_penalty:.2f}")
|
|
|
|
# Penalty for late entry after pivot confirmation
|
|
if time_distance > 2.0: # More than 2 hours after pivot
|
|
pivot_reward += self.pivot_reward.late_entry_penalty
|
|
logger.debug(f"LATE ENTRY PENALTY: {self.pivot_reward.late_entry_penalty:.2f}")
|
|
|
|
return pivot_reward
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating pivot rewards: {e}")
|
|
return 0.0
|
|
|
|
def _calculate_cnn_prediction_rewards(self,
|
|
trade_decision: Dict[str, Any],
|
|
market_data: pd.DataFrame,
|
|
trade_outcome: Dict[str, Any]) -> float:
|
|
"""Calculate rewards based on CNN pivot predictions"""
|
|
try:
|
|
# Check if we have CNN predictions available
|
|
if not hasattr(self.williams, 'cnn_model') or not self.williams.cnn_model:
|
|
return 0.0
|
|
|
|
action = trade_decision.get('action', 'HOLD')
|
|
confidence = trade_decision.get('confidence', 0.5)
|
|
net_pnl = trade_outcome.get('net_pnl', 0.0)
|
|
|
|
# Get latest CNN prediction if available
|
|
# This would be the prediction made before the trade
|
|
cnn_prediction = self._get_latest_cnn_prediction()
|
|
if not cnn_prediction:
|
|
return 0.0
|
|
|
|
cnn_reward = 0.0
|
|
|
|
# Reward for following CNN predictions that turn out correct
|
|
predicted_direction = self._interpret_cnn_prediction(cnn_prediction)
|
|
|
|
if predicted_direction == action and net_pnl > 0:
|
|
# CNN prediction was correct and we followed it
|
|
cnn_reward += 1.0 * confidence # Scale by confidence
|
|
logger.debug(f"CNN CORRECT FOLLOW: +{1.0 * confidence:.2f}")
|
|
|
|
elif predicted_direction != action and net_pnl < 0:
|
|
# We didn't follow CNN and it was right (we were wrong)
|
|
cnn_reward -= 0.5
|
|
logger.debug(f"CNN IGNORE PENALTY: -0.5")
|
|
|
|
elif predicted_direction == action and net_pnl < 0:
|
|
# We followed CNN but it was wrong
|
|
cnn_reward -= 0.2 # Small penalty, CNN predictions can be wrong
|
|
logger.debug(f"CNN WRONG FOLLOW: -0.2")
|
|
|
|
return cnn_reward
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating CNN prediction rewards: {e}")
|
|
return 0.0
|
|
|
|
def _calculate_uninvested_rewards(self,
|
|
trade_decision: Dict[str, Any],
|
|
confidence: float) -> float:
|
|
"""Calculate rewards for staying uninvested when uncertain"""
|
|
try:
|
|
action = trade_decision.get('action', 'HOLD')
|
|
|
|
# Reward staying out when confidence is low
|
|
if action == 'HOLD' and confidence < self.max_uninvested_reward_threshold:
|
|
uninvested_reward = self.pivot_reward.uninvested_reward
|
|
|
|
# Bonus for avoiding very uncertain setups
|
|
if confidence < 0.4:
|
|
uninvested_reward += self.pivot_reward.avoid_false_signal_bonus
|
|
logger.debug(f"AVOID FALSE SIGNAL BONUS: +{self.pivot_reward.avoid_false_signal_bonus:.2f}")
|
|
|
|
logger.debug(f"UNINVESTED REWARD: +{uninvested_reward:.2f}")
|
|
return uninvested_reward
|
|
|
|
return 0.0
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating uninvested rewards: {e}")
|
|
return 0.0
|
|
|
|
def _calculate_confidence_adjustment(self,
|
|
trade_decision: Dict[str, Any],
|
|
trade_outcome: Dict[str, Any]) -> float:
|
|
"""Adjust rewards based on confidence vs outcome to reduce overconfidence"""
|
|
try:
|
|
confidence = trade_decision.get('confidence', 0.5)
|
|
net_pnl = trade_outcome.get('net_pnl', 0.0)
|
|
|
|
confidence_adjustment = 0.0
|
|
|
|
# Track mistake severity
|
|
mistake_severity = abs(net_pnl) if net_pnl < 0 else 0.0
|
|
self.mistake_severity_tracker.append(mistake_severity)
|
|
|
|
# Penalize overconfidence on losses
|
|
if net_pnl < 0 and confidence > 0.7:
|
|
# High confidence but loss - penalize overconfidence
|
|
overconfidence_factor = (confidence - 0.7) / 0.3 # 0-1 scale
|
|
severity_factor = min(mistake_severity / 2.0, 1.0) # Scale by loss size
|
|
|
|
penalty = self.pivot_reward.overconfidence_penalty * overconfidence_factor * severity_factor
|
|
confidence_adjustment += penalty
|
|
|
|
logger.debug(f"OVERCONFIDENCE PENALTY: {penalty:.2f} (conf: {confidence:.2f}, loss: ${net_pnl:.2f})")
|
|
|
|
# Small penalty for underconfidence on wins
|
|
elif net_pnl > 0 and confidence < 0.4:
|
|
underconfidence_factor = (0.4 - confidence) / 0.4 # 0-1 scale
|
|
penalty = self.pivot_reward.underconfidence_penalty * underconfidence_factor
|
|
confidence_adjustment += penalty
|
|
|
|
logger.debug(f"UNDERCONFIDENCE PENALTY: {penalty:.2f} (conf: {confidence:.2f}, profit: ${net_pnl:.2f})")
|
|
|
|
# Update confidence learning
|
|
self._update_confidence_learning(confidence, net_pnl, mistake_severity)
|
|
|
|
return confidence_adjustment
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating confidence adjustment: {e}")
|
|
return 0.0
|
|
|
|
def _calculate_time_efficiency_reward(self,
|
|
duration: timedelta,
|
|
net_pnl: float,
|
|
market_data: pd.DataFrame) -> float:
|
|
"""Calculate time-based rewards considering market context"""
|
|
try:
|
|
duration_hours = duration.total_seconds() / 3600.0
|
|
|
|
# Quick profitable trades get bonus
|
|
if net_pnl > 0 and duration_hours < 0.5: # Less than 30 minutes
|
|
return 0.3
|
|
|
|
# Holding losses too long gets penalty
|
|
elif net_pnl < 0 and duration_hours > 2.0: # More than 2 hours
|
|
return -0.5
|
|
|
|
return 0.0
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calculating time efficiency reward: {e}")
|
|
return 0.0
|
|
|
|
def update_thresholds_based_on_performance(self):
|
|
"""Dynamically adjust entry/exit thresholds based on recent performance"""
|
|
try:
|
|
if len(self.trade_outcomes) < 20:
|
|
return
|
|
|
|
recent_outcomes = list(self.trade_outcomes)[-20:]
|
|
|
|
# Calculate win rate and average PnL
|
|
wins = sum(1 for outcome in recent_outcomes if outcome['net_pnl'] > 0)
|
|
win_rate = wins / len(recent_outcomes)
|
|
avg_pnl = np.mean([outcome['net_pnl'] for outcome in recent_outcomes])
|
|
|
|
# Adjust thresholds based on performance
|
|
if win_rate < 0.4: # Low win rate - be more selective
|
|
self.entry_threshold = min(self.entry_threshold + 0.02, 0.80)
|
|
logger.info(f"Low win rate ({win_rate:.2%}) - increased entry threshold to {self.entry_threshold:.2%}")
|
|
|
|
elif win_rate > 0.6 and avg_pnl > 0: # High win rate - can be more aggressive
|
|
self.entry_threshold = max(self.entry_threshold - 0.01, 0.50)
|
|
logger.info(f"High win rate ({win_rate:.2%}) - decreased entry threshold to {self.entry_threshold:.2%}")
|
|
|
|
# Adjust exit threshold based on loss severity
|
|
avg_loss_severity = np.mean(list(self.mistake_severity_tracker)) if self.mistake_severity_tracker else 0
|
|
|
|
if avg_loss_severity > 1.0: # Large average losses
|
|
self.exit_threshold = max(self.exit_threshold - 0.01, 0.20)
|
|
logger.info(f"High loss severity - decreased exit threshold to {self.exit_threshold:.2%}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating thresholds: {e}")
|
|
|
|
def get_current_thresholds(self) -> Dict[str, float]:
|
|
"""Get current entry and exit thresholds"""
|
|
return {
|
|
'entry_threshold': self.entry_threshold,
|
|
'exit_threshold': self.exit_threshold,
|
|
'uninvested_threshold': self.max_uninvested_reward_threshold
|
|
}
|
|
|
|
# Helper methods
|
|
|
|
def _convert_dataframe_to_ohlcv_array(self, df: pd.DataFrame) -> Optional[np.ndarray]:
|
|
"""Convert pandas DataFrame to numpy array for Williams analysis"""
|
|
try:
|
|
if df.empty:
|
|
return None
|
|
|
|
# Ensure we have required columns
|
|
required_cols = ['open', 'high', 'low', 'close', 'volume']
|
|
if not all(col in df.columns for col in required_cols):
|
|
return None
|
|
|
|
# Convert to numpy array
|
|
timestamps = df.index.astype(np.int64) // 10**9 # Convert to Unix timestamp
|
|
ohlcv_array = np.column_stack([
|
|
timestamps,
|
|
df['open'].values,
|
|
df['high'].values,
|
|
df['low'].values,
|
|
df['close'].values,
|
|
df['volume'].values
|
|
])
|
|
|
|
return ohlcv_array.astype(np.float64)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error converting DataFrame to OHLCV array: {e}")
|
|
return None
|
|
|
|
def _find_closest_pivot(self,
|
|
entry_price: float,
|
|
entry_time: datetime,
|
|
pivots: List[SwingPoint]) -> Optional[SwingPoint]:
|
|
"""Find the closest pivot point to the trade entry"""
|
|
try:
|
|
if not pivots:
|
|
return None
|
|
|
|
# Find pivot closest in time and price
|
|
best_pivot = None
|
|
best_score = float('inf')
|
|
|
|
for pivot in pivots:
|
|
time_diff = abs((entry_time - pivot.timestamp).total_seconds()) / 3600.0
|
|
price_diff = abs(entry_price - pivot.price) / pivot.price
|
|
|
|
# Combined score (weighted by time and price proximity)
|
|
score = time_diff * 0.3 + price_diff * 100 # Weight price difference more heavily
|
|
|
|
if score < best_score:
|
|
best_score = score
|
|
best_pivot = pivot
|
|
|
|
return best_pivot
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error finding closest pivot: {e}")
|
|
return None
|
|
|
|
def _trade_aligns_with_pivot(self,
|
|
action: str,
|
|
pivot: SwingPoint,
|
|
net_pnl: float) -> bool:
|
|
"""Check if trade direction aligns with pivot type and was profitable"""
|
|
try:
|
|
if net_pnl <= 0: # Only consider profitable trades as aligned
|
|
return False
|
|
|
|
if action == 'BUY' and pivot.swing_type == SwingType.SWING_LOW:
|
|
return True # Bought at/near swing low
|
|
elif action == 'SELL' and pivot.swing_type == SwingType.SWING_HIGH:
|
|
return True # Sold at/near swing high
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error checking trade alignment: {e}")
|
|
return False
|
|
|
|
def _get_latest_cnn_prediction(self) -> Optional[np.ndarray]:
|
|
"""Get the latest CNN prediction from Williams structure"""
|
|
try:
|
|
# This would access the Williams CNN model's latest prediction
|
|
# For now, return None if not available
|
|
if hasattr(self.williams, 'latest_cnn_prediction'):
|
|
return self.williams.latest_cnn_prediction
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting CNN prediction: {e}")
|
|
return None
|
|
|
|
def _interpret_cnn_prediction(self, prediction: np.ndarray) -> str:
|
|
"""Interpret CNN prediction array to trading action"""
|
|
try:
|
|
if len(prediction) < 2:
|
|
return 'HOLD'
|
|
|
|
# Assuming prediction format: [type, price] for level 0
|
|
predicted_type = prediction[0] # 0 = LOW, 1 = HIGH
|
|
|
|
if predicted_type > 0.5:
|
|
return 'SELL' # Expecting swing high - sell
|
|
else:
|
|
return 'BUY' # Expecting swing low - buy
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error interpreting CNN prediction: {e}")
|
|
return 'HOLD'
|
|
|
|
def _update_confidence_learning(self,
|
|
confidence: float,
|
|
net_pnl: float,
|
|
mistake_severity: float):
|
|
"""Update confidence learning parameters"""
|
|
try:
|
|
self.confidence_history.append({
|
|
'confidence': confidence,
|
|
'net_pnl': net_pnl,
|
|
'mistake_severity': mistake_severity,
|
|
'timestamp': datetime.now()
|
|
})
|
|
|
|
# Periodically update thresholds based on confidence patterns
|
|
if len(self.confidence_history) % 10 == 0:
|
|
self.update_thresholds_based_on_performance()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating confidence learning: {e}")
|
|
|
|
def _track_reward_outcome(self,
|
|
trade_decision: Dict[str, Any],
|
|
trade_outcome: Dict[str, Any],
|
|
total_reward: float):
|
|
"""Track reward outcomes for analysis"""
|
|
try:
|
|
outcome_record = {
|
|
'timestamp': datetime.now(),
|
|
'action': trade_decision.get('action'),
|
|
'confidence': trade_decision.get('confidence'),
|
|
'net_pnl': trade_outcome.get('net_pnl'),
|
|
'reward': total_reward,
|
|
'duration': trade_outcome.get('duration')
|
|
}
|
|
|
|
self.trade_outcomes.append(outcome_record)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error tracking reward outcome: {e}")
|
|
|
|
def _log_reward_breakdown(self,
|
|
trade_decision: Dict[str, Any],
|
|
trade_outcome: Dict[str, Any],
|
|
rewards: Dict[str, float]):
|
|
"""Log detailed reward breakdown"""
|
|
try:
|
|
action = trade_decision.get('action', 'UNKNOWN')
|
|
confidence = trade_decision.get('confidence', 0.0)
|
|
net_pnl = trade_outcome.get('net_pnl', 0.0)
|
|
|
|
logger.info(f"[REWARD] {action} (conf: {confidence:.2%}) PnL: ${net_pnl:.2f} -> Total: {rewards['total']:.2f}")
|
|
logger.debug(f" Base: {rewards['base']:.2f}, Pivot: {rewards['pivot']:.2f}, CNN: {rewards['cnn']:.2f}")
|
|
logger.debug(f" Uninvested: {rewards['uninvested']:.2f}, Confidence: {rewards['confidence']:.2f}, Time: {rewards['time']:.2f}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error logging reward breakdown: {e}")
|
|
|
|
def create_enhanced_pivot_trainer(data_provider: DataProvider = None,
|
|
orchestrator: Optional["EnhancedTradingOrchestrator"] = None) -> EnhancedPivotRLTrainer:
|
|
"""Factory function to create enhanced pivot trainer"""
|
|
return EnhancedPivotRLTrainer(data_provider, orchestrator) |