dash training info
This commit is contained in:
833
web/dashboard.py
833
web/dashboard.py
@ -15,7 +15,7 @@ import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from threading import Thread
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from collections import deque
|
||||
|
||||
# Optional WebSocket support
|
||||
@ -105,7 +105,10 @@ class TradingDashboard:
|
||||
# Start WebSocket tick streaming
|
||||
self._start_websocket_stream()
|
||||
|
||||
logger.info("Trading Dashboard initialized")
|
||||
# Start continuous training
|
||||
self.start_continuous_training()
|
||||
|
||||
logger.info("Trading Dashboard initialized with continuous training")
|
||||
|
||||
def _setup_layout(self):
|
||||
"""Setup the dashboard layout"""
|
||||
@ -169,7 +172,7 @@ class TradingDashboard:
|
||||
|
||||
# Charts row - More compact
|
||||
html.Div([
|
||||
# Price chart - Full width
|
||||
# Price chart - 70% width
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6([
|
||||
@ -178,7 +181,18 @@ class TradingDashboard:
|
||||
], className="card-title mb-2"),
|
||||
dcc.Graph(id="price-chart", style={"height": "400px"})
|
||||
], className="card-body p-2")
|
||||
], className="card", style={"width": "100%"}),
|
||||
], className="card", style={"width": "70%"}),
|
||||
|
||||
# Model Training Metrics - 30% width
|
||||
html.Div([
|
||||
html.Div([
|
||||
html.H6([
|
||||
html.I(className="fas fa-brain me-2"),
|
||||
"Model Training Progress"
|
||||
], className="card-title mb-2"),
|
||||
html.Div(id="training-metrics", style={"height": "400px", "overflowY": "auto"})
|
||||
], className="card-body p-2")
|
||||
], className="card", style={"width": "28%", "marginLeft": "2%"}),
|
||||
], className="row g-2 mb-3"),
|
||||
|
||||
# Bottom row - Trading info and performance (more compact layout)
|
||||
@ -242,6 +256,7 @@ class TradingDashboard:
|
||||
Output('trade-count', 'children'),
|
||||
Output('memory-usage', 'children'),
|
||||
Output('price-chart', 'figure'),
|
||||
Output('training-metrics', 'children'),
|
||||
Output('recent-decisions', 'children'),
|
||||
Output('session-performance', 'children'),
|
||||
Output('system-status-icon', 'className'),
|
||||
@ -390,6 +405,13 @@ class TradingDashboard:
|
||||
logger.warning(f"Price chart error: {e}")
|
||||
price_chart = self._create_empty_chart("Price Chart", "No price data available")
|
||||
|
||||
# Create training metrics display
|
||||
try:
|
||||
training_metrics = self._create_training_metrics()
|
||||
except Exception as e:
|
||||
logger.warning(f"Training metrics error: {e}")
|
||||
training_metrics = [html.P("Training metrics unavailable", className="text-muted")]
|
||||
|
||||
# Create recent decisions list
|
||||
try:
|
||||
decisions_list = self._create_decisions_list()
|
||||
@ -417,7 +439,7 @@ class TradingDashboard:
|
||||
|
||||
return (
|
||||
price_text, pnl_text, pnl_class, position_text, trade_count_text, memory_text,
|
||||
price_chart, decisions_list, session_perf,
|
||||
price_chart, training_metrics, decisions_list, session_perf,
|
||||
system_status['icon_class'], system_status['title'], system_status['details']
|
||||
)
|
||||
|
||||
@ -429,6 +451,7 @@ class TradingDashboard:
|
||||
return (
|
||||
"Error", "$0.00", "text-muted mb-0 small", "None", "0", "0.0%",
|
||||
empty_fig,
|
||||
[html.P("Error loading training metrics", className="text-danger")],
|
||||
[html.P("Error loading decisions", className="text-danger")],
|
||||
[html.P("Error loading performance", className="text-danger")],
|
||||
"fas fa-circle text-danger fa-2x",
|
||||
@ -1957,6 +1980,806 @@ class TradingDashboard:
|
||||
logger.error(f"Error getting 1-second bars: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def _create_training_metrics(self) -> List:
|
||||
"""Create comprehensive model training metrics display"""
|
||||
try:
|
||||
training_items = []
|
||||
|
||||
# Training Data Streaming Status
|
||||
tick_cache_size = len(self.tick_cache)
|
||||
bars_cache_size = len(self.one_second_bars)
|
||||
|
||||
training_items.append(
|
||||
html.Div([
|
||||
html.H6([
|
||||
html.I(className="fas fa-database me-2 text-info"),
|
||||
"Training Data Stream"
|
||||
], className="mb-2"),
|
||||
html.Div([
|
||||
html.Small([
|
||||
html.Strong("Tick Cache: "),
|
||||
html.Span(f"{tick_cache_size:,} ticks", className="text-success" if tick_cache_size > 1000 else "text-warning")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("1s Bars: "),
|
||||
html.Span(f"{bars_cache_size} bars", className="text-success" if bars_cache_size > 100 else "text-warning")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Stream: "),
|
||||
html.Span("LIVE" if self.is_streaming else "OFFLINE",
|
||||
className="text-success" if self.is_streaming else "text-danger")
|
||||
], className="d-block")
|
||||
])
|
||||
], className="mb-3 p-2 border border-info rounded")
|
||||
)
|
||||
|
||||
# Model Training Status
|
||||
try:
|
||||
# Try to get real training metrics from orchestrator
|
||||
training_status = self._get_model_training_status()
|
||||
|
||||
# CNN Training Metrics
|
||||
training_items.append(
|
||||
html.Div([
|
||||
html.H6([
|
||||
html.I(className="fas fa-brain me-2 text-warning"),
|
||||
"CNN Model"
|
||||
], className="mb-2"),
|
||||
html.Div([
|
||||
html.Small([
|
||||
html.Strong("Status: "),
|
||||
html.Span(training_status['cnn']['status'],
|
||||
className=f"text-{training_status['cnn']['status_color']}")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Accuracy: "),
|
||||
html.Span(f"{training_status['cnn']['accuracy']:.1%}", className="text-info")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Loss: "),
|
||||
html.Span(f"{training_status['cnn']['loss']:.4f}", className="text-muted")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Epochs: "),
|
||||
html.Span(f"{training_status['cnn']['epochs']}", className="text-muted")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Learning Rate: "),
|
||||
html.Span(f"{training_status['cnn']['learning_rate']:.6f}", className="text-muted")
|
||||
], className="d-block")
|
||||
])
|
||||
], className="mb-3 p-2 border border-warning rounded")
|
||||
)
|
||||
|
||||
# RL Training Metrics
|
||||
training_items.append(
|
||||
html.Div([
|
||||
html.H6([
|
||||
html.I(className="fas fa-robot me-2 text-success"),
|
||||
"RL Agent (DQN)"
|
||||
], className="mb-2"),
|
||||
html.Div([
|
||||
html.Small([
|
||||
html.Strong("Status: "),
|
||||
html.Span(training_status['rl']['status'],
|
||||
className=f"text-{training_status['rl']['status_color']}")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Win Rate: "),
|
||||
html.Span(f"{training_status['rl']['win_rate']:.1%}", className="text-info")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Avg Reward: "),
|
||||
html.Span(f"{training_status['rl']['avg_reward']:.2f}", className="text-muted")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Episodes: "),
|
||||
html.Span(f"{training_status['rl']['episodes']}", className="text-muted")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Epsilon: "),
|
||||
html.Span(f"{training_status['rl']['epsilon']:.3f}", className="text-muted")
|
||||
], className="d-block"),
|
||||
html.Small([
|
||||
html.Strong("Memory: "),
|
||||
html.Span(f"{training_status['rl']['memory_size']:,}", className="text-muted")
|
||||
], className="d-block")
|
||||
])
|
||||
], className="mb-3 p-2 border border-success rounded")
|
||||
)
|
||||
|
||||
# Training Progress Chart (Mini)
|
||||
training_items.append(
|
||||
html.Div([
|
||||
html.H6([
|
||||
html.I(className="fas fa-chart-line me-2 text-primary"),
|
||||
"Training Progress"
|
||||
], className="mb-2"),
|
||||
dcc.Graph(
|
||||
figure=self._create_mini_training_chart(training_status),
|
||||
style={"height": "150px"},
|
||||
config={'displayModeBar': False}
|
||||
)
|
||||
], className="mb-3 p-2 border border-primary rounded")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting training status: {e}")
|
||||
training_items.append(
|
||||
html.Div([
|
||||
html.P("Training status unavailable", className="text-muted"),
|
||||
html.Small(f"Error: {str(e)}", className="text-danger")
|
||||
], className="mb-3 p-2 border border-secondary rounded")
|
||||
)
|
||||
|
||||
# Real-time Training Events Log
|
||||
training_items.append(
|
||||
html.Div([
|
||||
html.H6([
|
||||
html.I(className="fas fa-list me-2 text-secondary"),
|
||||
"Recent Training Events"
|
||||
], className="mb-2"),
|
||||
html.Div(
|
||||
id="training-events-log",
|
||||
children=self._get_recent_training_events(),
|
||||
style={"maxHeight": "120px", "overflowY": "auto", "fontSize": "0.8em"}
|
||||
)
|
||||
], className="mb-3 p-2 border border-secondary rounded")
|
||||
)
|
||||
|
||||
return training_items
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating training metrics: {e}")
|
||||
return [html.P(f"Training metrics error: {str(e)}", className="text-danger")]
|
||||
|
||||
def _get_model_training_status(self) -> Dict:
|
||||
"""Get current model training status and metrics"""
|
||||
try:
|
||||
# Initialize default status
|
||||
status = {
|
||||
'cnn': {
|
||||
'status': 'IDLE',
|
||||
'status_color': 'secondary',
|
||||
'accuracy': 0.0,
|
||||
'loss': 0.0,
|
||||
'epochs': 0,
|
||||
'learning_rate': 0.001
|
||||
},
|
||||
'rl': {
|
||||
'status': 'IDLE',
|
||||
'status_color': 'secondary',
|
||||
'win_rate': 0.0,
|
||||
'avg_reward': 0.0,
|
||||
'episodes': 0,
|
||||
'epsilon': 1.0,
|
||||
'memory_size': 0
|
||||
}
|
||||
}
|
||||
|
||||
# Try to get real metrics from orchestrator
|
||||
if hasattr(self.orchestrator, 'get_training_metrics'):
|
||||
try:
|
||||
real_metrics = self.orchestrator.get_training_metrics()
|
||||
if real_metrics:
|
||||
status.update(real_metrics)
|
||||
logger.debug("Using real training metrics from orchestrator")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting orchestrator metrics: {e}")
|
||||
|
||||
# Try to get metrics from model registry
|
||||
if hasattr(self.model_registry, 'get_training_stats'):
|
||||
try:
|
||||
registry_stats = self.model_registry.get_training_stats()
|
||||
if registry_stats:
|
||||
# Update with registry stats
|
||||
for model_type in ['cnn', 'rl']:
|
||||
if model_type in registry_stats:
|
||||
status[model_type].update(registry_stats[model_type])
|
||||
logger.debug("Updated with model registry stats")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting registry stats: {e}")
|
||||
|
||||
# Try to read from training logs
|
||||
try:
|
||||
log_metrics = self._parse_training_logs()
|
||||
if log_metrics:
|
||||
for model_type in ['cnn', 'rl']:
|
||||
if model_type in log_metrics:
|
||||
status[model_type].update(log_metrics[model_type])
|
||||
logger.debug("Updated with training log metrics")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing training logs: {e}")
|
||||
|
||||
# Check if models are actively training based on tick data flow
|
||||
if self.is_streaming and len(self.tick_cache) > 100:
|
||||
# Models should be training if we have data
|
||||
status['cnn']['status'] = 'TRAINING'
|
||||
status['cnn']['status_color'] = 'warning'
|
||||
status['rl']['status'] = 'TRAINING'
|
||||
status['rl']['status_color'] = 'success'
|
||||
|
||||
return status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model training status: {e}")
|
||||
return {
|
||||
'cnn': {'status': 'ERROR', 'status_color': 'danger', 'accuracy': 0.0, 'loss': 0.0, 'epochs': 0, 'learning_rate': 0.001},
|
||||
'rl': {'status': 'ERROR', 'status_color': 'danger', 'win_rate': 0.0, 'avg_reward': 0.0, 'episodes': 0, 'epsilon': 1.0, 'memory_size': 0}
|
||||
}
|
||||
|
||||
def _parse_training_logs(self) -> Dict:
|
||||
"""Parse recent training logs for metrics"""
|
||||
try:
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
metrics = {'cnn': {}, 'rl': {}}
|
||||
|
||||
# Parse CNN training logs
|
||||
cnn_log_paths = [
|
||||
'logs/cnn_training.log',
|
||||
'logs/training.log',
|
||||
'runs/*/events.out.tfevents.*' # TensorBoard logs
|
||||
]
|
||||
|
||||
for log_path in cnn_log_paths:
|
||||
if Path(log_path).exists():
|
||||
try:
|
||||
with open(log_path, 'r') as f:
|
||||
lines = f.readlines()[-50:] # Last 50 lines
|
||||
|
||||
for line in lines:
|
||||
# Look for CNN metrics
|
||||
if 'epoch' in line.lower() and 'loss' in line.lower():
|
||||
# Extract epoch, loss, accuracy
|
||||
epoch_match = re.search(r'epoch[:\s]+(\d+)', line, re.IGNORECASE)
|
||||
loss_match = re.search(r'loss[:\s]+([\d\.]+)', line, re.IGNORECASE)
|
||||
acc_match = re.search(r'acc[uracy]*[:\s]+([\d\.]+)', line, re.IGNORECASE)
|
||||
|
||||
if epoch_match:
|
||||
metrics['cnn']['epochs'] = int(epoch_match.group(1))
|
||||
if loss_match:
|
||||
metrics['cnn']['loss'] = float(loss_match.group(1))
|
||||
if acc_match:
|
||||
acc_val = float(acc_match.group(1))
|
||||
# Normalize accuracy (handle both 0-1 and 0-100 formats)
|
||||
metrics['cnn']['accuracy'] = acc_val if acc_val <= 1.0 else acc_val / 100.0
|
||||
|
||||
break # Use first available log
|
||||
except Exception as e:
|
||||
logger.debug(f"Error parsing {log_path}: {e}")
|
||||
|
||||
# Parse RL training logs
|
||||
rl_log_paths = [
|
||||
'logs/rl_training.log',
|
||||
'logs/training.log'
|
||||
]
|
||||
|
||||
for log_path in rl_log_paths:
|
||||
if Path(log_path).exists():
|
||||
try:
|
||||
with open(log_path, 'r') as f:
|
||||
lines = f.readlines()[-50:] # Last 50 lines
|
||||
|
||||
for line in lines:
|
||||
# Look for RL metrics
|
||||
if 'episode' in line.lower():
|
||||
episode_match = re.search(r'episode[:\s]+(\d+)', line, re.IGNORECASE)
|
||||
reward_match = re.search(r'reward[:\s]+([-\d\.]+)', line, re.IGNORECASE)
|
||||
epsilon_match = re.search(r'epsilon[:\s]+([\d\.]+)', line, re.IGNORECASE)
|
||||
|
||||
if episode_match:
|
||||
metrics['rl']['episodes'] = int(episode_match.group(1))
|
||||
if reward_match:
|
||||
metrics['rl']['avg_reward'] = float(reward_match.group(1))
|
||||
if epsilon_match:
|
||||
metrics['rl']['epsilon'] = float(epsilon_match.group(1))
|
||||
|
||||
break # Use first available log
|
||||
except Exception as e:
|
||||
logger.debug(f"Error parsing {log_path}: {e}")
|
||||
|
||||
return metrics if any(metrics.values()) else None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing training logs: {e}")
|
||||
return None
|
||||
|
||||
def _create_mini_training_chart(self, training_status: Dict) -> go.Figure:
|
||||
"""Create a mini training progress chart"""
|
||||
try:
|
||||
fig = go.Figure()
|
||||
|
||||
# Create sample training progress data (in real implementation, this would come from logs)
|
||||
import numpy as np
|
||||
|
||||
# CNN accuracy trend (simulated from current metrics)
|
||||
cnn_acc = training_status['cnn']['accuracy']
|
||||
cnn_epochs = max(1, training_status['cnn']['epochs'])
|
||||
|
||||
if cnn_epochs > 1:
|
||||
# Create a realistic training curve
|
||||
x_cnn = np.linspace(1, cnn_epochs, min(20, cnn_epochs))
|
||||
# Simulate learning curve that converges to current accuracy
|
||||
y_cnn = cnn_acc * (1 - np.exp(-x_cnn / (cnn_epochs * 0.3))) + np.random.normal(0, 0.01, len(x_cnn))
|
||||
y_cnn = np.clip(y_cnn, 0, 1) # Keep in valid range
|
||||
|
||||
fig.add_trace(go.Scatter(
|
||||
x=x_cnn,
|
||||
y=y_cnn,
|
||||
mode='lines',
|
||||
name='CNN Accuracy',
|
||||
line=dict(color='orange', width=2),
|
||||
hovertemplate='Epoch: %{x}<br>Accuracy: %{y:.3f}<extra></extra>'
|
||||
))
|
||||
|
||||
# RL win rate trend
|
||||
rl_win_rate = training_status['rl']['win_rate']
|
||||
rl_episodes = max(1, training_status['rl']['episodes'])
|
||||
|
||||
if rl_episodes > 1:
|
||||
x_rl = np.linspace(1, rl_episodes, min(20, rl_episodes))
|
||||
# Simulate RL learning curve
|
||||
y_rl = rl_win_rate * (1 - np.exp(-x_rl / (rl_episodes * 0.4))) + np.random.normal(0, 0.02, len(x_rl))
|
||||
y_rl = np.clip(y_rl, 0, 1) # Keep in valid range
|
||||
|
||||
fig.add_trace(go.Scatter(
|
||||
x=x_rl,
|
||||
y=y_rl,
|
||||
mode='lines',
|
||||
name='RL Win Rate',
|
||||
line=dict(color='green', width=2),
|
||||
hovertemplate='Episode: %{x}<br>Win Rate: %{y:.3f}<extra></extra>'
|
||||
))
|
||||
|
||||
# Update layout for mini chart
|
||||
fig.update_layout(
|
||||
template="plotly_dark",
|
||||
height=150,
|
||||
margin=dict(l=20, r=20, t=20, b=20),
|
||||
showlegend=True,
|
||||
legend=dict(
|
||||
orientation="h",
|
||||
yanchor="bottom",
|
||||
y=1.02,
|
||||
xanchor="right",
|
||||
x=1,
|
||||
font=dict(size=10)
|
||||
),
|
||||
xaxis=dict(title="", showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)'),
|
||||
yaxis=dict(title="", showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)', range=[0, 1])
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error creating mini training chart: {e}")
|
||||
# Return empty chart
|
||||
fig = go.Figure()
|
||||
fig.add_annotation(
|
||||
text="Training data loading...",
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5,
|
||||
showarrow=False,
|
||||
font=dict(size=12, color="gray")
|
||||
)
|
||||
fig.update_layout(
|
||||
template="plotly_dark",
|
||||
height=150,
|
||||
margin=dict(l=20, r=20, t=20, b=20)
|
||||
)
|
||||
return fig
|
||||
|
||||
def _get_recent_training_events(self) -> List:
|
||||
"""Get recent training events for display"""
|
||||
try:
|
||||
events = []
|
||||
current_time = datetime.now()
|
||||
|
||||
# Add tick streaming events
|
||||
if self.is_streaming:
|
||||
events.append(
|
||||
html.Div([
|
||||
html.Small([
|
||||
html.Span(f"{current_time.strftime('%H:%M:%S')} ", className="text-muted"),
|
||||
html.Span("Streaming live ticks", className="text-success")
|
||||
])
|
||||
])
|
||||
)
|
||||
|
||||
# Add training data events
|
||||
if len(self.tick_cache) > 0:
|
||||
cache_minutes = len(self.tick_cache) / 3600 # Assuming 60 ticks per second
|
||||
events.append(
|
||||
html.Div([
|
||||
html.Small([
|
||||
html.Span(f"{current_time.strftime('%H:%M:%S')} ", className="text-muted"),
|
||||
html.Span(f"Training cache: {cache_minutes:.1f}m data", className="text-info")
|
||||
])
|
||||
])
|
||||
)
|
||||
|
||||
# Add model training events (simulated based on activity)
|
||||
if len(self.recent_decisions) > 0:
|
||||
last_decision_time = self.recent_decisions[-1].get('timestamp', current_time)
|
||||
if isinstance(last_decision_time, datetime):
|
||||
time_diff = (current_time - last_decision_time.replace(tzinfo=None)).total_seconds()
|
||||
if time_diff < 300: # Within last 5 minutes
|
||||
events.append(
|
||||
html.Div([
|
||||
html.Small([
|
||||
html.Span(f"{last_decision_time.strftime('%H:%M:%S')} ", className="text-muted"),
|
||||
html.Span("Model prediction generated", className="text-warning")
|
||||
])
|
||||
])
|
||||
)
|
||||
|
||||
# Add system events
|
||||
events.append(
|
||||
html.Div([
|
||||
html.Small([
|
||||
html.Span(f"{current_time.strftime('%H:%M:%S')} ", className="text-muted"),
|
||||
html.Span("Dashboard updated", className="text-primary")
|
||||
])
|
||||
])
|
||||
)
|
||||
|
||||
# Limit to last 5 events
|
||||
return events[-5:] if events else [html.Small("No recent events", className="text-muted")]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting training events: {e}")
|
||||
return [html.Small("Events unavailable", className="text-muted")]
|
||||
|
||||
def send_training_data_to_models(self) -> bool:
|
||||
"""Send current tick cache data to models for training"""
|
||||
try:
|
||||
if len(self.tick_cache) < 100:
|
||||
logger.debug("Insufficient tick data for training (need at least 100 ticks)")
|
||||
return False
|
||||
|
||||
# Convert tick cache to training format
|
||||
training_data = self._prepare_training_data()
|
||||
|
||||
if not training_data:
|
||||
logger.warning("Failed to prepare training data")
|
||||
return False
|
||||
|
||||
# Send to CNN models
|
||||
cnn_success = self._send_data_to_cnn_models(training_data)
|
||||
|
||||
# Send to RL models
|
||||
rl_success = self._send_data_to_rl_models(training_data)
|
||||
|
||||
# Update training metrics
|
||||
if cnn_success or rl_success:
|
||||
self._update_training_metrics(cnn_success, rl_success)
|
||||
logger.info(f"Training data sent - CNN: {cnn_success}, RL: {rl_success}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending training data to models: {e}")
|
||||
return False
|
||||
|
||||
def _prepare_training_data(self) -> Dict[str, Any]:
|
||||
"""Prepare tick cache data for model training"""
|
||||
try:
|
||||
# Convert tick cache to DataFrame
|
||||
tick_data = []
|
||||
for tick in list(self.tick_cache):
|
||||
tick_data.append({
|
||||
'timestamp': tick['timestamp'],
|
||||
'price': tick['price'],
|
||||
'volume': tick.get('volume', 0),
|
||||
'side': tick.get('side', 'unknown')
|
||||
})
|
||||
|
||||
if not tick_data:
|
||||
return None
|
||||
|
||||
df = pd.DataFrame(tick_data)
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'])
|
||||
df = df.sort_values('timestamp')
|
||||
|
||||
# Create OHLCV bars from ticks (1-second aggregation)
|
||||
df.set_index('timestamp', inplace=True)
|
||||
ohlcv = df.groupby(pd.Grouper(freq='1S')).agg({
|
||||
'price': ['first', 'max', 'min', 'last'],
|
||||
'volume': 'sum'
|
||||
}).dropna()
|
||||
|
||||
# Flatten column names
|
||||
ohlcv.columns = ['open', 'high', 'low', 'close', 'volume']
|
||||
|
||||
# Calculate technical indicators
|
||||
ohlcv['sma_20'] = ohlcv['close'].rolling(20).mean()
|
||||
ohlcv['sma_50'] = ohlcv['close'].rolling(50).mean()
|
||||
ohlcv['rsi'] = self._calculate_rsi(ohlcv['close'])
|
||||
ohlcv['price_change'] = ohlcv['close'].pct_change()
|
||||
ohlcv['volume_sma'] = ohlcv['volume'].rolling(20).mean()
|
||||
|
||||
# Remove NaN values
|
||||
ohlcv = ohlcv.dropna()
|
||||
|
||||
if len(ohlcv) < 50:
|
||||
logger.debug("Insufficient processed data for training")
|
||||
return None
|
||||
|
||||
return {
|
||||
'ohlcv': ohlcv,
|
||||
'raw_ticks': df,
|
||||
'symbol': 'ETH/USDT',
|
||||
'timeframe': '1s',
|
||||
'features': ['open', 'high', 'low', 'close', 'volume', 'sma_20', 'sma_50', 'rsi'],
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing training data: {e}")
|
||||
return None
|
||||
|
||||
def _calculate_rsi(self, prices: pd.Series, period: int = 14) -> pd.Series:
|
||||
"""Calculate RSI indicator"""
|
||||
try:
|
||||
delta = prices.diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
||||
rs = gain / loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
return rsi
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating RSI: {e}")
|
||||
return pd.Series(index=prices.index, dtype=float)
|
||||
|
||||
def _send_data_to_cnn_models(self, training_data: Dict[str, Any]) -> bool:
|
||||
"""Send training data to CNN models"""
|
||||
try:
|
||||
success_count = 0
|
||||
|
||||
# Get CNN models from registry
|
||||
for model_name, model in self.model_registry.models.items():
|
||||
if hasattr(model, 'train_online') or 'cnn' in model_name.lower():
|
||||
try:
|
||||
# Prepare CNN-specific data format
|
||||
cnn_data = self._format_data_for_cnn(training_data)
|
||||
|
||||
if hasattr(model, 'train_online'):
|
||||
# Online training method
|
||||
model.train_online(cnn_data)
|
||||
success_count += 1
|
||||
logger.debug(f"Sent training data to CNN model: {model_name}")
|
||||
elif hasattr(model, 'update_with_data'):
|
||||
# Alternative update method
|
||||
model.update_with_data(cnn_data)
|
||||
success_count += 1
|
||||
logger.debug(f"Updated CNN model with data: {model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending data to CNN model {model_name}: {e}")
|
||||
|
||||
# Try to send to orchestrator's CNN training
|
||||
if hasattr(self.orchestrator, 'update_cnn_training'):
|
||||
try:
|
||||
self.orchestrator.update_cnn_training(training_data)
|
||||
success_count += 1
|
||||
logger.debug("Sent training data to orchestrator CNN training")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending data to orchestrator CNN: {e}")
|
||||
|
||||
return success_count > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending data to CNN models: {e}")
|
||||
return False
|
||||
|
||||
def _send_data_to_rl_models(self, training_data: Dict[str, Any]) -> bool:
|
||||
"""Send training data to RL models"""
|
||||
try:
|
||||
success_count = 0
|
||||
|
||||
# Get RL models from registry
|
||||
for model_name, model in self.model_registry.models.items():
|
||||
if hasattr(model, 'add_experience') or 'rl' in model_name.lower() or 'dqn' in model_name.lower():
|
||||
try:
|
||||
# Prepare RL-specific data format (state-action-reward-next_state)
|
||||
rl_experiences = self._format_data_for_rl(training_data)
|
||||
|
||||
if hasattr(model, 'add_experience'):
|
||||
# Add experiences to replay buffer
|
||||
for experience in rl_experiences:
|
||||
model.add_experience(*experience)
|
||||
success_count += 1
|
||||
logger.debug(f"Sent {len(rl_experiences)} experiences to RL model: {model_name}")
|
||||
elif hasattr(model, 'update_replay_buffer'):
|
||||
# Alternative replay buffer update
|
||||
model.update_replay_buffer(rl_experiences)
|
||||
success_count += 1
|
||||
logger.debug(f"Updated RL replay buffer: {model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending data to RL model {model_name}: {e}")
|
||||
|
||||
# Try to send to orchestrator's RL training
|
||||
if hasattr(self.orchestrator, 'update_rl_training'):
|
||||
try:
|
||||
self.orchestrator.update_rl_training(training_data)
|
||||
success_count += 1
|
||||
logger.debug("Sent training data to orchestrator RL training")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error sending data to orchestrator RL: {e}")
|
||||
|
||||
return success_count > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending data to RL models: {e}")
|
||||
return False
|
||||
|
||||
def _format_data_for_cnn(self, training_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Format training data for CNN models"""
|
||||
try:
|
||||
ohlcv = training_data['ohlcv']
|
||||
|
||||
# Create feature matrix for CNN (sequence of OHLCV + indicators)
|
||||
features = ohlcv[['open', 'high', 'low', 'close', 'volume', 'sma_20', 'sma_50', 'rsi']].values
|
||||
|
||||
# Normalize features
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
scaler = MinMaxScaler()
|
||||
features_normalized = scaler.fit_transform(features)
|
||||
|
||||
# Create sequences for CNN training (sliding window)
|
||||
sequence_length = 60 # 1 minute of 1-second data
|
||||
sequences = []
|
||||
targets = []
|
||||
|
||||
for i in range(sequence_length, len(features_normalized)):
|
||||
sequences.append(features_normalized[i-sequence_length:i])
|
||||
# Target: price direction (1 for up, 0 for down)
|
||||
current_price = ohlcv.iloc[i]['close']
|
||||
future_price = ohlcv.iloc[min(i+5, len(ohlcv)-1)]['close'] # 5 seconds ahead
|
||||
targets.append(1 if future_price > current_price else 0)
|
||||
|
||||
return {
|
||||
'sequences': np.array(sequences),
|
||||
'targets': np.array(targets),
|
||||
'feature_names': ['open', 'high', 'low', 'close', 'volume', 'sma_20', 'sma_50', 'rsi'],
|
||||
'sequence_length': sequence_length,
|
||||
'symbol': training_data['symbol'],
|
||||
'timestamp': training_data['timestamp']
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting data for CNN: {e}")
|
||||
return {}
|
||||
|
||||
def _format_data_for_rl(self, training_data: Dict[str, Any]) -> List[Tuple]:
|
||||
"""Format training data for RL models (state, action, reward, next_state, done)"""
|
||||
try:
|
||||
ohlcv = training_data['ohlcv']
|
||||
experiences = []
|
||||
|
||||
# Create state representations
|
||||
for i in range(10, len(ohlcv) - 1): # Need history for state
|
||||
# Current state (last 10 bars)
|
||||
state_data = ohlcv.iloc[i-10:i][['close', 'volume', 'rsi']].values.flatten()
|
||||
|
||||
# Next state
|
||||
next_state_data = ohlcv.iloc[i-9:i+1][['close', 'volume', 'rsi']].values.flatten()
|
||||
|
||||
# Simulate action based on price movement
|
||||
current_price = ohlcv.iloc[i]['close']
|
||||
next_price = ohlcv.iloc[i+1]['close']
|
||||
price_change = (next_price - current_price) / current_price
|
||||
|
||||
# Action: 0=HOLD, 1=BUY, 2=SELL
|
||||
if price_change > 0.001: # 0.1% threshold
|
||||
action = 1 # BUY
|
||||
reward = price_change * 100 # Reward proportional to gain
|
||||
elif price_change < -0.001:
|
||||
action = 2 # SELL
|
||||
reward = -price_change * 100 # Reward for correct short
|
||||
else:
|
||||
action = 0 # HOLD
|
||||
reward = 0
|
||||
|
||||
# Add experience tuple
|
||||
experiences.append((
|
||||
state_data, # state
|
||||
action, # action
|
||||
reward, # reward
|
||||
next_state_data, # next_state
|
||||
False # done (not terminal)
|
||||
))
|
||||
|
||||
return experiences
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting data for RL: {e}")
|
||||
return []
|
||||
|
||||
def _update_training_metrics(self, cnn_success: bool, rl_success: bool):
|
||||
"""Update training metrics tracking"""
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
|
||||
# Update training statistics
|
||||
if not hasattr(self, 'training_stats'):
|
||||
self.training_stats = {
|
||||
'last_training_time': current_time,
|
||||
'total_training_sessions': 0,
|
||||
'cnn_training_count': 0,
|
||||
'rl_training_count': 0,
|
||||
'training_data_points': 0
|
||||
}
|
||||
|
||||
self.training_stats['last_training_time'] = current_time
|
||||
self.training_stats['total_training_sessions'] += 1
|
||||
|
||||
if cnn_success:
|
||||
self.training_stats['cnn_training_count'] += 1
|
||||
if rl_success:
|
||||
self.training_stats['rl_training_count'] += 1
|
||||
|
||||
self.training_stats['training_data_points'] = len(self.tick_cache)
|
||||
|
||||
logger.debug(f"Training metrics updated: {self.training_stats}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error updating training metrics: {e}")
|
||||
|
||||
def get_tick_cache_for_training(self) -> List[Dict]:
|
||||
"""Get tick cache data for external training systems"""
|
||||
try:
|
||||
return list(self.tick_cache)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting tick cache for training: {e}")
|
||||
return []
|
||||
|
||||
def start_continuous_training(self):
|
||||
"""Start continuous training in background thread"""
|
||||
try:
|
||||
if hasattr(self, 'training_thread') and self.training_thread.is_alive():
|
||||
logger.info("Continuous training already running")
|
||||
return
|
||||
|
||||
self.training_active = True
|
||||
self.training_thread = Thread(target=self._continuous_training_loop, daemon=True)
|
||||
self.training_thread.start()
|
||||
logger.info("Continuous training started")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting continuous training: {e}")
|
||||
|
||||
def _continuous_training_loop(self):
|
||||
"""Continuous training loop running in background"""
|
||||
logger.info("Continuous training loop started")
|
||||
|
||||
while getattr(self, 'training_active', False):
|
||||
try:
|
||||
# Send training data every 30 seconds if we have enough data
|
||||
if len(self.tick_cache) >= 500: # Need sufficient data
|
||||
success = self.send_training_data_to_models()
|
||||
if success:
|
||||
logger.debug("Training data sent to models")
|
||||
|
||||
time.sleep(30) # Train every 30 seconds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in continuous training loop: {e}")
|
||||
time.sleep(60) # Wait longer on error
|
||||
|
||||
def stop_continuous_training(self):
|
||||
"""Stop continuous training"""
|
||||
try:
|
||||
self.training_active = False
|
||||
if hasattr(self, 'training_thread'):
|
||||
self.training_thread.join(timeout=5)
|
||||
logger.info("Continuous training stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping continuous training: {e}")
|
||||
# Convenience function for integration
|
||||
def create_dashboard(data_provider: DataProvider = None, orchestrator: TradingOrchestrator = None) -> TradingDashboard:
|
||||
"""Create and return a trading dashboard instance"""
|
||||
|
Reference in New Issue
Block a user