420 lines
16 KiB
Python
420 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Realtime RL Training with TensorBoard and Web UI Monitoring
|
|
|
|
This script runs RL training with:
|
|
- TensorBoard monitoring for training metrics
|
|
- Web UI for real-time trading visualization
|
|
- Real market data integration
|
|
- PnL tracking and performance analysis
|
|
"""
|
|
|
|
import asyncio
|
|
import threading
|
|
import time
|
|
import logging
|
|
import argparse
|
|
from datetime import datetime
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add project path
|
|
project_root = Path(__file__).parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from core.config import setup_logging, get_config
|
|
from core.data_provider import DataProvider
|
|
from training.rl_trainer import RLTrainer
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class RealtimeRLTrainer:
|
|
"""Realtime RL Trainer with TensorBoard and Web UI"""
|
|
|
|
def __init__(self, symbol="ETH/USDT", initial_balance=1000.0):
|
|
self.symbol = symbol
|
|
self.initial_balance = initial_balance
|
|
|
|
# Initialize data provider
|
|
self.data_provider = DataProvider(
|
|
symbols=[symbol],
|
|
timeframes=['1s', '1m', '5m', '15m', '1h']
|
|
)
|
|
|
|
# Initialize RL trainer with TensorBoard
|
|
self.rl_trainer = RLTrainer(self.data_provider)
|
|
|
|
# Training state
|
|
self.current_episode = 0
|
|
self.session_trades = []
|
|
self.session_balance = initial_balance
|
|
self.session_pnl = 0.0
|
|
self.training_active = False
|
|
|
|
# Web dashboard
|
|
self.dashboard = None
|
|
self.dashboard_thread = None
|
|
|
|
logger.info(f"RealtimeRLTrainer initialized for {symbol}")
|
|
logger.info(f"TensorBoard logs: {self.rl_trainer.tensorboard_dir}")
|
|
|
|
def setup_web_dashboard(self, port=8051):
|
|
"""Setup web dashboard for monitoring"""
|
|
try:
|
|
import dash
|
|
from dash import dcc, html, Input, Output
|
|
import plotly.graph_objects as go
|
|
import plotly.express as px
|
|
|
|
# Create Dash app
|
|
app = dash.Dash(__name__)
|
|
|
|
# Layout
|
|
app.layout = html.Div([
|
|
html.H1(f"RL Training Monitor - {self.symbol}",
|
|
style={'textAlign': 'center', 'color': '#2c3e50'}),
|
|
|
|
# Refresh interval
|
|
dcc.Interval(
|
|
id='interval-component',
|
|
interval=2000, # Update every 2 seconds
|
|
n_intervals=0
|
|
),
|
|
|
|
# Status row
|
|
html.Div([
|
|
html.Div([
|
|
html.H3("Training Status", style={'color': '#34495e'}),
|
|
html.P(id='training-status', style={'fontSize': 18})
|
|
], className='three columns'),
|
|
|
|
html.Div([
|
|
html.H3("Current Episode", style={'color': '#34495e'}),
|
|
html.P(id='current-episode', style={'fontSize': 18})
|
|
], className='three columns'),
|
|
|
|
html.Div([
|
|
html.H3("Session Balance", style={'color': '#27ae60'}),
|
|
html.P(id='session-balance', style={'fontSize': 18})
|
|
], className='three columns'),
|
|
|
|
html.Div([
|
|
html.H3("Session PnL", style={'color': '#e74c3c'}),
|
|
html.P(id='session-pnl', style={'fontSize': 18})
|
|
], className='three columns'),
|
|
], className='row', style={'margin': '20px'}),
|
|
|
|
# Charts row
|
|
html.Div([
|
|
html.Div([
|
|
dcc.Graph(id='rewards-chart')
|
|
], className='six columns'),
|
|
|
|
html.Div([
|
|
dcc.Graph(id='balance-chart')
|
|
], className='six columns'),
|
|
], className='row'),
|
|
|
|
html.Div([
|
|
html.Div([
|
|
dcc.Graph(id='trades-chart')
|
|
], className='six columns'),
|
|
|
|
html.Div([
|
|
dcc.Graph(id='win-rate-chart')
|
|
], className='six columns'),
|
|
], className='row'),
|
|
|
|
# TensorBoard link
|
|
html.Div([
|
|
html.H3("TensorBoard Monitoring"),
|
|
html.A("Open TensorBoard",
|
|
href="http://localhost:6006",
|
|
target="_blank",
|
|
style={'fontSize': 16, 'color': '#3498db'})
|
|
], style={'textAlign': 'center', 'margin': '20px'})
|
|
])
|
|
|
|
# Callbacks
|
|
@app.callback(
|
|
[Output('training-status', 'children'),
|
|
Output('current-episode', 'children'),
|
|
Output('session-balance', 'children'),
|
|
Output('session-pnl', 'children'),
|
|
Output('rewards-chart', 'figure'),
|
|
Output('balance-chart', 'figure'),
|
|
Output('trades-chart', 'figure'),
|
|
Output('win-rate-chart', 'figure')],
|
|
[Input('interval-component', 'n_intervals')]
|
|
)
|
|
def update_dashboard(n):
|
|
# Status updates
|
|
status = "TRAINING" if self.training_active else "IDLE"
|
|
episode = f"{self.current_episode}"
|
|
balance = f"${self.session_balance:.2f}"
|
|
pnl = f"${self.session_pnl:.2f}"
|
|
|
|
# Create charts
|
|
rewards_fig = self._create_rewards_chart()
|
|
balance_fig = self._create_balance_chart()
|
|
trades_fig = self._create_trades_chart()
|
|
win_rate_fig = self._create_win_rate_chart()
|
|
|
|
return status, episode, balance, pnl, rewards_fig, balance_fig, trades_fig, win_rate_fig
|
|
|
|
self.dashboard = app
|
|
logger.info(f"Web dashboard created for port {port}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error setting up web dashboard: {e}")
|
|
self.dashboard = None
|
|
|
|
def _create_rewards_chart(self):
|
|
"""Create rewards chart"""
|
|
import plotly.graph_objects as go
|
|
|
|
if not self.rl_trainer.episode_rewards:
|
|
fig = go.Figure()
|
|
fig.add_annotation(text="No data yet", x=0.5, y=0.5, xref="paper", yref="paper")
|
|
else:
|
|
fig = go.Figure()
|
|
fig.add_trace(go.Scatter(
|
|
y=self.rl_trainer.episode_rewards,
|
|
mode='lines',
|
|
name='Episode Rewards',
|
|
line=dict(color='#3498db')
|
|
))
|
|
|
|
# Add moving average if enough data
|
|
if len(self.rl_trainer.avg_rewards) > 0:
|
|
fig.add_trace(go.Scatter(
|
|
y=self.rl_trainer.avg_rewards,
|
|
mode='lines',
|
|
name='Moving Average',
|
|
line=dict(color='#e74c3c', width=2)
|
|
))
|
|
|
|
fig.update_layout(title="Episode Rewards", xaxis_title="Episode", yaxis_title="Reward")
|
|
return fig
|
|
|
|
def _create_balance_chart(self):
|
|
"""Create balance chart"""
|
|
import plotly.graph_objects as go
|
|
|
|
if not self.rl_trainer.episode_balances:
|
|
fig = go.Figure()
|
|
fig.add_annotation(text="No data yet", x=0.5, y=0.5, xref="paper", yref="paper")
|
|
else:
|
|
fig = go.Figure()
|
|
fig.add_trace(go.Scatter(
|
|
y=self.rl_trainer.episode_balances,
|
|
mode='lines',
|
|
name='Balance',
|
|
line=dict(color='#27ae60')
|
|
))
|
|
|
|
# Add initial balance line
|
|
fig.add_hline(y=self.initial_balance, line_dash="dash",
|
|
annotation_text="Initial Balance")
|
|
|
|
fig.update_layout(title="Portfolio Balance", xaxis_title="Episode", yaxis_title="Balance ($)")
|
|
return fig
|
|
|
|
def _create_trades_chart(self):
|
|
"""Create trades per episode chart"""
|
|
import plotly.graph_objects as go
|
|
|
|
if not self.rl_trainer.episode_trades:
|
|
fig = go.Figure()
|
|
fig.add_annotation(text="No data yet", x=0.5, y=0.5, xref="paper", yref="paper")
|
|
else:
|
|
fig = go.Figure()
|
|
fig.add_trace(go.Bar(
|
|
y=self.rl_trainer.episode_trades,
|
|
name='Trades per Episode',
|
|
marker_color='#f39c12'
|
|
))
|
|
|
|
fig.update_layout(title="Trades per Episode", xaxis_title="Episode", yaxis_title="Number of Trades")
|
|
return fig
|
|
|
|
def _create_win_rate_chart(self):
|
|
"""Create win rate chart"""
|
|
import plotly.graph_objects as go
|
|
|
|
if not self.rl_trainer.win_rates:
|
|
fig = go.Figure()
|
|
fig.add_annotation(text="No data yet", x=0.5, y=0.5, xref="paper", yref="paper")
|
|
else:
|
|
fig = go.Figure()
|
|
fig.add_trace(go.Scatter(
|
|
y=self.rl_trainer.win_rates,
|
|
mode='lines+markers',
|
|
name='Win Rate',
|
|
line=dict(color='#9b59b6')
|
|
))
|
|
|
|
# Add 50% line
|
|
fig.add_hline(y=0.5, line_dash="dash",
|
|
annotation_text="Break Even")
|
|
|
|
fig.update_layout(title="Win Rate", xaxis_title="Evaluation", yaxis_title="Win Rate")
|
|
return fig
|
|
|
|
def start_web_dashboard(self, port=8051):
|
|
"""Start web dashboard in background thread"""
|
|
if self.dashboard is None:
|
|
self.setup_web_dashboard(port)
|
|
|
|
if self.dashboard is not None:
|
|
def run_dashboard():
|
|
try:
|
|
# Use run instead of run_server for newer Dash versions
|
|
self.dashboard.run(port=port, debug=False, use_reloader=False)
|
|
except Exception as e:
|
|
logger.error(f"Error running dashboard: {e}")
|
|
|
|
self.dashboard_thread = threading.Thread(target=run_dashboard, daemon=True)
|
|
self.dashboard_thread.start()
|
|
logger.info(f"Web dashboard started on http://localhost:{port}")
|
|
else:
|
|
logger.warning("Dashboard not available")
|
|
|
|
async def train_realtime(self, episodes=100, evaluation_interval=10):
|
|
"""Run realtime training with monitoring"""
|
|
logger.info(f"Starting realtime RL training for {episodes} episodes")
|
|
logger.info(f"TensorBoard: http://localhost:6006")
|
|
logger.info(f"Web UI: http://localhost:8051")
|
|
|
|
self.training_active = True
|
|
|
|
# Setup environment and agent
|
|
environment, agent = self.rl_trainer.setup_environment_and_agent()
|
|
|
|
# Training loop
|
|
for episode in range(episodes):
|
|
self.current_episode = episode
|
|
|
|
# Run episode
|
|
episode_start = time.time()
|
|
results = self.rl_trainer.run_episode(episode, training=True)
|
|
episode_time = time.time() - episode_start
|
|
|
|
# Update session tracking
|
|
self.session_balance = results.get('balance', self.initial_balance)
|
|
self.session_pnl = self.session_balance - self.initial_balance
|
|
|
|
# Log episode metrics to TensorBoard
|
|
self.rl_trainer.log_episode_metrics(episode, {
|
|
'total_reward': results['reward'],
|
|
'final_balance': results['balance'],
|
|
'total_return': results['pnl_percentage'],
|
|
'steps': results['steps'],
|
|
'total_trades': results['trades'],
|
|
'win_rate': 1.0 if results['pnl'] > 0 else 0.0,
|
|
'epsilon': agent.epsilon,
|
|
'memory_size': len(agent.memory) if hasattr(agent, 'memory') else 0
|
|
})
|
|
|
|
# Log progress
|
|
if episode % 10 == 0:
|
|
logger.info(
|
|
f"Episode {episode}/{episodes} - "
|
|
f"Reward: {results['reward']:.4f}, "
|
|
f"Balance: ${results['balance']:.2f}, "
|
|
f"PnL: {results['pnl_percentage']:.2f}%, "
|
|
f"Trades: {results['trades']}, "
|
|
f"Time: {episode_time:.2f}s"
|
|
)
|
|
|
|
# Evaluation
|
|
if episode % evaluation_interval == 0 and episode > 0:
|
|
eval_results = self.rl_trainer.evaluate_agent(num_episodes=3)
|
|
logger.info(
|
|
f"Evaluation - Avg Reward: {eval_results['avg_reward']:.4f}, "
|
|
f"Win Rate: {eval_results['win_rate']:.2%}"
|
|
)
|
|
|
|
# Small delay to allow UI updates
|
|
await asyncio.sleep(0.1)
|
|
|
|
self.training_active = False
|
|
logger.info("Training completed!")
|
|
|
|
# Save final model
|
|
save_path = f"models/rl/realtime_agent_{int(time.time())}.pt"
|
|
agent.save(save_path)
|
|
logger.info(f"Model saved: {save_path}")
|
|
|
|
return {
|
|
'episodes': episodes,
|
|
'final_balance': self.session_balance,
|
|
'final_pnl': self.session_pnl,
|
|
'model_path': save_path
|
|
}
|
|
|
|
async def main():
|
|
"""Main function"""
|
|
parser = argparse.ArgumentParser(description='Realtime RL Training with Monitoring')
|
|
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading symbol')
|
|
parser.add_argument('--episodes', type=int, default=50, help='Number of episodes')
|
|
parser.add_argument('--balance', type=float, default=1000.0, help='Initial balance')
|
|
parser.add_argument('--web-port', type=int, default=8051, help='Web dashboard port')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Setup logging
|
|
setup_logging()
|
|
|
|
logger.info("=" * 60)
|
|
logger.info("REALTIME RL TRAINING WITH MONITORING")
|
|
logger.info(f"Symbol: {args.symbol}")
|
|
logger.info(f"Episodes: {args.episodes}")
|
|
logger.info(f"Initial Balance: ${args.balance:.2f}")
|
|
logger.info("=" * 60)
|
|
|
|
try:
|
|
# Create trainer
|
|
trainer = RealtimeRLTrainer(
|
|
symbol=args.symbol,
|
|
initial_balance=args.balance
|
|
)
|
|
|
|
# Start web dashboard
|
|
trainer.start_web_dashboard(port=args.web_port)
|
|
|
|
# Wait for dashboard to start
|
|
await asyncio.sleep(2)
|
|
|
|
logger.info("MONITORING READY!")
|
|
logger.info(f"TensorBoard: http://localhost:6006")
|
|
logger.info(f"Web Dashboard: http://localhost:{args.web_port}")
|
|
logger.info("=" * 60)
|
|
|
|
# Run training
|
|
results = await trainer.train_realtime(
|
|
episodes=args.episodes,
|
|
evaluation_interval=10
|
|
)
|
|
|
|
logger.info("Training Results:")
|
|
logger.info(f" Final Balance: ${results['final_balance']:.2f}")
|
|
logger.info(f" Final PnL: ${results['final_pnl']:.2f}")
|
|
logger.info(f" Model Saved: {results['model_path']}")
|
|
|
|
# Keep running for monitoring
|
|
logger.info("Training complete. Press Ctrl+C to exit monitoring.")
|
|
while True:
|
|
await asyncio.sleep(1)
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Training stopped by user")
|
|
except Exception as e:
|
|
logger.error(f"Error in training: {e}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |