From 715261a3f9d1b0454a2221abd8a2190d2a563616 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Mon, 10 Mar 2025 13:32:35 +0200 Subject: [PATCH] improvements --- crypto/gogo2/checkpoints/best_metrics.json | 2 +- crypto/gogo2/main.py | 94 +++++++++++++++++----- 2 files changed, 75 insertions(+), 21 deletions(-) diff --git a/crypto/gogo2/checkpoints/best_metrics.json b/crypto/gogo2/checkpoints/best_metrics.json index 5f5dcc1..73f16e3 100644 --- a/crypto/gogo2/checkpoints/best_metrics.json +++ b/crypto/gogo2/checkpoints/best_metrics.json @@ -1 +1 @@ -{"best_reward": 202.7441047517104, "best_pnl": -10.072078721366783, "best_win_rate": 30.864197530864196, "last_episode": 10, "timestamp": "2025-03-10T12:45:27.247997"} \ No newline at end of file +{"best_reward": 202.7441047517104, "best_pnl": -1.285678343969877, "best_win_rate": 38.70967741935484, "last_episode": 20, "timestamp": "2025-03-10T13:31:02.938465"} \ No newline at end of file diff --git a/crypto/gogo2/main.py b/crypto/gogo2/main.py index 7dd2f6b..2559b6b 100644 --- a/crypto/gogo2/main.py +++ b/crypto/gogo2/main.py @@ -499,6 +499,8 @@ class TradingEnvironment: 'type': 'long', 'entry': self.entry_price, 'exit': self.stop_loss, + 'entry_time': self.data[self.entry_index]['timestamp'], + 'exit_time': self.data[self.current_step]['timestamp'], 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': self.current_step - self.entry_index, @@ -542,6 +544,8 @@ class TradingEnvironment: 'type': 'long', 'entry': self.entry_price, 'exit': self.take_profit, + 'entry_time': self.data[self.entry_index]['timestamp'], + 'exit_time': self.data[self.current_step]['timestamp'], 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': self.current_step - self.entry_index, @@ -588,6 +592,8 @@ class TradingEnvironment: 'type': 'short', 'entry': self.entry_price, 'exit': self.stop_loss, + 'entry_time': self.data[self.entry_index]['timestamp'], + 'exit_time': self.data[self.current_step]['timestamp'], 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': self.current_step - self.entry_index, @@ -631,6 +637,8 @@ class TradingEnvironment: 'type': 'short', 'entry': self.entry_price, 'exit': self.take_profit, + 'entry_time': self.data[self.entry_index]['timestamp'], + 'exit_time': self.data[self.current_step]['timestamp'], 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': self.current_step - self.entry_index, @@ -808,6 +816,8 @@ class TradingEnvironment: 'type': 'short', 'entry': self.entry_price, 'exit': self.current_price, + 'entry_time': self.data[self.entry_index]['timestamp'], + 'exit_time': self.data[self.current_step]['timestamp'], 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, 'duration': self.current_step - self.entry_index, @@ -872,6 +882,8 @@ class TradingEnvironment: 'type': 'long', 'entry': self.entry_price, 'exit': self.current_price, + 'entry_time': self.data[self.entry_index]['timestamp'], + 'exit_time': self.data[self.current_step]['timestamp'], 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar }) @@ -925,6 +937,8 @@ class TradingEnvironment: 'type': 'long', 'entry': self.entry_price, 'exit': self.current_price, + 'entry_time': self.data[self.entry_index]['timestamp'], + 'exit_time': self.data[self.current_step]['timestamp'], 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar }) @@ -970,6 +984,8 @@ class TradingEnvironment: 'type': 'short', 'entry': self.entry_price, 'exit': self.current_price, + 'entry_time': self.data[self.entry_index]['timestamp'], + 'exit_time': self.data[self.current_step]['timestamp'], 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar }) @@ -1667,6 +1683,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, # Add early stopping based on performance patience = 50 # Episodes to wait for improvement best_pnl = -float('inf') + best_reward = -float('inf') # Initialize best_reward + best_win_rate = 0 # Initialize best_win_rate episodes_without_improvement = 0 # Add adaptive learning rate @@ -1684,18 +1702,42 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, stats = { 'episode_rewards': [], - 'episode_lengths': [], - 'balances': [], + 'episode_profits': [], 'win_rates': [], - 'episode_pnls': [], - 'cumulative_pnl': [], - 'drawdowns': [], - 'prediction_accuracy': [] + 'trade_counts': [], + 'prediction_accuracies': [] } + # Create checkpoint directory if it doesn't exist + os.makedirs("checkpoints", exist_ok=True) + + # Load best model if it exists (to resume training) + best_model_path = "models/trading_agent_best_pnl.pt" + if os.path.exists(best_model_path): + try: + logger.info(f"Loading best model from {best_model_path} to resume training") + agent.load(best_model_path) + # Try to load best metrics from checkpoint file + checkpoint_info_path = "checkpoints/best_metrics.json" + if os.path.exists(checkpoint_info_path): + with open(checkpoint_info_path, 'r') as f: + best_metrics = json.load(f) + best_reward = best_metrics.get('best_reward', best_reward) + best_pnl = best_metrics.get('best_pnl', best_pnl) + best_win_rate = best_metrics.get('best_win_rate', best_win_rate) + logger.info(f"Resumed with best metrics - Reward: {best_reward:.2f}, PnL: ${best_pnl:.2f}, Win Rate: {best_win_rate:.1f}%") + except Exception as e: + logger.warning(f"Could not load best model: {e}") + try: - # Initialize price predictor - env.initialize_price_predictor(agent.device) + # Initialize price predictor and attach it to the environment + price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5) + price_predictor.to(agent.device) + price_predictor_optimizer = optim.Adam(price_predictor.parameters(), lr=1e-4) + + # Attach the price predictor to the environment + env.price_predictor = price_predictor + env.price_predictor_optimizer = price_predictor_optimizer for episode in range(num_episodes): try: @@ -1731,11 +1773,22 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, # Reset environment state = env.reset() - episode_reward = 0 - env.episode_pnl = 0.0 # Reset episode PnL - # Identify optimal trade points for this episode - env.identify_optimal_trades() + # Initialize episode variables + episode_reward = 0 + done = False + step = 0 + + # Initialize trade analysis dictionary + trade_analysis = { + 'win_rate': 0, + 'uptrend_win_rate': 0, + 'downtrend_win_rate': 0, + 'sideways_win_rate': 0, + 'avg_win_pnl': 0, + 'avg_loss_pnl': 0, + 'max_drawdown': 0 + } # Train price predictor prediction_loss = env.train_price_predictor() @@ -1743,7 +1796,7 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, # Update price predictions env.update_price_predictions() - for step in range(max_steps_per_episode): + while not done: # Select action action = agent.select_action(state) @@ -1782,13 +1835,10 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, # Update stats stats['episode_rewards'].append(episode_reward) - stats['episode_lengths'].append(step + 1) - stats['balances'].append(env.balance) + stats['episode_profits'].append(env.episode_pnl) stats['win_rates'].append(win_rate) - stats['episode_pnls'].append(env.episode_pnl) - stats['cumulative_pnl'].append(env.total_pnl) - stats['drawdowns'].append(env.max_drawdown * 100) - stats['prediction_accuracy'].append(prediction_accuracy) + stats['trade_counts'].append(total_trades) + stats['prediction_accuracies'].append(prediction_accuracy) # Log detailed trade analysis if trade_analysis: @@ -1866,6 +1916,10 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, logger.info(f"Early stopping triggered after {episode+1} episodes without improvement") break + # Create visualization every 10 episodes or on the last episode + if episode % 10 == 0 or episode == num_episodes - 1: + visualize_training_results(env, agent, episode) + except Exception as e: logger.error(f"Error in episode {episode}: {e}") logger.error(f"Traceback: {traceback.format_exc()}") @@ -1903,7 +1957,7 @@ def plot_training_results(stats): # Plot balance plt.subplot(3, 2, 2) - plt.plot(stats['balances']) + plt.plot(stats['episode_profits']) plt.title('Account Balance') plt.xlabel('Episode') plt.ylabel('Balance ($)')