more fixes
This commit is contained in:
@ -1693,7 +1693,7 @@ class DataProvider:
|
||||
# Stack all timeframe channels
|
||||
feature_matrix = np.stack(feature_channels, axis=0)
|
||||
|
||||
logger.info(f"Created feature matrix for {symbol}: {feature_matrix.shape} "
|
||||
logger.debug(f"Created feature matrix for {symbol}: {feature_matrix.shape} "
|
||||
f"({len(feature_channels)} timeframes, {window_size} steps, {len(common_feature_names)} features)")
|
||||
|
||||
return feature_matrix
|
||||
|
@ -419,7 +419,7 @@ class EnhancedTradingOrchestrator(TradingOrchestrator):
|
||||
symbol_predictions = {}
|
||||
for symbol in self.symbols:
|
||||
if symbol in market_states:
|
||||
predictions = await self._get_enhanced_predictions_universal(
|
||||
predictions = await self._get_enhanced_predictions_for_symbol(
|
||||
symbol, market_states[symbol], universal_stream
|
||||
)
|
||||
symbol_predictions[symbol] = predictions
|
||||
@ -444,6 +444,77 @@ class EnhancedTradingOrchestrator(TradingOrchestrator):
|
||||
|
||||
return decisions
|
||||
|
||||
async def _get_enhanced_predictions_for_symbol(self, symbol: str, market_state: MarketState,
|
||||
universal_stream: UniversalDataStream) -> List[EnhancedPrediction]:
|
||||
"""Get enhanced predictions for a symbol using universal data format"""
|
||||
predictions = []
|
||||
|
||||
try:
|
||||
# Get predictions from all registered models using the parent class method
|
||||
base_predictions = await self._get_all_predictions(symbol)
|
||||
|
||||
if not base_predictions:
|
||||
logger.warning(f"No base predictions available for {symbol}")
|
||||
return predictions
|
||||
|
||||
# Group predictions by model and create enhanced predictions
|
||||
model_predictions = {}
|
||||
for pred in base_predictions:
|
||||
if pred.model_name not in model_predictions:
|
||||
model_predictions[pred.model_name] = []
|
||||
model_predictions[pred.model_name].append(pred)
|
||||
|
||||
# Create enhanced predictions for each model
|
||||
for model_name, model_preds in model_predictions.items():
|
||||
# Convert base predictions to timeframe predictions
|
||||
timeframe_predictions = []
|
||||
for pred in model_preds:
|
||||
tf_pred = TimeframePrediction(
|
||||
timeframe=pred.timeframe,
|
||||
action=pred.action,
|
||||
confidence=pred.confidence,
|
||||
probabilities=pred.probabilities,
|
||||
timestamp=pred.timestamp,
|
||||
market_features=pred.metadata or {}
|
||||
)
|
||||
timeframe_predictions.append(tf_pred)
|
||||
|
||||
# Combine timeframe predictions into overall action
|
||||
if timeframe_predictions:
|
||||
overall_action, overall_confidence = self._combine_timeframe_predictions(
|
||||
timeframe_predictions, symbol
|
||||
)
|
||||
|
||||
# Enhance confidence with universal context
|
||||
enhanced_confidence = self._enhance_confidence_with_universal_context(
|
||||
overall_confidence, 'mixed', market_state, universal_stream
|
||||
)
|
||||
|
||||
# Create enhanced prediction
|
||||
enhanced_pred = EnhancedPrediction(
|
||||
symbol=symbol,
|
||||
timeframe_predictions=timeframe_predictions,
|
||||
overall_action=overall_action,
|
||||
overall_confidence=enhanced_confidence,
|
||||
model_name=model_name,
|
||||
timestamp=datetime.now(),
|
||||
metadata={
|
||||
'universal_data_used': True,
|
||||
'market_regime': market_state.market_regime,
|
||||
'volatility': market_state.volatility,
|
||||
'volume': market_state.volume
|
||||
}
|
||||
)
|
||||
|
||||
predictions.append(enhanced_pred)
|
||||
logger.debug(f"Created enhanced prediction for {symbol} from {model_name}: "
|
||||
f"{overall_action} (confidence: {enhanced_confidence:.3f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting enhanced predictions for {symbol}: {e}")
|
||||
|
||||
return predictions
|
||||
|
||||
async def _get_all_market_states_universal(self, universal_stream: UniversalDataStream) -> Dict[str, MarketState]:
|
||||
"""Get market states for all symbols with comprehensive data for RL"""
|
||||
market_states = {}
|
||||
|
Reference in New Issue
Block a user