From 0cc104f1ef95a1a5592976ac1b6f296200d937ef Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Wed, 23 Jul 2025 00:48:14 +0300 Subject: [PATCH] wip cob --- core/orchestrator.py | 84 ++++++++++++++++++++++++++++++++++++++-- core/trading_executor.py | 58 ++++++++++++++++----------- 2 files changed, 115 insertions(+), 27 deletions(-) diff --git a/core/orchestrator.py b/core/orchestrator.py index ecbec58..4afb30d 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -948,6 +948,12 @@ class TradingOrchestrator: rl_prediction = await self._get_rl_prediction(model, symbol) if rl_prediction: predictions.append(rl_prediction) + + elif isinstance(model, COBRLModelInterface): + # Get COB RL prediction + cob_prediction = await self._get_cob_rl_prediction(model, symbol) + if cob_prediction: + predictions.append(cob_prediction) else: # Generic model interface @@ -1007,6 +1013,19 @@ class TradingOrchestrator: logger.debug(f"Could not enhance CNN features with COB data: {cob_error}") enhanced_features = feature_matrix + # Add extrema features if available + if self.extrema_trainer: + try: + extrema_features = self.extrema_trainer.get_context_features_for_model(symbol) + if extrema_features is not None: + # Reshape and tile to match the enhanced_features shape + extrema_features = extrema_features.flatten() + tiled_extrema = np.tile(extrema_features, (enhanced_features.shape[0], enhanced_features.shape[1], 1)) + enhanced_features = np.concatenate([enhanced_features, tiled_extrema], axis=2) + logger.debug(f"Enhanced CNN features with Extrema data for {symbol}") + except Exception as extrema_error: + logger.debug(f"Could not enhance CNN features with Extrema data: {extrema_error}") + if enhanced_features is not None: # Get CNN prediction - use the actual underlying model try: @@ -1219,9 +1238,35 @@ class TradingOrchestrator: # Shape: (n_timeframes, window_size, n_features) -> (n_timeframes * window_size * n_features,) state = feature_matrix.flatten() - # Add additional state information (position, balance, etc.) - # This would come from a portfolio manager in a real implementation - additional_state = np.array([0.0, 1.0, 0.0]) # [position, balance, unrealized_pnl] + # Add extrema features if available + if self.extrema_trainer: + try: + extrema_features = self.extrema_trainer.get_context_features_for_model(symbol) + if extrema_features is not None: + state = np.concatenate([state, extrema_features.flatten()]) + logger.debug(f"Enhanced RL state with Extrema data for {symbol}") + except Exception as extrema_error: + logger.debug(f"Could not enhance RL state with Extrema data: {extrema_error}") + + # Get real-time portfolio information from the trading executor + position_size = 0.0 + balance = 1.0 # Default to a normalized value if not available + unrealized_pnl = 0.0 + + if self.trading_executor: + position = self.trading_executor.get_current_position(symbol) + if position: + position_size = position.get('quantity', 0.0) + + # Normalize balance or use a realistic value + current_balance = self.trading_executor.get_balance() + if current_balance and current_balance.get('total', 0) > 0: + # Simple normalization - can be improved + balance = min(1.0, current_balance.get('free', 0) / current_balance.get('total', 1)) + + unrealized_pnl = self._get_current_position_pnl(symbol, self.data_provider.get_current_price(symbol)) + + additional_state = np.array([position_size, balance, unrealized_pnl]) return np.concatenate([state, additional_state]) @@ -1955,4 +2000,35 @@ class TradingOrchestrator: } self.recent_cnn_predictions[symbol].append(prediction_data) except Exception as e: - logger.debug(f"Error capturing CNN prediction: {e}") \ No newline at end of file + logger.debug(f"Error capturing CNN prediction: {e}") + + async def _get_cob_rl_prediction(self, model: COBRLModelInterface, symbol: str) -> Optional[Prediction]: + """Get prediction from COB RL model""" + try: + cob_feature_matrix = self.get_cob_feature_matrix(symbol, sequence_length=1) + if cob_feature_matrix is None: + return None + + # The model expects a 1D array of features + cob_features = cob_feature_matrix.flatten() + + prediction_result = model.predict(cob_features) + + if prediction_result: + direction_map = {0: 'SELL', 1: 'HOLD', 2: 'BUY'} + action = direction_map.get(prediction_result['predicted_direction'], 'HOLD') + + prediction = Prediction( + action=action, + confidence=float(prediction_result['confidence']), + probabilities={direction_map.get(i, 'HOLD'): float(prob) for i, prob in enumerate(prediction_result['probabilities'])}, + timeframe='cob', + timestamp=datetime.now(), + model_name=model.name, + metadata={'value': prediction_result['value']} + ) + return prediction + return None + except Exception as e: + logger.error(f"Error getting COB RL prediction: {e}") + return None \ No newline at end of file diff --git a/core/trading_executor.py b/core/trading_executor.py index fae776f..93b6c52 100644 --- a/core/trading_executor.py +++ b/core/trading_executor.py @@ -59,6 +59,7 @@ class TradeRecord: fees: float confidence: float hold_time_seconds: float = 0.0 # Hold time in seconds + leverage: float = 1.0 # Leverage applied to this trade class TradingExecutor: """Handles trade execution through MEXC API with risk management""" @@ -344,7 +345,8 @@ class TradingExecutor: logger.info(f"SIMULATION MODE ({self.trading_mode.upper()}) - Trade logged but not executed") # Calculate simulated fees in simulation mode taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006) - simulated_fees = quantity * current_price * taker_fee_rate + current_leverage = self.get_leverage() + simulated_fees = quantity * current_price * taker_fee_rate * current_leverage # Create mock position for tracking self.positions[symbol] = Position( @@ -391,7 +393,8 @@ class TradingExecutor: if order: # Calculate simulated fees in simulation mode taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006) - simulated_fees = quantity * current_price * taker_fee_rate + current_leverage = self.get_leverage() + simulated_fees = quantity * current_price * taker_fee_rate * current_leverage # Create position record self.positions[symbol] = Position( @@ -424,6 +427,7 @@ class TradingExecutor: return self._execute_short(symbol, confidence, current_price) position = self.positions[symbol] + current_leverage = self.get_leverage() logger.info(f"Executing SELL: {position.quantity:.6f} {symbol} at ${current_price:.2f} " f"(confidence: {confidence:.2f}) [{'SIMULATION' if self.simulation_mode else 'LIVE'}]") @@ -431,13 +435,13 @@ class TradingExecutor: if self.simulation_mode: logger.info(f"SIMULATION MODE ({self.trading_mode.upper()}) - Trade logged but not executed") # Calculate P&L and hold time - pnl = position.calculate_pnl(current_price) + pnl = position.calculate_pnl(current_price) * current_leverage # Apply leverage to PnL exit_time = datetime.now() hold_time_seconds = (exit_time - position.entry_time).total_seconds() # Calculate simulated fees in simulation mode taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006) - simulated_fees = position.quantity * current_price * taker_fee_rate + simulated_fees = position.quantity * current_price * taker_fee_rate * current_leverage # Apply leverage to fees # Create trade record trade_record = TradeRecord( @@ -448,14 +452,15 @@ class TradingExecutor: exit_price=current_price, entry_time=position.entry_time, exit_time=exit_time, - pnl=pnl, + pnl=pnl - simulated_fees, fees=simulated_fees, confidence=confidence, - hold_time_seconds=hold_time_seconds + hold_time_seconds=hold_time_seconds, + leverage=current_leverage # Store leverage ) self.trade_history.append(trade_record) - self.daily_loss += max(0, -pnl) # Add to daily loss if negative + self.daily_loss += max(0, -(pnl - simulated_fees)) # Add to daily loss if negative # Update consecutive losses if pnl < -0.001: # A losing trade @@ -470,7 +475,7 @@ class TradingExecutor: self.last_trade_time[symbol] = datetime.now() self.daily_trades += 1 - logger.info(f"Position closed - P&L: ${pnl:.2f}") + logger.info(f"Position closed - P&L: ${pnl - simulated_fees:.2f}") return True try: @@ -505,10 +510,10 @@ class TradingExecutor: if order: # Calculate simulated fees in simulation mode taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006) - simulated_fees = position.quantity * current_price * taker_fee_rate + simulated_fees = position.quantity * current_price * taker_fee_rate * current_leverage # Apply leverage # Calculate P&L, fees, and hold time - pnl = position.calculate_pnl(current_price) + pnl = position.calculate_pnl(current_price) * current_leverage # Apply leverage to PnL fees = simulated_fees exit_time = datetime.now() hold_time_seconds = (exit_time - position.entry_time).total_seconds() @@ -525,7 +530,8 @@ class TradingExecutor: pnl=pnl - fees, fees=fees, confidence=confidence, - hold_time_seconds=hold_time_seconds + hold_time_seconds=hold_time_seconds, + leverage=current_leverage # Store leverage ) self.trade_history.append(trade_record) @@ -574,7 +580,8 @@ class TradingExecutor: logger.info(f"SIMULATION MODE ({self.trading_mode.upper()}) - Short position logged but not executed") # Calculate simulated fees in simulation mode taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006) - simulated_fees = quantity * current_price * taker_fee_rate + current_leverage = self.get_leverage() + simulated_fees = quantity * current_price * taker_fee_rate * current_leverage # Create mock short position for tracking self.positions[symbol] = Position( @@ -621,7 +628,8 @@ class TradingExecutor: if order: # Calculate simulated fees in simulation mode taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006) - simulated_fees = quantity * current_price * taker_fee_rate + current_leverage = self.get_leverage() + simulated_fees = quantity * current_price * taker_fee_rate * current_leverage # Create short position record self.positions[symbol] = Position( @@ -653,6 +661,8 @@ class TradingExecutor: return False position = self.positions[symbol] + current_leverage = self.get_leverage() # Get current leverage + if position.side != 'SHORT': logger.warning(f"Position in {symbol} is not SHORT, cannot close with BUY") return False @@ -664,10 +674,10 @@ class TradingExecutor: logger.info(f"SIMULATION MODE ({self.trading_mode.upper()}) - Short close logged but not executed") # Calculate simulated fees in simulation mode taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006) - simulated_fees = position.quantity * current_price * taker_fee_rate + simulated_fees = position.quantity * current_price * taker_fee_rate * current_leverage # Calculate P&L for short position and hold time - pnl = position.calculate_pnl(current_price) + pnl = position.calculate_pnl(current_price) * current_leverage # Apply leverage to PnL exit_time = datetime.now() hold_time_seconds = (exit_time - position.entry_time).total_seconds() @@ -680,21 +690,22 @@ class TradingExecutor: exit_price=current_price, entry_time=position.entry_time, exit_time=exit_time, - pnl=pnl, + pnl=pnl - simulated_fees, fees=simulated_fees, confidence=confidence, - hold_time_seconds=hold_time_seconds + hold_time_seconds=hold_time_seconds, + leverage=current_leverage # Store leverage ) self.trade_history.append(trade_record) - self.daily_loss += max(0, -pnl) # Add to daily loss if negative + self.daily_loss += max(0, -(pnl - simulated_fees)) # Add to daily loss if negative # Remove position del self.positions[symbol] self.last_trade_time[symbol] = datetime.now() self.daily_trades += 1 - logger.info(f"SHORT position closed - P&L: ${pnl:.2f}") + logger.info(f"SHORT position closed - P&L: ${pnl - simulated_fees:.2f}") return True try: @@ -729,10 +740,10 @@ class TradingExecutor: if order: # Calculate simulated fees in simulation mode taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006) - simulated_fees = position.quantity * current_price * taker_fee_rate + simulated_fees = position.quantity * current_price * taker_fee_rate * current_leverage # Calculate P&L, fees, and hold time - pnl = position.calculate_pnl(current_price) + pnl = position.calculate_pnl(current_price) * current_leverage # Apply leverage to PnL fees = simulated_fees exit_time = datetime.now() hold_time_seconds = (exit_time - position.entry_time).total_seconds() @@ -749,7 +760,8 @@ class TradingExecutor: pnl=pnl - fees, fees=fees, confidence=confidence, - hold_time_seconds=hold_time_seconds + hold_time_seconds=hold_time_seconds, + leverage=current_leverage # Store leverage ) self.trade_history.append(trade_record) @@ -875,7 +887,7 @@ class TradingExecutor: 'losing_trades': losing_trades, 'breakeven_trades': breakeven_trades, 'total_trades': total_trades, - 'win_rate': winning_trades / max(1, total_trades), + 'win_rate': winning_trades / max(1, winning_trades + losing_trades) if (winning_trades + losing_trades) > 0 else 0.0, 'avg_trade_pnl': avg_trade_pnl, 'avg_trade_fee': avg_trade_fee, 'avg_winning_trade': avg_winning_trade,