diff --git a/crypto/gogo2/checkpoints/best_metrics.json b/crypto/gogo2/checkpoints/best_metrics.json index 73f16e3..2656a75 100644 --- a/crypto/gogo2/checkpoints/best_metrics.json +++ b/crypto/gogo2/checkpoints/best_metrics.json @@ -1 +1 @@ -{"best_reward": 202.7441047517104, "best_pnl": -1.285678343969877, "best_win_rate": 38.70967741935484, "last_episode": 20, "timestamp": "2025-03-10T13:31:02.938465"} \ No newline at end of file +{"best_reward": 202.7441047517104, "best_pnl": 0.25999080227362914, "best_win_rate": 44.44444444444444, "last_episode": 0, "timestamp": "2025-03-10T14:42:11.838854"} \ No newline at end of file diff --git a/crypto/gogo2/main.py b/crypto/gogo2/main.py index 2559b6b..ad806c6 100644 --- a/crypto/gogo2/main.py +++ b/crypto/gogo2/main.py @@ -23,6 +23,9 @@ import copy import argparse import traceback import math +import matplotlib.dates as mdates +from matplotlib.figure import Figure +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas # Configure logging logging.basicConfig( @@ -1758,18 +1761,11 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, # Set risk factor for this episode env.risk_factor = risk_factor - # Refresh data with latest candles if exchange is provided - if exchange is not None: - try: - logger.info(f"Fetching latest data for episode {episode}") - latest_data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 100) - if latest_data: - # Add new data to environment - for candle in latest_data: - env.add_data(candle) - logger.info(f"Added {len(latest_data)} new candles for episode {episode}") - except Exception as e: - logger.error(f"Error refreshing data: {e}") + # Update training data if exchange is available + if exchange and args.refresh_data: + # Fetch new data at the start of each episode + logger.info(f"Refreshing data for episode {episode}") + await env.fetch_new_data(exchange, "ETH/USDT", "1m", 100) # Reset environment state = env.reset() @@ -1796,6 +1792,59 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, # Update price predictions env.update_price_predictions() + # Log OHLCV data to TensorBoard at the start of the episode + if episode % 5 == 0: # Log every 5 episodes to avoid too much data + # Create a DataFrame from the environment's data + df_ohlcv = pd.DataFrame([{ + 'timestamp': candle['timestamp'], + 'open': candle['open'], + 'high': candle['high'], + 'low': candle['low'], + 'close': candle['close'], + 'volume': candle['volume'] + } for candle in env.data[-100:]]) # Use last 100 candles + + # Convert timestamp to datetime + df_ohlcv['timestamp'] = pd.to_datetime(df_ohlcv['timestamp'], unit='ms') + df_ohlcv.set_index('timestamp', inplace=True) + + # Extract buy/sell signals from trades + buy_signals = [] + sell_signals = [] + + if hasattr(env, 'trades') and env.trades: + for trade in env.trades: + if 'entry_time' in trade and 'entry' in trade: + if trade['type'] == 'long': + # Buy signal + entry_time = pd.to_datetime(trade['entry_time'], unit='ms') + buy_signals.append((entry_time, trade['entry'])) + + # Sell signal if closed + if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0: + exit_time = pd.to_datetime(trade['exit_time'], unit='ms') + sell_signals.append((exit_time, trade['exit'])) + + elif trade['type'] == 'short': + # Sell short signal + entry_time = pd.to_datetime(trade['entry_time'], unit='ms') + sell_signals.append((entry_time, trade['entry'])) + + # Buy to cover signal if closed + if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0: + exit_time = pd.to_datetime(trade['exit_time'], unit='ms') + buy_signals.append((exit_time, trade['exit'])) + + # Log to TensorBoard + log_ohlcv_to_tensorboard( + agent.writer, + df_ohlcv, + buy_signals, + sell_signals, + episode, + tag_prefix=f"episode_{episode}" + ) + while not done: # Select action action = agent.select_action(state) @@ -1920,6 +1969,72 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, if episode % 10 == 0 or episode == num_episodes - 1: visualize_training_results(env, agent, episode) + # After episode is complete, log final state with all trades + if episode % 10 == 0 or episode == num_episodes - 1: + # Create a DataFrame from the environment's data + df_ohlcv = pd.DataFrame([{ + 'timestamp': candle['timestamp'], + 'open': candle['open'], + 'high': candle['high'], + 'low': candle['low'], + 'close': candle['close'], + 'volume': candle['volume'] + } for candle in env.data[-100:]]) # Use last 100 candles + + # Convert timestamp to datetime + df_ohlcv['timestamp'] = pd.to_datetime(df_ohlcv['timestamp'], unit='ms') + df_ohlcv.set_index('timestamp', inplace=True) + + # Extract buy/sell signals from trades + buy_signals = [] + sell_signals = [] + + if hasattr(env, 'trades') and env.trades: + for trade in env.trades: + if 'entry_time' in trade and 'entry' in trade: + if trade['type'] == 'long': + # Buy signal + entry_time = pd.to_datetime(trade['entry_time'], unit='ms') + buy_signals.append((entry_time, trade['entry'])) + + # Sell signal if closed + if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0: + exit_time = pd.to_datetime(trade['exit_time'], unit='ms') + sell_signals.append((exit_time, trade['exit'])) + + elif trade['type'] == 'short': + # Sell short signal + entry_time = pd.to_datetime(trade['entry_time'], unit='ms') + sell_signals.append((entry_time, trade['entry'])) + + # Buy to cover signal if closed + if 'exit_time' in trade and 'exit' in trade and trade['exit'] > 0: + exit_time = pd.to_datetime(trade['exit_time'], unit='ms') + buy_signals.append((exit_time, trade['exit'])) + + # Log to TensorBoard - use a fixed tag to overwrite previous charts + log_ohlcv_to_tensorboard( + agent.writer, + df_ohlcv, + buy_signals, + sell_signals, + episode, + tag_prefix="latest_trading_data" # Use a fixed tag to overwrite previous charts + ) + + # Create visualization - only keep the latest one + os.makedirs("visualizations", exist_ok=True) + # Remove previous visualizations to save disk space + for file in os.listdir("visualizations"): + if file.startswith("training_episode_") and file.endswith(".png"): + try: + os.remove(os.path.join("visualizations", file)) + except: + pass + + # Create new visualization + visualize_training_results(env, agent, episode) + except Exception as e: logger.error(f"Error in episode {episode}: {e}") logger.error(f"Traceback: {traceback.format_exc()}") @@ -1945,8 +2060,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, return stats def plot_training_results(stats): - """Plot detailed training results""" - plt.figure(figsize=(20, 15)) + """Plot training results""" + plt.figure(figsize=(15, 15)) # Plot rewards plt.subplot(3, 2, 1) @@ -1957,10 +2072,10 @@ def plot_training_results(stats): # Plot balance plt.subplot(3, 2, 2) - plt.plot(stats['episode_profits']) - plt.title('Account Balance') + plt.plot(stats['episode_profits']) # Changed from 'episode_pnls' + plt.title('Episode Profits') plt.xlabel('Episode') - plt.ylabel('Balance ($)') + plt.ylabel('Profit ($)') # Plot win rate plt.subplot(3, 2, 3) @@ -1969,35 +2084,24 @@ def plot_training_results(stats): plt.xlabel('Episode') plt.ylabel('Win Rate (%)') - # Plot episode PnL + # Plot trade count plt.subplot(3, 2, 4) - plt.plot(stats['episode_pnls']) - plt.title('Episode PnL') + plt.plot(stats['trade_counts']) # Changed from 'episode_lengths' + plt.title('Number of Trades') plt.xlabel('Episode') - plt.ylabel('PnL ($)') + plt.ylabel('Trades') - # Plot cumulative PnL + # Plot prediction accuracy plt.subplot(3, 2, 5) - plt.plot(stats['cumulative_pnl']) - plt.title('Cumulative PnL') + plt.plot(stats['prediction_accuracies']) # Changed from 'prediction_accuracy' + plt.title('Prediction Accuracy') plt.xlabel('Episode') - plt.ylabel('Cumulative PnL ($)') - - # Plot drawdown - plt.subplot(3, 2, 6) - plt.plot(stats['drawdowns']) - plt.title('Maximum Drawdown') - plt.xlabel('Episode') - plt.ylabel('Drawdown (%)') + plt.ylabel('Accuracy (%)') + # Save the figure plt.tight_layout() plt.savefig('training_results.png') - - # Save statistics to CSV - df = pd.DataFrame(stats) - df.to_csv('training_stats.csv', index=False) - - logger.info("Training statistics saved to training_stats.csv and training_results.png") + plt.close() def evaluate_agent(agent, env, num_episodes=10): """Evaluate the agent on test data""" @@ -2515,6 +2619,128 @@ def visualize_training_results(env, agent, episode_num): logger.error(f"Error creating visualization: {e}") logger.error(f"Traceback: {traceback.format_exc()}") +def log_ohlcv_to_tensorboard(writer, df_ohlcv, buy_signals, sell_signals, step, tag_prefix="trading"): + """ + Log OHLCV chart with buy/sell signals to TensorBoard + + Parameters: + ----------- + writer : torch.utils.tensorboard.SummaryWriter + TensorBoard writer instance + df_ohlcv : pandas.DataFrame + DataFrame with OHLCV data + buy_signals : list of tuples + List of (datetime, price) tuples for buy signals + sell_signals : list of tuples + List of (datetime, price) tuples for sell signals + step : int + Global step value to record + tag_prefix : str + Prefix for the tag in TensorBoard + """ + try: + import matplotlib.pyplot as plt + import matplotlib.dates as mdates + from matplotlib.figure import Figure + from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas + import numpy as np + + # Check if DataFrame is empty + if df_ohlcv.empty: + logger.warning("Empty OHLCV DataFrame, skipping visualization") + return + + # Create figure + fig = Figure(figsize=(12, 8)) + canvas = FigureCanvas(fig) + + # Create subplots for price and volume + ax1 = fig.add_subplot(2, 1, 1) # Price chart + ax2 = fig.add_subplot(2, 1, 2, sharex=ax1) # Volume chart + + # Plot OHLC + dates = mdates.date2num(df_ohlcv.index.to_pydatetime()) + ohlc = np.column_stack((dates, df_ohlcv['open'], df_ohlcv['high'], df_ohlcv['low'], df_ohlcv['close'])) + + # Plot candlestick chart + from matplotlib.lines import Line2D + from matplotlib.patches import Rectangle + + width = 0.6 / (len(df_ohlcv) + 1) # Adjust width based on number of candles + + for i, (date, open_price, high, low, close) in enumerate(ohlc): + # Determine candle color + if close >= open_price: + color = 'green' + body_bottom = open_price + body_height = close - open_price + else: + color = 'red' + body_bottom = close + body_height = open_price - close + + # Plot candle body + rect = Rectangle( + xy=(date - width/2, body_bottom), + width=width, + height=body_height, + facecolor=color, + edgecolor='black', + alpha=0.8 + ) + ax1.add_patch(rect) + + # Plot wick + ax1.plot([date, date], [low, high], color='black', linewidth=1) + + # Plot buy signals + if buy_signals: + buy_dates = mdates.date2num([x[0] for x in buy_signals]) + buy_prices = [x[1] for x in buy_signals] + ax1.scatter(buy_dates, buy_prices, marker='^', color='green', s=100, label='Buy') + + # Plot sell signals + if sell_signals: + sell_dates = mdates.date2num([x[0] for x in sell_signals]) + sell_prices = [x[1] for x in sell_signals] + ax1.scatter(sell_dates, sell_prices, marker='v', color='red', s=100, label='Sell') + + # Plot volume + ax2.bar(dates, df_ohlcv['volume'], width=width, color='blue', alpha=0.5) + + # Format axes + ax1.set_title(f'OHLC with Buy/Sell Signals - {tag_prefix}') + ax1.set_ylabel('Price') + ax1.legend() + ax1.grid(True) + + ax2.set_xlabel('Date') + ax2.set_ylabel('Volume') + ax2.grid(True) + + # Format date + date_format = mdates.DateFormatter('%Y-%m-%d %H:%M') + ax2.xaxis.set_major_formatter(date_format) + fig.autofmt_xdate() + + # Adjust layout + fig.tight_layout() + + # Log to TensorBoard + if tag_prefix == "latest_trading_data": + # For the latest data, use a fixed tag without step to overwrite previous charts + writer.add_figure(f"{tag_prefix}/ohlcv_chart", fig) + else: + # For other charts, include the step + writer.add_figure(f"{tag_prefix}/ohlcv_chart", fig, global_step=step) + + # Clean up + plt.close(fig) + + except Exception as e: + logger.error(f"Error in log_ohlcv_to_tensorboard: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") + if __name__ == "__main__": try: asyncio.run(main())