gogo2/train_realtime_with_tensorboard.py
Dobromir Popov 310f3c5bf9 wip
2025-05-24 09:59:11 +03:00

476 lines
18 KiB
Python

#!/usr/bin/env python3
"""
Realtime RL Training with TensorBoard and Web UI Monitoring
This script runs RL training with:
- TensorBoard monitoring for training metrics
- Web UI for real-time trading visualization
- Real market data integration
- PnL tracking and performance analysis
"""
import asyncio
import threading
import time
import logging
import argparse
from datetime import datetime
import os
import sys
from pathlib import Path
# Add project path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from core.config import setup_logging, get_config
from core.data_provider import DataProvider
from training.rl_trainer import RLTrainer
from torch.utils.tensorboard import SummaryWriter
logger = logging.getLogger(__name__)
class RealtimeRLTrainer:
"""Realtime RL Trainer with TensorBoard and Web UI"""
def __init__(self, symbol="ETH/USDT", initial_balance=1000.0):
self.symbol = symbol
self.initial_balance = initial_balance
# Initialize data provider
self.data_provider = DataProvider(
symbols=[symbol],
timeframes=['1s', '1m', '5m', '15m', '1h']
)
# Initialize RL trainer with TensorBoard
self.rl_trainer = RLTrainer(self.data_provider)
# Training state
self.current_episode = 0
self.session_trades = []
self.session_balance = initial_balance
self.session_pnl = 0.0
self.training_active = False
# Web dashboard
self.dashboard = None
self.dashboard_thread = None
logger.info(f"RealtimeRLTrainer initialized for {symbol}")
logger.info(f"TensorBoard logs: {self.rl_trainer.tensorboard_dir}")
def setup_web_dashboard(self, port=8051):
"""Setup web dashboard for monitoring"""
try:
import dash
from dash import dcc, html, Input, Output
import plotly.graph_objects as go
import plotly.express as px
# Create Dash app
app = dash.Dash(__name__)
# Layout
app.layout = html.Div([
html.H1(f"RL Training Monitor - {self.symbol}",
style={'textAlign': 'center', 'color': '#2c3e50'}),
# Refresh interval
dcc.Interval(
id='interval-component',
interval=2000, # Update every 2 seconds
n_intervals=0
),
# Status row
html.Div([
html.Div([
html.H3("Training Status", style={'color': '#34495e'}),
html.P(id='training-status', style={'fontSize': 18})
], className='three columns'),
html.Div([
html.H3("Current Episode", style={'color': '#34495e'}),
html.P(id='current-episode', style={'fontSize': 18})
], className='three columns'),
html.Div([
html.H3("Session Balance", style={'color': '#27ae60'}),
html.P(id='session-balance', style={'fontSize': 18})
], className='three columns'),
html.Div([
html.H3("Session PnL", style={'color': '#e74c3c'}),
html.P(id='session-pnl', style={'fontSize': 18})
], className='three columns'),
], className='row', style={'margin': '20px'}),
# Charts row
html.Div([
html.Div([
dcc.Graph(id='rewards-chart')
], className='six columns'),
html.Div([
dcc.Graph(id='balance-chart')
], className='six columns'),
], className='row'),
html.Div([
html.Div([
dcc.Graph(id='trades-chart')
], className='six columns'),
html.Div([
dcc.Graph(id='win-rate-chart')
], className='six columns'),
], className='row'),
# TensorBoard link
html.Div([
html.H3("TensorBoard Monitoring"),
html.A("Open TensorBoard",
href="http://localhost:6006",
target="_blank",
style={'fontSize': 16, 'color': '#3498db'})
], style={'textAlign': 'center', 'margin': '20px'})
])
# Callbacks
@app.callback(
[Output('training-status', 'children'),
Output('current-episode', 'children'),
Output('session-balance', 'children'),
Output('session-pnl', 'children'),
Output('rewards-chart', 'figure'),
Output('balance-chart', 'figure'),
Output('trades-chart', 'figure'),
Output('win-rate-chart', 'figure')],
[Input('interval-component', 'n_intervals')]
)
def update_dashboard(n):
# Status updates
status = "TRAINING" if self.training_active else "IDLE"
episode = f"{self.current_episode}"
balance = f"${self.session_balance:.2f}"
pnl = f"${self.session_pnl:.2f}"
# Create charts
rewards_fig = self._create_rewards_chart()
balance_fig = self._create_balance_chart()
trades_fig = self._create_trades_chart()
win_rate_fig = self._create_win_rate_chart()
return status, episode, balance, pnl, rewards_fig, balance_fig, trades_fig, win_rate_fig
self.dashboard = app
logger.info(f"Web dashboard created for port {port}")
except Exception as e:
logger.error(f"Error setting up web dashboard: {e}")
self.dashboard = None
def _create_rewards_chart(self):
"""Create rewards chart"""
import plotly.graph_objects as go
if not self.rl_trainer.episode_rewards:
fig = go.Figure()
fig.add_annotation(text="No data yet", x=0.5, y=0.5, xref="paper", yref="paper")
else:
fig = go.Figure()
fig.add_trace(go.Scatter(
y=self.rl_trainer.episode_rewards,
mode='lines',
name='Episode Rewards',
line=dict(color='#3498db')
))
# Add moving average if enough data
if len(self.rl_trainer.avg_rewards) > 0:
fig.add_trace(go.Scatter(
y=self.rl_trainer.avg_rewards,
mode='lines',
name='Moving Average',
line=dict(color='#e74c3c', width=2)
))
fig.update_layout(title="Episode Rewards", xaxis_title="Episode", yaxis_title="Reward")
return fig
def _create_balance_chart(self):
"""Create balance chart"""
import plotly.graph_objects as go
if not self.rl_trainer.episode_balances:
fig = go.Figure()
fig.add_annotation(text="No data yet", x=0.5, y=0.5, xref="paper", yref="paper")
else:
fig = go.Figure()
fig.add_trace(go.Scatter(
y=self.rl_trainer.episode_balances,
mode='lines',
name='Balance',
line=dict(color='#27ae60')
))
# Add initial balance line
fig.add_hline(y=self.initial_balance, line_dash="dash",
annotation_text="Initial Balance")
fig.update_layout(title="Portfolio Balance", xaxis_title="Episode", yaxis_title="Balance ($)")
return fig
def _create_trades_chart(self):
"""Create trades per episode chart"""
import plotly.graph_objects as go
if not self.rl_trainer.episode_trades:
fig = go.Figure()
fig.add_annotation(text="No data yet", x=0.5, y=0.5, xref="paper", yref="paper")
else:
fig = go.Figure()
fig.add_trace(go.Bar(
y=self.rl_trainer.episode_trades,
name='Trades per Episode',
marker_color='#f39c12'
))
fig.update_layout(title="Trades per Episode", xaxis_title="Episode", yaxis_title="Number of Trades")
return fig
def _create_win_rate_chart(self):
"""Create win rate chart"""
import plotly.graph_objects as go
if not self.rl_trainer.win_rates:
fig = go.Figure()
fig.add_annotation(text="No data yet", x=0.5, y=0.5, xref="paper", yref="paper")
else:
fig = go.Figure()
fig.add_trace(go.Scatter(
y=self.rl_trainer.win_rates,
mode='lines+markers',
name='Win Rate',
line=dict(color='#9b59b6')
))
# Add 50% line
fig.add_hline(y=0.5, line_dash="dash",
annotation_text="Break Even")
fig.update_layout(title="Win Rate", xaxis_title="Evaluation", yaxis_title="Win Rate")
return fig
def start_web_dashboard(self, port=8051):
"""Start web dashboard in background thread"""
if self.dashboard is None:
self.setup_web_dashboard(port)
if self.dashboard is not None:
def run_dashboard():
try:
# Use run instead of run_server for newer Dash versions
self.dashboard.run(port=port, debug=False, use_reloader=False)
except Exception as e:
logger.error(f"Error running dashboard: {e}")
self.dashboard_thread = threading.Thread(target=run_dashboard, daemon=True)
self.dashboard_thread.start()
logger.info(f"Web dashboard started on http://localhost:{port}")
else:
logger.warning("Dashboard not available")
async def train_realtime(self, episodes=100, evaluation_interval=10):
"""Run realtime training with monitoring"""
logger.info(f"Starting realtime RL training for {episodes} episodes")
logger.info(f"TensorBoard: http://localhost:6006")
logger.info(f"Web UI: http://localhost:8051")
self.training_active = True
# Setup environment and agent
environment, agent = self.rl_trainer.setup_environment_and_agent()
# Assign to trainer instance
self.rl_trainer.environment = environment
self.rl_trainer.agent = agent
# Training loop
for episode in range(episodes):
self.current_episode = episode
# Run episode
episode_start = time.time()
results = self.rl_trainer.run_episode(episode, training=True)
episode_time = time.time() - episode_start
# Update session tracking
self.session_balance = results.get('balance', self.initial_balance)
self.session_pnl = self.session_balance - self.initial_balance
# Log episode metrics to TensorBoard
self.rl_trainer.log_episode_metrics(episode, {
'total_reward': results['reward'],
'final_balance': results['balance'],
'total_return': results['pnl_percentage'],
'steps': results['steps'],
'total_trades': results['trades'],
'win_rate': 1.0 if results['pnl'] > 0 else 0.0,
'epsilon': agent.epsilon,
'memory_size': len(agent.memory) if hasattr(agent, 'memory') else 0
})
# Log progress
if episode % 10 == 0:
logger.info(
f"Episode {episode}/{episodes} - "
f"Reward: {results['reward']:.4f}, "
f"Balance: ${results['balance']:.2f}, "
f"PnL: {results['pnl_percentage']:.2f}%, "
f"Trades: {results['trades']}, "
f"Time: {episode_time:.2f}s"
)
# Evaluation
if episode % evaluation_interval == 0 and episode > 0:
eval_results = self.rl_trainer.evaluate_agent(num_episodes=3)
logger.info(
f"Evaluation - Avg Reward: {eval_results['avg_reward']:.4f}, "
f"Win Rate: {eval_results['win_rate']:.2%}"
)
# Small delay to allow UI updates
await asyncio.sleep(0.1)
self.training_active = False
logger.info("Training completed!")
# Save final model
save_path = f"models/rl/realtime_agent_{int(time.time())}.pt"
agent.save(save_path)
logger.info(f"Model saved: {save_path}")
return {
'episodes': episodes,
'final_balance': self.session_balance,
'final_pnl': self.session_pnl,
'model_path': save_path
}
async def main():
"""Main function"""
parser = argparse.ArgumentParser(description='Realtime RL Training with Monitoring')
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading symbol')
parser.add_argument('--episodes', type=int, default=50, help='Number of episodes')
parser.add_argument('--balance', type=float, default=1000.0, help='Initial balance')
parser.add_argument('--web-port', type=int, default=8051, help='Web dashboard port')
parser.add_argument('--keep-alive', type=int, default=300, help='Keep monitoring alive for N seconds after training')
args = parser.parse_args()
# Setup logging
setup_logging()
logger.info("=" * 60)
logger.info("REALTIME RL TRAINING WITH MONITORING")
logger.info(f"Symbol: {args.symbol}")
logger.info(f"Episodes: {args.episodes}")
logger.info(f"Initial Balance: ${args.balance:.2f}")
logger.info("=" * 60)
# Check if TensorBoard is accessible
try:
import requests
import time
import json
# Try to read port configuration
tensorboard_port = 6006 # default
try:
with open("monitoring_ports.json", "r") as f:
config = json.load(f)
tensorboard_port = config.get("tensorboard_port", 6006)
logger.info(f"📋 Using TensorBoard port {tensorboard_port} from config")
except FileNotFoundError:
logger.info("📋 No port config file found, using default ports")
logger.info("Checking TensorBoard accessibility...")
# Wait for TensorBoard to start
for i in range(10):
try:
response = requests.get(f"http://localhost:{tensorboard_port}", timeout=2)
logger.info(f"✅ TensorBoard is accessible at http://localhost:{tensorboard_port}")
break
except requests.exceptions.RequestException:
if i == 0:
logger.info("⏳ Waiting for TensorBoard to start...")
await asyncio.sleep(2)
else:
logger.warning(f"⚠️ TensorBoard may not be running on port {tensorboard_port}")
logger.warning(" Run: python start_monitoring.py")
except ImportError:
tensorboard_port = 6006
logger.warning("requests module not available for TensorBoard check")
try:
# Create trainer
trainer = RealtimeRLTrainer(
symbol=args.symbol,
initial_balance=args.balance
)
# Start web dashboard
logger.info("🚀 Starting web dashboard...")
trainer.start_web_dashboard(port=args.web_port)
# Wait for dashboard to start
await asyncio.sleep(3)
# Check if web dashboard is accessible
try:
import requests
response = requests.get(f"http://localhost:{args.web_port}", timeout=5)
logger.info(f"✅ Web Dashboard is accessible at http://localhost:{args.web_port}")
except:
logger.warning(f"⚠️ Web Dashboard may not be fully ready at http://localhost:{args.web_port}")
logger.info("MONITORING READY!")
logger.info(f"📊 TensorBoard: http://localhost:{tensorboard_port}")
logger.info(f"🌐 Web Dashboard: http://localhost:{args.web_port}")
logger.info("=" * 60)
# Run training
results = await trainer.train_realtime(
episodes=args.episodes,
evaluation_interval=10
)
logger.info("Training Results:")
logger.info(f" Final Balance: ${results['final_balance']:.2f}")
logger.info(f" Final PnL: ${results['final_pnl']:.2f}")
logger.info(f" Model Saved: {results['model_path']}")
# Keep monitoring alive for specified time
logger.info(f"🔄 Keeping monitoring alive for {args.keep_alive} seconds...")
logger.info(f"📊 TensorBoard: http://localhost:6006")
logger.info(f"🌐 Web Dashboard: http://localhost:{args.web_port}")
logger.info("Press Ctrl+C to exit monitoring.")
for remaining in range(args.keep_alive, 0, -10):
logger.info(f"⏰ Monitoring active - {remaining} seconds remaining")
await asyncio.sleep(10)
logger.info("✅ Monitoring session completed.")
except KeyboardInterrupt:
logger.info("Training stopped by user")
except Exception as e:
logger.error(f"Error in training: {e}")
import traceback
logger.error(traceback.format_exc())
if __name__ == "__main__":
asyncio.run(main())