init
This commit is contained in:
337
tests.py
Normal file
337
tests.py
Normal file
@ -0,0 +1,337 @@
|
||||
"""
|
||||
Unit tests for the trading bot.
|
||||
This file contains tests for various components of the trading bot, including:
|
||||
1. Periodic candle updates
|
||||
2. Backtesting on historical data
|
||||
3. Training on the last 7 days of data
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import datetime
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.StreamHandler()])
|
||||
|
||||
# Import functionality from main.py
|
||||
import main
|
||||
from main import (
|
||||
CandleCache, BacktestCandles, initialize_exchange,
|
||||
TradingEnvironment, Agent, train_with_backtesting,
|
||||
fetch_multi_timeframe_data, train_agent
|
||||
)
|
||||
|
||||
class TestPeriodicUpdates(unittest.TestCase):
|
||||
"""Test that candle data is periodically updated during training."""
|
||||
|
||||
async def async_test_periodic_updates(self):
|
||||
"""Test that candle data is periodically updated during training."""
|
||||
logging.info("Testing periodic candle updates...")
|
||||
|
||||
# Initialize exchange
|
||||
exchange = await initialize_exchange()
|
||||
self.assertIsNotNone(exchange, "Failed to initialize exchange")
|
||||
|
||||
# Create candle cache
|
||||
candle_cache = CandleCache()
|
||||
|
||||
# Initial fetch of candle data
|
||||
candle_data = await fetch_multi_timeframe_data(exchange, "ETH/USDT", candle_cache)
|
||||
self.assertIsNotNone(candle_data, "Failed to fetch initial candle data")
|
||||
self.assertIn('1m', candle_data, "1m candles not found in initial data")
|
||||
|
||||
# Check initial data timestamps
|
||||
initial_1m_candles = candle_data['1m']
|
||||
self.assertGreater(len(initial_1m_candles), 0, "No 1m candles found in initial data")
|
||||
initial_timestamp = initial_1m_candles[-1][0]
|
||||
|
||||
# Wait for update interval to pass
|
||||
logging.info("Waiting for update interval to pass (5 seconds for testing)...")
|
||||
await asyncio.sleep(5) # Short wait for testing
|
||||
|
||||
# Force update by setting last_updated to None
|
||||
candle_cache.last_updated['1m'] = None
|
||||
|
||||
# Fetch updated data
|
||||
updated_data = await fetch_multi_timeframe_data(exchange, "ETH/USDT", candle_cache)
|
||||
self.assertIsNotNone(updated_data, "Failed to fetch updated candle data")
|
||||
|
||||
# Check if data was updated
|
||||
updated_1m_candles = updated_data['1m']
|
||||
self.assertGreater(len(updated_1m_candles), 0, "No 1m candles found in updated data")
|
||||
updated_timestamp = updated_1m_candles[-1][0]
|
||||
|
||||
# In a live scenario, this check should pass with real-time updates
|
||||
# For testing, we just ensure data was fetched
|
||||
logging.info(f"Initial timestamp: {initial_timestamp}, Updated timestamp: {updated_timestamp}")
|
||||
self.assertIsNotNone(updated_timestamp, "Updated timestamp is None")
|
||||
|
||||
# Close exchange connection
|
||||
try:
|
||||
await exchange.close()
|
||||
except AttributeError:
|
||||
# Some exchanges don't have a close method
|
||||
pass
|
||||
logging.info("Periodic update test completed")
|
||||
|
||||
def test_periodic_updates(self):
|
||||
"""Run the async test."""
|
||||
asyncio.run(self.async_test_periodic_updates())
|
||||
|
||||
|
||||
class TestBacktesting(unittest.TestCase):
|
||||
"""Test backtesting on historical data."""
|
||||
|
||||
async def async_test_backtesting(self):
|
||||
"""Test backtesting on a specific time period."""
|
||||
logging.info("Testing backtesting with historical data...")
|
||||
|
||||
# Initialize exchange
|
||||
exchange = await initialize_exchange()
|
||||
self.assertIsNotNone(exchange, "Failed to initialize exchange")
|
||||
|
||||
# Create a timestamp for 24 hours ago
|
||||
now = datetime.datetime.now()
|
||||
yesterday = now - datetime.timedelta(days=1)
|
||||
since_timestamp = int(yesterday.timestamp() * 1000) # Convert to milliseconds
|
||||
|
||||
# Create a backtesting candle cache
|
||||
backtest_cache = BacktestCandles(since_timestamp=since_timestamp)
|
||||
backtest_cache.period_name = "1-day-ago"
|
||||
|
||||
# Fetch historical data
|
||||
candle_data = await backtest_cache.fetch_all_timeframes(exchange, "ETH/USDT")
|
||||
self.assertIsNotNone(candle_data, "Failed to fetch historical candle data")
|
||||
self.assertIn('1m', candle_data, "1m candles not found in historical data")
|
||||
|
||||
# Check historical data timestamps
|
||||
minute_candles = candle_data['1m']
|
||||
self.assertGreater(len(minute_candles), 0, "No minute candles found in historical data")
|
||||
|
||||
# Check if timestamps are within the requested range
|
||||
first_timestamp = minute_candles[0][0]
|
||||
last_timestamp = minute_candles[-1][0]
|
||||
|
||||
logging.info(f"Requested since: {since_timestamp}")
|
||||
logging.info(f"First timestamp in data: {first_timestamp}")
|
||||
logging.info(f"Last timestamp in data: {last_timestamp}")
|
||||
|
||||
# In real tests, this check should compare timestamps precisely
|
||||
# For this test, we just ensure data was fetched
|
||||
self.assertLessEqual(first_timestamp, last_timestamp, "First timestamp should be before last timestamp")
|
||||
|
||||
# Close exchange connection
|
||||
try:
|
||||
await exchange.close()
|
||||
except AttributeError:
|
||||
# Some exchanges don't have a close method
|
||||
pass
|
||||
logging.info("Backtesting fetch test completed")
|
||||
|
||||
def test_backtesting(self):
|
||||
"""Run the async test."""
|
||||
asyncio.run(self.async_test_backtesting())
|
||||
|
||||
|
||||
class TestBacktestingLastSevenDays(unittest.TestCase):
|
||||
"""Test backtesting on the last 7 days of data."""
|
||||
|
||||
async def async_test_seven_days_backtesting(self):
|
||||
"""Test backtesting on the last 7 days."""
|
||||
logging.info("Testing backtesting on the last 7 days...")
|
||||
|
||||
# Initialize exchange
|
||||
exchange = await initialize_exchange()
|
||||
self.assertIsNotNone(exchange, "Failed to initialize exchange")
|
||||
|
||||
# Create environment with small initial balance for testing
|
||||
env = TradingEnvironment(
|
||||
initial_balance=100, # Small balance for testing
|
||||
leverage=10, # Lower leverage for testing
|
||||
window_size=50, # Smaller window for faster testing
|
||||
commission=0.0004 # Standard commission
|
||||
)
|
||||
|
||||
# Create agent
|
||||
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
|
||||
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
|
||||
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE)
|
||||
|
||||
# Initialize empty results dataframe
|
||||
all_results = pd.DataFrame()
|
||||
|
||||
# Run backtesting for the last 7 days, one day at a time
|
||||
now = datetime.datetime.now()
|
||||
|
||||
for day_offset in range(1, 8):
|
||||
# Calculate time period
|
||||
end_day = now - datetime.timedelta(days=day_offset-1)
|
||||
start_day = end_day - datetime.timedelta(days=1)
|
||||
|
||||
# Convert to milliseconds
|
||||
since_timestamp = int(start_day.timestamp() * 1000)
|
||||
until_timestamp = int(end_day.timestamp() * 1000)
|
||||
|
||||
# Period name
|
||||
period_name = f"Day-{day_offset}"
|
||||
|
||||
logging.info(f"Testing backtesting for period: {period_name}")
|
||||
logging.info(f" - From: {start_day.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
logging.info(f" - To: {end_day.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
# Run backtesting with a small number of episodes for testing
|
||||
stats = await train_with_backtesting(
|
||||
agent=agent,
|
||||
env=env,
|
||||
symbol="ETH/USDT",
|
||||
since_timestamp=since_timestamp,
|
||||
until_timestamp=until_timestamp,
|
||||
num_episodes=3, # Use a small number for testing
|
||||
max_steps_per_episode=200, # Use a small number for testing
|
||||
period_name=period_name
|
||||
)
|
||||
|
||||
# Check if stats were returned
|
||||
if stats is None:
|
||||
logging.warning(f"No stats returned for period: {period_name}")
|
||||
continue
|
||||
|
||||
# Create a dataframe from stats
|
||||
if len(stats['episode_rewards']) > 0:
|
||||
df = pd.DataFrame({
|
||||
'Period': [period_name] * len(stats['episode_rewards']),
|
||||
'Episode': list(range(1, len(stats['episode_rewards']) + 1)),
|
||||
'Reward': stats['episode_rewards'],
|
||||
'Balance': stats['balances'],
|
||||
'PnL': stats['episode_pnls'],
|
||||
'Fees': stats['fees'],
|
||||
'Net_PnL': stats['net_pnl_after_fees']
|
||||
})
|
||||
|
||||
# Append to all results
|
||||
all_results = pd.concat([all_results, df], ignore_index=True)
|
||||
|
||||
logging.info(f"Completed backtesting for period: {period_name}")
|
||||
logging.info(f" - Episodes: {len(stats['episode_rewards'])}")
|
||||
logging.info(f" - Final Balance: ${stats['balances'][-1]:.2f}")
|
||||
logging.info(f" - Net PnL: ${stats['net_pnl_after_fees'][-1]:.2f}")
|
||||
else:
|
||||
logging.warning(f"No episodes completed for period: {period_name}")
|
||||
|
||||
# Save all results
|
||||
if not all_results.empty:
|
||||
all_results.to_csv("all_backtest_results.csv", index=False)
|
||||
logging.info("Saved all backtest results to all_backtest_results.csv")
|
||||
|
||||
# Create plot of results
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
# Plot Net PnL by period
|
||||
all_results.groupby('Period')['Net_PnL'].last().plot(kind='bar')
|
||||
plt.title('Net PnL by Training Period (Last Episode)')
|
||||
plt.ylabel('Net PnL ($)')
|
||||
plt.tight_layout()
|
||||
plt.savefig("backtest_results.png")
|
||||
logging.info("Saved backtest results plot to backtest_results.png")
|
||||
|
||||
# Close exchange connection
|
||||
try:
|
||||
await exchange.close()
|
||||
except AttributeError:
|
||||
# Some exchanges don't have a close method
|
||||
pass
|
||||
logging.info("7-day backtesting test completed")
|
||||
|
||||
def test_seven_days_backtesting(self):
|
||||
"""Run the async test."""
|
||||
asyncio.run(self.async_test_seven_days_backtesting())
|
||||
|
||||
|
||||
class TestSingleDayBacktesting(unittest.TestCase):
|
||||
"""Test backtesting on a single day of historical data."""
|
||||
|
||||
async def async_test_single_day_backtesting(self):
|
||||
"""Test backtesting on a single day."""
|
||||
logging.info("Testing backtesting on a single day...")
|
||||
|
||||
# Initialize exchange
|
||||
exchange = await initialize_exchange()
|
||||
self.assertIsNotNone(exchange, "Failed to initialize exchange")
|
||||
|
||||
# Create environment with small initial balance for testing
|
||||
env = TradingEnvironment(
|
||||
initial_balance=100, # Small balance for testing
|
||||
leverage=10, # Lower leverage for testing
|
||||
window_size=50, # Smaller window for faster testing
|
||||
commission=0.0004 # Standard commission
|
||||
)
|
||||
|
||||
# Create agent
|
||||
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
|
||||
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
|
||||
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE)
|
||||
|
||||
# Calculate time period for 1 day ago
|
||||
now = datetime.datetime.now()
|
||||
end_day = now
|
||||
start_day = end_day - datetime.timedelta(days=1)
|
||||
|
||||
# Convert to milliseconds
|
||||
since_timestamp = int(start_day.timestamp() * 1000)
|
||||
until_timestamp = int(end_day.timestamp() * 1000)
|
||||
|
||||
# Period name
|
||||
period_name = "Test-Day-1"
|
||||
|
||||
logging.info(f"Testing backtesting for period: {period_name}")
|
||||
logging.info(f" - From: {start_day.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
logging.info(f" - To: {end_day.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
# Run backtesting with a small number of episodes for testing
|
||||
stats = await train_with_backtesting(
|
||||
agent=agent,
|
||||
env=env,
|
||||
symbol="ETH/USDT",
|
||||
since_timestamp=since_timestamp,
|
||||
until_timestamp=until_timestamp,
|
||||
num_episodes=2, # Very small number for quick testing
|
||||
max_steps_per_episode=100, # Very small number for quick testing
|
||||
period_name=period_name
|
||||
)
|
||||
|
||||
# Check if stats were returned
|
||||
self.assertIsNotNone(stats, "No stats returned from backtesting")
|
||||
|
||||
# Check if episodes were completed
|
||||
self.assertGreater(len(stats['episode_rewards']), 0, "No episodes completed")
|
||||
|
||||
# Log results
|
||||
logging.info(f"Completed backtesting for period: {period_name}")
|
||||
logging.info(f" - Episodes: {len(stats['episode_rewards'])}")
|
||||
logging.info(f" - Final Balance: ${stats['balances'][-1]:.2f}")
|
||||
logging.info(f" - Net PnL: ${stats['net_pnl_after_fees'][-1]:.2f}")
|
||||
|
||||
# Close exchange connection
|
||||
try:
|
||||
await exchange.close()
|
||||
except AttributeError:
|
||||
# Some exchanges don't have a close method
|
||||
pass
|
||||
logging.info("Single day backtesting test completed")
|
||||
|
||||
def test_single_day_backtesting(self):
|
||||
"""Run the async test."""
|
||||
asyncio.run(self.async_test_single_day_backtesting())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Reference in New Issue
Block a user