dash training info

This commit is contained in:
Dobromir Popov
2025-05-26 23:05:04 +03:00
parent 392dbb4b61
commit 678cf951a5
3 changed files with 1272 additions and 5 deletions

View File

@ -15,7 +15,7 @@ import logging
import time
from datetime import datetime, timedelta, timezone
from threading import Thread
from typing import Dict, List, Optional, Any
from typing import Dict, List, Optional, Any, Tuple
from collections import deque
# Optional WebSocket support
@ -105,7 +105,10 @@ class TradingDashboard:
# Start WebSocket tick streaming
self._start_websocket_stream()
logger.info("Trading Dashboard initialized")
# Start continuous training
self.start_continuous_training()
logger.info("Trading Dashboard initialized with continuous training")
def _setup_layout(self):
"""Setup the dashboard layout"""
@ -169,7 +172,7 @@ class TradingDashboard:
# Charts row - More compact
html.Div([
# Price chart - Full width
# Price chart - 70% width
html.Div([
html.Div([
html.H6([
@ -178,7 +181,18 @@ class TradingDashboard:
], className="card-title mb-2"),
dcc.Graph(id="price-chart", style={"height": "400px"})
], className="card-body p-2")
], className="card", style={"width": "100%"}),
], className="card", style={"width": "70%"}),
# Model Training Metrics - 30% width
html.Div([
html.Div([
html.H6([
html.I(className="fas fa-brain me-2"),
"Model Training Progress"
], className="card-title mb-2"),
html.Div(id="training-metrics", style={"height": "400px", "overflowY": "auto"})
], className="card-body p-2")
], className="card", style={"width": "28%", "marginLeft": "2%"}),
], className="row g-2 mb-3"),
# Bottom row - Trading info and performance (more compact layout)
@ -242,6 +256,7 @@ class TradingDashboard:
Output('trade-count', 'children'),
Output('memory-usage', 'children'),
Output('price-chart', 'figure'),
Output('training-metrics', 'children'),
Output('recent-decisions', 'children'),
Output('session-performance', 'children'),
Output('system-status-icon', 'className'),
@ -390,6 +405,13 @@ class TradingDashboard:
logger.warning(f"Price chart error: {e}")
price_chart = self._create_empty_chart("Price Chart", "No price data available")
# Create training metrics display
try:
training_metrics = self._create_training_metrics()
except Exception as e:
logger.warning(f"Training metrics error: {e}")
training_metrics = [html.P("Training metrics unavailable", className="text-muted")]
# Create recent decisions list
try:
decisions_list = self._create_decisions_list()
@ -417,7 +439,7 @@ class TradingDashboard:
return (
price_text, pnl_text, pnl_class, position_text, trade_count_text, memory_text,
price_chart, decisions_list, session_perf,
price_chart, training_metrics, decisions_list, session_perf,
system_status['icon_class'], system_status['title'], system_status['details']
)
@ -429,6 +451,7 @@ class TradingDashboard:
return (
"Error", "$0.00", "text-muted mb-0 small", "None", "0", "0.0%",
empty_fig,
[html.P("Error loading training metrics", className="text-danger")],
[html.P("Error loading decisions", className="text-danger")],
[html.P("Error loading performance", className="text-danger")],
"fas fa-circle text-danger fa-2x",
@ -1957,6 +1980,806 @@ class TradingDashboard:
logger.error(f"Error getting 1-second bars: {e}")
return pd.DataFrame()
def _create_training_metrics(self) -> List:
"""Create comprehensive model training metrics display"""
try:
training_items = []
# Training Data Streaming Status
tick_cache_size = len(self.tick_cache)
bars_cache_size = len(self.one_second_bars)
training_items.append(
html.Div([
html.H6([
html.I(className="fas fa-database me-2 text-info"),
"Training Data Stream"
], className="mb-2"),
html.Div([
html.Small([
html.Strong("Tick Cache: "),
html.Span(f"{tick_cache_size:,} ticks", className="text-success" if tick_cache_size > 1000 else "text-warning")
], className="d-block"),
html.Small([
html.Strong("1s Bars: "),
html.Span(f"{bars_cache_size} bars", className="text-success" if bars_cache_size > 100 else "text-warning")
], className="d-block"),
html.Small([
html.Strong("Stream: "),
html.Span("LIVE" if self.is_streaming else "OFFLINE",
className="text-success" if self.is_streaming else "text-danger")
], className="d-block")
])
], className="mb-3 p-2 border border-info rounded")
)
# Model Training Status
try:
# Try to get real training metrics from orchestrator
training_status = self._get_model_training_status()
# CNN Training Metrics
training_items.append(
html.Div([
html.H6([
html.I(className="fas fa-brain me-2 text-warning"),
"CNN Model"
], className="mb-2"),
html.Div([
html.Small([
html.Strong("Status: "),
html.Span(training_status['cnn']['status'],
className=f"text-{training_status['cnn']['status_color']}")
], className="d-block"),
html.Small([
html.Strong("Accuracy: "),
html.Span(f"{training_status['cnn']['accuracy']:.1%}", className="text-info")
], className="d-block"),
html.Small([
html.Strong("Loss: "),
html.Span(f"{training_status['cnn']['loss']:.4f}", className="text-muted")
], className="d-block"),
html.Small([
html.Strong("Epochs: "),
html.Span(f"{training_status['cnn']['epochs']}", className="text-muted")
], className="d-block"),
html.Small([
html.Strong("Learning Rate: "),
html.Span(f"{training_status['cnn']['learning_rate']:.6f}", className="text-muted")
], className="d-block")
])
], className="mb-3 p-2 border border-warning rounded")
)
# RL Training Metrics
training_items.append(
html.Div([
html.H6([
html.I(className="fas fa-robot me-2 text-success"),
"RL Agent (DQN)"
], className="mb-2"),
html.Div([
html.Small([
html.Strong("Status: "),
html.Span(training_status['rl']['status'],
className=f"text-{training_status['rl']['status_color']}")
], className="d-block"),
html.Small([
html.Strong("Win Rate: "),
html.Span(f"{training_status['rl']['win_rate']:.1%}", className="text-info")
], className="d-block"),
html.Small([
html.Strong("Avg Reward: "),
html.Span(f"{training_status['rl']['avg_reward']:.2f}", className="text-muted")
], className="d-block"),
html.Small([
html.Strong("Episodes: "),
html.Span(f"{training_status['rl']['episodes']}", className="text-muted")
], className="d-block"),
html.Small([
html.Strong("Epsilon: "),
html.Span(f"{training_status['rl']['epsilon']:.3f}", className="text-muted")
], className="d-block"),
html.Small([
html.Strong("Memory: "),
html.Span(f"{training_status['rl']['memory_size']:,}", className="text-muted")
], className="d-block")
])
], className="mb-3 p-2 border border-success rounded")
)
# Training Progress Chart (Mini)
training_items.append(
html.Div([
html.H6([
html.I(className="fas fa-chart-line me-2 text-primary"),
"Training Progress"
], className="mb-2"),
dcc.Graph(
figure=self._create_mini_training_chart(training_status),
style={"height": "150px"},
config={'displayModeBar': False}
)
], className="mb-3 p-2 border border-primary rounded")
)
except Exception as e:
logger.warning(f"Error getting training status: {e}")
training_items.append(
html.Div([
html.P("Training status unavailable", className="text-muted"),
html.Small(f"Error: {str(e)}", className="text-danger")
], className="mb-3 p-2 border border-secondary rounded")
)
# Real-time Training Events Log
training_items.append(
html.Div([
html.H6([
html.I(className="fas fa-list me-2 text-secondary"),
"Recent Training Events"
], className="mb-2"),
html.Div(
id="training-events-log",
children=self._get_recent_training_events(),
style={"maxHeight": "120px", "overflowY": "auto", "fontSize": "0.8em"}
)
], className="mb-3 p-2 border border-secondary rounded")
)
return training_items
except Exception as e:
logger.error(f"Error creating training metrics: {e}")
return [html.P(f"Training metrics error: {str(e)}", className="text-danger")]
def _get_model_training_status(self) -> Dict:
"""Get current model training status and metrics"""
try:
# Initialize default status
status = {
'cnn': {
'status': 'IDLE',
'status_color': 'secondary',
'accuracy': 0.0,
'loss': 0.0,
'epochs': 0,
'learning_rate': 0.001
},
'rl': {
'status': 'IDLE',
'status_color': 'secondary',
'win_rate': 0.0,
'avg_reward': 0.0,
'episodes': 0,
'epsilon': 1.0,
'memory_size': 0
}
}
# Try to get real metrics from orchestrator
if hasattr(self.orchestrator, 'get_training_metrics'):
try:
real_metrics = self.orchestrator.get_training_metrics()
if real_metrics:
status.update(real_metrics)
logger.debug("Using real training metrics from orchestrator")
except Exception as e:
logger.warning(f"Error getting orchestrator metrics: {e}")
# Try to get metrics from model registry
if hasattr(self.model_registry, 'get_training_stats'):
try:
registry_stats = self.model_registry.get_training_stats()
if registry_stats:
# Update with registry stats
for model_type in ['cnn', 'rl']:
if model_type in registry_stats:
status[model_type].update(registry_stats[model_type])
logger.debug("Updated with model registry stats")
except Exception as e:
logger.warning(f"Error getting registry stats: {e}")
# Try to read from training logs
try:
log_metrics = self._parse_training_logs()
if log_metrics:
for model_type in ['cnn', 'rl']:
if model_type in log_metrics:
status[model_type].update(log_metrics[model_type])
logger.debug("Updated with training log metrics")
except Exception as e:
logger.warning(f"Error parsing training logs: {e}")
# Check if models are actively training based on tick data flow
if self.is_streaming and len(self.tick_cache) > 100:
# Models should be training if we have data
status['cnn']['status'] = 'TRAINING'
status['cnn']['status_color'] = 'warning'
status['rl']['status'] = 'TRAINING'
status['rl']['status_color'] = 'success'
return status
except Exception as e:
logger.error(f"Error getting model training status: {e}")
return {
'cnn': {'status': 'ERROR', 'status_color': 'danger', 'accuracy': 0.0, 'loss': 0.0, 'epochs': 0, 'learning_rate': 0.001},
'rl': {'status': 'ERROR', 'status_color': 'danger', 'win_rate': 0.0, 'avg_reward': 0.0, 'episodes': 0, 'epsilon': 1.0, 'memory_size': 0}
}
def _parse_training_logs(self) -> Dict:
"""Parse recent training logs for metrics"""
try:
from pathlib import Path
import re
metrics = {'cnn': {}, 'rl': {}}
# Parse CNN training logs
cnn_log_paths = [
'logs/cnn_training.log',
'logs/training.log',
'runs/*/events.out.tfevents.*' # TensorBoard logs
]
for log_path in cnn_log_paths:
if Path(log_path).exists():
try:
with open(log_path, 'r') as f:
lines = f.readlines()[-50:] # Last 50 lines
for line in lines:
# Look for CNN metrics
if 'epoch' in line.lower() and 'loss' in line.lower():
# Extract epoch, loss, accuracy
epoch_match = re.search(r'epoch[:\s]+(\d+)', line, re.IGNORECASE)
loss_match = re.search(r'loss[:\s]+([\d\.]+)', line, re.IGNORECASE)
acc_match = re.search(r'acc[uracy]*[:\s]+([\d\.]+)', line, re.IGNORECASE)
if epoch_match:
metrics['cnn']['epochs'] = int(epoch_match.group(1))
if loss_match:
metrics['cnn']['loss'] = float(loss_match.group(1))
if acc_match:
acc_val = float(acc_match.group(1))
# Normalize accuracy (handle both 0-1 and 0-100 formats)
metrics['cnn']['accuracy'] = acc_val if acc_val <= 1.0 else acc_val / 100.0
break # Use first available log
except Exception as e:
logger.debug(f"Error parsing {log_path}: {e}")
# Parse RL training logs
rl_log_paths = [
'logs/rl_training.log',
'logs/training.log'
]
for log_path in rl_log_paths:
if Path(log_path).exists():
try:
with open(log_path, 'r') as f:
lines = f.readlines()[-50:] # Last 50 lines
for line in lines:
# Look for RL metrics
if 'episode' in line.lower():
episode_match = re.search(r'episode[:\s]+(\d+)', line, re.IGNORECASE)
reward_match = re.search(r'reward[:\s]+([-\d\.]+)', line, re.IGNORECASE)
epsilon_match = re.search(r'epsilon[:\s]+([\d\.]+)', line, re.IGNORECASE)
if episode_match:
metrics['rl']['episodes'] = int(episode_match.group(1))
if reward_match:
metrics['rl']['avg_reward'] = float(reward_match.group(1))
if epsilon_match:
metrics['rl']['epsilon'] = float(epsilon_match.group(1))
break # Use first available log
except Exception as e:
logger.debug(f"Error parsing {log_path}: {e}")
return metrics if any(metrics.values()) else None
except Exception as e:
logger.warning(f"Error parsing training logs: {e}")
return None
def _create_mini_training_chart(self, training_status: Dict) -> go.Figure:
"""Create a mini training progress chart"""
try:
fig = go.Figure()
# Create sample training progress data (in real implementation, this would come from logs)
import numpy as np
# CNN accuracy trend (simulated from current metrics)
cnn_acc = training_status['cnn']['accuracy']
cnn_epochs = max(1, training_status['cnn']['epochs'])
if cnn_epochs > 1:
# Create a realistic training curve
x_cnn = np.linspace(1, cnn_epochs, min(20, cnn_epochs))
# Simulate learning curve that converges to current accuracy
y_cnn = cnn_acc * (1 - np.exp(-x_cnn / (cnn_epochs * 0.3))) + np.random.normal(0, 0.01, len(x_cnn))
y_cnn = np.clip(y_cnn, 0, 1) # Keep in valid range
fig.add_trace(go.Scatter(
x=x_cnn,
y=y_cnn,
mode='lines',
name='CNN Accuracy',
line=dict(color='orange', width=2),
hovertemplate='Epoch: %{x}<br>Accuracy: %{y:.3f}<extra></extra>'
))
# RL win rate trend
rl_win_rate = training_status['rl']['win_rate']
rl_episodes = max(1, training_status['rl']['episodes'])
if rl_episodes > 1:
x_rl = np.linspace(1, rl_episodes, min(20, rl_episodes))
# Simulate RL learning curve
y_rl = rl_win_rate * (1 - np.exp(-x_rl / (rl_episodes * 0.4))) + np.random.normal(0, 0.02, len(x_rl))
y_rl = np.clip(y_rl, 0, 1) # Keep in valid range
fig.add_trace(go.Scatter(
x=x_rl,
y=y_rl,
mode='lines',
name='RL Win Rate',
line=dict(color='green', width=2),
hovertemplate='Episode: %{x}<br>Win Rate: %{y:.3f}<extra></extra>'
))
# Update layout for mini chart
fig.update_layout(
template="plotly_dark",
height=150,
margin=dict(l=20, r=20, t=20, b=20),
showlegend=True,
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1,
font=dict(size=10)
),
xaxis=dict(title="", showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)'),
yaxis=dict(title="", showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)', range=[0, 1])
)
return fig
except Exception as e:
logger.warning(f"Error creating mini training chart: {e}")
# Return empty chart
fig = go.Figure()
fig.add_annotation(
text="Training data loading...",
xref="paper", yref="paper",
x=0.5, y=0.5,
showarrow=False,
font=dict(size=12, color="gray")
)
fig.update_layout(
template="plotly_dark",
height=150,
margin=dict(l=20, r=20, t=20, b=20)
)
return fig
def _get_recent_training_events(self) -> List:
"""Get recent training events for display"""
try:
events = []
current_time = datetime.now()
# Add tick streaming events
if self.is_streaming:
events.append(
html.Div([
html.Small([
html.Span(f"{current_time.strftime('%H:%M:%S')} ", className="text-muted"),
html.Span("Streaming live ticks", className="text-success")
])
])
)
# Add training data events
if len(self.tick_cache) > 0:
cache_minutes = len(self.tick_cache) / 3600 # Assuming 60 ticks per second
events.append(
html.Div([
html.Small([
html.Span(f"{current_time.strftime('%H:%M:%S')} ", className="text-muted"),
html.Span(f"Training cache: {cache_minutes:.1f}m data", className="text-info")
])
])
)
# Add model training events (simulated based on activity)
if len(self.recent_decisions) > 0:
last_decision_time = self.recent_decisions[-1].get('timestamp', current_time)
if isinstance(last_decision_time, datetime):
time_diff = (current_time - last_decision_time.replace(tzinfo=None)).total_seconds()
if time_diff < 300: # Within last 5 minutes
events.append(
html.Div([
html.Small([
html.Span(f"{last_decision_time.strftime('%H:%M:%S')} ", className="text-muted"),
html.Span("Model prediction generated", className="text-warning")
])
])
)
# Add system events
events.append(
html.Div([
html.Small([
html.Span(f"{current_time.strftime('%H:%M:%S')} ", className="text-muted"),
html.Span("Dashboard updated", className="text-primary")
])
])
)
# Limit to last 5 events
return events[-5:] if events else [html.Small("No recent events", className="text-muted")]
except Exception as e:
logger.warning(f"Error getting training events: {e}")
return [html.Small("Events unavailable", className="text-muted")]
def send_training_data_to_models(self) -> bool:
"""Send current tick cache data to models for training"""
try:
if len(self.tick_cache) < 100:
logger.debug("Insufficient tick data for training (need at least 100 ticks)")
return False
# Convert tick cache to training format
training_data = self._prepare_training_data()
if not training_data:
logger.warning("Failed to prepare training data")
return False
# Send to CNN models
cnn_success = self._send_data_to_cnn_models(training_data)
# Send to RL models
rl_success = self._send_data_to_rl_models(training_data)
# Update training metrics
if cnn_success or rl_success:
self._update_training_metrics(cnn_success, rl_success)
logger.info(f"Training data sent - CNN: {cnn_success}, RL: {rl_success}")
return True
return False
except Exception as e:
logger.error(f"Error sending training data to models: {e}")
return False
def _prepare_training_data(self) -> Dict[str, Any]:
"""Prepare tick cache data for model training"""
try:
# Convert tick cache to DataFrame
tick_data = []
for tick in list(self.tick_cache):
tick_data.append({
'timestamp': tick['timestamp'],
'price': tick['price'],
'volume': tick.get('volume', 0),
'side': tick.get('side', 'unknown')
})
if not tick_data:
return None
df = pd.DataFrame(tick_data)
df['timestamp'] = pd.to_datetime(df['timestamp'])
df = df.sort_values('timestamp')
# Create OHLCV bars from ticks (1-second aggregation)
df.set_index('timestamp', inplace=True)
ohlcv = df.groupby(pd.Grouper(freq='1S')).agg({
'price': ['first', 'max', 'min', 'last'],
'volume': 'sum'
}).dropna()
# Flatten column names
ohlcv.columns = ['open', 'high', 'low', 'close', 'volume']
# Calculate technical indicators
ohlcv['sma_20'] = ohlcv['close'].rolling(20).mean()
ohlcv['sma_50'] = ohlcv['close'].rolling(50).mean()
ohlcv['rsi'] = self._calculate_rsi(ohlcv['close'])
ohlcv['price_change'] = ohlcv['close'].pct_change()
ohlcv['volume_sma'] = ohlcv['volume'].rolling(20).mean()
# Remove NaN values
ohlcv = ohlcv.dropna()
if len(ohlcv) < 50:
logger.debug("Insufficient processed data for training")
return None
return {
'ohlcv': ohlcv,
'raw_ticks': df,
'symbol': 'ETH/USDT',
'timeframe': '1s',
'features': ['open', 'high', 'low', 'close', 'volume', 'sma_20', 'sma_50', 'rsi'],
'timestamp': datetime.now()
}
except Exception as e:
logger.error(f"Error preparing training data: {e}")
return None
def _calculate_rsi(self, prices: pd.Series, period: int = 14) -> pd.Series:
"""Calculate RSI indicator"""
try:
delta = prices.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
rsi = 100 - (100 / (1 + rs))
return rsi
except Exception as e:
logger.warning(f"Error calculating RSI: {e}")
return pd.Series(index=prices.index, dtype=float)
def _send_data_to_cnn_models(self, training_data: Dict[str, Any]) -> bool:
"""Send training data to CNN models"""
try:
success_count = 0
# Get CNN models from registry
for model_name, model in self.model_registry.models.items():
if hasattr(model, 'train_online') or 'cnn' in model_name.lower():
try:
# Prepare CNN-specific data format
cnn_data = self._format_data_for_cnn(training_data)
if hasattr(model, 'train_online'):
# Online training method
model.train_online(cnn_data)
success_count += 1
logger.debug(f"Sent training data to CNN model: {model_name}")
elif hasattr(model, 'update_with_data'):
# Alternative update method
model.update_with_data(cnn_data)
success_count += 1
logger.debug(f"Updated CNN model with data: {model_name}")
except Exception as e:
logger.warning(f"Error sending data to CNN model {model_name}: {e}")
# Try to send to orchestrator's CNN training
if hasattr(self.orchestrator, 'update_cnn_training'):
try:
self.orchestrator.update_cnn_training(training_data)
success_count += 1
logger.debug("Sent training data to orchestrator CNN training")
except Exception as e:
logger.warning(f"Error sending data to orchestrator CNN: {e}")
return success_count > 0
except Exception as e:
logger.error(f"Error sending data to CNN models: {e}")
return False
def _send_data_to_rl_models(self, training_data: Dict[str, Any]) -> bool:
"""Send training data to RL models"""
try:
success_count = 0
# Get RL models from registry
for model_name, model in self.model_registry.models.items():
if hasattr(model, 'add_experience') or 'rl' in model_name.lower() or 'dqn' in model_name.lower():
try:
# Prepare RL-specific data format (state-action-reward-next_state)
rl_experiences = self._format_data_for_rl(training_data)
if hasattr(model, 'add_experience'):
# Add experiences to replay buffer
for experience in rl_experiences:
model.add_experience(*experience)
success_count += 1
logger.debug(f"Sent {len(rl_experiences)} experiences to RL model: {model_name}")
elif hasattr(model, 'update_replay_buffer'):
# Alternative replay buffer update
model.update_replay_buffer(rl_experiences)
success_count += 1
logger.debug(f"Updated RL replay buffer: {model_name}")
except Exception as e:
logger.warning(f"Error sending data to RL model {model_name}: {e}")
# Try to send to orchestrator's RL training
if hasattr(self.orchestrator, 'update_rl_training'):
try:
self.orchestrator.update_rl_training(training_data)
success_count += 1
logger.debug("Sent training data to orchestrator RL training")
except Exception as e:
logger.warning(f"Error sending data to orchestrator RL: {e}")
return success_count > 0
except Exception as e:
logger.error(f"Error sending data to RL models: {e}")
return False
def _format_data_for_cnn(self, training_data: Dict[str, Any]) -> Dict[str, Any]:
"""Format training data for CNN models"""
try:
ohlcv = training_data['ohlcv']
# Create feature matrix for CNN (sequence of OHLCV + indicators)
features = ohlcv[['open', 'high', 'low', 'close', 'volume', 'sma_20', 'sma_50', 'rsi']].values
# Normalize features
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
features_normalized = scaler.fit_transform(features)
# Create sequences for CNN training (sliding window)
sequence_length = 60 # 1 minute of 1-second data
sequences = []
targets = []
for i in range(sequence_length, len(features_normalized)):
sequences.append(features_normalized[i-sequence_length:i])
# Target: price direction (1 for up, 0 for down)
current_price = ohlcv.iloc[i]['close']
future_price = ohlcv.iloc[min(i+5, len(ohlcv)-1)]['close'] # 5 seconds ahead
targets.append(1 if future_price > current_price else 0)
return {
'sequences': np.array(sequences),
'targets': np.array(targets),
'feature_names': ['open', 'high', 'low', 'close', 'volume', 'sma_20', 'sma_50', 'rsi'],
'sequence_length': sequence_length,
'symbol': training_data['symbol'],
'timestamp': training_data['timestamp']
}
except Exception as e:
logger.error(f"Error formatting data for CNN: {e}")
return {}
def _format_data_for_rl(self, training_data: Dict[str, Any]) -> List[Tuple]:
"""Format training data for RL models (state, action, reward, next_state, done)"""
try:
ohlcv = training_data['ohlcv']
experiences = []
# Create state representations
for i in range(10, len(ohlcv) - 1): # Need history for state
# Current state (last 10 bars)
state_data = ohlcv.iloc[i-10:i][['close', 'volume', 'rsi']].values.flatten()
# Next state
next_state_data = ohlcv.iloc[i-9:i+1][['close', 'volume', 'rsi']].values.flatten()
# Simulate action based on price movement
current_price = ohlcv.iloc[i]['close']
next_price = ohlcv.iloc[i+1]['close']
price_change = (next_price - current_price) / current_price
# Action: 0=HOLD, 1=BUY, 2=SELL
if price_change > 0.001: # 0.1% threshold
action = 1 # BUY
reward = price_change * 100 # Reward proportional to gain
elif price_change < -0.001:
action = 2 # SELL
reward = -price_change * 100 # Reward for correct short
else:
action = 0 # HOLD
reward = 0
# Add experience tuple
experiences.append((
state_data, # state
action, # action
reward, # reward
next_state_data, # next_state
False # done (not terminal)
))
return experiences
except Exception as e:
logger.error(f"Error formatting data for RL: {e}")
return []
def _update_training_metrics(self, cnn_success: bool, rl_success: bool):
"""Update training metrics tracking"""
try:
current_time = datetime.now()
# Update training statistics
if not hasattr(self, 'training_stats'):
self.training_stats = {
'last_training_time': current_time,
'total_training_sessions': 0,
'cnn_training_count': 0,
'rl_training_count': 0,
'training_data_points': 0
}
self.training_stats['last_training_time'] = current_time
self.training_stats['total_training_sessions'] += 1
if cnn_success:
self.training_stats['cnn_training_count'] += 1
if rl_success:
self.training_stats['rl_training_count'] += 1
self.training_stats['training_data_points'] = len(self.tick_cache)
logger.debug(f"Training metrics updated: {self.training_stats}")
except Exception as e:
logger.warning(f"Error updating training metrics: {e}")
def get_tick_cache_for_training(self) -> List[Dict]:
"""Get tick cache data for external training systems"""
try:
return list(self.tick_cache)
except Exception as e:
logger.error(f"Error getting tick cache for training: {e}")
return []
def start_continuous_training(self):
"""Start continuous training in background thread"""
try:
if hasattr(self, 'training_thread') and self.training_thread.is_alive():
logger.info("Continuous training already running")
return
self.training_active = True
self.training_thread = Thread(target=self._continuous_training_loop, daemon=True)
self.training_thread.start()
logger.info("Continuous training started")
except Exception as e:
logger.error(f"Error starting continuous training: {e}")
def _continuous_training_loop(self):
"""Continuous training loop running in background"""
logger.info("Continuous training loop started")
while getattr(self, 'training_active', False):
try:
# Send training data every 30 seconds if we have enough data
if len(self.tick_cache) >= 500: # Need sufficient data
success = self.send_training_data_to_models()
if success:
logger.debug("Training data sent to models")
time.sleep(30) # Train every 30 seconds
except Exception as e:
logger.error(f"Error in continuous training loop: {e}")
time.sleep(60) # Wait longer on error
def stop_continuous_training(self):
"""Stop continuous training"""
try:
self.training_active = False
if hasattr(self, 'training_thread'):
self.training_thread.join(timeout=5)
logger.info("Continuous training stopped")
except Exception as e:
logger.error(f"Error stopping continuous training: {e}")
# Convenience function for integration
def create_dashboard(data_provider: DataProvider = None, orchestrator: TradingOrchestrator = None) -> TradingDashboard:
"""Create and return a trading dashboard instance"""