diff --git a/crypto/gogo2/main.py b/crypto/gogo2/main.py index ad806c6..6260071 100644 --- a/crypto/gogo2/main.py +++ b/crypto/gogo2/main.py @@ -1703,6 +1703,7 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, ] current_stage = 0 + # Initialize stats dictionary with the correct keys stats = { 'episode_rewards': [], 'episode_profits': [], @@ -2070,9 +2071,9 @@ def plot_training_results(stats): plt.xlabel('Episode') plt.ylabel('Reward') - # Plot balance + # Plot balance/profits plt.subplot(3, 2, 2) - plt.plot(stats['episode_profits']) # Changed from 'episode_pnls' + plt.plot(stats['episode_profits']) plt.title('Episode Profits') plt.xlabel('Episode') plt.ylabel('Profit ($)') @@ -2086,14 +2087,14 @@ def plot_training_results(stats): # Plot trade count plt.subplot(3, 2, 4) - plt.plot(stats['trade_counts']) # Changed from 'episode_lengths' + plt.plot(stats['trade_counts']) plt.title('Number of Trades') plt.xlabel('Episode') plt.ylabel('Trades') # Plot prediction accuracy plt.subplot(3, 2, 5) - plt.plot(stats['prediction_accuracies']) # Changed from 'prediction_accuracy' + plt.plot(stats['prediction_accuracies']) plt.title('Prediction Accuracy') plt.xlabel('Episode') plt.ylabel('Accuracy (%)') @@ -2102,6 +2103,8 @@ def plot_training_results(stats): plt.tight_layout() plt.savefig('training_results.png') plt.close() + + logger.info("Training results saved to training_results.png") def evaluate_agent(agent, env, num_episodes=10): """Evaluate the agent on test data""" @@ -2385,8 +2388,9 @@ async def fetch_ohlcv_data(exchange, symbol, timeframe, limit): async def main(): """Main function to run the trading bot""" parser = argparse.ArgumentParser(description='Crypto Trading Bot') - parser.add_argument('--mode', type=str, default='train', choices=['train', 'evaluate', 'live'], - help='Mode to run the bot in') + parser.add_argument('--mode', type=str, default='train', + choices=['train', 'evaluate', 'live', 'continuous'], + help='Mode to run the bot in (train, evaluate, live, or continuous)') parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train') parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)') parser.add_argument('--refresh-data', action='store_true', help='Refresh data during training') @@ -2398,56 +2402,104 @@ async def main(): device = get_device(args.device) exchange = None - try: # Initialize exchange - exchange = await initialize_exchange() - - # Create environment - env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=args.demo) + exchange_id = 'mexc' + exchange_class = getattr(ccxt, exchange_id) + exchange = exchange_class({ + 'apiKey': MEXC_API_KEY, + 'secret': MEXC_SECRET_KEY, + 'enableRateLimit': True, + 'options': { + 'defaultType': 'future', + } + }) + logger.info(f"Exchange initialized with standard CCXT: {exchange.id}") # Fetch initial data - await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000) + logger.info("Fetching initial data for ETH/USDT") + data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 500) - # Create agent + # Initialize environment + env = TradingEnvironment( + data=data, + symbol="ETH/USDT", + timeframe="1m", + leverage=MAX_LEVERAGE, + initial_balance=INITIAL_BALANCE, + is_demo=args.demo or args.mode != 'live' + ) + logger.info(f"Initialized environment with {len(data)} candles") + + # Initialize agent agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device) if args.mode == 'train': # Train the agent - logger.info(f"Starting training for {args.episodes} episodes...") + stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange) - # Pass exchange to training function if refresh-data is enabled - if args.refresh_data: - logger.info("Data refresh enabled during training") - stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange) - else: - stats = await train_agent(agent, env, num_episodes=args.episodes) + elif args.mode == 'continuous': + # Run in continuous mode - train indefinitely + logger.info("Starting continuous training mode. Press Ctrl+C to stop.") + episode_counter = 0 + try: + while True: # Run indefinitely until manually stopped + # Train for a batch of episodes + batch_size = 50 # Train in batches of 50 episodes + logger.info(f"Starting training batch {episode_counter // batch_size + 1}") + + # Refresh data at the start of each batch + if exchange: + logger.info("Refreshing data for new training batch") + new_data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 500) + if new_data: + # Replace environment data with fresh data + env.data = new_data + env.reset() + logger.info(f"Updated environment with {len(new_data)} fresh candles") + + # Train for a batch of episodes + stats = await train_agent(agent, env, num_episodes=batch_size, exchange=exchange) + + # Save model after each batch + agent.save(f"models/trading_agent_continuous_{episode_counter}.pt") + + # Increment counter + episode_counter += batch_size + + # Sleep briefly to prevent excessive API calls + await asyncio.sleep(5) + + except KeyboardInterrupt: + logger.info("Continuous training stopped by user") + # Save final model + agent.save("models/trading_agent_continuous_final.pt") + logger.info("Final model saved") elif args.mode == 'evaluate': - # Load trained model + # Load the best model agent.load("models/trading_agent_best_pnl.pt") # Evaluate the agent - logger.info("Evaluating agent...") - avg_reward, avg_profit, win_rate = evaluate_agent(agent, env) + results = evaluate_agent(agent, env, num_episodes=10) + logger.info(f"Evaluation results: {results}") elif args.mode == 'live': - # Load trained model + # Load the best model agent.load("models/trading_agent_best_pnl.pt") # Run live trading logger.info("Starting live trading...") - await live_trading(agent, env, exchange, demo=args.demo) - + await live_trading(agent, env, exchange) + + except Exception as e: + logger.error(f"Error: {e}") + logger.error(f"Traceback: {traceback.format_exc()}") finally: - # Clean up exchange connection - safely close if possible + # Close exchange connection if exchange: try: - # Some CCXT exchanges have close method, others don't - if hasattr(exchange, 'close'): - await exchange.close() - elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'): - await exchange.client.close() + await exchange.client.close() logger.info("Exchange connection closed") except Exception as e: logger.warning(f"Could not properly close exchange connection: {e}") diff --git a/crypto/gogo2/requirements.txt b/crypto/gogo2/requirements.txt index 5a89b65..cda528e 100644 --- a/crypto/gogo2/requirements.txt +++ b/crypto/gogo2/requirements.txt @@ -6,4 +6,5 @@ python-dotenv>=0.19.0 ccxt>=2.0.0 websockets>=10.0 tensorboard>=2.6.0 -scikit-learn \ No newline at end of file +scikit-learn +mplfinance \ No newline at end of file