This commit is contained in:
Dobromir Popov
2025-07-23 00:48:14 +03:00
parent 8898f71832
commit 0cc104f1ef
2 changed files with 115 additions and 27 deletions

View File

@ -948,6 +948,12 @@ class TradingOrchestrator:
rl_prediction = await self._get_rl_prediction(model, symbol)
if rl_prediction:
predictions.append(rl_prediction)
elif isinstance(model, COBRLModelInterface):
# Get COB RL prediction
cob_prediction = await self._get_cob_rl_prediction(model, symbol)
if cob_prediction:
predictions.append(cob_prediction)
else:
# Generic model interface
@ -1007,6 +1013,19 @@ class TradingOrchestrator:
logger.debug(f"Could not enhance CNN features with COB data: {cob_error}")
enhanced_features = feature_matrix
# Add extrema features if available
if self.extrema_trainer:
try:
extrema_features = self.extrema_trainer.get_context_features_for_model(symbol)
if extrema_features is not None:
# Reshape and tile to match the enhanced_features shape
extrema_features = extrema_features.flatten()
tiled_extrema = np.tile(extrema_features, (enhanced_features.shape[0], enhanced_features.shape[1], 1))
enhanced_features = np.concatenate([enhanced_features, tiled_extrema], axis=2)
logger.debug(f"Enhanced CNN features with Extrema data for {symbol}")
except Exception as extrema_error:
logger.debug(f"Could not enhance CNN features with Extrema data: {extrema_error}")
if enhanced_features is not None:
# Get CNN prediction - use the actual underlying model
try:
@ -1219,9 +1238,35 @@ class TradingOrchestrator:
# Shape: (n_timeframes, window_size, n_features) -> (n_timeframes * window_size * n_features,)
state = feature_matrix.flatten()
# Add additional state information (position, balance, etc.)
# This would come from a portfolio manager in a real implementation
additional_state = np.array([0.0, 1.0, 0.0]) # [position, balance, unrealized_pnl]
# Add extrema features if available
if self.extrema_trainer:
try:
extrema_features = self.extrema_trainer.get_context_features_for_model(symbol)
if extrema_features is not None:
state = np.concatenate([state, extrema_features.flatten()])
logger.debug(f"Enhanced RL state with Extrema data for {symbol}")
except Exception as extrema_error:
logger.debug(f"Could not enhance RL state with Extrema data: {extrema_error}")
# Get real-time portfolio information from the trading executor
position_size = 0.0
balance = 1.0 # Default to a normalized value if not available
unrealized_pnl = 0.0
if self.trading_executor:
position = self.trading_executor.get_current_position(symbol)
if position:
position_size = position.get('quantity', 0.0)
# Normalize balance or use a realistic value
current_balance = self.trading_executor.get_balance()
if current_balance and current_balance.get('total', 0) > 0:
# Simple normalization - can be improved
balance = min(1.0, current_balance.get('free', 0) / current_balance.get('total', 1))
unrealized_pnl = self._get_current_position_pnl(symbol, self.data_provider.get_current_price(symbol))
additional_state = np.array([position_size, balance, unrealized_pnl])
return np.concatenate([state, additional_state])
@ -1955,4 +2000,35 @@ class TradingOrchestrator:
}
self.recent_cnn_predictions[symbol].append(prediction_data)
except Exception as e:
logger.debug(f"Error capturing CNN prediction: {e}")
logger.debug(f"Error capturing CNN prediction: {e}")
async def _get_cob_rl_prediction(self, model: COBRLModelInterface, symbol: str) -> Optional[Prediction]:
"""Get prediction from COB RL model"""
try:
cob_feature_matrix = self.get_cob_feature_matrix(symbol, sequence_length=1)
if cob_feature_matrix is None:
return None
# The model expects a 1D array of features
cob_features = cob_feature_matrix.flatten()
prediction_result = model.predict(cob_features)
if prediction_result:
direction_map = {0: 'SELL', 1: 'HOLD', 2: 'BUY'}
action = direction_map.get(prediction_result['predicted_direction'], 'HOLD')
prediction = Prediction(
action=action,
confidence=float(prediction_result['confidence']),
probabilities={direction_map.get(i, 'HOLD'): float(prob) for i, prob in enumerate(prediction_result['probabilities'])},
timeframe='cob',
timestamp=datetime.now(),
model_name=model.name,
metadata={'value': prediction_result['value']}
)
return prediction
return None
except Exception as e:
logger.error(f"Error getting COB RL prediction: {e}")
return None