This commit is contained in:
Dobromir Popov
2025-07-07 23:39:12 +03:00
parent 2d8f763eeb
commit 9cd2d5d8a4
6 changed files with 188 additions and 49 deletions

View File

@ -1007,6 +1007,17 @@ class TradingOrchestrator:
if enhanced_features is not None:
# Get CNN prediction - use the actual underlying model
try:
# Ensure features are properly shaped and limited
if isinstance(enhanced_features, np.ndarray):
# Flatten and limit features to prevent shape mismatches
enhanced_features = enhanced_features.flatten()
if len(enhanced_features) > 100: # Limit to 100 features
enhanced_features = enhanced_features[:100]
elif len(enhanced_features) < 100: # Pad with zeros
padded = np.zeros(100)
padded[:len(enhanced_features)] = enhanced_features
enhanced_features = padded
if hasattr(model.model, 'act'):
# Use the CNN's act method
action_result = model.model.act(enhanced_features, explore=False)
@ -1020,7 +1031,7 @@ class TradingOrchestrator:
action_probs = [0.1, 0.1, 0.8] # Default distribution
action_probs[action_idx] = confidence
else:
# Fallback to generic predict method
# Fallback to generic predict method
action_probs, confidence = model.predict(enhanced_features)
except Exception as e:
logger.warning(f"CNN prediction failed: {e}")
@ -1138,6 +1149,17 @@ class TradingOrchestrator:
)
if feature_matrix is not None:
# Ensure feature_matrix is properly shaped and limited
if isinstance(feature_matrix, np.ndarray):
# Flatten and limit features to prevent shape mismatches
feature_matrix = feature_matrix.flatten()
if len(feature_matrix) > 2000: # Limit to 2000 features for generic models
feature_matrix = feature_matrix[:2000]
elif len(feature_matrix) < 2000: # Pad with zeros
padded = np.zeros(2000)
padded[:len(feature_matrix)] = feature_matrix
feature_matrix = padded
prediction_result = model.predict(feature_matrix)
# Handle different return formats from model.predict()
@ -1833,4 +1855,101 @@ class TradingOrchestrator:
def set_trading_executor(self, trading_executor):
"""Set the trading executor for position tracking"""
self.trading_executor = trading_executor
logger.info("Trading executor set for position tracking and P&L feedback")
logger.info("Trading executor set for position tracking and P&L feedback")
def _get_current_price(self, symbol: str) -> float:
"""Get current price for symbol"""
try:
# Try to get from data provider
if self.data_provider:
try:
# Try different methods to get current price
if hasattr(self.data_provider, 'get_latest_data'):
latest_data = self.data_provider.get_latest_data(symbol)
if latest_data and 'price' in latest_data:
return float(latest_data['price'])
elif latest_data and 'close' in latest_data:
return float(latest_data['close'])
elif hasattr(self.data_provider, 'get_current_price'):
return float(self.data_provider.get_current_price(symbol))
elif hasattr(self.data_provider, 'get_latest_candle'):
latest_candle = self.data_provider.get_latest_candle(symbol, '1m')
if latest_candle and 'close' in latest_candle:
return float(latest_candle['close'])
except Exception as e:
logger.debug(f"Could not get price from data provider: {e}")
# Try to get from universal adapter
if self.universal_adapter:
try:
data_stream = self.universal_adapter.get_latest_data(symbol)
if data_stream and hasattr(data_stream, 'current_price'):
return float(data_stream.current_price)
except Exception as e:
logger.debug(f"Could not get price from universal adapter: {e}")
# Fallback to default prices
default_prices = {
'ETH/USDT': 2500.0,
'BTC/USDT': 108000.0
}
return default_prices.get(symbol, 1000.0)
except Exception as e:
logger.error(f"Error getting current price for {symbol}: {e}")
# Return default price based on symbol
if 'ETH' in symbol:
return 2500.0
elif 'BTC' in symbol:
return 108000.0
else:
return 1000.0
def _generate_fallback_prediction(self, symbol: str) -> Dict[str, Any]:
"""Generate fallback prediction when models fail"""
try:
return {
'action': 'HOLD',
'confidence': 0.5,
'price': self._get_current_price(symbol) or 2500.0,
'timestamp': datetime.now(),
'model': 'fallback'
}
except Exception as e:
logger.debug(f"Error generating fallback prediction: {e}")
return {
'action': 'HOLD',
'confidence': 0.5,
'price': 2500.0,
'timestamp': datetime.now(),
'model': 'fallback'
}
def capture_dqn_prediction(self, symbol: str, action_idx: int, confidence: float, price: float, q_values: List[float] = None):
"""Capture DQN prediction for dashboard visualization"""
try:
if symbol not in self.recent_dqn_predictions:
self.recent_dqn_predictions[symbol] = deque(maxlen=100)
prediction_data = {
'timestamp': datetime.now(),
'action': ['SELL', 'HOLD', 'BUY'][action_idx],
'confidence': confidence,
'price': price,
'q_values': q_values or [0.33, 0.33, 0.34]
}
self.recent_dqn_predictions[symbol].append(prediction_data)
except Exception as e:
logger.debug(f"Error capturing DQN prediction: {e}")
def capture_cnn_prediction(self, symbol: str, direction: int, confidence: float, current_price: float, predicted_price: float):
"""Capture CNN prediction for dashboard visualization"""
try:
if symbol not in self.recent_cnn_predictions:
self.recent_cnn_predictions[symbol] = deque(maxlen=50)
prediction_data = {
'timestamp': datetime.now(),
'direction': ['DOWN', 'SAME', 'UP'][direction],
'confidence': confidence,
'current_price': current_price,
'predicted_price': predicted_price
}
self.recent_cnn_predictions[symbol].append(prediction_data)
except Exception as e:
logger.debug(f"Error capturing CNN prediction: {e}")