plot charts
This commit is contained in:
parent
715261a3f9
commit
6f78703ba1
@ -1 +1 @@
|
||||
{"best_reward": 202.7441047517104, "best_pnl": -1.285678343969877, "best_win_rate": 38.70967741935484, "last_episode": 20, "timestamp": "2025-03-10T13:31:02.938465"}
|
||||
{"best_reward": 202.7441047517104, "best_pnl": 0.25999080227362914, "best_win_rate": 44.44444444444444, "last_episode": 0, "timestamp": "2025-03-10T14:42:11.838854"}
|
@ -23,6 +23,9 @@ import copy
|
||||
import argparse
|
||||
import traceback
|
||||
import math
|
||||
import matplotlib.dates as mdates
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@ -1758,18 +1761,11 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
|
||||
# Set risk factor for this episode
|
||||
env.risk_factor = risk_factor
|
||||
|
||||
# Refresh data with latest candles if exchange is provided
|
||||
if exchange is not None:
|
||||
try:
|
||||
logger.info(f"Fetching latest data for episode {episode}")
|
||||
latest_data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 100)
|
||||
if latest_data:
|
||||
# Add new data to environment
|
||||
for candle in latest_data:
|
||||
env.add_data(candle)
|
||||
logger.info(f"Added {len(latest_data)} new candles for episode {episode}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing data: {e}")
|
||||
# Update training data if exchange is available
|
||||
if exchange and args.refresh_data:
|
||||
# Fetch new data at the start of each episode
|
||||
logger.info(f"Refreshing data for episode {episode}")
|
||||
await env.fetch_new_data(exchange, "ETH/USDT", "1m", 100)
|
||||
|
||||
# Reset environment
|
||||
state = env.reset()
|
||||
@ -1796,6 +1792,59 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
|
||||
# Update price predictions
|
||||
env.update_price_predictions()
|
||||
|
||||
# Log OHLCV data to TensorBoard at the start of the episode
|
||||
if episode % 5 == 0: # Log every 5 episodes to avoid too much data
|
||||
# Create a DataFrame from the environment's data
|
||||
df_ohlcv = pd.DataFrame([{
|
||||
'timestamp': candle['timestamp'],
|
||||
'open': candle['open'],
|
||||
'high': candle['high'],
|
||||
'low': candle['low'],
|
||||
'close': candle['close'],
|
||||
'volume': candle['volume']
|
||||
} for candle in env.data[-100:]]) # Use last 100 candles
|
||||
|
||||
# Convert timestamp to datetime
|
||||
df_ohlcv['timestamp'] = pd.to_datetime(df_ohlcv['timestamp'], unit='ms')
|
||||
df_ohlcv.set_index('timestamp', inplace=True)
|
||||
|
||||
# Extract buy/sell signals from trades
|
||||
buy_signals = []
|
||||
sell_signals = []
|
||||
|
||||
if hasattr(env, 'trades') and env.trades:
|
||||
for trade in env.trades:
|
||||
if 'entry_time' in trade and 'entry' in trade:
|
||||
if trade['type'] == 'long':
|
||||
# Buy signal
|
||||
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
|
||||
buy_signals.append((entry_time, trade['entry']))
|
||||
|
||||
# Sell signal if closed
|
||||
if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0:
|
||||
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
|
||||
sell_signals.append((exit_time, trade['exit']))
|
||||
|
||||
elif trade['type'] == 'short':
|
||||
# Sell short signal
|
||||
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
|
||||
sell_signals.append((entry_time, trade['entry']))
|
||||
|
||||
# Buy to cover signal if closed
|
||||
if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0:
|
||||
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
|
||||
buy_signals.append((exit_time, trade['exit']))
|
||||
|
||||
# Log to TensorBoard
|
||||
log_ohlcv_to_tensorboard(
|
||||
agent.writer,
|
||||
df_ohlcv,
|
||||
buy_signals,
|
||||
sell_signals,
|
||||
episode,
|
||||
tag_prefix=f"episode_{episode}"
|
||||
)
|
||||
|
||||
while not done:
|
||||
# Select action
|
||||
action = agent.select_action(state)
|
||||
@ -1920,6 +1969,72 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
|
||||
if episode % 10 == 0 or episode == num_episodes - 1:
|
||||
visualize_training_results(env, agent, episode)
|
||||
|
||||
# After episode is complete, log final state with all trades
|
||||
if episode % 10 == 0 or episode == num_episodes - 1:
|
||||
# Create a DataFrame from the environment's data
|
||||
df_ohlcv = pd.DataFrame([{
|
||||
'timestamp': candle['timestamp'],
|
||||
'open': candle['open'],
|
||||
'high': candle['high'],
|
||||
'low': candle['low'],
|
||||
'close': candle['close'],
|
||||
'volume': candle['volume']
|
||||
} for candle in env.data[-100:]]) # Use last 100 candles
|
||||
|
||||
# Convert timestamp to datetime
|
||||
df_ohlcv['timestamp'] = pd.to_datetime(df_ohlcv['timestamp'], unit='ms')
|
||||
df_ohlcv.set_index('timestamp', inplace=True)
|
||||
|
||||
# Extract buy/sell signals from trades
|
||||
buy_signals = []
|
||||
sell_signals = []
|
||||
|
||||
if hasattr(env, 'trades') and env.trades:
|
||||
for trade in env.trades:
|
||||
if 'entry_time' in trade and 'entry' in trade:
|
||||
if trade['type'] == 'long':
|
||||
# Buy signal
|
||||
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
|
||||
buy_signals.append((entry_time, trade['entry']))
|
||||
|
||||
# Sell signal if closed
|
||||
if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0:
|
||||
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
|
||||
sell_signals.append((exit_time, trade['exit']))
|
||||
|
||||
elif trade['type'] == 'short':
|
||||
# Sell short signal
|
||||
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
|
||||
sell_signals.append((entry_time, trade['entry']))
|
||||
|
||||
# Buy to cover signal if closed
|
||||
if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0:
|
||||
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
|
||||
buy_signals.append((exit_time, trade['exit']))
|
||||
|
||||
# Log to TensorBoard - use a fixed tag to overwrite previous charts
|
||||
log_ohlcv_to_tensorboard(
|
||||
agent.writer,
|
||||
df_ohlcv,
|
||||
buy_signals,
|
||||
sell_signals,
|
||||
episode,
|
||||
tag_prefix="latest_trading_data" # Use a fixed tag to overwrite previous charts
|
||||
)
|
||||
|
||||
# Create visualization - only keep the latest one
|
||||
os.makedirs("visualizations", exist_ok=True)
|
||||
# Remove previous visualizations to save disk space
|
||||
for file in os.listdir("visualizations"):
|
||||
if file.startswith("training_episode_") and file.endswith(".png"):
|
||||
try:
|
||||
os.remove(os.path.join("visualizations", file))
|
||||
except:
|
||||
pass
|
||||
|
||||
# Create new visualization
|
||||
visualize_training_results(env, agent, episode)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in episode {episode}: {e}")
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
@ -1945,8 +2060,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
|
||||
return stats
|
||||
|
||||
def plot_training_results(stats):
|
||||
"""Plot detailed training results"""
|
||||
plt.figure(figsize=(20, 15))
|
||||
"""Plot training results"""
|
||||
plt.figure(figsize=(15, 15))
|
||||
|
||||
# Plot rewards
|
||||
plt.subplot(3, 2, 1)
|
||||
@ -1957,10 +2072,10 @@ def plot_training_results(stats):
|
||||
|
||||
# Plot balance
|
||||
plt.subplot(3, 2, 2)
|
||||
plt.plot(stats['episode_profits'])
|
||||
plt.title('Account Balance')
|
||||
plt.plot(stats['episode_profits']) # Changed from 'episode_pnls'
|
||||
plt.title('Episode Profits')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Balance ($)')
|
||||
plt.ylabel('Profit ($)')
|
||||
|
||||
# Plot win rate
|
||||
plt.subplot(3, 2, 3)
|
||||
@ -1969,35 +2084,24 @@ def plot_training_results(stats):
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Win Rate (%)')
|
||||
|
||||
# Plot episode PnL
|
||||
# Plot trade count
|
||||
plt.subplot(3, 2, 4)
|
||||
plt.plot(stats['episode_pnls'])
|
||||
plt.title('Episode PnL')
|
||||
plt.plot(stats['trade_counts']) # Changed from 'episode_lengths'
|
||||
plt.title('Number of Trades')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('PnL ($)')
|
||||
plt.ylabel('Trades')
|
||||
|
||||
# Plot cumulative PnL
|
||||
# Plot prediction accuracy
|
||||
plt.subplot(3, 2, 5)
|
||||
plt.plot(stats['cumulative_pnl'])
|
||||
plt.title('Cumulative PnL')
|
||||
plt.plot(stats['prediction_accuracies']) # Changed from 'prediction_accuracy'
|
||||
plt.title('Prediction Accuracy')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Cumulative PnL ($)')
|
||||
|
||||
# Plot drawdown
|
||||
plt.subplot(3, 2, 6)
|
||||
plt.plot(stats['drawdowns'])
|
||||
plt.title('Maximum Drawdown')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Drawdown (%)')
|
||||
plt.ylabel('Accuracy (%)')
|
||||
|
||||
# Save the figure
|
||||
plt.tight_layout()
|
||||
plt.savefig('training_results.png')
|
||||
|
||||
# Save statistics to CSV
|
||||
df = pd.DataFrame(stats)
|
||||
df.to_csv('training_stats.csv', index=False)
|
||||
|
||||
logger.info("Training statistics saved to training_stats.csv and training_results.png")
|
||||
plt.close()
|
||||
|
||||
def evaluate_agent(agent, env, num_episodes=10):
|
||||
"""Evaluate the agent on test data"""
|
||||
@ -2515,6 +2619,128 @@ def visualize_training_results(env, agent, episode_num):
|
||||
logger.error(f"Error creating visualization: {e}")
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
def log_ohlcv_to_tensorboard(writer, df_ohlcv, buy_signals, sell_signals, step, tag_prefix="trading"):
|
||||
"""
|
||||
Log OHLCV chart with buy/sell signals to TensorBoard
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
writer : torch.utils.tensorboard.SummaryWriter
|
||||
TensorBoard writer instance
|
||||
df_ohlcv : pandas.DataFrame
|
||||
DataFrame with OHLCV data
|
||||
buy_signals : list of tuples
|
||||
List of (datetime, price) tuples for buy signals
|
||||
sell_signals : list of tuples
|
||||
List of (datetime, price) tuples for sell signals
|
||||
step : int
|
||||
Global step value to record
|
||||
tag_prefix : str
|
||||
Prefix for the tag in TensorBoard
|
||||
"""
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
||||
import numpy as np
|
||||
|
||||
# Check if DataFrame is empty
|
||||
if df_ohlcv.empty:
|
||||
logger.warning("Empty OHLCV DataFrame, skipping visualization")
|
||||
return
|
||||
|
||||
# Create figure
|
||||
fig = Figure(figsize=(12, 8))
|
||||
canvas = FigureCanvas(fig)
|
||||
|
||||
# Create subplots for price and volume
|
||||
ax1 = fig.add_subplot(2, 1, 1) # Price chart
|
||||
ax2 = fig.add_subplot(2, 1, 2, sharex=ax1) # Volume chart
|
||||
|
||||
# Plot OHLC
|
||||
dates = mdates.date2num(df_ohlcv.index.to_pydatetime())
|
||||
ohlc = np.column_stack((dates, df_ohlcv['open'], df_ohlcv['high'], df_ohlcv['low'], df_ohlcv['close']))
|
||||
|
||||
# Plot candlestick chart
|
||||
from matplotlib.lines import Line2D
|
||||
from matplotlib.patches import Rectangle
|
||||
|
||||
width = 0.6 / (len(df_ohlcv) + 1) # Adjust width based on number of candles
|
||||
|
||||
for i, (date, open_price, high, low, close) in enumerate(ohlc):
|
||||
# Determine candle color
|
||||
if close >= open_price:
|
||||
color = 'green'
|
||||
body_bottom = open_price
|
||||
body_height = close - open_price
|
||||
else:
|
||||
color = 'red'
|
||||
body_bottom = close
|
||||
body_height = open_price - close
|
||||
|
||||
# Plot candle body
|
||||
rect = Rectangle(
|
||||
xy=(date - width/2, body_bottom),
|
||||
width=width,
|
||||
height=body_height,
|
||||
facecolor=color,
|
||||
edgecolor='black',
|
||||
alpha=0.8
|
||||
)
|
||||
ax1.add_patch(rect)
|
||||
|
||||
# Plot wick
|
||||
ax1.plot([date, date], [low, high], color='black', linewidth=1)
|
||||
|
||||
# Plot buy signals
|
||||
if buy_signals:
|
||||
buy_dates = mdates.date2num([x[0] for x in buy_signals])
|
||||
buy_prices = [x[1] for x in buy_signals]
|
||||
ax1.scatter(buy_dates, buy_prices, marker='^', color='green', s=100, label='Buy')
|
||||
|
||||
# Plot sell signals
|
||||
if sell_signals:
|
||||
sell_dates = mdates.date2num([x[0] for x in sell_signals])
|
||||
sell_prices = [x[1] for x in sell_signals]
|
||||
ax1.scatter(sell_dates, sell_prices, marker='v', color='red', s=100, label='Sell')
|
||||
|
||||
# Plot volume
|
||||
ax2.bar(dates, df_ohlcv['volume'], width=width, color='blue', alpha=0.5)
|
||||
|
||||
# Format axes
|
||||
ax1.set_title(f'OHLC with Buy/Sell Signals - {tag_prefix}')
|
||||
ax1.set_ylabel('Price')
|
||||
ax1.legend()
|
||||
ax1.grid(True)
|
||||
|
||||
ax2.set_xlabel('Date')
|
||||
ax2.set_ylabel('Volume')
|
||||
ax2.grid(True)
|
||||
|
||||
# Format date
|
||||
date_format = mdates.DateFormatter('%Y-%m-%d %H:%M')
|
||||
ax2.xaxis.set_major_formatter(date_format)
|
||||
fig.autofmt_xdate()
|
||||
|
||||
# Adjust layout
|
||||
fig.tight_layout()
|
||||
|
||||
# Log to TensorBoard
|
||||
if tag_prefix == "latest_trading_data":
|
||||
# For the latest data, use a fixed tag without step to overwrite previous charts
|
||||
writer.add_figure(f"{tag_prefix}/ohlcv_chart", fig)
|
||||
else:
|
||||
# For other charts, include the step
|
||||
writer.add_figure(f"{tag_prefix}/ohlcv_chart", fig, global_step=step)
|
||||
|
||||
# Clean up
|
||||
plt.close(fig)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in log_ohlcv_to_tensorboard: {e}")
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
|
Loading…
x
Reference in New Issue
Block a user