wip cob
This commit is contained in:
@ -948,6 +948,12 @@ class TradingOrchestrator:
|
||||
rl_prediction = await self._get_rl_prediction(model, symbol)
|
||||
if rl_prediction:
|
||||
predictions.append(rl_prediction)
|
||||
|
||||
elif isinstance(model, COBRLModelInterface):
|
||||
# Get COB RL prediction
|
||||
cob_prediction = await self._get_cob_rl_prediction(model, symbol)
|
||||
if cob_prediction:
|
||||
predictions.append(cob_prediction)
|
||||
|
||||
else:
|
||||
# Generic model interface
|
||||
@ -1007,6 +1013,19 @@ class TradingOrchestrator:
|
||||
logger.debug(f"Could not enhance CNN features with COB data: {cob_error}")
|
||||
enhanced_features = feature_matrix
|
||||
|
||||
# Add extrema features if available
|
||||
if self.extrema_trainer:
|
||||
try:
|
||||
extrema_features = self.extrema_trainer.get_context_features_for_model(symbol)
|
||||
if extrema_features is not None:
|
||||
# Reshape and tile to match the enhanced_features shape
|
||||
extrema_features = extrema_features.flatten()
|
||||
tiled_extrema = np.tile(extrema_features, (enhanced_features.shape[0], enhanced_features.shape[1], 1))
|
||||
enhanced_features = np.concatenate([enhanced_features, tiled_extrema], axis=2)
|
||||
logger.debug(f"Enhanced CNN features with Extrema data for {symbol}")
|
||||
except Exception as extrema_error:
|
||||
logger.debug(f"Could not enhance CNN features with Extrema data: {extrema_error}")
|
||||
|
||||
if enhanced_features is not None:
|
||||
# Get CNN prediction - use the actual underlying model
|
||||
try:
|
||||
@ -1219,9 +1238,35 @@ class TradingOrchestrator:
|
||||
# Shape: (n_timeframes, window_size, n_features) -> (n_timeframes * window_size * n_features,)
|
||||
state = feature_matrix.flatten()
|
||||
|
||||
# Add additional state information (position, balance, etc.)
|
||||
# This would come from a portfolio manager in a real implementation
|
||||
additional_state = np.array([0.0, 1.0, 0.0]) # [position, balance, unrealized_pnl]
|
||||
# Add extrema features if available
|
||||
if self.extrema_trainer:
|
||||
try:
|
||||
extrema_features = self.extrema_trainer.get_context_features_for_model(symbol)
|
||||
if extrema_features is not None:
|
||||
state = np.concatenate([state, extrema_features.flatten()])
|
||||
logger.debug(f"Enhanced RL state with Extrema data for {symbol}")
|
||||
except Exception as extrema_error:
|
||||
logger.debug(f"Could not enhance RL state with Extrema data: {extrema_error}")
|
||||
|
||||
# Get real-time portfolio information from the trading executor
|
||||
position_size = 0.0
|
||||
balance = 1.0 # Default to a normalized value if not available
|
||||
unrealized_pnl = 0.0
|
||||
|
||||
if self.trading_executor:
|
||||
position = self.trading_executor.get_current_position(symbol)
|
||||
if position:
|
||||
position_size = position.get('quantity', 0.0)
|
||||
|
||||
# Normalize balance or use a realistic value
|
||||
current_balance = self.trading_executor.get_balance()
|
||||
if current_balance and current_balance.get('total', 0) > 0:
|
||||
# Simple normalization - can be improved
|
||||
balance = min(1.0, current_balance.get('free', 0) / current_balance.get('total', 1))
|
||||
|
||||
unrealized_pnl = self._get_current_position_pnl(symbol, self.data_provider.get_current_price(symbol))
|
||||
|
||||
additional_state = np.array([position_size, balance, unrealized_pnl])
|
||||
|
||||
return np.concatenate([state, additional_state])
|
||||
|
||||
@ -1955,4 +2000,35 @@ class TradingOrchestrator:
|
||||
}
|
||||
self.recent_cnn_predictions[symbol].append(prediction_data)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error capturing CNN prediction: {e}")
|
||||
logger.debug(f"Error capturing CNN prediction: {e}")
|
||||
|
||||
async def _get_cob_rl_prediction(self, model: COBRLModelInterface, symbol: str) -> Optional[Prediction]:
|
||||
"""Get prediction from COB RL model"""
|
||||
try:
|
||||
cob_feature_matrix = self.get_cob_feature_matrix(symbol, sequence_length=1)
|
||||
if cob_feature_matrix is None:
|
||||
return None
|
||||
|
||||
# The model expects a 1D array of features
|
||||
cob_features = cob_feature_matrix.flatten()
|
||||
|
||||
prediction_result = model.predict(cob_features)
|
||||
|
||||
if prediction_result:
|
||||
direction_map = {0: 'SELL', 1: 'HOLD', 2: 'BUY'}
|
||||
action = direction_map.get(prediction_result['predicted_direction'], 'HOLD')
|
||||
|
||||
prediction = Prediction(
|
||||
action=action,
|
||||
confidence=float(prediction_result['confidence']),
|
||||
probabilities={direction_map.get(i, 'HOLD'): float(prob) for i, prob in enumerate(prediction_result['probabilities'])},
|
||||
timeframe='cob',
|
||||
timestamp=datetime.now(),
|
||||
model_name=model.name,
|
||||
metadata={'value': prediction_result['value']}
|
||||
)
|
||||
return prediction
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting COB RL prediction: {e}")
|
||||
return None
|
Reference in New Issue
Block a user