353 lines
15 KiB
Python
353 lines
15 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Prediction Chart Component - Visualizes model predictions and their outcomes
|
|
"""
|
|
|
|
import dash
|
|
from dash import dcc, html, dash_table
|
|
import plotly.graph_objs as go
|
|
import plotly.express as px
|
|
import pandas as pd
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Any, Optional
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class PredictionChartComponent:
|
|
"""Component for visualizing prediction tracking and outcomes"""
|
|
|
|
def __init__(self):
|
|
self.colors = {
|
|
'BUY': '#28a745', # Green
|
|
'SELL': '#dc3545', # Red
|
|
'HOLD': '#6c757d', # Gray
|
|
'reward': '#28a745', # Green for positive rewards
|
|
'penalty': '#dc3545' # Red for negative rewards
|
|
}
|
|
|
|
def create_prediction_timeline_chart(self, predictions_data: List[Dict[str, Any]]) -> dcc.Graph:
|
|
"""Create a timeline chart showing predictions and their outcomes"""
|
|
try:
|
|
if not predictions_data:
|
|
# Empty chart
|
|
fig = go.Figure()
|
|
fig.add_annotation(
|
|
text="No prediction data available",
|
|
xref="paper", yref="paper",
|
|
x=0.5, y=0.5, xanchor='center', yanchor='middle',
|
|
showarrow=False, font=dict(size=16, color="gray")
|
|
)
|
|
fig.update_layout(
|
|
title="Model Predictions Timeline",
|
|
xaxis_title="Time",
|
|
yaxis_title="Confidence",
|
|
height=300
|
|
)
|
|
return dcc.Graph(figure=fig, id="prediction-timeline")
|
|
|
|
# Convert to DataFrame
|
|
df = pd.DataFrame(predictions_data)
|
|
df['timestamp'] = pd.to_datetime(df['timestamp'])
|
|
|
|
# Create the plot
|
|
fig = go.Figure()
|
|
|
|
# Add prediction points
|
|
for prediction_type in ['BUY', 'SELL', 'HOLD']:
|
|
type_data = df[df['prediction_type'] == prediction_type]
|
|
if not type_data.empty:
|
|
# Different markers for resolved vs pending
|
|
resolved_data = type_data[type_data['is_resolved'] == True]
|
|
pending_data = type_data[type_data['is_resolved'] == False]
|
|
|
|
if not resolved_data.empty:
|
|
# Resolved predictions
|
|
colors = [self.colors['reward'] if r > 0 else self.colors['penalty']
|
|
for r in resolved_data['reward']]
|
|
fig.add_trace(go.Scatter(
|
|
x=resolved_data['timestamp'],
|
|
y=resolved_data['confidence'],
|
|
mode='markers',
|
|
marker=dict(
|
|
size=10,
|
|
color=colors,
|
|
symbol='circle',
|
|
line=dict(width=2, color=self.colors[prediction_type])
|
|
),
|
|
name=f'{prediction_type} (Resolved)',
|
|
text=[f"Model: {m}<br>Confidence: {c:.3f}<br>Reward: {r:.2f}"
|
|
for m, c, r in zip(resolved_data['model_name'],
|
|
resolved_data['confidence'],
|
|
resolved_data['reward'])],
|
|
hovertemplate='%{text}<extra></extra>'
|
|
))
|
|
|
|
if not pending_data.empty:
|
|
# Pending predictions
|
|
fig.add_trace(go.Scatter(
|
|
x=pending_data['timestamp'],
|
|
y=pending_data['confidence'],
|
|
mode='markers',
|
|
marker=dict(
|
|
size=8,
|
|
color=self.colors[prediction_type],
|
|
symbol='circle-open',
|
|
line=dict(width=2)
|
|
),
|
|
name=f'{prediction_type} (Pending)',
|
|
text=[f"Model: {m}<br>Confidence: {c:.3f}<br>Status: Pending"
|
|
for m, c in zip(pending_data['model_name'],
|
|
pending_data['confidence'])],
|
|
hovertemplate='%{text}<extra></extra>'
|
|
))
|
|
|
|
# Update layout
|
|
fig.update_layout(
|
|
title="Model Predictions Timeline",
|
|
xaxis_title="Time",
|
|
yaxis_title="Confidence",
|
|
yaxis=dict(range=[0, 1]),
|
|
height=400,
|
|
showlegend=True,
|
|
legend=dict(x=0.02, y=0.98),
|
|
hovermode='closest'
|
|
)
|
|
|
|
return dcc.Graph(figure=fig, id="prediction-timeline")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating prediction timeline chart: {e}")
|
|
# Return empty chart on error
|
|
fig = go.Figure()
|
|
fig.add_annotation(text=f"Error: {str(e)}", x=0.5, y=0.5)
|
|
return dcc.Graph(figure=fig, id="prediction-timeline")
|
|
|
|
def create_model_performance_chart(self, model_stats: List[Dict[str, Any]]) -> dcc.Graph:
|
|
"""Create a bar chart showing model performance metrics"""
|
|
try:
|
|
if not model_stats:
|
|
fig = go.Figure()
|
|
fig.add_annotation(
|
|
text="No model performance data available",
|
|
xref="paper", yref="paper",
|
|
x=0.5, y=0.5, xanchor='center', yanchor='middle',
|
|
showarrow=False, font=dict(size=16, color="gray")
|
|
)
|
|
fig.update_layout(
|
|
title="Model Performance Comparison",
|
|
height=300
|
|
)
|
|
return dcc.Graph(figure=fig, id="model-performance")
|
|
|
|
# Extract data
|
|
model_names = [stats['model_name'] for stats in model_stats]
|
|
accuracies = [stats['accuracy'] * 100 for stats in model_stats] # Convert to percentage
|
|
total_rewards = [stats['total_reward'] for stats in model_stats]
|
|
total_predictions = [stats['total_predictions'] for stats in model_stats]
|
|
|
|
# Create subplots
|
|
fig = go.Figure()
|
|
|
|
# Add accuracy bars
|
|
fig.add_trace(go.Bar(
|
|
x=model_names,
|
|
y=accuracies,
|
|
name='Accuracy (%)',
|
|
marker_color='lightblue',
|
|
yaxis='y',
|
|
text=[f"{a:.1f}%" for a in accuracies],
|
|
textposition='auto'
|
|
))
|
|
|
|
# Add total reward on secondary y-axis
|
|
fig.add_trace(go.Scatter(
|
|
x=model_names,
|
|
y=total_rewards,
|
|
mode='markers+text',
|
|
name='Total Reward',
|
|
marker=dict(
|
|
size=12,
|
|
color='orange',
|
|
symbol='diamond'
|
|
),
|
|
yaxis='y2',
|
|
text=[f"{r:.1f}" for r in total_rewards],
|
|
textposition='top center'
|
|
))
|
|
|
|
# Update layout
|
|
fig.update_layout(
|
|
title="Model Performance Comparison",
|
|
xaxis_title="Model",
|
|
yaxis=dict(
|
|
title="Accuracy (%)",
|
|
side="left",
|
|
range=[0, 100]
|
|
),
|
|
yaxis2=dict(
|
|
title="Total Reward",
|
|
side="right",
|
|
overlaying="y"
|
|
),
|
|
height=400,
|
|
showlegend=True,
|
|
legend=dict(x=0.02, y=0.98)
|
|
)
|
|
|
|
return dcc.Graph(figure=fig, id="model-performance")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating model performance chart: {e}")
|
|
fig = go.Figure()
|
|
fig.add_annotation(text=f"Error: {str(e)}", x=0.5, y=0.5)
|
|
return dcc.Graph(figure=fig, id="model-performance")
|
|
|
|
def create_prediction_table(self, recent_predictions: List[Dict[str, Any]]) -> dash_table.DataTable:
|
|
"""Create a table showing recent predictions"""
|
|
try:
|
|
if not recent_predictions:
|
|
return dash_table.DataTable(
|
|
id="prediction-table",
|
|
columns=[
|
|
{"name": "Model", "id": "model_name"},
|
|
{"name": "Symbol", "id": "symbol"},
|
|
{"name": "Prediction", "id": "prediction_type"},
|
|
{"name": "Confidence", "id": "confidence"},
|
|
{"name": "Status", "id": "status"},
|
|
{"name": "Reward", "id": "reward"}
|
|
],
|
|
data=[],
|
|
style_cell={'textAlign': 'center'},
|
|
style_header={'backgroundColor': 'rgb(230, 230, 230)', 'fontWeight': 'bold'},
|
|
page_size=10
|
|
)
|
|
|
|
# Format data for table
|
|
table_data = []
|
|
for pred in recent_predictions[-20:]: # Show last 20 predictions
|
|
table_data.append({
|
|
'model_name': pred.get('model_name', 'Unknown'),
|
|
'symbol': pred.get('symbol', 'N/A'),
|
|
'prediction_type': pred.get('prediction_type', 'N/A'),
|
|
'confidence': f"{pred.get('confidence', 0):.3f}",
|
|
'status': 'Resolved' if pred.get('is_resolved', False) else 'Pending',
|
|
'reward': f"{pred.get('reward', 0):.2f}" if pred.get('is_resolved', False) else 'Pending'
|
|
})
|
|
|
|
return dash_table.DataTable(
|
|
id="prediction-table",
|
|
columns=[
|
|
{"name": "Model", "id": "model_name"},
|
|
{"name": "Symbol", "id": "symbol"},
|
|
{"name": "Prediction", "id": "prediction_type"},
|
|
{"name": "Confidence", "id": "confidence"},
|
|
{"name": "Status", "id": "status"},
|
|
{"name": "Reward", "id": "reward"}
|
|
],
|
|
data=table_data,
|
|
style_cell={'textAlign': 'center', 'fontSize': '12px'},
|
|
style_header={'backgroundColor': 'rgb(230, 230, 230)', 'fontWeight': 'bold'},
|
|
style_data_conditional=[
|
|
{
|
|
'if': {'filter_query': '{status} = Resolved and {reward} > 0'},
|
|
'backgroundColor': 'rgba(40, 167, 69, 0.1)',
|
|
'color': 'black',
|
|
},
|
|
{
|
|
'if': {'filter_query': '{status} = Resolved and {reward} < 0'},
|
|
'backgroundColor': 'rgba(220, 53, 69, 0.1)',
|
|
'color': 'black',
|
|
},
|
|
{
|
|
'if': {'filter_query': '{status} = Pending'},
|
|
'backgroundColor': 'rgba(108, 117, 125, 0.1)',
|
|
'color': 'black',
|
|
}
|
|
],
|
|
page_size=10,
|
|
sort_action="native"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating prediction table: {e}")
|
|
return dash_table.DataTable(
|
|
id="prediction-table",
|
|
columns=[{"name": "Error", "id": "error"}],
|
|
data=[{"error": str(e)}]
|
|
)
|
|
|
|
def create_prediction_panel(self, prediction_stats: Dict[str, Any]) -> html.Div:
|
|
"""Create a complete prediction tracking panel"""
|
|
try:
|
|
predictions_data = prediction_stats.get('predictions', [])
|
|
model_stats = prediction_stats.get('models', [])
|
|
|
|
return html.Div([
|
|
html.H4("📊 Prediction Tracking & Performance", className="mb-3"),
|
|
|
|
# Summary cards
|
|
html.Div([
|
|
html.Div([
|
|
html.H6(f"{prediction_stats.get('total_predictions', 0)}", className="mb-0"),
|
|
html.Small("Total Predictions", className="text-muted")
|
|
], className="card-body text-center"),
|
|
], className="card col-md-3 mx-1"),
|
|
|
|
html.Div([
|
|
html.Div([
|
|
html.H6(f"{prediction_stats.get('active_predictions', 0)}", className="mb-0"),
|
|
html.Small("Pending Resolution", className="text-muted")
|
|
], className="card-body text-center"),
|
|
], className="card col-md-3 mx-1"),
|
|
|
|
html.Div([
|
|
html.Div([
|
|
html.H6(f"{len(model_stats)}", className="mb-0"),
|
|
html.Small("Active Models", className="text-muted")
|
|
], className="card-body text-center"),
|
|
], className="card col-md-3 mx-1"),
|
|
|
|
html.Div([
|
|
html.Div([
|
|
html.H6(f"{sum(s.get('total_reward', 0) for s in model_stats):.1f}", className="mb-0"),
|
|
html.Small("Total Rewards", className="text-muted")
|
|
], className="card-body text-center"),
|
|
], className="card col-md-3 mx-1")
|
|
|
|
], className="row mb-4"),
|
|
|
|
# Charts
|
|
html.Div([
|
|
html.Div([
|
|
self.create_prediction_timeline_chart(predictions_data)
|
|
], className="col-md-6"),
|
|
|
|
html.Div([
|
|
self.create_model_performance_chart(model_stats)
|
|
], className="col-md-6")
|
|
], className="row mb-4"),
|
|
|
|
# Recent predictions table
|
|
html.Div([
|
|
html.H5("Recent Predictions", className="mb-2"),
|
|
self.create_prediction_table(predictions_data)
|
|
], className="mb-3")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating prediction panel: {e}")
|
|
return html.Div([
|
|
html.H4("📊 Prediction Tracking & Performance"),
|
|
html.P(f"Error loading prediction data: {str(e)}", className="text-danger")
|
|
])
|
|
|
|
# Global instance
|
|
_prediction_chart = None
|
|
|
|
def get_prediction_chart() -> PredictionChartComponent:
|
|
"""Get global prediction chart component"""
|
|
global _prediction_chart
|
|
if _prediction_chart is None:
|
|
_prediction_chart = PredictionChartComponent()
|
|
return _prediction_chart
|