Files
gogo2/web/prediction_chart.py
Dobromir Popov 4fe952dbee wip
2025-09-08 11:44:15 +03:00

353 lines
15 KiB
Python

#!/usr/bin/env python3
"""
Prediction Chart Component - Visualizes model predictions and their outcomes
"""
import dash
from dash import dcc, html, dash_table
import plotly.graph_objs as go
import plotly.express as px
import pandas as pd
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
import logging
logger = logging.getLogger(__name__)
class PredictionChartComponent:
"""Component for visualizing prediction tracking and outcomes"""
def __init__(self):
self.colors = {
'BUY': '#28a745', # Green
'SELL': '#dc3545', # Red
'HOLD': '#6c757d', # Gray
'reward': '#28a745', # Green for positive rewards
'penalty': '#dc3545' # Red for negative rewards
}
def create_prediction_timeline_chart(self, predictions_data: List[Dict[str, Any]]) -> dcc.Graph:
"""Create a timeline chart showing predictions and their outcomes"""
try:
if not predictions_data:
# Empty chart
fig = go.Figure()
fig.add_annotation(
text="No prediction data available",
xref="paper", yref="paper",
x=0.5, y=0.5, xanchor='center', yanchor='middle',
showarrow=False, font=dict(size=16, color="gray")
)
fig.update_layout(
title="Model Predictions Timeline",
xaxis_title="Time",
yaxis_title="Confidence",
height=300
)
return dcc.Graph(figure=fig, id="prediction-timeline")
# Convert to DataFrame
df = pd.DataFrame(predictions_data)
df['timestamp'] = pd.to_datetime(df['timestamp'])
# Create the plot
fig = go.Figure()
# Add prediction points
for prediction_type in ['BUY', 'SELL', 'HOLD']:
type_data = df[df['prediction_type'] == prediction_type]
if not type_data.empty:
# Different markers for resolved vs pending
resolved_data = type_data[type_data['is_resolved'] == True]
pending_data = type_data[type_data['is_resolved'] == False]
if not resolved_data.empty:
# Resolved predictions
colors = [self.colors['reward'] if r > 0 else self.colors['penalty']
for r in resolved_data['reward']]
fig.add_trace(go.Scatter(
x=resolved_data['timestamp'],
y=resolved_data['confidence'],
mode='markers',
marker=dict(
size=10,
color=colors,
symbol='circle',
line=dict(width=2, color=self.colors[prediction_type])
),
name=f'{prediction_type} (Resolved)',
text=[f"Model: {m}<br>Confidence: {c:.3f}<br>Reward: {r:.2f}"
for m, c, r in zip(resolved_data['model_name'],
resolved_data['confidence'],
resolved_data['reward'])],
hovertemplate='%{text}<extra></extra>'
))
if not pending_data.empty:
# Pending predictions
fig.add_trace(go.Scatter(
x=pending_data['timestamp'],
y=pending_data['confidence'],
mode='markers',
marker=dict(
size=8,
color=self.colors[prediction_type],
symbol='circle-open',
line=dict(width=2)
),
name=f'{prediction_type} (Pending)',
text=[f"Model: {m}<br>Confidence: {c:.3f}<br>Status: Pending"
for m, c in zip(pending_data['model_name'],
pending_data['confidence'])],
hovertemplate='%{text}<extra></extra>'
))
# Update layout
fig.update_layout(
title="Model Predictions Timeline",
xaxis_title="Time",
yaxis_title="Confidence",
yaxis=dict(range=[0, 1]),
height=400,
showlegend=True,
legend=dict(x=0.02, y=0.98),
hovermode='closest'
)
return dcc.Graph(figure=fig, id="prediction-timeline")
except Exception as e:
logger.error(f"Error creating prediction timeline chart: {e}")
# Return empty chart on error
fig = go.Figure()
fig.add_annotation(text=f"Error: {str(e)}", x=0.5, y=0.5)
return dcc.Graph(figure=fig, id="prediction-timeline")
def create_model_performance_chart(self, model_stats: List[Dict[str, Any]]) -> dcc.Graph:
"""Create a bar chart showing model performance metrics"""
try:
if not model_stats:
fig = go.Figure()
fig.add_annotation(
text="No model performance data available",
xref="paper", yref="paper",
x=0.5, y=0.5, xanchor='center', yanchor='middle',
showarrow=False, font=dict(size=16, color="gray")
)
fig.update_layout(
title="Model Performance Comparison",
height=300
)
return dcc.Graph(figure=fig, id="model-performance")
# Extract data
model_names = [stats['model_name'] for stats in model_stats]
accuracies = [stats['accuracy'] * 100 for stats in model_stats] # Convert to percentage
total_rewards = [stats['total_reward'] for stats in model_stats]
total_predictions = [stats['total_predictions'] for stats in model_stats]
# Create subplots
fig = go.Figure()
# Add accuracy bars
fig.add_trace(go.Bar(
x=model_names,
y=accuracies,
name='Accuracy (%)',
marker_color='lightblue',
yaxis='y',
text=[f"{a:.1f}%" for a in accuracies],
textposition='auto'
))
# Add total reward on secondary y-axis
fig.add_trace(go.Scatter(
x=model_names,
y=total_rewards,
mode='markers+text',
name='Total Reward',
marker=dict(
size=12,
color='orange',
symbol='diamond'
),
yaxis='y2',
text=[f"{r:.1f}" for r in total_rewards],
textposition='top center'
))
# Update layout
fig.update_layout(
title="Model Performance Comparison",
xaxis_title="Model",
yaxis=dict(
title="Accuracy (%)",
side="left",
range=[0, 100]
),
yaxis2=dict(
title="Total Reward",
side="right",
overlaying="y"
),
height=400,
showlegend=True,
legend=dict(x=0.02, y=0.98)
)
return dcc.Graph(figure=fig, id="model-performance")
except Exception as e:
logger.error(f"Error creating model performance chart: {e}")
fig = go.Figure()
fig.add_annotation(text=f"Error: {str(e)}", x=0.5, y=0.5)
return dcc.Graph(figure=fig, id="model-performance")
def create_prediction_table(self, recent_predictions: List[Dict[str, Any]]) -> dash_table.DataTable:
"""Create a table showing recent predictions"""
try:
if not recent_predictions:
return dash_table.DataTable(
id="prediction-table",
columns=[
{"name": "Model", "id": "model_name"},
{"name": "Symbol", "id": "symbol"},
{"name": "Prediction", "id": "prediction_type"},
{"name": "Confidence", "id": "confidence"},
{"name": "Status", "id": "status"},
{"name": "Reward", "id": "reward"}
],
data=[],
style_cell={'textAlign': 'center'},
style_header={'backgroundColor': 'rgb(230, 230, 230)', 'fontWeight': 'bold'},
page_size=10
)
# Format data for table
table_data = []
for pred in recent_predictions[-20:]: # Show last 20 predictions
table_data.append({
'model_name': pred.get('model_name', 'Unknown'),
'symbol': pred.get('symbol', 'N/A'),
'prediction_type': pred.get('prediction_type', 'N/A'),
'confidence': f"{pred.get('confidence', 0):.3f}",
'status': 'Resolved' if pred.get('is_resolved', False) else 'Pending',
'reward': f"{pred.get('reward', 0):.2f}" if pred.get('is_resolved', False) else 'Pending'
})
return dash_table.DataTable(
id="prediction-table",
columns=[
{"name": "Model", "id": "model_name"},
{"name": "Symbol", "id": "symbol"},
{"name": "Prediction", "id": "prediction_type"},
{"name": "Confidence", "id": "confidence"},
{"name": "Status", "id": "status"},
{"name": "Reward", "id": "reward"}
],
data=table_data,
style_cell={'textAlign': 'center', 'fontSize': '12px'},
style_header={'backgroundColor': 'rgb(230, 230, 230)', 'fontWeight': 'bold'},
style_data_conditional=[
{
'if': {'filter_query': '{status} = Resolved and {reward} > 0'},
'backgroundColor': 'rgba(40, 167, 69, 0.1)',
'color': 'black',
},
{
'if': {'filter_query': '{status} = Resolved and {reward} < 0'},
'backgroundColor': 'rgba(220, 53, 69, 0.1)',
'color': 'black',
},
{
'if': {'filter_query': '{status} = Pending'},
'backgroundColor': 'rgba(108, 117, 125, 0.1)',
'color': 'black',
}
],
page_size=10,
sort_action="native"
)
except Exception as e:
logger.error(f"Error creating prediction table: {e}")
return dash_table.DataTable(
id="prediction-table",
columns=[{"name": "Error", "id": "error"}],
data=[{"error": str(e)}]
)
def create_prediction_panel(self, prediction_stats: Dict[str, Any]) -> html.Div:
"""Create a complete prediction tracking panel"""
try:
predictions_data = prediction_stats.get('predictions', [])
model_stats = prediction_stats.get('models', [])
return html.Div([
html.H4("📊 Prediction Tracking & Performance", className="mb-3"),
# Summary cards
html.Div([
html.Div([
html.H6(f"{prediction_stats.get('total_predictions', 0)}", className="mb-0"),
html.Small("Total Predictions", className="text-muted")
], className="card-body text-center"),
], className="card col-md-3 mx-1"),
html.Div([
html.Div([
html.H6(f"{prediction_stats.get('active_predictions', 0)}", className="mb-0"),
html.Small("Pending Resolution", className="text-muted")
], className="card-body text-center"),
], className="card col-md-3 mx-1"),
html.Div([
html.Div([
html.H6(f"{len(model_stats)}", className="mb-0"),
html.Small("Active Models", className="text-muted")
], className="card-body text-center"),
], className="card col-md-3 mx-1"),
html.Div([
html.Div([
html.H6(f"{sum(s.get('total_reward', 0) for s in model_stats):.1f}", className="mb-0"),
html.Small("Total Rewards", className="text-muted")
], className="card-body text-center"),
], className="card col-md-3 mx-1")
], className="row mb-4"),
# Charts
html.Div([
html.Div([
self.create_prediction_timeline_chart(predictions_data)
], className="col-md-6"),
html.Div([
self.create_model_performance_chart(model_stats)
], className="col-md-6")
], className="row mb-4"),
# Recent predictions table
html.Div([
html.H5("Recent Predictions", className="mb-2"),
self.create_prediction_table(predictions_data)
], className="mb-3")
except Exception as e:
logger.error(f"Error creating prediction panel: {e}")
return html.Div([
html.H4("📊 Prediction Tracking & Performance"),
html.P(f"Error loading prediction data: {str(e)}", className="text-danger")
])
# Global instance
_prediction_chart = None
def get_prediction_chart() -> PredictionChartComponent:
"""Get global prediction chart component"""
global _prediction_chart
if _prediction_chart is None:
_prediction_chart = PredictionChartComponent()
return _prediction_chart