mtp
This commit is contained in:
@@ -295,6 +295,7 @@ class TradingOrchestrator:
|
||||
file_path, metadata = result
|
||||
# Actually load the model weights from the checkpoint
|
||||
try:
|
||||
# TODO(Guideline: initialize required attributes before use) Define self.device (CUDA/CPU) before loading checkpoints.
|
||||
checkpoint_data = torch.load(file_path, map_location=self.device)
|
||||
if 'model_state_dict' in checkpoint_data:
|
||||
self.cnn_model.load_state_dict(checkpoint_data['model_state_dict'])
|
||||
@@ -1127,14 +1128,9 @@ class TradingOrchestrator:
|
||||
predictions = await self._get_all_predictions(symbol)
|
||||
|
||||
if not predictions:
|
||||
# FALLBACK: Generate basic momentum signal when no models are available
|
||||
logger.debug(f"No model predictions available for {symbol}, generating fallback signal")
|
||||
fallback_prediction = await self._generate_fallback_prediction(symbol, current_price)
|
||||
if fallback_prediction:
|
||||
predictions = [fallback_prediction]
|
||||
else:
|
||||
logger.debug(f"No fallback prediction available for {symbol}")
|
||||
return None
|
||||
# TODO(Guideline: no stubs / no synthetic data) Replace this short-circuit with a real aggregated signal path.
|
||||
logger.warning("No model predictions available for %s; skipping decision per guidelines", symbol)
|
||||
return None
|
||||
|
||||
# Combine predictions
|
||||
decision = self._combine_predictions(
|
||||
@@ -1171,17 +1167,8 @@ class TradingOrchestrator:
|
||||
|
||||
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
|
||||
"""Get predictions from all registered models via ModelManager"""
|
||||
predictions = []
|
||||
|
||||
# This method now delegates to ModelManager for model iteration
|
||||
# The actual model prediction logic has been moved to individual methods
|
||||
# that are called by the ModelManager
|
||||
|
||||
logger.debug(f"Getting predictions for {symbol} - model management handled by ModelManager")
|
||||
|
||||
# For now, return empty list as this method needs to be restructured
|
||||
# to work with the new ModelManager architecture
|
||||
return predictions
|
||||
# TODO(Guideline: remove stubs / integrate existing code) Implement ModelManager-driven prediction aggregation.
|
||||
raise RuntimeError("_get_all_predictions requires a real ModelManager integration (guideline: no stubs / no synthetic data).")
|
||||
|
||||
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
|
||||
"""Get CNN predictions for multiple timeframes"""
|
||||
@@ -1497,16 +1484,19 @@ class TradingOrchestrator:
|
||||
balance = 1.0 # Default to a normalized value if not available
|
||||
unrealized_pnl = 0.0
|
||||
|
||||
if self.trading_executor:
|
||||
position = self.trading_executor.get_current_position(symbol)
|
||||
if position:
|
||||
position_size = position.get('quantity', 0.0)
|
||||
|
||||
# Normalize balance or use a realistic value
|
||||
if self.trading_executor:
|
||||
position = self.trading_executor.get_current_position(symbol)
|
||||
if position:
|
||||
position_size = position.get('quantity', 0.0)
|
||||
|
||||
if hasattr(self.trading_executor, "get_balance"):
|
||||
current_balance = self.trading_executor.get_balance()
|
||||
if current_balance and current_balance.get('total', 0) > 0:
|
||||
# Simple normalization - can be improved
|
||||
balance = min(1.0, current_balance.get('free', 0) / current_balance.get('total', 1))
|
||||
else:
|
||||
# TODO(Guideline: ensure integrations call real APIs) Expose a balance accessor on TradingExecutor for decision-state enrichment.
|
||||
logger.warning("TradingExecutor lacks get_balance(); implement real balance access per guidelines")
|
||||
current_balance = {}
|
||||
if current_balance and current_balance.get('total', 0) > 0:
|
||||
balance = min(1.0, current_balance.get('free', 0) / current_balance.get('total', 1))
|
||||
|
||||
unrealized_pnl = self._get_current_position_pnl(symbol, self.data_provider.get_current_price(symbol))
|
||||
|
||||
@@ -2232,24 +2222,9 @@ class TradingOrchestrator:
|
||||
|
||||
# SINGLE-USE FUNCTION - Called only once in codebase
|
||||
def _generate_fallback_prediction(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Generate fallback prediction when models fail"""
|
||||
try:
|
||||
return {
|
||||
'action': 'HOLD',
|
||||
'confidence': 0.5,
|
||||
'price': self._get_current_price(symbol) or 2500.0,
|
||||
'timestamp': datetime.now(),
|
||||
'model': 'fallback'
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"Error generating fallback prediction: {e}")
|
||||
return {
|
||||
'action': 'HOLD',
|
||||
'confidence': 0.5,
|
||||
'price': 2500.0,
|
||||
'timestamp': datetime.now(),
|
||||
'model': 'fallback'
|
||||
}
|
||||
"""Fallback predictions were removed to avoid synthetic signals."""
|
||||
# TODO(Guideline: no synthetic data / no stubs) Provide a real degraded-mode signal pipeline or remove this hook entirely.
|
||||
raise RuntimeError("Fallback predictions disabled per guidelines; supply real model output instead.")
|
||||
|
||||
# UNUSED FUNCTION - Not called anywhere in codebase
|
||||
def capture_dqn_prediction(self, symbol: str, action_idx: int, confidence: float, price: float, q_values: List[float] = None):
|
||||
|
Reference in New Issue
Block a user