pivot improvement
This commit is contained in:
@ -31,6 +31,7 @@ from .extrema_trainer import ExtremaTrainer
|
||||
from .trading_action import TradingAction
|
||||
from .negative_case_trainer import NegativeCaseTrainer
|
||||
from .trading_executor import TradingExecutor
|
||||
from training.enhanced_pivot_rl_trainer import EnhancedPivotRLTrainer, create_enhanced_pivot_trainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -157,12 +158,22 @@ class EnhancedTradingOrchestrator:
|
||||
self.current_positions = {} # symbol -> {'side': 'LONG'|'SHORT'|'FLAT', 'entry_price': float, 'timestamp': datetime}
|
||||
self.last_signals = {} # symbol -> {'action': 'BUY'|'SELL', 'timestamp': datetime, 'confidence': float}
|
||||
|
||||
# Different thresholds for entry vs exit
|
||||
self.entry_threshold = self.config.orchestrator.get('entry_threshold', 0.75) # Higher threshold for entries
|
||||
self.exit_threshold = self.config.orchestrator.get('exit_threshold', 0.35) # Lower threshold for exits
|
||||
# Initialize Enhanced Pivot RL Trainer
|
||||
self.pivot_rl_trainer = create_enhanced_pivot_trainer(
|
||||
data_provider=self.data_provider,
|
||||
orchestrator=self
|
||||
)
|
||||
|
||||
logger.info(f"Entry threshold: {self.entry_threshold:.3f} (more certain)")
|
||||
logger.info(f"Exit threshold: {self.exit_threshold:.3f} (easier to exit)")
|
||||
# Get dynamic thresholds from pivot trainer
|
||||
thresholds = self.pivot_rl_trainer.get_current_thresholds()
|
||||
self.entry_threshold = thresholds['entry_threshold'] # Higher threshold for entries
|
||||
self.exit_threshold = thresholds['exit_threshold'] # Lower threshold for exits
|
||||
self.uninvested_threshold = thresholds['uninvested_threshold'] # Stay out threshold
|
||||
|
||||
logger.info(f"Dynamic Pivot-Based Thresholds:")
|
||||
logger.info(f" Entry threshold: {self.entry_threshold:.3f} (more certain)")
|
||||
logger.info(f" Exit threshold: {self.exit_threshold:.3f} (easier to exit)")
|
||||
logger.info(f" Uninvested threshold: {self.uninvested_threshold:.3f} (stay out when uncertain)")
|
||||
|
||||
# Initialize universal data adapter
|
||||
self.universal_adapter = UniversalDataAdapter(self.data_provider)
|
||||
@ -2046,29 +2057,33 @@ class EnhancedTradingOrchestrator:
|
||||
|
||||
def _make_2_action_decision(self, symbol: str, predictions: List[EnhancedPrediction],
|
||||
market_state: MarketState) -> Optional[TradingAction]:
|
||||
"""
|
||||
Make trading decision using strict 2-action system (BUY/SELL only)
|
||||
|
||||
STRICT Logic:
|
||||
- When FLAT: BUY signal -> go LONG, SELL signal -> go SHORT
|
||||
- When LONG: SELL signal -> close LONG immediately (and optionally enter SHORT if no other positions)
|
||||
- When SHORT: BUY signal -> close SHORT immediately (and optionally enter LONG if no other positions)
|
||||
- ALWAYS close opposite positions first before opening new ones
|
||||
"""
|
||||
if not predictions:
|
||||
return None
|
||||
|
||||
"""Enhanced 2-action decision making with pivot analysis and CNN predictions"""
|
||||
try:
|
||||
# Get best prediction
|
||||
best_pred = max(predictions, key=lambda p: p.overall_confidence)
|
||||
raw_action = best_pred.overall_action
|
||||
confidence = best_pred.overall_confidence
|
||||
if not predictions:
|
||||
return None
|
||||
|
||||
# Get current position for this symbol
|
||||
current_position = self.current_positions.get(symbol, {'side': 'FLAT'})
|
||||
position_side = current_position['side']
|
||||
# Get the best prediction
|
||||
best_pred = max(predictions, key=lambda p: p.confidence)
|
||||
confidence = best_pred.confidence
|
||||
raw_action = best_pred.action
|
||||
|
||||
# STRICT LOGIC: Determine action type
|
||||
# Update dynamic thresholds periodically
|
||||
if hasattr(self, '_last_threshold_update'):
|
||||
if (datetime.now() - self._last_threshold_update).total_seconds() > 3600: # Every hour
|
||||
self.update_dynamic_thresholds()
|
||||
self._last_threshold_update = datetime.now()
|
||||
else:
|
||||
self._last_threshold_update = datetime.now()
|
||||
|
||||
# Check if we should stay uninvested due to low confidence
|
||||
if confidence < self.uninvested_threshold:
|
||||
logger.info(f"[{symbol}] Staying uninvested - confidence {confidence:.3f} below threshold {self.uninvested_threshold:.3f}")
|
||||
return None
|
||||
|
||||
# Get current position
|
||||
position_side = self._get_current_position_side(symbol)
|
||||
|
||||
# Determine if this is entry or exit
|
||||
is_entry = False
|
||||
is_exit = False
|
||||
final_action = raw_action
|
||||
@ -2098,10 +2113,29 @@ class EnhancedTradingOrchestrator:
|
||||
logger.info(f"[{symbol}] SHORT position - SELL signal ignored (already short)")
|
||||
return None
|
||||
|
||||
# Apply appropriate threshold
|
||||
# Apply appropriate threshold with CNN enhancement
|
||||
if is_entry:
|
||||
threshold = self.entry_threshold
|
||||
threshold_type = "ENTRY"
|
||||
|
||||
# For entries, check if CNN predicts favorable pivot
|
||||
if hasattr(self.pivot_rl_trainer, 'williams') and self.pivot_rl_trainer.williams.cnn_model:
|
||||
try:
|
||||
# Get market data for CNN analysis
|
||||
current_price = market_state.prices.get(self.timeframes[0], 0)
|
||||
|
||||
# CNN prediction could lower entry threshold if it predicts favorable pivot
|
||||
# This allows earlier entry before pivot is confirmed
|
||||
cnn_adjustment = self._get_cnn_threshold_adjustment(symbol, raw_action, market_state)
|
||||
adjusted_threshold = max(threshold - cnn_adjustment, threshold * 0.8) # Max 20% reduction
|
||||
|
||||
if cnn_adjustment > 0:
|
||||
logger.info(f"[{symbol}] CNN predicts favorable pivot - adjusted entry threshold: {threshold:.3f} -> {adjusted_threshold:.3f}")
|
||||
threshold = adjusted_threshold
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting CNN threshold adjustment: {e}")
|
||||
|
||||
elif is_exit:
|
||||
threshold = self.exit_threshold
|
||||
threshold_type = "EXIT"
|
||||
@ -2130,7 +2164,8 @@ class EnhancedTradingOrchestrator:
|
||||
'position_before': position_side,
|
||||
'action_type': threshold_type,
|
||||
'threshold_used': threshold,
|
||||
'strict_mode': True,
|
||||
'pivot_enhanced': True,
|
||||
'cnn_integrated': hasattr(self.pivot_rl_trainer, 'williams') and self.pivot_rl_trainer.williams.cnn_model is not None,
|
||||
'timeframe_breakdown': [(tf.timeframe, tf.action, tf.confidence)
|
||||
for tf in best_pred.timeframe_predictions],
|
||||
'market_regime': market_state.market_regime
|
||||
@ -2148,14 +2183,72 @@ class EnhancedTradingOrchestrator:
|
||||
'confidence': confidence
|
||||
}
|
||||
|
||||
logger.info(f"[{symbol}] STRICT {threshold_type} Decision: {final_action} (conf: {confidence:.3f}, threshold: {threshold:.3f})")
|
||||
logger.info(f"[{symbol}] ENHANCED {threshold_type} Decision: {final_action} (conf: {confidence:.3f}, threshold: {threshold:.3f})")
|
||||
|
||||
return action
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error making strict 2-action decision for {symbol}: {e}")
|
||||
logger.error(f"Error making enhanced 2-action decision for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _get_cnn_threshold_adjustment(self, symbol: str, action: str, market_state: MarketState) -> float:
|
||||
"""Get threshold adjustment based on CNN pivot predictions"""
|
||||
try:
|
||||
# This would analyze CNN predictions to determine if we should lower entry threshold
|
||||
# For example, if CNN predicts a swing low and we want to BUY, we can be more aggressive
|
||||
|
||||
# Placeholder implementation - in real scenario, this would:
|
||||
# 1. Get recent market data
|
||||
# 2. Run CNN prediction through Williams structure
|
||||
# 3. Check if predicted pivot aligns with our intended action
|
||||
# 4. Return threshold adjustment (0.0 to 0.1 typically)
|
||||
|
||||
# For now, return small adjustment to demonstrate concept
|
||||
if hasattr(self.pivot_rl_trainer.williams, 'cnn_model') and self.pivot_rl_trainer.williams.cnn_model:
|
||||
# CNN is available, could provide small threshold reduction for better entries
|
||||
return 0.05 # 5% threshold reduction when CNN available
|
||||
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting CNN threshold adjustment: {e}")
|
||||
return 0.0
|
||||
|
||||
def update_dynamic_thresholds(self):
|
||||
"""Update thresholds based on recent performance"""
|
||||
try:
|
||||
# Update thresholds in pivot trainer
|
||||
self.pivot_rl_trainer.update_thresholds_based_on_performance()
|
||||
|
||||
# Get updated thresholds
|
||||
thresholds = self.pivot_rl_trainer.get_current_thresholds()
|
||||
old_entry = self.entry_threshold
|
||||
old_exit = self.exit_threshold
|
||||
|
||||
self.entry_threshold = thresholds['entry_threshold']
|
||||
self.exit_threshold = thresholds['exit_threshold']
|
||||
self.uninvested_threshold = thresholds['uninvested_threshold']
|
||||
|
||||
# Log changes if significant
|
||||
if abs(old_entry - self.entry_threshold) > 0.01 or abs(old_exit - self.exit_threshold) > 0.01:
|
||||
logger.info(f"Threshold Update - Entry: {old_entry:.3f} -> {self.entry_threshold:.3f}, "
|
||||
f"Exit: {old_exit:.3f} -> {self.exit_threshold:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating dynamic thresholds: {e}")
|
||||
|
||||
def calculate_enhanced_pivot_reward(self, trade_decision: Dict[str, Any],
|
||||
market_data: pd.DataFrame,
|
||||
trade_outcome: Dict[str, Any]) -> float:
|
||||
"""Calculate reward using the enhanced pivot-based system"""
|
||||
try:
|
||||
return self.pivot_rl_trainer.calculate_pivot_based_reward(
|
||||
trade_decision, market_data, trade_outcome
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating enhanced pivot reward: {e}")
|
||||
return 0.0
|
||||
|
||||
def _update_2_action_position(self, symbol: str, action: TradingAction):
|
||||
"""Update position tracking for strict 2-action system"""
|
||||
try:
|
||||
|
Reference in New Issue
Block a user