pivot improvement
This commit is contained in:
584
training/enhanced_pivot_rl_trainer.py
Normal file
584
training/enhanced_pivot_rl_trainer.py
Normal file
@ -0,0 +1,584 @@
|
||||
"""
|
||||
Enhanced Pivot-Based RL Trainer
|
||||
|
||||
Integrates Williams Market Structure pivot points with CNN predictions
|
||||
for improved trading decisions and training rewards.
|
||||
|
||||
Key Features:
|
||||
- Train RL model to buy/sell at local pivot points
|
||||
- CNN predicts next pivot to avoid late signals
|
||||
- Different thresholds for entry vs exit
|
||||
- Rewards for staying uninvested when uncertain
|
||||
- Uncertainty-based confidence adjustment
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from collections import deque, namedtuple
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union, TYPE_CHECKING
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from training.williams_market_structure import WilliamsMarketStructure, SwingType, SwingPoint
|
||||
|
||||
# Use TYPE_CHECKING to avoid circular import
|
||||
if TYPE_CHECKING:
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PivotReward:
|
||||
"""Reward structure for pivot-based trading decisions"""
|
||||
|
||||
def __init__(self):
|
||||
# Pivot-based reward weights
|
||||
self.pivot_hit_bonus = 2.0 # Bonus for trading at actual pivot points
|
||||
self.pivot_anticipation_bonus = 1.5 # Bonus for trading before pivot (CNN prediction)
|
||||
self.wrong_direction_penalty = -1.0 # Penalty for trading opposite to pivot direction
|
||||
self.late_entry_penalty = -0.5 # Penalty for entering after pivot is confirmed
|
||||
|
||||
# Stay uninvested rewards
|
||||
self.uninvested_reward = 0.1 # Small positive reward for staying out of poor setups
|
||||
self.avoid_false_signal_bonus = 0.5 # Bonus for avoiding false signals
|
||||
|
||||
# Uncertainty penalties
|
||||
self.overconfidence_penalty = -0.3 # Penalty for being overconfident on losses
|
||||
self.underconfidence_penalty = -0.1 # Small penalty for being underconfident on wins
|
||||
|
||||
class EnhancedPivotRLTrainer:
|
||||
"""Enhanced RL trainer focused on Williams pivot points and CNN predictions"""
|
||||
|
||||
def __init__(self,
|
||||
data_provider: DataProvider = None,
|
||||
orchestrator: Optional["EnhancedTradingOrchestrator"] = None):
|
||||
|
||||
self.config = get_config()
|
||||
self.data_provider = data_provider or DataProvider()
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
# Initialize Williams Market Structure with CNN
|
||||
self.williams = WilliamsMarketStructure(
|
||||
swing_strengths=[2, 4, 6, 8, 10], # Multiple strengths for better detection
|
||||
enable_cnn_feature=True,
|
||||
training_data_provider=data_provider
|
||||
)
|
||||
|
||||
# Pivot tracking
|
||||
self.recent_pivots = deque(maxlen=50)
|
||||
self.pivot_predictions = deque(maxlen=20)
|
||||
self.trade_outcomes = deque(maxlen=100)
|
||||
|
||||
# Threshold management - different for entry vs exit
|
||||
self.entry_threshold = 0.65 # Higher threshold for entering positions
|
||||
self.exit_threshold = 0.35 # Lower threshold for exiting positions
|
||||
self.max_uninvested_reward_threshold = 0.60 # Stay out if confidence below this
|
||||
|
||||
# Confidence learning parameters
|
||||
self.confidence_history = deque(maxlen=200)
|
||||
self.mistake_severity_tracker = deque(maxlen=50)
|
||||
|
||||
# Reward calculator
|
||||
self.pivot_reward = PivotReward()
|
||||
|
||||
logger.info("Enhanced Pivot RL Trainer initialized")
|
||||
logger.info(f"Entry threshold: {self.entry_threshold:.2%}")
|
||||
logger.info(f"Exit threshold: {self.exit_threshold:.2%}")
|
||||
logger.info(f"Uninvested reward threshold: {self.max_uninvested_reward_threshold:.2%}")
|
||||
|
||||
def calculate_pivot_based_reward(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
market_data: pd.DataFrame,
|
||||
trade_outcome: Dict[str, Any]) -> float:
|
||||
"""
|
||||
Calculate enhanced reward based on pivot points and CNN predictions
|
||||
|
||||
Args:
|
||||
trade_decision: The trading decision made by the model
|
||||
market_data: Market data context
|
||||
trade_outcome: Actual trade outcome
|
||||
|
||||
Returns:
|
||||
Enhanced reward score
|
||||
"""
|
||||
try:
|
||||
base_pnl = trade_outcome.get('net_pnl', 0.0)
|
||||
confidence = trade_decision.get('confidence', 0.5)
|
||||
action = trade_decision.get('action', 'HOLD')
|
||||
entry_price = trade_decision.get('price', 0.0)
|
||||
exit_price = trade_outcome.get('exit_price', entry_price)
|
||||
duration = trade_outcome.get('duration', timedelta(0))
|
||||
|
||||
# Base PnL reward
|
||||
base_reward = base_pnl / 5.0
|
||||
|
||||
# 1. Pivot Point Analysis Rewards
|
||||
pivot_reward = self._calculate_pivot_rewards(
|
||||
trade_decision, market_data, trade_outcome
|
||||
)
|
||||
|
||||
# 2. CNN Prediction Accuracy Rewards
|
||||
cnn_reward = self._calculate_cnn_prediction_rewards(
|
||||
trade_decision, market_data, trade_outcome
|
||||
)
|
||||
|
||||
# 3. Uninvested Period Rewards
|
||||
uninvested_reward = self._calculate_uninvested_rewards(
|
||||
trade_decision, confidence
|
||||
)
|
||||
|
||||
# 4. Uncertainty-based Confidence Adjustment
|
||||
confidence_adjustment = self._calculate_confidence_adjustment(
|
||||
trade_decision, trade_outcome
|
||||
)
|
||||
|
||||
# 5. Time efficiency with pivot context
|
||||
time_reward = self._calculate_time_efficiency_reward(
|
||||
duration, base_pnl, market_data
|
||||
)
|
||||
|
||||
# Combine all rewards
|
||||
total_reward = (
|
||||
base_reward +
|
||||
pivot_reward +
|
||||
cnn_reward +
|
||||
uninvested_reward +
|
||||
confidence_adjustment +
|
||||
time_reward
|
||||
)
|
||||
|
||||
# Log detailed reward breakdown
|
||||
self._log_reward_breakdown(
|
||||
trade_decision, trade_outcome, {
|
||||
'base': base_reward,
|
||||
'pivot': pivot_reward,
|
||||
'cnn': cnn_reward,
|
||||
'uninvested': uninvested_reward,
|
||||
'confidence': confidence_adjustment,
|
||||
'time': time_reward,
|
||||
'total': total_reward
|
||||
}
|
||||
)
|
||||
|
||||
# Track for learning
|
||||
self._track_reward_outcome(trade_decision, trade_outcome, total_reward)
|
||||
|
||||
return np.clip(total_reward, -15.0, 10.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating pivot-based reward: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_pivot_rewards(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
market_data: pd.DataFrame,
|
||||
trade_outcome: Dict[str, Any]) -> float:
|
||||
"""Calculate rewards based on proximity to pivot points"""
|
||||
try:
|
||||
entry_price = trade_decision.get('price', 0.0)
|
||||
action = trade_decision.get('action', 'HOLD')
|
||||
entry_time = trade_decision.get('timestamp', datetime.now())
|
||||
net_pnl = trade_outcome.get('net_pnl', 0.0)
|
||||
|
||||
# Find recent pivot points from Williams analysis
|
||||
ohlcv_array = self._convert_dataframe_to_ohlcv_array(market_data)
|
||||
if ohlcv_array is None or len(ohlcv_array) < 20:
|
||||
return 0.0
|
||||
|
||||
# Get pivot points from Williams structure
|
||||
structure_levels = self.williams.calculate_recursive_pivot_points(ohlcv_array)
|
||||
if not structure_levels or 'level_0' not in structure_levels:
|
||||
return 0.0
|
||||
|
||||
level_0_pivots = structure_levels['level_0'].swing_points
|
||||
if not level_0_pivots:
|
||||
return 0.0
|
||||
|
||||
# Find closest pivot to entry
|
||||
closest_pivot = self._find_closest_pivot(entry_price, entry_time, level_0_pivots)
|
||||
if not closest_pivot:
|
||||
return 0.0
|
||||
|
||||
# Calculate distance to pivot (price and time)
|
||||
price_distance = abs(entry_price - closest_pivot.price) / closest_pivot.price
|
||||
time_distance = abs((entry_time - closest_pivot.timestamp).total_seconds()) / 3600.0 # hours
|
||||
|
||||
pivot_reward = 0.0
|
||||
|
||||
# Reward trading at or near pivot points
|
||||
if price_distance < 0.005: # Within 0.5% of pivot
|
||||
if time_distance < 0.5: # Within 30 minutes
|
||||
pivot_reward += self.pivot_reward.pivot_hit_bonus
|
||||
logger.debug(f"PIVOT HIT BONUS: {self.pivot_reward.pivot_hit_bonus:.2f}")
|
||||
|
||||
# Check if trade direction aligns with pivot
|
||||
if self._trade_aligns_with_pivot(action, closest_pivot, net_pnl):
|
||||
pivot_reward += self.pivot_reward.pivot_anticipation_bonus
|
||||
logger.debug(f"PIVOT DIRECTION BONUS: {self.pivot_reward.pivot_anticipation_bonus:.2f}")
|
||||
else:
|
||||
pivot_reward += self.pivot_reward.wrong_direction_penalty
|
||||
logger.debug(f"WRONG DIRECTION PENALTY: {self.pivot_reward.wrong_direction_penalty:.2f}")
|
||||
|
||||
# Penalty for late entry after pivot confirmation
|
||||
if time_distance > 2.0: # More than 2 hours after pivot
|
||||
pivot_reward += self.pivot_reward.late_entry_penalty
|
||||
logger.debug(f"LATE ENTRY PENALTY: {self.pivot_reward.late_entry_penalty:.2f}")
|
||||
|
||||
return pivot_reward
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating pivot rewards: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_cnn_prediction_rewards(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
market_data: pd.DataFrame,
|
||||
trade_outcome: Dict[str, Any]) -> float:
|
||||
"""Calculate rewards based on CNN pivot predictions"""
|
||||
try:
|
||||
# Check if we have CNN predictions available
|
||||
if not hasattr(self.williams, 'cnn_model') or not self.williams.cnn_model:
|
||||
return 0.0
|
||||
|
||||
action = trade_decision.get('action', 'HOLD')
|
||||
confidence = trade_decision.get('confidence', 0.5)
|
||||
net_pnl = trade_outcome.get('net_pnl', 0.0)
|
||||
|
||||
# Get latest CNN prediction if available
|
||||
# This would be the prediction made before the trade
|
||||
cnn_prediction = self._get_latest_cnn_prediction()
|
||||
if not cnn_prediction:
|
||||
return 0.0
|
||||
|
||||
cnn_reward = 0.0
|
||||
|
||||
# Reward for following CNN predictions that turn out correct
|
||||
predicted_direction = self._interpret_cnn_prediction(cnn_prediction)
|
||||
|
||||
if predicted_direction == action and net_pnl > 0:
|
||||
# CNN prediction was correct and we followed it
|
||||
cnn_reward += 1.0 * confidence # Scale by confidence
|
||||
logger.debug(f"CNN CORRECT FOLLOW: +{1.0 * confidence:.2f}")
|
||||
|
||||
elif predicted_direction != action and net_pnl < 0:
|
||||
# We didn't follow CNN and it was right (we were wrong)
|
||||
cnn_reward -= 0.5
|
||||
logger.debug(f"CNN IGNORE PENALTY: -0.5")
|
||||
|
||||
elif predicted_direction == action and net_pnl < 0:
|
||||
# We followed CNN but it was wrong
|
||||
cnn_reward -= 0.2 # Small penalty, CNN predictions can be wrong
|
||||
logger.debug(f"CNN WRONG FOLLOW: -0.2")
|
||||
|
||||
return cnn_reward
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating CNN prediction rewards: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_uninvested_rewards(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
confidence: float) -> float:
|
||||
"""Calculate rewards for staying uninvested when uncertain"""
|
||||
try:
|
||||
action = trade_decision.get('action', 'HOLD')
|
||||
|
||||
# Reward staying out when confidence is low
|
||||
if action == 'HOLD' and confidence < self.max_uninvested_reward_threshold:
|
||||
uninvested_reward = self.pivot_reward.uninvested_reward
|
||||
|
||||
# Bonus for avoiding very uncertain setups
|
||||
if confidence < 0.4:
|
||||
uninvested_reward += self.pivot_reward.avoid_false_signal_bonus
|
||||
logger.debug(f"AVOID FALSE SIGNAL BONUS: +{self.pivot_reward.avoid_false_signal_bonus:.2f}")
|
||||
|
||||
logger.debug(f"UNINVESTED REWARD: +{uninvested_reward:.2f}")
|
||||
return uninvested_reward
|
||||
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating uninvested rewards: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_confidence_adjustment(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
trade_outcome: Dict[str, Any]) -> float:
|
||||
"""Adjust rewards based on confidence vs outcome to reduce overconfidence"""
|
||||
try:
|
||||
confidence = trade_decision.get('confidence', 0.5)
|
||||
net_pnl = trade_outcome.get('net_pnl', 0.0)
|
||||
|
||||
confidence_adjustment = 0.0
|
||||
|
||||
# Track mistake severity
|
||||
mistake_severity = abs(net_pnl) if net_pnl < 0 else 0.0
|
||||
self.mistake_severity_tracker.append(mistake_severity)
|
||||
|
||||
# Penalize overconfidence on losses
|
||||
if net_pnl < 0 and confidence > 0.7:
|
||||
# High confidence but loss - penalize overconfidence
|
||||
overconfidence_factor = (confidence - 0.7) / 0.3 # 0-1 scale
|
||||
severity_factor = min(mistake_severity / 2.0, 1.0) # Scale by loss size
|
||||
|
||||
penalty = self.pivot_reward.overconfidence_penalty * overconfidence_factor * severity_factor
|
||||
confidence_adjustment += penalty
|
||||
|
||||
logger.debug(f"OVERCONFIDENCE PENALTY: {penalty:.2f} (conf: {confidence:.2f}, loss: ${net_pnl:.2f})")
|
||||
|
||||
# Small penalty for underconfidence on wins
|
||||
elif net_pnl > 0 and confidence < 0.4:
|
||||
underconfidence_factor = (0.4 - confidence) / 0.4 # 0-1 scale
|
||||
penalty = self.pivot_reward.underconfidence_penalty * underconfidence_factor
|
||||
confidence_adjustment += penalty
|
||||
|
||||
logger.debug(f"UNDERCONFIDENCE PENALTY: {penalty:.2f} (conf: {confidence:.2f}, profit: ${net_pnl:.2f})")
|
||||
|
||||
# Update confidence learning
|
||||
self._update_confidence_learning(confidence, net_pnl, mistake_severity)
|
||||
|
||||
return confidence_adjustment
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating confidence adjustment: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_time_efficiency_reward(self,
|
||||
duration: timedelta,
|
||||
net_pnl: float,
|
||||
market_data: pd.DataFrame) -> float:
|
||||
"""Calculate time-based rewards considering market context"""
|
||||
try:
|
||||
duration_hours = duration.total_seconds() / 3600.0
|
||||
|
||||
# Quick profitable trades get bonus
|
||||
if net_pnl > 0 and duration_hours < 0.5: # Less than 30 minutes
|
||||
return 0.3
|
||||
|
||||
# Holding losses too long gets penalty
|
||||
elif net_pnl < 0 and duration_hours > 2.0: # More than 2 hours
|
||||
return -0.5
|
||||
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating time efficiency reward: {e}")
|
||||
return 0.0
|
||||
|
||||
def update_thresholds_based_on_performance(self):
|
||||
"""Dynamically adjust entry/exit thresholds based on recent performance"""
|
||||
try:
|
||||
if len(self.trade_outcomes) < 20:
|
||||
return
|
||||
|
||||
recent_outcomes = list(self.trade_outcomes)[-20:]
|
||||
|
||||
# Calculate win rate and average PnL
|
||||
wins = sum(1 for outcome in recent_outcomes if outcome['net_pnl'] > 0)
|
||||
win_rate = wins / len(recent_outcomes)
|
||||
avg_pnl = np.mean([outcome['net_pnl'] for outcome in recent_outcomes])
|
||||
|
||||
# Adjust thresholds based on performance
|
||||
if win_rate < 0.4: # Low win rate - be more selective
|
||||
self.entry_threshold = min(self.entry_threshold + 0.02, 0.80)
|
||||
logger.info(f"Low win rate ({win_rate:.2%}) - increased entry threshold to {self.entry_threshold:.2%}")
|
||||
|
||||
elif win_rate > 0.6 and avg_pnl > 0: # High win rate - can be more aggressive
|
||||
self.entry_threshold = max(self.entry_threshold - 0.01, 0.50)
|
||||
logger.info(f"High win rate ({win_rate:.2%}) - decreased entry threshold to {self.entry_threshold:.2%}")
|
||||
|
||||
# Adjust exit threshold based on loss severity
|
||||
avg_loss_severity = np.mean(list(self.mistake_severity_tracker)) if self.mistake_severity_tracker else 0
|
||||
|
||||
if avg_loss_severity > 1.0: # Large average losses
|
||||
self.exit_threshold = max(self.exit_threshold - 0.01, 0.20)
|
||||
logger.info(f"High loss severity - decreased exit threshold to {self.exit_threshold:.2%}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating thresholds: {e}")
|
||||
|
||||
def get_current_thresholds(self) -> Dict[str, float]:
|
||||
"""Get current entry and exit thresholds"""
|
||||
return {
|
||||
'entry_threshold': self.entry_threshold,
|
||||
'exit_threshold': self.exit_threshold,
|
||||
'uninvested_threshold': self.max_uninvested_reward_threshold
|
||||
}
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _convert_dataframe_to_ohlcv_array(self, df: pd.DataFrame) -> Optional[np.ndarray]:
|
||||
"""Convert pandas DataFrame to numpy array for Williams analysis"""
|
||||
try:
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
# Ensure we have required columns
|
||||
required_cols = ['open', 'high', 'low', 'close', 'volume']
|
||||
if not all(col in df.columns for col in required_cols):
|
||||
return None
|
||||
|
||||
# Convert to numpy array
|
||||
timestamps = df.index.astype(np.int64) // 10**9 # Convert to Unix timestamp
|
||||
ohlcv_array = np.column_stack([
|
||||
timestamps,
|
||||
df['open'].values,
|
||||
df['high'].values,
|
||||
df['low'].values,
|
||||
df['close'].values,
|
||||
df['volume'].values
|
||||
])
|
||||
|
||||
return ohlcv_array.astype(np.float64)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting DataFrame to OHLCV array: {e}")
|
||||
return None
|
||||
|
||||
def _find_closest_pivot(self,
|
||||
entry_price: float,
|
||||
entry_time: datetime,
|
||||
pivots: List[SwingPoint]) -> Optional[SwingPoint]:
|
||||
"""Find the closest pivot point to the trade entry"""
|
||||
try:
|
||||
if not pivots:
|
||||
return None
|
||||
|
||||
# Find pivot closest in time and price
|
||||
best_pivot = None
|
||||
best_score = float('inf')
|
||||
|
||||
for pivot in pivots:
|
||||
time_diff = abs((entry_time - pivot.timestamp).total_seconds()) / 3600.0
|
||||
price_diff = abs(entry_price - pivot.price) / pivot.price
|
||||
|
||||
# Combined score (weighted by time and price proximity)
|
||||
score = time_diff * 0.3 + price_diff * 100 # Weight price difference more heavily
|
||||
|
||||
if score < best_score:
|
||||
best_score = score
|
||||
best_pivot = pivot
|
||||
|
||||
return best_pivot
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding closest pivot: {e}")
|
||||
return None
|
||||
|
||||
def _trade_aligns_with_pivot(self,
|
||||
action: str,
|
||||
pivot: SwingPoint,
|
||||
net_pnl: float) -> bool:
|
||||
"""Check if trade direction aligns with pivot type and was profitable"""
|
||||
try:
|
||||
if net_pnl <= 0: # Only consider profitable trades as aligned
|
||||
return False
|
||||
|
||||
if action == 'BUY' and pivot.swing_type == SwingType.SWING_LOW:
|
||||
return True # Bought at/near swing low
|
||||
elif action == 'SELL' and pivot.swing_type == SwingType.SWING_HIGH:
|
||||
return True # Sold at/near swing high
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking trade alignment: {e}")
|
||||
return False
|
||||
|
||||
def _get_latest_cnn_prediction(self) -> Optional[np.ndarray]:
|
||||
"""Get the latest CNN prediction from Williams structure"""
|
||||
try:
|
||||
# This would access the Williams CNN model's latest prediction
|
||||
# For now, return None if not available
|
||||
if hasattr(self.williams, 'latest_cnn_prediction'):
|
||||
return self.williams.latest_cnn_prediction
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting CNN prediction: {e}")
|
||||
return None
|
||||
|
||||
def _interpret_cnn_prediction(self, prediction: np.ndarray) -> str:
|
||||
"""Interpret CNN prediction array to trading action"""
|
||||
try:
|
||||
if len(prediction) < 2:
|
||||
return 'HOLD'
|
||||
|
||||
# Assuming prediction format: [type, price] for level 0
|
||||
predicted_type = prediction[0] # 0 = LOW, 1 = HIGH
|
||||
|
||||
if predicted_type > 0.5:
|
||||
return 'SELL' # Expecting swing high - sell
|
||||
else:
|
||||
return 'BUY' # Expecting swing low - buy
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error interpreting CNN prediction: {e}")
|
||||
return 'HOLD'
|
||||
|
||||
def _update_confidence_learning(self,
|
||||
confidence: float,
|
||||
net_pnl: float,
|
||||
mistake_severity: float):
|
||||
"""Update confidence learning parameters"""
|
||||
try:
|
||||
self.confidence_history.append({
|
||||
'confidence': confidence,
|
||||
'net_pnl': net_pnl,
|
||||
'mistake_severity': mistake_severity,
|
||||
'timestamp': datetime.now()
|
||||
})
|
||||
|
||||
# Periodically update thresholds based on confidence patterns
|
||||
if len(self.confidence_history) % 10 == 0:
|
||||
self.update_thresholds_based_on_performance()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating confidence learning: {e}")
|
||||
|
||||
def _track_reward_outcome(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
trade_outcome: Dict[str, Any],
|
||||
total_reward: float):
|
||||
"""Track reward outcomes for analysis"""
|
||||
try:
|
||||
outcome_record = {
|
||||
'timestamp': datetime.now(),
|
||||
'action': trade_decision.get('action'),
|
||||
'confidence': trade_decision.get('confidence'),
|
||||
'net_pnl': trade_outcome.get('net_pnl'),
|
||||
'reward': total_reward,
|
||||
'duration': trade_outcome.get('duration')
|
||||
}
|
||||
|
||||
self.trade_outcomes.append(outcome_record)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error tracking reward outcome: {e}")
|
||||
|
||||
def _log_reward_breakdown(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
trade_outcome: Dict[str, Any],
|
||||
rewards: Dict[str, float]):
|
||||
"""Log detailed reward breakdown"""
|
||||
try:
|
||||
action = trade_decision.get('action', 'UNKNOWN')
|
||||
confidence = trade_decision.get('confidence', 0.0)
|
||||
net_pnl = trade_outcome.get('net_pnl', 0.0)
|
||||
|
||||
logger.info(f"[REWARD] {action} (conf: {confidence:.2%}) PnL: ${net_pnl:.2f} -> Total: {rewards['total']:.2f}")
|
||||
logger.debug(f" Base: {rewards['base']:.2f}, Pivot: {rewards['pivot']:.2f}, CNN: {rewards['cnn']:.2f}")
|
||||
logger.debug(f" Uninvested: {rewards['uninvested']:.2f}, Confidence: {rewards['confidence']:.2f}, Time: {rewards['time']:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging reward breakdown: {e}")
|
||||
|
||||
def create_enhanced_pivot_trainer(data_provider: DataProvider = None,
|
||||
orchestrator: Optional["EnhancedTradingOrchestrator"] = None) -> EnhancedPivotRLTrainer:
|
||||
"""Factory function to create enhanced pivot trainer"""
|
||||
return EnhancedPivotRLTrainer(data_provider, orchestrator)
|
Reference in New Issue
Block a user