wip
This commit is contained in:
@@ -1182,137 +1182,156 @@ class TradingOrchestrator:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
return 0.0
|
||||
|
||||
# Get standard feature matrix for this timeframe
|
||||
feature_matrix = self.data_provider.get_feature_matrix(
|
||||
symbol=symbol,
|
||||
timeframes=[timeframe],
|
||||
window_size=getattr(model, 'window_size', 20)
|
||||
)
|
||||
|
||||
# Enhance with COB feature matrix if available
|
||||
enhanced_features = feature_matrix
|
||||
if feature_matrix is not None and self.cob_integration:
|
||||
try:
|
||||
# Get COB feature matrix (5-minute history)
|
||||
cob_feature_matrix = self.get_cob_feature_matrix(symbol, sequence_length=60)
|
||||
|
||||
if cob_feature_matrix is not None:
|
||||
# Take the latest COB features to augment the standard features
|
||||
latest_cob_features = cob_feature_matrix[-1:, :] # Shape: (1, 400)
|
||||
|
||||
# Resize to match the feature matrix timeframe dimension
|
||||
timeframe_count = feature_matrix.shape[0]
|
||||
cob_features_expanded = np.repeat(latest_cob_features, timeframe_count, axis=0)
|
||||
|
||||
# Concatenate COB features with standard features
|
||||
# Standard features shape: (timeframes, window_size, features)
|
||||
# COB features shape: (timeframes, 400)
|
||||
# We'll add COB as additional features to each timeframe
|
||||
window_size = feature_matrix.shape[1]
|
||||
cob_features_reshaped = cob_features_expanded.reshape(timeframe_count, 1, 400)
|
||||
cob_features_tiled = np.tile(cob_features_reshaped, (1, window_size, 1))
|
||||
|
||||
# Concatenate along feature dimension
|
||||
enhanced_features = np.concatenate([feature_matrix, cob_features_tiled], axis=2)
|
||||
|
||||
logger.debug(f"Enhanced CNN features with COB data for {symbol}: "
|
||||
f"{feature_matrix.shape} + COB -> {enhanced_features.shape}")
|
||||
|
||||
except Exception as cob_error:
|
||||
logger.debug(f"Could not enhance CNN features with COB data: {cob_error}")
|
||||
enhanced_features = feature_matrix
|
||||
|
||||
# Add extrema features if available
|
||||
if self.extrema_trainer:
|
||||
try:
|
||||
extrema_features = self.extrema_trainer.get_context_features_for_model(symbol)
|
||||
if extrema_features is not None:
|
||||
# Reshape and tile to match the enhanced_features shape
|
||||
extrema_features = extrema_features.flatten()
|
||||
tiled_extrema = np.tile(extrema_features, (enhanced_features.shape[0], enhanced_features.shape[1], 1))
|
||||
enhanced_features = np.concatenate([enhanced_features, tiled_extrema], axis=2)
|
||||
logger.debug(f"Enhanced CNN features with Extrema data for {symbol}")
|
||||
except Exception as extrema_error:
|
||||
logger.debug(f"Could not enhance CNN features with Extrema data: {extrema_error}")
|
||||
|
||||
if enhanced_features is not None:
|
||||
# Get CNN prediction - use the actual underlying model
|
||||
try:
|
||||
# Ensure features are properly shaped and limited
|
||||
if isinstance(enhanced_features, np.ndarray):
|
||||
# Flatten and limit features to prevent shape mismatches
|
||||
enhanced_features = enhanced_features.flatten()
|
||||
if len(enhanced_features) > 100: # Limit to 100 features
|
||||
enhanced_features = enhanced_features[:100]
|
||||
elif len(enhanced_features) < 100: # Pad with zeros
|
||||
padded = np.zeros(100)
|
||||
padded[:len(enhanced_features)] = enhanced_features
|
||||
enhanced_features = padded
|
||||
|
||||
if hasattr(model.model, 'act'):
|
||||
# Use the CNN's act method
|
||||
action_result = model.model.act(enhanced_features, explore=False)
|
||||
if isinstance(action_result, tuple):
|
||||
action_idx, confidence = action_result
|
||||
else:
|
||||
action_idx = action_result
|
||||
confidence = 0.7 # Default confidence
|
||||
|
||||
# Convert to action probabilities
|
||||
action_probs = [0.1, 0.1, 0.8] # Default distribution
|
||||
action_probs[action_idx] = confidence
|
||||
else:
|
||||
# Fallback to generic predict method
|
||||
action_probs, confidence = model.predict(enhanced_features)
|
||||
except Exception as e:
|
||||
logger.warning(f"CNN prediction failed: {e}")
|
||||
action_probs, confidence = None, None
|
||||
|
||||
if action_probs is not None:
|
||||
# Convert to prediction object
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
best_action_idx = np.argmax(action_probs)
|
||||
best_action = action_names[best_action_idx]
|
||||
|
||||
prediction = Prediction(
|
||||
action=best_action,
|
||||
confidence=float(confidence) if confidence is not None else float(action_probs[best_action_idx]),
|
||||
probabilities={name: float(prob) for name, prob in zip(action_names, action_probs)},
|
||||
timeframe=timeframe,
|
||||
timestamp=datetime.now(),
|
||||
model_name=model.name,
|
||||
metadata={
|
||||
'timeframe_specific': True,
|
||||
'cob_enhanced': enhanced_features is not feature_matrix,
|
||||
'feature_shape': str(enhanced_features.shape)
|
||||
}
|
||||
)
|
||||
|
||||
predictions.append(prediction)
|
||||
|
||||
# Capture CNN prediction for dashboard visualization
|
||||
current_price = self._get_current_price(symbol)
|
||||
if current_price:
|
||||
direction = best_action_idx # 0=SELL, 1=HOLD, 2=BUY
|
||||
pred_confidence = float(confidence) if confidence is not None else float(action_probs[best_action_idx])
|
||||
predicted_price = current_price * (1 + (pred_confidence * 0.01 if best_action == 'BUY' else -pred_confidence * 0.01 if best_action == 'SELL' else 0))
|
||||
self.capture_cnn_prediction(symbol, int(direction), pred_confidence, current_price, predicted_price)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting CNN predictions: {e}")
|
||||
|
||||
return predictions
|
||||
|
||||
async def _get_rl_prediction(self, model: RLAgentInterface, symbol: str) -> Optional[Prediction]:
|
||||
"""Get prediction from RL agent"""
|
||||
async def _get_cob_rl_prediction(self, model: COBRLModelInterface, symbol: str) -> Optional[Prediction]:
|
||||
"""Get prediction from COB RL model"""
|
||||
try:
|
||||
# Get current state for RL agent
|
||||
state = self._get_rl_state(symbol)
|
||||
if state is None:
|
||||
# Get COB state from current market data
|
||||
cob_state = self._get_cob_state(symbol)
|
||||
if cob_state is None:
|
||||
return None
|
||||
|
||||
# Get RL agent's action, confidence, and q_values from the underlying model
|
||||
# Get prediction from COB RL model
|
||||
if hasattr(model.model, 'act_with_confidence'):
|
||||
result = model.model.act_with_confidence(cob_state)
|
||||
if len(result) == 2:
|
||||
action_idx, confidence = result
|
||||
else:
|
||||
action_idx = result[0] if isinstance(result, (list, tuple)) else result
|
||||
confidence = 0.6
|
||||
else:
|
||||
action_idx = model.model.act(cob_state)
|
||||
confidence = 0.6
|
||||
|
||||
# Convert to action name
|
||||
action_names = ['BUY', 'SELL', 'HOLD']
|
||||
if 0 <= action_idx < len(action_names):
|
||||
action = action_names[action_idx]
|
||||
else:
|
||||
return None
|
||||
|
||||
# Store prediction in database for tracking
|
||||
if (hasattr(self, 'enhanced_training_system') and
|
||||
self.enhanced_training_system and
|
||||
hasattr(self.enhanced_training_system, 'store_model_prediction')):
|
||||
|
||||
current_price = self._get_current_price_safe(symbol)
|
||||
if current_price > 0:
|
||||
prediction_id = self.enhanced_training_system.store_model_prediction(
|
||||
model_name=f"COB_RL_{model.model_name}" if hasattr(model, 'model_name') else "COB_RL",
|
||||
symbol=symbol,
|
||||
prediction_type=action,
|
||||
confidence=confidence,
|
||||
current_price=current_price
|
||||
)
|
||||
logger.debug(f"Stored COB RL prediction {prediction_id} for {symbol}")
|
||||
|
||||
# Create prediction object
|
||||
prediction = Prediction(
|
||||
model_name=f"COB_RL_{model.model_name}" if hasattr(model, 'model_name') else "COB_RL",
|
||||
symbol=symbol,
|
||||
signal=action,
|
||||
confidence=confidence,
|
||||
reasoning=f"COB RL model prediction based on order book imbalance",
|
||||
features=cob_state.tolist() if isinstance(cob_state, np.ndarray) else [],
|
||||
metadata={
|
||||
'action_idx': action_idx,
|
||||
'cob_state_size': len(cob_state) if cob_state is not None else 0
|
||||
}
|
||||
)
|
||||
|
||||
return prediction
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting COB RL prediction for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
async def _get_generic_prediction(self, model, symbol: str) -> Optional[Prediction]:
|
||||
"""Get prediction from generic model interface"""
|
||||
try:
|
||||
# Placeholder for generic model prediction
|
||||
logger.debug(f"Getting generic prediction from {model} for {symbol}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting generic prediction for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Build RL state vector for DQN agent"""
|
||||
try:
|
||||
# Use data provider to get comprehensive RL state
|
||||
if hasattr(self.data_provider, 'get_dqn_state_for_inference'):
|
||||
symbols_timeframes = [(symbol, '1m'), (symbol, '5m'), (symbol, '1h')]
|
||||
state = self.data_provider.get_dqn_state_for_inference(symbols_timeframes, target_size=100)
|
||||
if state is not None:
|
||||
return state
|
||||
|
||||
# Fallback: build basic state from market data
|
||||
market_features = []
|
||||
|
||||
# Get latest price data
|
||||
latest_data = self.data_provider.get_latest_data(symbol)
|
||||
if latest_data and 'close' in latest_data:
|
||||
current_price = float(latest_data['close'])
|
||||
market_features.extend([
|
||||
current_price,
|
||||
latest_data.get('volume', 0.0),
|
||||
latest_data.get('high', current_price) - latest_data.get('low', current_price), # Range
|
||||
latest_data.get('open', current_price)
|
||||
])
|
||||
else:
|
||||
market_features.extend([4300.0, 100.0, 10.0, 4295.0]) # Default values
|
||||
|
||||
# Pad to standard size
|
||||
while len(market_features) < 100:
|
||||
market_features.append(0.0)
|
||||
|
||||
return np.array(market_features[:100], dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error building RL state for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _get_cob_state(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Build COB state vector for COB RL agent"""
|
||||
try:
|
||||
# Get COB data from integration
|
||||
if hasattr(self, 'cob_integration') and self.cob_integration:
|
||||
cob_snapshot = self.cob_integration.get_cob_snapshot(symbol)
|
||||
if cob_snapshot:
|
||||
# Extract features from COB snapshot
|
||||
features = []
|
||||
|
||||
# Add bid/ask imbalance
|
||||
bid_volume = sum([level['volume'] for level in cob_snapshot.get('bids', [])])
|
||||
ask_volume = sum([level['volume'] for level in cob_snapshot.get('asks', [])])
|
||||
if bid_volume + ask_volume > 0:
|
||||
imbalance = (bid_volume - ask_volume) / (bid_volume + ask_volume)
|
||||
else:
|
||||
imbalance = 0.0
|
||||
features.append(imbalance)
|
||||
|
||||
# Add spread
|
||||
if cob_snapshot.get('bids') and cob_snapshot.get('asks'):
|
||||
spread = cob_snapshot['asks'][0]['price'] - cob_snapshot['bids'][0]['price']
|
||||
features.append(spread)
|
||||
else:
|
||||
features.append(0.0)
|
||||
|
||||
# Pad to standard size
|
||||
while len(features) < 50:
|
||||
features.append(0.0)
|
||||
|
||||
return np.array(features[:50], dtype=np.float32)
|
||||
|
||||
# Fallback state
|
||||
return np.zeros(50, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error building COB state for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _combine_predictions(self, symbol: str, price: float, predictions: List[Prediction],
|
||||
timestamp: datetime) -> TradingDecision:
|
||||
# Call act_with_confidence and handle different return formats
|
||||
result = model.model.act_with_confidence(state)
|
||||
|
||||
@@ -1728,7 +1747,7 @@ class TradingOrchestrator:
|
||||
)
|
||||
|
||||
if needs_refresh:
|
||||
result = load_best_checkpoint(model_name)
|
||||
result = load_best_checkpoint(model_name)
|
||||
self._checkpoint_cache[model_name] = result
|
||||
self._checkpoint_cache_time[model_name] = current_time
|
||||
|
||||
@@ -1874,9 +1893,9 @@ class TradingOrchestrator:
|
||||
try:
|
||||
from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
||||
|
||||
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
|
||||
orchestrator=self,
|
||||
data_provider=self.data_provider,
|
||||
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
|
||||
orchestrator=self,
|
||||
data_provider=self.data_provider,
|
||||
dashboard=None
|
||||
)
|
||||
|
||||
|
@@ -13,7 +13,7 @@ import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Optional
|
||||
import numpy as np
|
||||
from utils.reward_calculator import RewardCalculator
|
||||
from core.reward_calculator import RewardCalculator
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
Binary file not shown.
@@ -345,18 +345,61 @@ class CleanTradingDashboard:
|
||||
'timestamp': datetime.now().isoformat()
|
||||
})
|
||||
|
||||
@self.app.server.route('/api/predictions/stats', methods=['GET'])
|
||||
def get_prediction_stats():
|
||||
"""Get model prediction statistics"""
|
||||
@self.app.server.route('/api/predictions/recent', methods=['GET'])
|
||||
def get_recent_predictions():
|
||||
"""Get recent predictions with their outcomes"""
|
||||
try:
|
||||
if (hasattr(self.orchestrator, 'enhanced_training_system') and
|
||||
self.orchestrator.enhanced_training_system):
|
||||
stats = self.orchestrator.enhanced_training_system.get_model_performance_stats()
|
||||
return jsonify(stats)
|
||||
|
||||
# Get predictions from database
|
||||
from core.prediction_database import get_prediction_db
|
||||
db = get_prediction_db()
|
||||
|
||||
# Get recent predictions (last 24 hours)
|
||||
predictions = []
|
||||
|
||||
# Mock data for now - replace with actual database query
|
||||
import sqlite3
|
||||
try:
|
||||
with sqlite3.connect(db.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT model_name, symbol, prediction_type, confidence,
|
||||
timestamp, price_at_prediction, outcome_timestamp,
|
||||
actual_price_change, reward, is_correct
|
||||
FROM predictions
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 50
|
||||
""")
|
||||
|
||||
for row in cursor.fetchall():
|
||||
predictions.append({
|
||||
'model_name': row[0],
|
||||
'symbol': row[1],
|
||||
'prediction_type': row[2],
|
||||
'confidence': row[3],
|
||||
'timestamp': row[4],
|
||||
'price_at_prediction': row[5],
|
||||
'outcome_timestamp': row[6],
|
||||
'actual_price_change': row[7],
|
||||
'reward': row[8],
|
||||
'is_correct': row[9],
|
||||
'is_resolved': row[6] is not None
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"Error fetching predictions from database: {e}")
|
||||
|
||||
return jsonify({
|
||||
'predictions': predictions,
|
||||
'total_predictions': len(predictions),
|
||||
'active_predictions': len([p for p in predictions if not p['is_resolved']]),
|
||||
'timestamp': datetime.now().isoformat()
|
||||
})
|
||||
else:
|
||||
return jsonify({"error": "Training system not available"}), 503
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting prediction stats: {e}")
|
||||
logger.error(f"Error getting recent predictions: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
def _get_ohlcv_data_with_indicators(self, symbol: str, timeframe: str, limit: int = 300):
|
||||
@@ -980,6 +1023,135 @@ class CleanTradingDashboard:
|
||||
logger.error(f"Training status error: {e}")
|
||||
return 'Error', 'badge bg-danger small'
|
||||
|
||||
@self.app.callback(
|
||||
[Output('total-predictions-count', 'children'),
|
||||
Output('pending-predictions-count', 'children'),
|
||||
Output('active-models-count', 'children'),
|
||||
Output('total-rewards-sum', 'children'),
|
||||
Output('prediction-timeline-chart', 'figure'),
|
||||
Output('model-performance-chart', 'figure')],
|
||||
[Input('interval-component', 'n_intervals')]
|
||||
)
|
||||
def update_prediction_tracking(n_intervals):
|
||||
"""Update prediction tracking charts and metrics"""
|
||||
try:
|
||||
if (hasattr(self.orchestrator, 'enhanced_training_system') and
|
||||
self.orchestrator.enhanced_training_system):
|
||||
|
||||
# Get prediction data
|
||||
stats = self.orchestrator.enhanced_training_system.get_model_performance_stats()
|
||||
models = stats.get('models', [])
|
||||
total_active = stats.get('total_active_predictions', 0)
|
||||
|
||||
# Calculate totals
|
||||
total_predictions = sum(m.get('total_predictions', 0) for m in models)
|
||||
total_rewards = sum(m.get('total_reward', 0) for m in models)
|
||||
active_models = len(models)
|
||||
|
||||
# Create timeline chart (simplified)
|
||||
timeline_fig = {
|
||||
'data': [],
|
||||
'layout': {
|
||||
'title': 'Recent Predictions Timeline',
|
||||
'xaxis': {'title': 'Time'},
|
||||
'yaxis': {'title': 'Confidence'},
|
||||
'template': 'plotly_dark',
|
||||
'height': 300,
|
||||
'showlegend': False
|
||||
}
|
||||
}
|
||||
|
||||
# Add empty annotation if no data
|
||||
if not models:
|
||||
timeline_fig['layout']['annotations'] = [{
|
||||
'text': 'No prediction data yet',
|
||||
'xref': 'paper', 'yref': 'paper',
|
||||
'x': 0.5, 'y': 0.5,
|
||||
'showarrow': False,
|
||||
'font': {'size': 16, 'color': 'gray'}
|
||||
}]
|
||||
|
||||
# Create performance chart
|
||||
performance_fig = {
|
||||
'data': [],
|
||||
'layout': {
|
||||
'title': 'Model Performance',
|
||||
'template': 'plotly_dark',
|
||||
'height': 300,
|
||||
'showlegend': True
|
||||
}
|
||||
}
|
||||
|
||||
if models:
|
||||
model_names = [m.get('model_name', 'Unknown') for m in models]
|
||||
accuracies = [m.get('accuracy', 0) * 100 for m in models]
|
||||
rewards = [m.get('total_reward', 0) for m in models]
|
||||
|
||||
# Add accuracy bars
|
||||
performance_fig['data'].append({
|
||||
'x': model_names,
|
||||
'y': accuracies,
|
||||
'type': 'bar',
|
||||
'name': 'Accuracy (%)',
|
||||
'marker': {'color': 'lightblue'}
|
||||
})
|
||||
|
||||
performance_fig['layout']['xaxis'] = {'title': 'Model'}
|
||||
performance_fig['layout']['yaxis'] = {'title': 'Accuracy (%)'}
|
||||
else:
|
||||
performance_fig['layout']['annotations'] = [{
|
||||
'text': 'No model data yet',
|
||||
'xref': 'paper', 'yref': 'paper',
|
||||
'x': 0.5, 'y': 0.5,
|
||||
'showarrow': False,
|
||||
'font': {'size': 16, 'color': 'gray'}
|
||||
}]
|
||||
|
||||
return (
|
||||
str(total_predictions),
|
||||
str(total_active),
|
||||
str(active_models),
|
||||
f"{total_rewards:.1f}",
|
||||
timeline_fig,
|
||||
performance_fig
|
||||
)
|
||||
else:
|
||||
# Training system not available
|
||||
empty_fig = {
|
||||
'data': [],
|
||||
'layout': {
|
||||
'template': 'plotly_dark',
|
||||
'height': 300,
|
||||
'annotations': [{
|
||||
'text': 'Training system not available',
|
||||
'xref': 'paper', 'yref': 'paper',
|
||||
'x': 0.5, 'y': 0.5,
|
||||
'showarrow': False,
|
||||
'font': {'size': 16, 'color': 'red'}
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
return "N/A", "N/A", "N/A", "N/A", empty_fig, empty_fig
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating prediction tracking: {e}")
|
||||
error_fig = {
|
||||
'data': [],
|
||||
'layout': {
|
||||
'template': 'plotly_dark',
|
||||
'height': 300,
|
||||
'annotations': [{
|
||||
'text': f'Error: {str(e)}',
|
||||
'xref': 'paper', 'yref': 'paper',
|
||||
'x': 0.5, 'y': 0.5,
|
||||
'showarrow': False,
|
||||
'font': {'size': 14, 'color': 'red'}
|
||||
}]
|
||||
}
|
||||
}
|
||||
return "Error", "Error", "Error", "Error", error_fig, error_fig
|
||||
|
||||
@self.app.callback(
|
||||
[Output('eth-cob-content', 'children'),
|
||||
Output('btc-cob-content', 'children')],
|
||||
|
@@ -19,13 +19,72 @@ class DashboardLayoutManager:
|
||||
return html.Div([
|
||||
self._create_header(),
|
||||
self._create_interval_component(),
|
||||
self._create_main_content()
|
||||
self._create_main_content(),
|
||||
self._create_prediction_tracking_section() # NEW: Prediction tracking
|
||||
], className="container-fluid", style={
|
||||
"backgroundColor": "#111827",
|
||||
"minHeight": "100vh",
|
||||
"color": "#f8f9fa"
|
||||
})
|
||||
|
||||
def _create_prediction_tracking_section(self):
|
||||
"""Create prediction tracking and model performance section"""
|
||||
return html.Div([
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6([
|
||||
html.I(className="fas fa-brain me-2"),
|
||||
"🧠 Model Predictions & Performance Tracking"
|
||||
], className="text-light mb-3"),
|
||||
|
||||
# Summary cards row
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6("0", id="total-predictions-count", className="mb-0 text-primary"),
|
||||
html.Small("Total Predictions", className="text-light")
|
||||
], className="card-body text-center p-2 bg-dark")
|
||||
], className="card col-md-3 mx-1 bg-dark border-secondary"),
|
||||
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6("0", id="pending-predictions-count", className="mb-0 text-warning"),
|
||||
html.Small("Pending Resolution", className="text-light")
|
||||
], className="card-body text-center p-2 bg-dark")
|
||||
], className="card col-md-3 mx-1 bg-dark border-secondary"),
|
||||
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6("0", id="active-models-count", className="mb-0 text-info"),
|
||||
html.Small("Active Models", className="text-light")
|
||||
], className="card-body text-center p-2 bg-dark")
|
||||
], className="card col-md-3 mx-1 bg-dark border-secondary"),
|
||||
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6("0.0", id="total-rewards-sum", className="mb-0 text-success"),
|
||||
html.Small("Total Rewards", className="text-light")
|
||||
], className="card-body text-center p-2 bg-dark")
|
||||
], className="card col-md-3 mx-1 bg-dark border-secondary")
|
||||
], className="row mb-3"),
|
||||
|
||||
# Charts row
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6("Recent Predictions Timeline", className="mb-2 text-light"),
|
||||
dcc.Graph(id="prediction-timeline-chart", style={"height": "300px"})
|
||||
], className="col-md-6"),
|
||||
|
||||
html.Div([
|
||||
html.H6("Model Performance", className="mb-2 text-light"),
|
||||
dcc.Graph(id="model-performance-chart", style={"height": "300px"})
|
||||
], className="col-md-6")
|
||||
], className="row")
|
||||
|
||||
], className="p-3")
|
||||
], className="card bg-dark border-secondary mb-3")
|
||||
], className="mt-3")
|
||||
|
||||
def _create_header(self):
|
||||
"""Create the dashboard header"""
|
||||
trading_mode = "SIMULATION" if (not self.trading_executor or
|
||||
|
352
web/prediction_chart.py
Normal file
352
web/prediction_chart.py
Normal file
@@ -0,0 +1,352 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Prediction Chart Component - Visualizes model predictions and their outcomes
|
||||
"""
|
||||
|
||||
import dash
|
||||
from dash import dcc, html, dash_table
|
||||
import plotly.graph_objs as go
|
||||
import plotly.express as px
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PredictionChartComponent:
|
||||
"""Component for visualizing prediction tracking and outcomes"""
|
||||
|
||||
def __init__(self):
|
||||
self.colors = {
|
||||
'BUY': '#28a745', # Green
|
||||
'SELL': '#dc3545', # Red
|
||||
'HOLD': '#6c757d', # Gray
|
||||
'reward': '#28a745', # Green for positive rewards
|
||||
'penalty': '#dc3545' # Red for negative rewards
|
||||
}
|
||||
|
||||
def create_prediction_timeline_chart(self, predictions_data: List[Dict[str, Any]]) -> dcc.Graph:
|
||||
"""Create a timeline chart showing predictions and their outcomes"""
|
||||
try:
|
||||
if not predictions_data:
|
||||
# Empty chart
|
||||
fig = go.Figure()
|
||||
fig.add_annotation(
|
||||
text="No prediction data available",
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5, xanchor='center', yanchor='middle',
|
||||
showarrow=False, font=dict(size=16, color="gray")
|
||||
)
|
||||
fig.update_layout(
|
||||
title="Model Predictions Timeline",
|
||||
xaxis_title="Time",
|
||||
yaxis_title="Confidence",
|
||||
height=300
|
||||
)
|
||||
return dcc.Graph(figure=fig, id="prediction-timeline")
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(predictions_data)
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'])
|
||||
|
||||
# Create the plot
|
||||
fig = go.Figure()
|
||||
|
||||
# Add prediction points
|
||||
for prediction_type in ['BUY', 'SELL', 'HOLD']:
|
||||
type_data = df[df['prediction_type'] == prediction_type]
|
||||
if not type_data.empty:
|
||||
# Different markers for resolved vs pending
|
||||
resolved_data = type_data[type_data['is_resolved'] == True]
|
||||
pending_data = type_data[type_data['is_resolved'] == False]
|
||||
|
||||
if not resolved_data.empty:
|
||||
# Resolved predictions
|
||||
colors = [self.colors['reward'] if r > 0 else self.colors['penalty']
|
||||
for r in resolved_data['reward']]
|
||||
fig.add_trace(go.Scatter(
|
||||
x=resolved_data['timestamp'],
|
||||
y=resolved_data['confidence'],
|
||||
mode='markers',
|
||||
marker=dict(
|
||||
size=10,
|
||||
color=colors,
|
||||
symbol='circle',
|
||||
line=dict(width=2, color=self.colors[prediction_type])
|
||||
),
|
||||
name=f'{prediction_type} (Resolved)',
|
||||
text=[f"Model: {m}<br>Confidence: {c:.3f}<br>Reward: {r:.2f}"
|
||||
for m, c, r in zip(resolved_data['model_name'],
|
||||
resolved_data['confidence'],
|
||||
resolved_data['reward'])],
|
||||
hovertemplate='%{text}<extra></extra>'
|
||||
))
|
||||
|
||||
if not pending_data.empty:
|
||||
# Pending predictions
|
||||
fig.add_trace(go.Scatter(
|
||||
x=pending_data['timestamp'],
|
||||
y=pending_data['confidence'],
|
||||
mode='markers',
|
||||
marker=dict(
|
||||
size=8,
|
||||
color=self.colors[prediction_type],
|
||||
symbol='circle-open',
|
||||
line=dict(width=2)
|
||||
),
|
||||
name=f'{prediction_type} (Pending)',
|
||||
text=[f"Model: {m}<br>Confidence: {c:.3f}<br>Status: Pending"
|
||||
for m, c in zip(pending_data['model_name'],
|
||||
pending_data['confidence'])],
|
||||
hovertemplate='%{text}<extra></extra>'
|
||||
))
|
||||
|
||||
# Update layout
|
||||
fig.update_layout(
|
||||
title="Model Predictions Timeline",
|
||||
xaxis_title="Time",
|
||||
yaxis_title="Confidence",
|
||||
yaxis=dict(range=[0, 1]),
|
||||
height=400,
|
||||
showlegend=True,
|
||||
legend=dict(x=0.02, y=0.98),
|
||||
hovermode='closest'
|
||||
)
|
||||
|
||||
return dcc.Graph(figure=fig, id="prediction-timeline")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating prediction timeline chart: {e}")
|
||||
# Return empty chart on error
|
||||
fig = go.Figure()
|
||||
fig.add_annotation(text=f"Error: {str(e)}", x=0.5, y=0.5)
|
||||
return dcc.Graph(figure=fig, id="prediction-timeline")
|
||||
|
||||
def create_model_performance_chart(self, model_stats: List[Dict[str, Any]]) -> dcc.Graph:
|
||||
"""Create a bar chart showing model performance metrics"""
|
||||
try:
|
||||
if not model_stats:
|
||||
fig = go.Figure()
|
||||
fig.add_annotation(
|
||||
text="No model performance data available",
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5, xanchor='center', yanchor='middle',
|
||||
showarrow=False, font=dict(size=16, color="gray")
|
||||
)
|
||||
fig.update_layout(
|
||||
title="Model Performance Comparison",
|
||||
height=300
|
||||
)
|
||||
return dcc.Graph(figure=fig, id="model-performance")
|
||||
|
||||
# Extract data
|
||||
model_names = [stats['model_name'] for stats in model_stats]
|
||||
accuracies = [stats['accuracy'] * 100 for stats in model_stats] # Convert to percentage
|
||||
total_rewards = [stats['total_reward'] for stats in model_stats]
|
||||
total_predictions = [stats['total_predictions'] for stats in model_stats]
|
||||
|
||||
# Create subplots
|
||||
fig = go.Figure()
|
||||
|
||||
# Add accuracy bars
|
||||
fig.add_trace(go.Bar(
|
||||
x=model_names,
|
||||
y=accuracies,
|
||||
name='Accuracy (%)',
|
||||
marker_color='lightblue',
|
||||
yaxis='y',
|
||||
text=[f"{a:.1f}%" for a in accuracies],
|
||||
textposition='auto'
|
||||
))
|
||||
|
||||
# Add total reward on secondary y-axis
|
||||
fig.add_trace(go.Scatter(
|
||||
x=model_names,
|
||||
y=total_rewards,
|
||||
mode='markers+text',
|
||||
name='Total Reward',
|
||||
marker=dict(
|
||||
size=12,
|
||||
color='orange',
|
||||
symbol='diamond'
|
||||
),
|
||||
yaxis='y2',
|
||||
text=[f"{r:.1f}" for r in total_rewards],
|
||||
textposition='top center'
|
||||
))
|
||||
|
||||
# Update layout
|
||||
fig.update_layout(
|
||||
title="Model Performance Comparison",
|
||||
xaxis_title="Model",
|
||||
yaxis=dict(
|
||||
title="Accuracy (%)",
|
||||
side="left",
|
||||
range=[0, 100]
|
||||
),
|
||||
yaxis2=dict(
|
||||
title="Total Reward",
|
||||
side="right",
|
||||
overlaying="y"
|
||||
),
|
||||
height=400,
|
||||
showlegend=True,
|
||||
legend=dict(x=0.02, y=0.98)
|
||||
)
|
||||
|
||||
return dcc.Graph(figure=fig, id="model-performance")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating model performance chart: {e}")
|
||||
fig = go.Figure()
|
||||
fig.add_annotation(text=f"Error: {str(e)}", x=0.5, y=0.5)
|
||||
return dcc.Graph(figure=fig, id="model-performance")
|
||||
|
||||
def create_prediction_table(self, recent_predictions: List[Dict[str, Any]]) -> dash_table.DataTable:
|
||||
"""Create a table showing recent predictions"""
|
||||
try:
|
||||
if not recent_predictions:
|
||||
return dash_table.DataTable(
|
||||
id="prediction-table",
|
||||
columns=[
|
||||
{"name": "Model", "id": "model_name"},
|
||||
{"name": "Symbol", "id": "symbol"},
|
||||
{"name": "Prediction", "id": "prediction_type"},
|
||||
{"name": "Confidence", "id": "confidence"},
|
||||
{"name": "Status", "id": "status"},
|
||||
{"name": "Reward", "id": "reward"}
|
||||
],
|
||||
data=[],
|
||||
style_cell={'textAlign': 'center'},
|
||||
style_header={'backgroundColor': 'rgb(230, 230, 230)', 'fontWeight': 'bold'},
|
||||
page_size=10
|
||||
)
|
||||
|
||||
# Format data for table
|
||||
table_data = []
|
||||
for pred in recent_predictions[-20:]: # Show last 20 predictions
|
||||
table_data.append({
|
||||
'model_name': pred.get('model_name', 'Unknown'),
|
||||
'symbol': pred.get('symbol', 'N/A'),
|
||||
'prediction_type': pred.get('prediction_type', 'N/A'),
|
||||
'confidence': f"{pred.get('confidence', 0):.3f}",
|
||||
'status': 'Resolved' if pred.get('is_resolved', False) else 'Pending',
|
||||
'reward': f"{pred.get('reward', 0):.2f}" if pred.get('is_resolved', False) else 'Pending'
|
||||
})
|
||||
|
||||
return dash_table.DataTable(
|
||||
id="prediction-table",
|
||||
columns=[
|
||||
{"name": "Model", "id": "model_name"},
|
||||
{"name": "Symbol", "id": "symbol"},
|
||||
{"name": "Prediction", "id": "prediction_type"},
|
||||
{"name": "Confidence", "id": "confidence"},
|
||||
{"name": "Status", "id": "status"},
|
||||
{"name": "Reward", "id": "reward"}
|
||||
],
|
||||
data=table_data,
|
||||
style_cell={'textAlign': 'center', 'fontSize': '12px'},
|
||||
style_header={'backgroundColor': 'rgb(230, 230, 230)', 'fontWeight': 'bold'},
|
||||
style_data_conditional=[
|
||||
{
|
||||
'if': {'filter_query': '{status} = Resolved and {reward} > 0'},
|
||||
'backgroundColor': 'rgba(40, 167, 69, 0.1)',
|
||||
'color': 'black',
|
||||
},
|
||||
{
|
||||
'if': {'filter_query': '{status} = Resolved and {reward} < 0'},
|
||||
'backgroundColor': 'rgba(220, 53, 69, 0.1)',
|
||||
'color': 'black',
|
||||
},
|
||||
{
|
||||
'if': {'filter_query': '{status} = Pending'},
|
||||
'backgroundColor': 'rgba(108, 117, 125, 0.1)',
|
||||
'color': 'black',
|
||||
}
|
||||
],
|
||||
page_size=10,
|
||||
sort_action="native"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating prediction table: {e}")
|
||||
return dash_table.DataTable(
|
||||
id="prediction-table",
|
||||
columns=[{"name": "Error", "id": "error"}],
|
||||
data=[{"error": str(e)}]
|
||||
)
|
||||
|
||||
def create_prediction_panel(self, prediction_stats: Dict[str, Any]) -> html.Div:
|
||||
"""Create a complete prediction tracking panel"""
|
||||
try:
|
||||
predictions_data = prediction_stats.get('predictions', [])
|
||||
model_stats = prediction_stats.get('models', [])
|
||||
|
||||
return html.Div([
|
||||
html.H4("📊 Prediction Tracking & Performance", className="mb-3"),
|
||||
|
||||
# Summary cards
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6(f"{prediction_stats.get('total_predictions', 0)}", className="mb-0"),
|
||||
html.Small("Total Predictions", className="text-muted")
|
||||
], className="card-body text-center"),
|
||||
], className="card col-md-3 mx-1"),
|
||||
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6(f"{prediction_stats.get('active_predictions', 0)}", className="mb-0"),
|
||||
html.Small("Pending Resolution", className="text-muted")
|
||||
], className="card-body text-center"),
|
||||
], className="card col-md-3 mx-1"),
|
||||
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6(f"{len(model_stats)}", className="mb-0"),
|
||||
html.Small("Active Models", className="text-muted")
|
||||
], className="card-body text-center"),
|
||||
], className="card col-md-3 mx-1"),
|
||||
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6(f"{sum(s.get('total_reward', 0) for s in model_stats):.1f}", className="mb-0"),
|
||||
html.Small("Total Rewards", className="text-muted")
|
||||
], className="card-body text-center"),
|
||||
], className="card col-md-3 mx-1")
|
||||
|
||||
], className="row mb-4"),
|
||||
|
||||
# Charts
|
||||
html.Div([
|
||||
html.Div([
|
||||
self.create_prediction_timeline_chart(predictions_data)
|
||||
], className="col-md-6"),
|
||||
|
||||
html.Div([
|
||||
self.create_model_performance_chart(model_stats)
|
||||
], className="col-md-6")
|
||||
], className="row mb-4"),
|
||||
|
||||
# Recent predictions table
|
||||
html.Div([
|
||||
html.H5("Recent Predictions", className="mb-2"),
|
||||
self.create_prediction_table(predictions_data)
|
||||
], className="mb-3")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating prediction panel: {e}")
|
||||
return html.Div([
|
||||
html.H4("📊 Prediction Tracking & Performance"),
|
||||
html.P(f"Error loading prediction data: {str(e)}", className="text-danger")
|
||||
])
|
||||
|
||||
# Global instance
|
||||
_prediction_chart = None
|
||||
|
||||
def get_prediction_chart() -> PredictionChartComponent:
|
||||
"""Get global prediction chart component"""
|
||||
global _prediction_chart
|
||||
if _prediction_chart is None:
|
||||
_prediction_chart = PredictionChartComponent()
|
||||
return _prediction_chart
|
Reference in New Issue
Block a user