This commit is contained in:
Dobromir Popov
2025-07-23 00:48:14 +03:00
parent 8898f71832
commit 0cc104f1ef
2 changed files with 115 additions and 27 deletions

View File

@ -948,6 +948,12 @@ class TradingOrchestrator:
rl_prediction = await self._get_rl_prediction(model, symbol)
if rl_prediction:
predictions.append(rl_prediction)
elif isinstance(model, COBRLModelInterface):
# Get COB RL prediction
cob_prediction = await self._get_cob_rl_prediction(model, symbol)
if cob_prediction:
predictions.append(cob_prediction)
else:
# Generic model interface
@ -1007,6 +1013,19 @@ class TradingOrchestrator:
logger.debug(f"Could not enhance CNN features with COB data: {cob_error}")
enhanced_features = feature_matrix
# Add extrema features if available
if self.extrema_trainer:
try:
extrema_features = self.extrema_trainer.get_context_features_for_model(symbol)
if extrema_features is not None:
# Reshape and tile to match the enhanced_features shape
extrema_features = extrema_features.flatten()
tiled_extrema = np.tile(extrema_features, (enhanced_features.shape[0], enhanced_features.shape[1], 1))
enhanced_features = np.concatenate([enhanced_features, tiled_extrema], axis=2)
logger.debug(f"Enhanced CNN features with Extrema data for {symbol}")
except Exception as extrema_error:
logger.debug(f"Could not enhance CNN features with Extrema data: {extrema_error}")
if enhanced_features is not None:
# Get CNN prediction - use the actual underlying model
try:
@ -1219,9 +1238,35 @@ class TradingOrchestrator:
# Shape: (n_timeframes, window_size, n_features) -> (n_timeframes * window_size * n_features,)
state = feature_matrix.flatten()
# Add additional state information (position, balance, etc.)
# This would come from a portfolio manager in a real implementation
additional_state = np.array([0.0, 1.0, 0.0]) # [position, balance, unrealized_pnl]
# Add extrema features if available
if self.extrema_trainer:
try:
extrema_features = self.extrema_trainer.get_context_features_for_model(symbol)
if extrema_features is not None:
state = np.concatenate([state, extrema_features.flatten()])
logger.debug(f"Enhanced RL state with Extrema data for {symbol}")
except Exception as extrema_error:
logger.debug(f"Could not enhance RL state with Extrema data: {extrema_error}")
# Get real-time portfolio information from the trading executor
position_size = 0.0
balance = 1.0 # Default to a normalized value if not available
unrealized_pnl = 0.0
if self.trading_executor:
position = self.trading_executor.get_current_position(symbol)
if position:
position_size = position.get('quantity', 0.0)
# Normalize balance or use a realistic value
current_balance = self.trading_executor.get_balance()
if current_balance and current_balance.get('total', 0) > 0:
# Simple normalization - can be improved
balance = min(1.0, current_balance.get('free', 0) / current_balance.get('total', 1))
unrealized_pnl = self._get_current_position_pnl(symbol, self.data_provider.get_current_price(symbol))
additional_state = np.array([position_size, balance, unrealized_pnl])
return np.concatenate([state, additional_state])
@ -1955,4 +2000,35 @@ class TradingOrchestrator:
}
self.recent_cnn_predictions[symbol].append(prediction_data)
except Exception as e:
logger.debug(f"Error capturing CNN prediction: {e}")
logger.debug(f"Error capturing CNN prediction: {e}")
async def _get_cob_rl_prediction(self, model: COBRLModelInterface, symbol: str) -> Optional[Prediction]:
"""Get prediction from COB RL model"""
try:
cob_feature_matrix = self.get_cob_feature_matrix(symbol, sequence_length=1)
if cob_feature_matrix is None:
return None
# The model expects a 1D array of features
cob_features = cob_feature_matrix.flatten()
prediction_result = model.predict(cob_features)
if prediction_result:
direction_map = {0: 'SELL', 1: 'HOLD', 2: 'BUY'}
action = direction_map.get(prediction_result['predicted_direction'], 'HOLD')
prediction = Prediction(
action=action,
confidence=float(prediction_result['confidence']),
probabilities={direction_map.get(i, 'HOLD'): float(prob) for i, prob in enumerate(prediction_result['probabilities'])},
timeframe='cob',
timestamp=datetime.now(),
model_name=model.name,
metadata={'value': prediction_result['value']}
)
return prediction
return None
except Exception as e:
logger.error(f"Error getting COB RL prediction: {e}")
return None

View File

@ -59,6 +59,7 @@ class TradeRecord:
fees: float
confidence: float
hold_time_seconds: float = 0.0 # Hold time in seconds
leverage: float = 1.0 # Leverage applied to this trade
class TradingExecutor:
"""Handles trade execution through MEXC API with risk management"""
@ -344,7 +345,8 @@ class TradingExecutor:
logger.info(f"SIMULATION MODE ({self.trading_mode.upper()}) - Trade logged but not executed")
# Calculate simulated fees in simulation mode
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
simulated_fees = quantity * current_price * taker_fee_rate
current_leverage = self.get_leverage()
simulated_fees = quantity * current_price * taker_fee_rate * current_leverage
# Create mock position for tracking
self.positions[symbol] = Position(
@ -391,7 +393,8 @@ class TradingExecutor:
if order:
# Calculate simulated fees in simulation mode
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
simulated_fees = quantity * current_price * taker_fee_rate
current_leverage = self.get_leverage()
simulated_fees = quantity * current_price * taker_fee_rate * current_leverage
# Create position record
self.positions[symbol] = Position(
@ -424,6 +427,7 @@ class TradingExecutor:
return self._execute_short(symbol, confidence, current_price)
position = self.positions[symbol]
current_leverage = self.get_leverage()
logger.info(f"Executing SELL: {position.quantity:.6f} {symbol} at ${current_price:.2f} "
f"(confidence: {confidence:.2f}) [{'SIMULATION' if self.simulation_mode else 'LIVE'}]")
@ -431,13 +435,13 @@ class TradingExecutor:
if self.simulation_mode:
logger.info(f"SIMULATION MODE ({self.trading_mode.upper()}) - Trade logged but not executed")
# Calculate P&L and hold time
pnl = position.calculate_pnl(current_price)
pnl = position.calculate_pnl(current_price) * current_leverage # Apply leverage to PnL
exit_time = datetime.now()
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
# Calculate simulated fees in simulation mode
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
simulated_fees = position.quantity * current_price * taker_fee_rate
simulated_fees = position.quantity * current_price * taker_fee_rate * current_leverage # Apply leverage to fees
# Create trade record
trade_record = TradeRecord(
@ -448,14 +452,15 @@ class TradingExecutor:
exit_price=current_price,
entry_time=position.entry_time,
exit_time=exit_time,
pnl=pnl,
pnl=pnl - simulated_fees,
fees=simulated_fees,
confidence=confidence,
hold_time_seconds=hold_time_seconds
hold_time_seconds=hold_time_seconds,
leverage=current_leverage # Store leverage
)
self.trade_history.append(trade_record)
self.daily_loss += max(0, -pnl) # Add to daily loss if negative
self.daily_loss += max(0, -(pnl - simulated_fees)) # Add to daily loss if negative
# Update consecutive losses
if pnl < -0.001: # A losing trade
@ -470,7 +475,7 @@ class TradingExecutor:
self.last_trade_time[symbol] = datetime.now()
self.daily_trades += 1
logger.info(f"Position closed - P&L: ${pnl:.2f}")
logger.info(f"Position closed - P&L: ${pnl - simulated_fees:.2f}")
return True
try:
@ -505,10 +510,10 @@ class TradingExecutor:
if order:
# Calculate simulated fees in simulation mode
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
simulated_fees = position.quantity * current_price * taker_fee_rate
simulated_fees = position.quantity * current_price * taker_fee_rate * current_leverage # Apply leverage
# Calculate P&L, fees, and hold time
pnl = position.calculate_pnl(current_price)
pnl = position.calculate_pnl(current_price) * current_leverage # Apply leverage to PnL
fees = simulated_fees
exit_time = datetime.now()
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
@ -525,7 +530,8 @@ class TradingExecutor:
pnl=pnl - fees,
fees=fees,
confidence=confidence,
hold_time_seconds=hold_time_seconds
hold_time_seconds=hold_time_seconds,
leverage=current_leverage # Store leverage
)
self.trade_history.append(trade_record)
@ -574,7 +580,8 @@ class TradingExecutor:
logger.info(f"SIMULATION MODE ({self.trading_mode.upper()}) - Short position logged but not executed")
# Calculate simulated fees in simulation mode
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
simulated_fees = quantity * current_price * taker_fee_rate
current_leverage = self.get_leverage()
simulated_fees = quantity * current_price * taker_fee_rate * current_leverage
# Create mock short position for tracking
self.positions[symbol] = Position(
@ -621,7 +628,8 @@ class TradingExecutor:
if order:
# Calculate simulated fees in simulation mode
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
simulated_fees = quantity * current_price * taker_fee_rate
current_leverage = self.get_leverage()
simulated_fees = quantity * current_price * taker_fee_rate * current_leverage
# Create short position record
self.positions[symbol] = Position(
@ -653,6 +661,8 @@ class TradingExecutor:
return False
position = self.positions[symbol]
current_leverage = self.get_leverage() # Get current leverage
if position.side != 'SHORT':
logger.warning(f"Position in {symbol} is not SHORT, cannot close with BUY")
return False
@ -664,10 +674,10 @@ class TradingExecutor:
logger.info(f"SIMULATION MODE ({self.trading_mode.upper()}) - Short close logged but not executed")
# Calculate simulated fees in simulation mode
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
simulated_fees = position.quantity * current_price * taker_fee_rate
simulated_fees = position.quantity * current_price * taker_fee_rate * current_leverage
# Calculate P&L for short position and hold time
pnl = position.calculate_pnl(current_price)
pnl = position.calculate_pnl(current_price) * current_leverage # Apply leverage to PnL
exit_time = datetime.now()
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
@ -680,21 +690,22 @@ class TradingExecutor:
exit_price=current_price,
entry_time=position.entry_time,
exit_time=exit_time,
pnl=pnl,
pnl=pnl - simulated_fees,
fees=simulated_fees,
confidence=confidence,
hold_time_seconds=hold_time_seconds
hold_time_seconds=hold_time_seconds,
leverage=current_leverage # Store leverage
)
self.trade_history.append(trade_record)
self.daily_loss += max(0, -pnl) # Add to daily loss if negative
self.daily_loss += max(0, -(pnl - simulated_fees)) # Add to daily loss if negative
# Remove position
del self.positions[symbol]
self.last_trade_time[symbol] = datetime.now()
self.daily_trades += 1
logger.info(f"SHORT position closed - P&L: ${pnl:.2f}")
logger.info(f"SHORT position closed - P&L: ${pnl - simulated_fees:.2f}")
return True
try:
@ -729,10 +740,10 @@ class TradingExecutor:
if order:
# Calculate simulated fees in simulation mode
taker_fee_rate = self.mexc_config.get('trading_fees', {}).get('taker_fee', 0.0006)
simulated_fees = position.quantity * current_price * taker_fee_rate
simulated_fees = position.quantity * current_price * taker_fee_rate * current_leverage
# Calculate P&L, fees, and hold time
pnl = position.calculate_pnl(current_price)
pnl = position.calculate_pnl(current_price) * current_leverage # Apply leverage to PnL
fees = simulated_fees
exit_time = datetime.now()
hold_time_seconds = (exit_time - position.entry_time).total_seconds()
@ -749,7 +760,8 @@ class TradingExecutor:
pnl=pnl - fees,
fees=fees,
confidence=confidence,
hold_time_seconds=hold_time_seconds
hold_time_seconds=hold_time_seconds,
leverage=current_leverage # Store leverage
)
self.trade_history.append(trade_record)
@ -875,7 +887,7 @@ class TradingExecutor:
'losing_trades': losing_trades,
'breakeven_trades': breakeven_trades,
'total_trades': total_trades,
'win_rate': winning_trades / max(1, total_trades),
'win_rate': winning_trades / max(1, winning_trades + losing_trades) if (winning_trades + losing_trades) > 0 else 0.0,
'avg_trade_pnl': avg_trade_pnl,
'avg_trade_fee': avg_trade_fee,
'avg_winning_trade': avg_winning_trade,