better CNN info in the dash
This commit is contained in:
@ -31,6 +31,7 @@ from .extrema_trainer import ExtremaTrainer
|
||||
from .trading_action import TradingAction
|
||||
from .negative_case_trainer import NegativeCaseTrainer
|
||||
from .trading_executor import TradingExecutor
|
||||
from .cnn_monitor import log_cnn_prediction, start_cnn_training_session
|
||||
# Enhanced pivot RL trainer functionality integrated into orchestrator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -790,19 +791,145 @@ class EnhancedTradingOrchestrator:
|
||||
async def _get_timeframe_prediction_universal(self, model: CNNModelInterface, feature_matrix: np.ndarray,
|
||||
timeframe: str, market_state: MarketState,
|
||||
universal_stream: UniversalDataStream) -> Tuple[Optional[np.ndarray], float]:
|
||||
"""Get prediction for specific timeframe using universal data format"""
|
||||
"""Get prediction for specific timeframe using universal data format with CNN monitoring"""
|
||||
try:
|
||||
# Check if model supports timeframe-specific prediction
|
||||
# Measure prediction timing
|
||||
prediction_start_time = time.time()
|
||||
|
||||
# Get current price for context
|
||||
current_price = market_state.prices.get(timeframe)
|
||||
|
||||
# Check if model supports timeframe-specific prediction or enhanced predict method
|
||||
if hasattr(model, 'predict_timeframe'):
|
||||
action_probs, confidence = model.predict_timeframe(feature_matrix, timeframe)
|
||||
elif hasattr(model, 'predict') and hasattr(model.predict, '__call__'):
|
||||
# Enhanced CNN model with detailed output
|
||||
if hasattr(model, 'enhanced_predict'):
|
||||
# Get detailed prediction results
|
||||
prediction_result = model.enhanced_predict(feature_matrix)
|
||||
action_probs = prediction_result.get('probabilities', [])
|
||||
confidence = prediction_result.get('confidence', 0.0)
|
||||
else:
|
||||
# Standard prediction
|
||||
prediction_result = model.predict(feature_matrix)
|
||||
if isinstance(prediction_result, dict):
|
||||
action_probs = prediction_result.get('probabilities', [])
|
||||
confidence = prediction_result.get('confidence', 0.0)
|
||||
else:
|
||||
action_probs, confidence = prediction_result
|
||||
else:
|
||||
action_probs, confidence = model.predict(feature_matrix)
|
||||
|
||||
# Calculate prediction latency
|
||||
prediction_latency_ms = (time.time() - prediction_start_time) * 1000
|
||||
|
||||
if action_probs is not None and confidence is not None:
|
||||
# Enhance confidence based on universal data quality and market conditions
|
||||
enhanced_confidence = self._enhance_confidence_with_universal_context(
|
||||
confidence, timeframe, market_state, universal_stream
|
||||
)
|
||||
|
||||
# Log detailed CNN prediction for monitoring
|
||||
try:
|
||||
# Convert probabilities to list if needed
|
||||
if hasattr(action_probs, 'tolist'):
|
||||
prob_list = action_probs.tolist()
|
||||
elif isinstance(action_probs, (list, tuple)):
|
||||
prob_list = list(action_probs)
|
||||
else:
|
||||
prob_list = [float(action_probs)]
|
||||
|
||||
# Determine action and action confidence
|
||||
if len(prob_list) >= 2:
|
||||
action_idx = np.argmax(prob_list)
|
||||
action_name = ['SELL', 'BUY'][action_idx] if len(prob_list) == 2 else ['SELL', 'HOLD', 'BUY'][action_idx]
|
||||
action_confidence = prob_list[action_idx]
|
||||
else:
|
||||
action_idx = 0
|
||||
action_name = 'HOLD'
|
||||
action_confidence = enhanced_confidence
|
||||
|
||||
# Get model memory usage if available
|
||||
model_memory_mb = None
|
||||
if hasattr(model, 'get_memory_usage'):
|
||||
try:
|
||||
memory_info = model.get_memory_usage()
|
||||
if isinstance(memory_info, dict):
|
||||
model_memory_mb = memory_info.get('total_size_mb', 0.0)
|
||||
else:
|
||||
model_memory_mb = float(memory_info)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Create detailed prediction result for monitoring
|
||||
detailed_prediction = {
|
||||
'action': action_idx,
|
||||
'action_name': action_name,
|
||||
'confidence': float(enhanced_confidence),
|
||||
'action_confidence': float(action_confidence),
|
||||
'probabilities': prob_list,
|
||||
'raw_logits': prob_list # Use probabilities as proxy for logits if not available
|
||||
}
|
||||
|
||||
# Add enhanced model outputs if available
|
||||
if hasattr(model, 'enhanced_predict') and isinstance(prediction_result, dict):
|
||||
detailed_prediction.update({
|
||||
'regime_probabilities': prediction_result.get('regime_probabilities'),
|
||||
'volatility_prediction': prediction_result.get('volatility_prediction'),
|
||||
'extrema_prediction': prediction_result.get('extrema_prediction'),
|
||||
'risk_assessment': prediction_result.get('risk_assessment')
|
||||
})
|
||||
|
||||
# Calculate price changes for context
|
||||
price_change_1m = None
|
||||
price_change_5m = None
|
||||
volume_ratio = None
|
||||
|
||||
if current_price and timeframe in market_state.prices:
|
||||
# Try to get historical prices for context
|
||||
try:
|
||||
# Get 1m and 5m price changes if available
|
||||
if '1m' in market_state.prices and market_state.prices['1m'] != current_price:
|
||||
price_change_1m = (current_price - market_state.prices['1m']) / market_state.prices['1m']
|
||||
if '5m' in market_state.prices and market_state.prices['5m'] != current_price:
|
||||
price_change_5m = (current_price - market_state.prices['5m']) / market_state.prices['5m']
|
||||
|
||||
# Volume ratio (current vs average)
|
||||
volume_ratio = market_state.volume
|
||||
except:
|
||||
pass
|
||||
|
||||
# Log the CNN prediction with full context
|
||||
log_cnn_prediction(
|
||||
model_name=getattr(model, 'name', model.__class__.__name__),
|
||||
symbol=market_state.symbol,
|
||||
prediction_result=detailed_prediction,
|
||||
feature_matrix_shape=feature_matrix.shape,
|
||||
current_price=current_price,
|
||||
prediction_latency_ms=prediction_latency_ms,
|
||||
model_memory_usage_mb=model_memory_mb
|
||||
)
|
||||
|
||||
# Enhanced logging for detailed analysis
|
||||
logger.info(f"CNN [{getattr(model, 'name', 'Unknown')}] {market_state.symbol} {timeframe}: "
|
||||
f"{action_name} (conf: {enhanced_confidence:.3f}, "
|
||||
f"action_conf: {action_confidence:.3f}, "
|
||||
f"latency: {prediction_latency_ms:.1f}ms)")
|
||||
|
||||
if detailed_prediction.get('regime_probabilities'):
|
||||
regime_idx = np.argmax(detailed_prediction['regime_probabilities'])
|
||||
regime_conf = detailed_prediction['regime_probabilities'][regime_idx]
|
||||
logger.info(f" Regime: {regime_idx} (conf: {regime_conf:.3f})")
|
||||
|
||||
if detailed_prediction.get('volatility_prediction') is not None:
|
||||
logger.info(f" Volatility: {detailed_prediction['volatility_prediction']:.3f}")
|
||||
|
||||
if price_change_1m is not None:
|
||||
logger.info(f" Context: 1m_change: {price_change_1m:.4f}, volume_ratio: {volume_ratio:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error logging CNN prediction details: {e}")
|
||||
|
||||
return action_probs, enhanced_confidence
|
||||
|
||||
except Exception as e:
|
||||
|
Reference in New Issue
Block a user