337 lines
14 KiB
Python
337 lines
14 KiB
Python
"""
|
|
Unit tests for the trading bot.
|
|
This file contains tests for various components of the trading bot, including:
|
|
1. Periodic candle updates
|
|
2. Backtesting on historical data
|
|
3. Training on the last 7 days of data
|
|
"""
|
|
|
|
import unittest
|
|
import asyncio
|
|
import os
|
|
import sys
|
|
import logging
|
|
import datetime
|
|
import pandas as pd
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from pathlib import Path
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[logging.StreamHandler()])
|
|
|
|
# Import functionality from main.py
|
|
import main
|
|
from main import (
|
|
CandleCache, BacktestCandles, initialize_exchange,
|
|
TradingEnvironment, Agent, train_with_backtesting,
|
|
fetch_multi_timeframe_data, train_agent
|
|
)
|
|
|
|
class TestPeriodicUpdates(unittest.TestCase):
|
|
"""Test that candle data is periodically updated during training."""
|
|
|
|
async def async_test_periodic_updates(self):
|
|
"""Test that candle data is periodically updated during training."""
|
|
logging.info("Testing periodic candle updates...")
|
|
|
|
# Initialize exchange
|
|
exchange = await initialize_exchange()
|
|
self.assertIsNotNone(exchange, "Failed to initialize exchange")
|
|
|
|
# Create candle cache
|
|
candle_cache = CandleCache()
|
|
|
|
# Initial fetch of candle data
|
|
candle_data = await fetch_multi_timeframe_data(exchange, "ETH/USDT", candle_cache)
|
|
self.assertIsNotNone(candle_data, "Failed to fetch initial candle data")
|
|
self.assertIn('1m', candle_data, "1m candles not found in initial data")
|
|
|
|
# Check initial data timestamps
|
|
initial_1m_candles = candle_data['1m']
|
|
self.assertGreater(len(initial_1m_candles), 0, "No 1m candles found in initial data")
|
|
initial_timestamp = initial_1m_candles[-1][0]
|
|
|
|
# Wait for update interval to pass
|
|
logging.info("Waiting for update interval to pass (5 seconds for testing)...")
|
|
await asyncio.sleep(5) # Short wait for testing
|
|
|
|
# Force update by setting last_updated to None
|
|
candle_cache.last_updated['1m'] = None
|
|
|
|
# Fetch updated data
|
|
updated_data = await fetch_multi_timeframe_data(exchange, "ETH/USDT", candle_cache)
|
|
self.assertIsNotNone(updated_data, "Failed to fetch updated candle data")
|
|
|
|
# Check if data was updated
|
|
updated_1m_candles = updated_data['1m']
|
|
self.assertGreater(len(updated_1m_candles), 0, "No 1m candles found in updated data")
|
|
updated_timestamp = updated_1m_candles[-1][0]
|
|
|
|
# In a live scenario, this check should pass with real-time updates
|
|
# For testing, we just ensure data was fetched
|
|
logging.info(f"Initial timestamp: {initial_timestamp}, Updated timestamp: {updated_timestamp}")
|
|
self.assertIsNotNone(updated_timestamp, "Updated timestamp is None")
|
|
|
|
# Close exchange connection
|
|
try:
|
|
await exchange.close()
|
|
except AttributeError:
|
|
# Some exchanges don't have a close method
|
|
pass
|
|
logging.info("Periodic update test completed")
|
|
|
|
def test_periodic_updates(self):
|
|
"""Run the async test."""
|
|
asyncio.run(self.async_test_periodic_updates())
|
|
|
|
|
|
class TestBacktesting(unittest.TestCase):
|
|
"""Test backtesting on historical data."""
|
|
|
|
async def async_test_backtesting(self):
|
|
"""Test backtesting on a specific time period."""
|
|
logging.info("Testing backtesting with historical data...")
|
|
|
|
# Initialize exchange
|
|
exchange = await initialize_exchange()
|
|
self.assertIsNotNone(exchange, "Failed to initialize exchange")
|
|
|
|
# Create a timestamp for 24 hours ago
|
|
now = datetime.datetime.now()
|
|
yesterday = now - datetime.timedelta(days=1)
|
|
since_timestamp = int(yesterday.timestamp() * 1000) # Convert to milliseconds
|
|
|
|
# Create a backtesting candle cache
|
|
backtest_cache = BacktestCandles(since_timestamp=since_timestamp)
|
|
backtest_cache.period_name = "1-day-ago"
|
|
|
|
# Fetch historical data
|
|
candle_data = await backtest_cache.fetch_all_timeframes(exchange, "ETH/USDT")
|
|
self.assertIsNotNone(candle_data, "Failed to fetch historical candle data")
|
|
self.assertIn('1m', candle_data, "1m candles not found in historical data")
|
|
|
|
# Check historical data timestamps
|
|
minute_candles = candle_data['1m']
|
|
self.assertGreater(len(minute_candles), 0, "No minute candles found in historical data")
|
|
|
|
# Check if timestamps are within the requested range
|
|
first_timestamp = minute_candles[0][0]
|
|
last_timestamp = minute_candles[-1][0]
|
|
|
|
logging.info(f"Requested since: {since_timestamp}")
|
|
logging.info(f"First timestamp in data: {first_timestamp}")
|
|
logging.info(f"Last timestamp in data: {last_timestamp}")
|
|
|
|
# In real tests, this check should compare timestamps precisely
|
|
# For this test, we just ensure data was fetched
|
|
self.assertLessEqual(first_timestamp, last_timestamp, "First timestamp should be before last timestamp")
|
|
|
|
# Close exchange connection
|
|
try:
|
|
await exchange.close()
|
|
except AttributeError:
|
|
# Some exchanges don't have a close method
|
|
pass
|
|
logging.info("Backtesting fetch test completed")
|
|
|
|
def test_backtesting(self):
|
|
"""Run the async test."""
|
|
asyncio.run(self.async_test_backtesting())
|
|
|
|
|
|
class TestBacktestingLastSevenDays(unittest.TestCase):
|
|
"""Test backtesting on the last 7 days of data."""
|
|
|
|
async def async_test_seven_days_backtesting(self):
|
|
"""Test backtesting on the last 7 days."""
|
|
logging.info("Testing backtesting on the last 7 days...")
|
|
|
|
# Initialize exchange
|
|
exchange = await initialize_exchange()
|
|
self.assertIsNotNone(exchange, "Failed to initialize exchange")
|
|
|
|
# Create environment with small initial balance for testing
|
|
env = TradingEnvironment(
|
|
initial_balance=100, # Small balance for testing
|
|
leverage=10, # Lower leverage for testing
|
|
window_size=50, # Smaller window for faster testing
|
|
commission=0.0004 # Standard commission
|
|
)
|
|
|
|
# Create agent
|
|
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
|
|
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
|
|
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE)
|
|
|
|
# Initialize empty results dataframe
|
|
all_results = pd.DataFrame()
|
|
|
|
# Run backtesting for the last 7 days, one day at a time
|
|
now = datetime.datetime.now()
|
|
|
|
for day_offset in range(1, 8):
|
|
# Calculate time period
|
|
end_day = now - datetime.timedelta(days=day_offset-1)
|
|
start_day = end_day - datetime.timedelta(days=1)
|
|
|
|
# Convert to milliseconds
|
|
since_timestamp = int(start_day.timestamp() * 1000)
|
|
until_timestamp = int(end_day.timestamp() * 1000)
|
|
|
|
# Period name
|
|
period_name = f"Day-{day_offset}"
|
|
|
|
logging.info(f"Testing backtesting for period: {period_name}")
|
|
logging.info(f" - From: {start_day.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
logging.info(f" - To: {end_day.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
|
|
# Run backtesting with a small number of episodes for testing
|
|
stats = await train_with_backtesting(
|
|
agent=agent,
|
|
env=env,
|
|
symbol="ETH/USDT",
|
|
since_timestamp=since_timestamp,
|
|
until_timestamp=until_timestamp,
|
|
num_episodes=3, # Use a small number for testing
|
|
max_steps_per_episode=200, # Use a small number for testing
|
|
period_name=period_name
|
|
)
|
|
|
|
# Check if stats were returned
|
|
if stats is None:
|
|
logging.warning(f"No stats returned for period: {period_name}")
|
|
continue
|
|
|
|
# Create a dataframe from stats
|
|
if len(stats['episode_rewards']) > 0:
|
|
df = pd.DataFrame({
|
|
'Period': [period_name] * len(stats['episode_rewards']),
|
|
'Episode': list(range(1, len(stats['episode_rewards']) + 1)),
|
|
'Reward': stats['episode_rewards'],
|
|
'Balance': stats['balances'],
|
|
'PnL': stats['episode_pnls'],
|
|
'Fees': stats['fees'],
|
|
'Net_PnL': stats['net_pnl_after_fees']
|
|
})
|
|
|
|
# Append to all results
|
|
all_results = pd.concat([all_results, df], ignore_index=True)
|
|
|
|
logging.info(f"Completed backtesting for period: {period_name}")
|
|
logging.info(f" - Episodes: {len(stats['episode_rewards'])}")
|
|
logging.info(f" - Final Balance: ${stats['balances'][-1]:.2f}")
|
|
logging.info(f" - Net PnL: ${stats['net_pnl_after_fees'][-1]:.2f}")
|
|
else:
|
|
logging.warning(f"No episodes completed for period: {period_name}")
|
|
|
|
# Save all results
|
|
if not all_results.empty:
|
|
all_results.to_csv("all_backtest_results.csv", index=False)
|
|
logging.info("Saved all backtest results to all_backtest_results.csv")
|
|
|
|
# Create plot of results
|
|
plt.figure(figsize=(12, 8))
|
|
|
|
# Plot Net PnL by period
|
|
all_results.groupby('Period')['Net_PnL'].last().plot(kind='bar')
|
|
plt.title('Net PnL by Training Period (Last Episode)')
|
|
plt.ylabel('Net PnL ($)')
|
|
plt.tight_layout()
|
|
plt.savefig("backtest_results.png")
|
|
logging.info("Saved backtest results plot to backtest_results.png")
|
|
|
|
# Close exchange connection
|
|
try:
|
|
await exchange.close()
|
|
except AttributeError:
|
|
# Some exchanges don't have a close method
|
|
pass
|
|
logging.info("7-day backtesting test completed")
|
|
|
|
def test_seven_days_backtesting(self):
|
|
"""Run the async test."""
|
|
asyncio.run(self.async_test_seven_days_backtesting())
|
|
|
|
|
|
class TestSingleDayBacktesting(unittest.TestCase):
|
|
"""Test backtesting on a single day of historical data."""
|
|
|
|
async def async_test_single_day_backtesting(self):
|
|
"""Test backtesting on a single day."""
|
|
logging.info("Testing backtesting on a single day...")
|
|
|
|
# Initialize exchange
|
|
exchange = await initialize_exchange()
|
|
self.assertIsNotNone(exchange, "Failed to initialize exchange")
|
|
|
|
# Create environment with small initial balance for testing
|
|
env = TradingEnvironment(
|
|
initial_balance=100, # Small balance for testing
|
|
leverage=10, # Lower leverage for testing
|
|
window_size=50, # Smaller window for faster testing
|
|
commission=0.0004 # Standard commission
|
|
)
|
|
|
|
# Create agent
|
|
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
|
|
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
|
|
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE)
|
|
|
|
# Calculate time period for 1 day ago
|
|
now = datetime.datetime.now()
|
|
end_day = now
|
|
start_day = end_day - datetime.timedelta(days=1)
|
|
|
|
# Convert to milliseconds
|
|
since_timestamp = int(start_day.timestamp() * 1000)
|
|
until_timestamp = int(end_day.timestamp() * 1000)
|
|
|
|
# Period name
|
|
period_name = "Test-Day-1"
|
|
|
|
logging.info(f"Testing backtesting for period: {period_name}")
|
|
logging.info(f" - From: {start_day.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
logging.info(f" - To: {end_day.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
|
|
# Run backtesting with a small number of episodes for testing
|
|
stats = await train_with_backtesting(
|
|
agent=agent,
|
|
env=env,
|
|
symbol="ETH/USDT",
|
|
since_timestamp=since_timestamp,
|
|
until_timestamp=until_timestamp,
|
|
num_episodes=2, # Very small number for quick testing
|
|
max_steps_per_episode=100, # Very small number for quick testing
|
|
period_name=period_name
|
|
)
|
|
|
|
# Check if stats were returned
|
|
self.assertIsNotNone(stats, "No stats returned from backtesting")
|
|
|
|
# Check if episodes were completed
|
|
self.assertGreater(len(stats['episode_rewards']), 0, "No episodes completed")
|
|
|
|
# Log results
|
|
logging.info(f"Completed backtesting for period: {period_name}")
|
|
logging.info(f" - Episodes: {len(stats['episode_rewards'])}")
|
|
logging.info(f" - Final Balance: ${stats['balances'][-1]:.2f}")
|
|
logging.info(f" - Net PnL: ${stats['net_pnl_after_fees'][-1]:.2f}")
|
|
|
|
# Close exchange connection
|
|
try:
|
|
await exchange.close()
|
|
except AttributeError:
|
|
# Some exchanges don't have a close method
|
|
pass
|
|
logging.info("Single day backtesting test completed")
|
|
|
|
def test_single_day_backtesting(self):
|
|
"""Run the async test."""
|
|
asyncio.run(self.async_test_single_day_backtesting())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |