added continious mode. fixed errors

This commit is contained in:
Dobromir Popov 2025-03-10 15:37:02 +02:00
parent cfddc996d7
commit e884f0c9e6
2 changed files with 33 additions and 36 deletions

View File

@ -6,7 +6,7 @@
"type": "python", "type": "python",
"request": "launch", "request": "launch",
"program": "main.py", "program": "main.py",
"args": ["--mode", "train", "--episodes", "1000"], "args": ["--mode", "train", "--episodes", "100"],
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": true "justMyCode": true
}, },
@ -36,6 +36,15 @@
"args": ["--mode", "live"], "args": ["--mode", "live"],
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": true "justMyCode": true
},
{
"name": "Continuous Training",
"type": "python",
"request": "launch",
"program": "main.py",
"args": ["--mode", "continuous", "--refresh-data"],
"console": "integratedTerminal",
"justMyCode": true
} }
] ]
} }

View File

@ -1679,7 +1679,7 @@ async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
await asyncio.sleep(5) await asyncio.sleep(5)
break break
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, exchange=None): async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, exchange=None, args=None):
"""Train the agent using historical and live data with GPU acceleration""" """Train the agent using historical and live data with GPU acceleration"""
logger.info(f"Starting training on device: {agent.device}") logger.info(f"Starting training on device: {agent.device}")
@ -2404,39 +2404,26 @@ async def main():
exchange = None exchange = None
try: try:
# Initialize exchange # Initialize exchange
exchange_id = 'mexc' exchange = await initialize_exchange()
exchange_class = getattr(ccxt, exchange_id)
exchange = exchange_class({ # Create environment with the correct parameters
'apiKey': MEXC_API_KEY, env = TradingEnvironment(
'secret': MEXC_SECRET_KEY, initial_balance=INITIAL_BALANCE,
'enableRateLimit': True, window_size=30,
'options': { demo=args.demo or args.mode != 'live'
'defaultType': 'future', )
}
})
logger.info(f"Exchange initialized with standard CCXT: {exchange.id}")
# Fetch initial data # Fetch initial data
logger.info("Fetching initial data for ETH/USDT") logger.info("Fetching initial data for ETH/USDT")
data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 500) await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 500)
# Initialize environment
env = TradingEnvironment(
data=data,
symbol="ETH/USDT",
timeframe="1m",
leverage=MAX_LEVERAGE,
initial_balance=INITIAL_BALANCE,
is_demo=args.demo or args.mode != 'live'
)
logger.info(f"Initialized environment with {len(data)} candles")
# Initialize agent # Initialize agent
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device) agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
if args.mode == 'train': if args.mode == 'train':
# Train the agent # Train the agent
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange) logger.info(f"Starting training for {args.episodes} episodes...")
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange, args=args)
elif args.mode == 'continuous': elif args.mode == 'continuous':
# Run in continuous mode - train indefinitely # Run in continuous mode - train indefinitely
@ -2449,17 +2436,13 @@ async def main():
logger.info(f"Starting training batch {episode_counter // batch_size + 1}") logger.info(f"Starting training batch {episode_counter // batch_size + 1}")
# Refresh data at the start of each batch # Refresh data at the start of each batch
if exchange: if exchange and args.refresh_data:
logger.info("Refreshing data for new training batch") logger.info("Refreshing data for new training batch")
new_data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 500) await env.fetch_new_data(exchange, "ETH/USDT", "1m", 500)
if new_data: logger.info(f"Updated environment with fresh candles")
# Replace environment data with fresh data
env.data = new_data
env.reset()
logger.info(f"Updated environment with {len(new_data)} fresh candles")
# Train for a batch of episodes # Train for a batch of episodes
stats = await train_agent(agent, env, num_episodes=batch_size, exchange=exchange) stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange, args=args)
# Save model after each batch # Save model after each batch
agent.save(f"models/trading_agent_continuous_{episode_counter}.pt") agent.save(f"models/trading_agent_continuous_{episode_counter}.pt")
@ -2481,6 +2464,7 @@ async def main():
agent.load("models/trading_agent_best_pnl.pt") agent.load("models/trading_agent_best_pnl.pt")
# Evaluate the agent # Evaluate the agent
logger.info("Evaluating agent...")
results = evaluate_agent(agent, env, num_episodes=10) results = evaluate_agent(agent, env, num_episodes=10)
logger.info(f"Evaluation results: {results}") logger.info(f"Evaluation results: {results}")
@ -2490,7 +2474,7 @@ async def main():
# Run live trading # Run live trading
logger.info("Starting live trading...") logger.info("Starting live trading...")
await live_trading(agent, env, exchange) await live_trading(agent, env, exchange, demo=args.demo)
except Exception as e: except Exception as e:
logger.error(f"Error: {e}") logger.error(f"Error: {e}")
@ -2499,7 +2483,11 @@ async def main():
# Close exchange connection # Close exchange connection
if exchange: if exchange:
try: try:
await exchange.client.close() # Some CCXT exchanges have close method, others don't
if hasattr(exchange, 'close'):
await exchange.close()
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
await exchange.client.close()
logger.info("Exchange connection closed") logger.info("Exchange connection closed")
except Exception as e: except Exception as e:
logger.warning(f"Could not properly close exchange connection: {e}") logger.warning(f"Could not properly close exchange connection: {e}")