This commit is contained in:
Dobromir Popov 2025-03-10 14:53:21 +02:00
parent 6f78703ba1
commit cfddc996d7
2 changed files with 86 additions and 33 deletions

View File

@ -1703,6 +1703,7 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
]
current_stage = 0
# Initialize stats dictionary with the correct keys
stats = {
'episode_rewards': [],
'episode_profits': [],
@ -2070,9 +2071,9 @@ def plot_training_results(stats):
plt.xlabel('Episode')
plt.ylabel('Reward')
# Plot balance
# Plot balance/profits
plt.subplot(3, 2, 2)
plt.plot(stats['episode_profits']) # Changed from 'episode_pnls'
plt.plot(stats['episode_profits'])
plt.title('Episode Profits')
plt.xlabel('Episode')
plt.ylabel('Profit ($)')
@ -2086,14 +2087,14 @@ def plot_training_results(stats):
# Plot trade count
plt.subplot(3, 2, 4)
plt.plot(stats['trade_counts']) # Changed from 'episode_lengths'
plt.plot(stats['trade_counts'])
plt.title('Number of Trades')
plt.xlabel('Episode')
plt.ylabel('Trades')
# Plot prediction accuracy
plt.subplot(3, 2, 5)
plt.plot(stats['prediction_accuracies']) # Changed from 'prediction_accuracy'
plt.plot(stats['prediction_accuracies'])
plt.title('Prediction Accuracy')
plt.xlabel('Episode')
plt.ylabel('Accuracy (%)')
@ -2103,6 +2104,8 @@ def plot_training_results(stats):
plt.savefig('training_results.png')
plt.close()
logger.info("Training results saved to training_results.png")
def evaluate_agent(agent, env, num_episodes=10):
"""Evaluate the agent on test data"""
total_reward = 0
@ -2385,8 +2388,9 @@ async def fetch_ohlcv_data(exchange, symbol, timeframe, limit):
async def main():
"""Main function to run the trading bot"""
parser = argparse.ArgumentParser(description='Crypto Trading Bot')
parser.add_argument('--mode', type=str, default='train', choices=['train', 'evaluate', 'live'],
help='Mode to run the bot in')
parser.add_argument('--mode', type=str, default='train',
choices=['train', 'evaluate', 'live', 'continuous'],
help='Mode to run the bot in (train, evaluate, live, or continuous)')
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train')
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
parser.add_argument('--refresh-data', action='store_true', help='Refresh data during training')
@ -2398,56 +2402,104 @@ async def main():
device = get_device(args.device)
exchange = None
try:
# Initialize exchange
exchange = await initialize_exchange()
# Create environment
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=args.demo)
exchange_id = 'mexc'
exchange_class = getattr(ccxt, exchange_id)
exchange = exchange_class({
'apiKey': MEXC_API_KEY,
'secret': MEXC_SECRET_KEY,
'enableRateLimit': True,
'options': {
'defaultType': 'future',
}
})
logger.info(f"Exchange initialized with standard CCXT: {exchange.id}")
# Fetch initial data
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
logger.info("Fetching initial data for ETH/USDT")
data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 500)
# Create agent
# Initialize environment
env = TradingEnvironment(
data=data,
symbol="ETH/USDT",
timeframe="1m",
leverage=MAX_LEVERAGE,
initial_balance=INITIAL_BALANCE,
is_demo=args.demo or args.mode != 'live'
)
logger.info(f"Initialized environment with {len(data)} candles")
# Initialize agent
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
if args.mode == 'train':
# Train the agent
logger.info(f"Starting training for {args.episodes} episodes...")
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange)
# Pass exchange to training function if refresh-data is enabled
if args.refresh_data:
logger.info("Data refresh enabled during training")
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange)
else:
stats = await train_agent(agent, env, num_episodes=args.episodes)
elif args.mode == 'continuous':
# Run in continuous mode - train indefinitely
logger.info("Starting continuous training mode. Press Ctrl+C to stop.")
episode_counter = 0
try:
while True: # Run indefinitely until manually stopped
# Train for a batch of episodes
batch_size = 50 # Train in batches of 50 episodes
logger.info(f"Starting training batch {episode_counter // batch_size + 1}")
# Refresh data at the start of each batch
if exchange:
logger.info("Refreshing data for new training batch")
new_data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 500)
if new_data:
# Replace environment data with fresh data
env.data = new_data
env.reset()
logger.info(f"Updated environment with {len(new_data)} fresh candles")
# Train for a batch of episodes
stats = await train_agent(agent, env, num_episodes=batch_size, exchange=exchange)
# Save model after each batch
agent.save(f"models/trading_agent_continuous_{episode_counter}.pt")
# Increment counter
episode_counter += batch_size
# Sleep briefly to prevent excessive API calls
await asyncio.sleep(5)
except KeyboardInterrupt:
logger.info("Continuous training stopped by user")
# Save final model
agent.save("models/trading_agent_continuous_final.pt")
logger.info("Final model saved")
elif args.mode == 'evaluate':
# Load trained model
# Load the best model
agent.load("models/trading_agent_best_pnl.pt")
# Evaluate the agent
logger.info("Evaluating agent...")
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env)
results = evaluate_agent(agent, env, num_episodes=10)
logger.info(f"Evaluation results: {results}")
elif args.mode == 'live':
# Load trained model
# Load the best model
agent.load("models/trading_agent_best_pnl.pt")
# Run live trading
logger.info("Starting live trading...")
await live_trading(agent, env, exchange, demo=args.demo)
await live_trading(agent, env, exchange)
except Exception as e:
logger.error(f"Error: {e}")
logger.error(f"Traceback: {traceback.format_exc()}")
finally:
# Clean up exchange connection - safely close if possible
# Close exchange connection
if exchange:
try:
# Some CCXT exchanges have close method, others don't
if hasattr(exchange, 'close'):
await exchange.close()
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
await exchange.client.close()
await exchange.client.close()
logger.info("Exchange connection closed")
except Exception as e:
logger.warning(f"Could not properly close exchange connection: {e}")

View File

@ -7,3 +7,4 @@ ccxt>=2.0.0
websockets>=10.0
tensorboard>=2.6.0
scikit-learn
mplfinance