From e884f0c9e64a4445d185af32595794a65d6cf109 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Mon, 10 Mar 2025 15:37:02 +0200 Subject: [PATCH] added continious mode. fixed errors --- crypto/gogo2/.vscode/launch.json | 11 +++++- crypto/gogo2/main.py | 58 +++++++++++++------------------- 2 files changed, 33 insertions(+), 36 deletions(-) diff --git a/crypto/gogo2/.vscode/launch.json b/crypto/gogo2/.vscode/launch.json index 7e77d5b..5245d38 100644 --- a/crypto/gogo2/.vscode/launch.json +++ b/crypto/gogo2/.vscode/launch.json @@ -6,7 +6,7 @@ "type": "python", "request": "launch", "program": "main.py", - "args": ["--mode", "train", "--episodes", "1000"], + "args": ["--mode", "train", "--episodes", "100"], "console": "integratedTerminal", "justMyCode": true }, @@ -36,6 +36,15 @@ "args": ["--mode", "live"], "console": "integratedTerminal", "justMyCode": true + }, + { + "name": "Continuous Training", + "type": "python", + "request": "launch", + "program": "main.py", + "args": ["--mode", "continuous", "--refresh-data"], + "console": "integratedTerminal", + "justMyCode": true } ] } \ No newline at end of file diff --git a/crypto/gogo2/main.py b/crypto/gogo2/main.py index 6260071..e8bb07a 100644 --- a/crypto/gogo2/main.py +++ b/crypto/gogo2/main.py @@ -1679,7 +1679,7 @@ async def get_live_prices(symbol="ETH/USDT", timeframe="1m"): await asyncio.sleep(5) break -async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, exchange=None): +async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, exchange=None, args=None): """Train the agent using historical and live data with GPU acceleration""" logger.info(f"Starting training on device: {agent.device}") @@ -2404,39 +2404,26 @@ async def main(): exchange = None try: # Initialize exchange - exchange_id = 'mexc' - exchange_class = getattr(ccxt, exchange_id) - exchange = exchange_class({ - 'apiKey': MEXC_API_KEY, - 'secret': MEXC_SECRET_KEY, - 'enableRateLimit': True, - 'options': { - 'defaultType': 'future', - } - }) - logger.info(f"Exchange initialized with standard CCXT: {exchange.id}") + exchange = await initialize_exchange() + + # Create environment with the correct parameters + env = TradingEnvironment( + initial_balance=INITIAL_BALANCE, + window_size=30, + demo=args.demo or args.mode != 'live' + ) # Fetch initial data logger.info("Fetching initial data for ETH/USDT") - data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 500) - - # Initialize environment - env = TradingEnvironment( - data=data, - symbol="ETH/USDT", - timeframe="1m", - leverage=MAX_LEVERAGE, - initial_balance=INITIAL_BALANCE, - is_demo=args.demo or args.mode != 'live' - ) - logger.info(f"Initialized environment with {len(data)} candles") + await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 500) # Initialize agent agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device) if args.mode == 'train': # Train the agent - stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange) + logger.info(f"Starting training for {args.episodes} episodes...") + stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange, args=args) elif args.mode == 'continuous': # Run in continuous mode - train indefinitely @@ -2449,17 +2436,13 @@ async def main(): logger.info(f"Starting training batch {episode_counter // batch_size + 1}") # Refresh data at the start of each batch - if exchange: + if exchange and args.refresh_data: logger.info("Refreshing data for new training batch") - new_data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 500) - if new_data: - # Replace environment data with fresh data - env.data = new_data - env.reset() - logger.info(f"Updated environment with {len(new_data)} fresh candles") + await env.fetch_new_data(exchange, "ETH/USDT", "1m", 500) + logger.info(f"Updated environment with fresh candles") # Train for a batch of episodes - stats = await train_agent(agent, env, num_episodes=batch_size, exchange=exchange) + stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange, args=args) # Save model after each batch agent.save(f"models/trading_agent_continuous_{episode_counter}.pt") @@ -2481,6 +2464,7 @@ async def main(): agent.load("models/trading_agent_best_pnl.pt") # Evaluate the agent + logger.info("Evaluating agent...") results = evaluate_agent(agent, env, num_episodes=10) logger.info(f"Evaluation results: {results}") @@ -2490,7 +2474,7 @@ async def main(): # Run live trading logger.info("Starting live trading...") - await live_trading(agent, env, exchange) + await live_trading(agent, env, exchange, demo=args.demo) except Exception as e: logger.error(f"Error: {e}") @@ -2499,7 +2483,11 @@ async def main(): # Close exchange connection if exchange: try: - await exchange.client.close() + # Some CCXT exchanges have close method, others don't + if hasattr(exchange, 'close'): + await exchange.close() + elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'): + await exchange.client.close() logger.info("Exchange connection closed") except Exception as e: logger.warning(f"Could not properly close exchange connection: {e}")