fixes
This commit is contained in:
parent
6f78703ba1
commit
cfddc996d7
@ -1703,6 +1703,7 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000,
|
|||||||
]
|
]
|
||||||
current_stage = 0
|
current_stage = 0
|
||||||
|
|
||||||
|
# Initialize stats dictionary with the correct keys
|
||||||
stats = {
|
stats = {
|
||||||
'episode_rewards': [],
|
'episode_rewards': [],
|
||||||
'episode_profits': [],
|
'episode_profits': [],
|
||||||
@ -2070,9 +2071,9 @@ def plot_training_results(stats):
|
|||||||
plt.xlabel('Episode')
|
plt.xlabel('Episode')
|
||||||
plt.ylabel('Reward')
|
plt.ylabel('Reward')
|
||||||
|
|
||||||
# Plot balance
|
# Plot balance/profits
|
||||||
plt.subplot(3, 2, 2)
|
plt.subplot(3, 2, 2)
|
||||||
plt.plot(stats['episode_profits']) # Changed from 'episode_pnls'
|
plt.plot(stats['episode_profits'])
|
||||||
plt.title('Episode Profits')
|
plt.title('Episode Profits')
|
||||||
plt.xlabel('Episode')
|
plt.xlabel('Episode')
|
||||||
plt.ylabel('Profit ($)')
|
plt.ylabel('Profit ($)')
|
||||||
@ -2086,14 +2087,14 @@ def plot_training_results(stats):
|
|||||||
|
|
||||||
# Plot trade count
|
# Plot trade count
|
||||||
plt.subplot(3, 2, 4)
|
plt.subplot(3, 2, 4)
|
||||||
plt.plot(stats['trade_counts']) # Changed from 'episode_lengths'
|
plt.plot(stats['trade_counts'])
|
||||||
plt.title('Number of Trades')
|
plt.title('Number of Trades')
|
||||||
plt.xlabel('Episode')
|
plt.xlabel('Episode')
|
||||||
plt.ylabel('Trades')
|
plt.ylabel('Trades')
|
||||||
|
|
||||||
# Plot prediction accuracy
|
# Plot prediction accuracy
|
||||||
plt.subplot(3, 2, 5)
|
plt.subplot(3, 2, 5)
|
||||||
plt.plot(stats['prediction_accuracies']) # Changed from 'prediction_accuracy'
|
plt.plot(stats['prediction_accuracies'])
|
||||||
plt.title('Prediction Accuracy')
|
plt.title('Prediction Accuracy')
|
||||||
plt.xlabel('Episode')
|
plt.xlabel('Episode')
|
||||||
plt.ylabel('Accuracy (%)')
|
plt.ylabel('Accuracy (%)')
|
||||||
@ -2102,6 +2103,8 @@ def plot_training_results(stats):
|
|||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig('training_results.png')
|
plt.savefig('training_results.png')
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
logger.info("Training results saved to training_results.png")
|
||||||
|
|
||||||
def evaluate_agent(agent, env, num_episodes=10):
|
def evaluate_agent(agent, env, num_episodes=10):
|
||||||
"""Evaluate the agent on test data"""
|
"""Evaluate the agent on test data"""
|
||||||
@ -2385,8 +2388,9 @@ async def fetch_ohlcv_data(exchange, symbol, timeframe, limit):
|
|||||||
async def main():
|
async def main():
|
||||||
"""Main function to run the trading bot"""
|
"""Main function to run the trading bot"""
|
||||||
parser = argparse.ArgumentParser(description='Crypto Trading Bot')
|
parser = argparse.ArgumentParser(description='Crypto Trading Bot')
|
||||||
parser.add_argument('--mode', type=str, default='train', choices=['train', 'evaluate', 'live'],
|
parser.add_argument('--mode', type=str, default='train',
|
||||||
help='Mode to run the bot in')
|
choices=['train', 'evaluate', 'live', 'continuous'],
|
||||||
|
help='Mode to run the bot in (train, evaluate, live, or continuous)')
|
||||||
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train')
|
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train')
|
||||||
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
|
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
|
||||||
parser.add_argument('--refresh-data', action='store_true', help='Refresh data during training')
|
parser.add_argument('--refresh-data', action='store_true', help='Refresh data during training')
|
||||||
@ -2398,56 +2402,104 @@ async def main():
|
|||||||
device = get_device(args.device)
|
device = get_device(args.device)
|
||||||
|
|
||||||
exchange = None
|
exchange = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize exchange
|
# Initialize exchange
|
||||||
exchange = await initialize_exchange()
|
exchange_id = 'mexc'
|
||||||
|
exchange_class = getattr(ccxt, exchange_id)
|
||||||
# Create environment
|
exchange = exchange_class({
|
||||||
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=args.demo)
|
'apiKey': MEXC_API_KEY,
|
||||||
|
'secret': MEXC_SECRET_KEY,
|
||||||
|
'enableRateLimit': True,
|
||||||
|
'options': {
|
||||||
|
'defaultType': 'future',
|
||||||
|
}
|
||||||
|
})
|
||||||
|
logger.info(f"Exchange initialized with standard CCXT: {exchange.id}")
|
||||||
|
|
||||||
# Fetch initial data
|
# Fetch initial data
|
||||||
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
|
logger.info("Fetching initial data for ETH/USDT")
|
||||||
|
data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 500)
|
||||||
|
|
||||||
# Create agent
|
# Initialize environment
|
||||||
|
env = TradingEnvironment(
|
||||||
|
data=data,
|
||||||
|
symbol="ETH/USDT",
|
||||||
|
timeframe="1m",
|
||||||
|
leverage=MAX_LEVERAGE,
|
||||||
|
initial_balance=INITIAL_BALANCE,
|
||||||
|
is_demo=args.demo or args.mode != 'live'
|
||||||
|
)
|
||||||
|
logger.info(f"Initialized environment with {len(data)} candles")
|
||||||
|
|
||||||
|
# Initialize agent
|
||||||
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
||||||
|
|
||||||
if args.mode == 'train':
|
if args.mode == 'train':
|
||||||
# Train the agent
|
# Train the agent
|
||||||
logger.info(f"Starting training for {args.episodes} episodes...")
|
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange)
|
||||||
|
|
||||||
# Pass exchange to training function if refresh-data is enabled
|
elif args.mode == 'continuous':
|
||||||
if args.refresh_data:
|
# Run in continuous mode - train indefinitely
|
||||||
logger.info("Data refresh enabled during training")
|
logger.info("Starting continuous training mode. Press Ctrl+C to stop.")
|
||||||
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange)
|
episode_counter = 0
|
||||||
else:
|
try:
|
||||||
stats = await train_agent(agent, env, num_episodes=args.episodes)
|
while True: # Run indefinitely until manually stopped
|
||||||
|
# Train for a batch of episodes
|
||||||
|
batch_size = 50 # Train in batches of 50 episodes
|
||||||
|
logger.info(f"Starting training batch {episode_counter // batch_size + 1}")
|
||||||
|
|
||||||
|
# Refresh data at the start of each batch
|
||||||
|
if exchange:
|
||||||
|
logger.info("Refreshing data for new training batch")
|
||||||
|
new_data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 500)
|
||||||
|
if new_data:
|
||||||
|
# Replace environment data with fresh data
|
||||||
|
env.data = new_data
|
||||||
|
env.reset()
|
||||||
|
logger.info(f"Updated environment with {len(new_data)} fresh candles")
|
||||||
|
|
||||||
|
# Train for a batch of episodes
|
||||||
|
stats = await train_agent(agent, env, num_episodes=batch_size, exchange=exchange)
|
||||||
|
|
||||||
|
# Save model after each batch
|
||||||
|
agent.save(f"models/trading_agent_continuous_{episode_counter}.pt")
|
||||||
|
|
||||||
|
# Increment counter
|
||||||
|
episode_counter += batch_size
|
||||||
|
|
||||||
|
# Sleep briefly to prevent excessive API calls
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Continuous training stopped by user")
|
||||||
|
# Save final model
|
||||||
|
agent.save("models/trading_agent_continuous_final.pt")
|
||||||
|
logger.info("Final model saved")
|
||||||
|
|
||||||
elif args.mode == 'evaluate':
|
elif args.mode == 'evaluate':
|
||||||
# Load trained model
|
# Load the best model
|
||||||
agent.load("models/trading_agent_best_pnl.pt")
|
agent.load("models/trading_agent_best_pnl.pt")
|
||||||
|
|
||||||
# Evaluate the agent
|
# Evaluate the agent
|
||||||
logger.info("Evaluating agent...")
|
results = evaluate_agent(agent, env, num_episodes=10)
|
||||||
avg_reward, avg_profit, win_rate = evaluate_agent(agent, env)
|
logger.info(f"Evaluation results: {results}")
|
||||||
|
|
||||||
elif args.mode == 'live':
|
elif args.mode == 'live':
|
||||||
# Load trained model
|
# Load the best model
|
||||||
agent.load("models/trading_agent_best_pnl.pt")
|
agent.load("models/trading_agent_best_pnl.pt")
|
||||||
|
|
||||||
# Run live trading
|
# Run live trading
|
||||||
logger.info("Starting live trading...")
|
logger.info("Starting live trading...")
|
||||||
await live_trading(agent, env, exchange, demo=args.demo)
|
await live_trading(agent, env, exchange)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error: {e}")
|
||||||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
finally:
|
finally:
|
||||||
# Clean up exchange connection - safely close if possible
|
# Close exchange connection
|
||||||
if exchange:
|
if exchange:
|
||||||
try:
|
try:
|
||||||
# Some CCXT exchanges have close method, others don't
|
await exchange.client.close()
|
||||||
if hasattr(exchange, 'close'):
|
|
||||||
await exchange.close()
|
|
||||||
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
|
|
||||||
await exchange.client.close()
|
|
||||||
logger.info("Exchange connection closed")
|
logger.info("Exchange connection closed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not properly close exchange connection: {e}")
|
logger.warning(f"Could not properly close exchange connection: {e}")
|
||||||
|
@ -6,4 +6,5 @@ python-dotenv>=0.19.0
|
|||||||
ccxt>=2.0.0
|
ccxt>=2.0.0
|
||||||
websockets>=10.0
|
websockets>=10.0
|
||||||
tensorboard>=2.6.0
|
tensorboard>=2.6.0
|
||||||
scikit-learn
|
scikit-learn
|
||||||
|
mplfinance
|
Loading…
x
Reference in New Issue
Block a user