gogo2/tests.py
Dobromir Popov 3871afd4b8 init
2025-03-18 09:23:09 +02:00

337 lines
14 KiB
Python

"""
Unit tests for the trading bot.
This file contains tests for various components of the trading bot, including:
1. Periodic candle updates
2. Backtesting on historical data
3. Training on the last 7 days of data
"""
import unittest
import asyncio
import os
import sys
import logging
import datetime
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()])
# Import functionality from main.py
import main
from main import (
CandleCache, BacktestCandles, initialize_exchange,
TradingEnvironment, Agent, train_with_backtesting,
fetch_multi_timeframe_data, train_agent
)
class TestPeriodicUpdates(unittest.TestCase):
"""Test that candle data is periodically updated during training."""
async def async_test_periodic_updates(self):
"""Test that candle data is periodically updated during training."""
logging.info("Testing periodic candle updates...")
# Initialize exchange
exchange = await initialize_exchange()
self.assertIsNotNone(exchange, "Failed to initialize exchange")
# Create candle cache
candle_cache = CandleCache()
# Initial fetch of candle data
candle_data = await fetch_multi_timeframe_data(exchange, "ETH/USDT", candle_cache)
self.assertIsNotNone(candle_data, "Failed to fetch initial candle data")
self.assertIn('1m', candle_data, "1m candles not found in initial data")
# Check initial data timestamps
initial_1m_candles = candle_data['1m']
self.assertGreater(len(initial_1m_candles), 0, "No 1m candles found in initial data")
initial_timestamp = initial_1m_candles[-1][0]
# Wait for update interval to pass
logging.info("Waiting for update interval to pass (5 seconds for testing)...")
await asyncio.sleep(5) # Short wait for testing
# Force update by setting last_updated to None
candle_cache.last_updated['1m'] = None
# Fetch updated data
updated_data = await fetch_multi_timeframe_data(exchange, "ETH/USDT", candle_cache)
self.assertIsNotNone(updated_data, "Failed to fetch updated candle data")
# Check if data was updated
updated_1m_candles = updated_data['1m']
self.assertGreater(len(updated_1m_candles), 0, "No 1m candles found in updated data")
updated_timestamp = updated_1m_candles[-1][0]
# In a live scenario, this check should pass with real-time updates
# For testing, we just ensure data was fetched
logging.info(f"Initial timestamp: {initial_timestamp}, Updated timestamp: {updated_timestamp}")
self.assertIsNotNone(updated_timestamp, "Updated timestamp is None")
# Close exchange connection
try:
await exchange.close()
except AttributeError:
# Some exchanges don't have a close method
pass
logging.info("Periodic update test completed")
def test_periodic_updates(self):
"""Run the async test."""
asyncio.run(self.async_test_periodic_updates())
class TestBacktesting(unittest.TestCase):
"""Test backtesting on historical data."""
async def async_test_backtesting(self):
"""Test backtesting on a specific time period."""
logging.info("Testing backtesting with historical data...")
# Initialize exchange
exchange = await initialize_exchange()
self.assertIsNotNone(exchange, "Failed to initialize exchange")
# Create a timestamp for 24 hours ago
now = datetime.datetime.now()
yesterday = now - datetime.timedelta(days=1)
since_timestamp = int(yesterday.timestamp() * 1000) # Convert to milliseconds
# Create a backtesting candle cache
backtest_cache = BacktestCandles(since_timestamp=since_timestamp)
backtest_cache.period_name = "1-day-ago"
# Fetch historical data
candle_data = await backtest_cache.fetch_all_timeframes(exchange, "ETH/USDT")
self.assertIsNotNone(candle_data, "Failed to fetch historical candle data")
self.assertIn('1m', candle_data, "1m candles not found in historical data")
# Check historical data timestamps
minute_candles = candle_data['1m']
self.assertGreater(len(minute_candles), 0, "No minute candles found in historical data")
# Check if timestamps are within the requested range
first_timestamp = minute_candles[0][0]
last_timestamp = minute_candles[-1][0]
logging.info(f"Requested since: {since_timestamp}")
logging.info(f"First timestamp in data: {first_timestamp}")
logging.info(f"Last timestamp in data: {last_timestamp}")
# In real tests, this check should compare timestamps precisely
# For this test, we just ensure data was fetched
self.assertLessEqual(first_timestamp, last_timestamp, "First timestamp should be before last timestamp")
# Close exchange connection
try:
await exchange.close()
except AttributeError:
# Some exchanges don't have a close method
pass
logging.info("Backtesting fetch test completed")
def test_backtesting(self):
"""Run the async test."""
asyncio.run(self.async_test_backtesting())
class TestBacktestingLastSevenDays(unittest.TestCase):
"""Test backtesting on the last 7 days of data."""
async def async_test_seven_days_backtesting(self):
"""Test backtesting on the last 7 days."""
logging.info("Testing backtesting on the last 7 days...")
# Initialize exchange
exchange = await initialize_exchange()
self.assertIsNotNone(exchange, "Failed to initialize exchange")
# Create environment with small initial balance for testing
env = TradingEnvironment(
initial_balance=100, # Small balance for testing
leverage=10, # Lower leverage for testing
window_size=50, # Smaller window for faster testing
commission=0.0004 # Standard commission
)
# Create agent
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE)
# Initialize empty results dataframe
all_results = pd.DataFrame()
# Run backtesting for the last 7 days, one day at a time
now = datetime.datetime.now()
for day_offset in range(1, 8):
# Calculate time period
end_day = now - datetime.timedelta(days=day_offset-1)
start_day = end_day - datetime.timedelta(days=1)
# Convert to milliseconds
since_timestamp = int(start_day.timestamp() * 1000)
until_timestamp = int(end_day.timestamp() * 1000)
# Period name
period_name = f"Day-{day_offset}"
logging.info(f"Testing backtesting for period: {period_name}")
logging.info(f" - From: {start_day.strftime('%Y-%m-%d %H:%M:%S')}")
logging.info(f" - To: {end_day.strftime('%Y-%m-%d %H:%M:%S')}")
# Run backtesting with a small number of episodes for testing
stats = await train_with_backtesting(
agent=agent,
env=env,
symbol="ETH/USDT",
since_timestamp=since_timestamp,
until_timestamp=until_timestamp,
num_episodes=3, # Use a small number for testing
max_steps_per_episode=200, # Use a small number for testing
period_name=period_name
)
# Check if stats were returned
if stats is None:
logging.warning(f"No stats returned for period: {period_name}")
continue
# Create a dataframe from stats
if len(stats['episode_rewards']) > 0:
df = pd.DataFrame({
'Period': [period_name] * len(stats['episode_rewards']),
'Episode': list(range(1, len(stats['episode_rewards']) + 1)),
'Reward': stats['episode_rewards'],
'Balance': stats['balances'],
'PnL': stats['episode_pnls'],
'Fees': stats['fees'],
'Net_PnL': stats['net_pnl_after_fees']
})
# Append to all results
all_results = pd.concat([all_results, df], ignore_index=True)
logging.info(f"Completed backtesting for period: {period_name}")
logging.info(f" - Episodes: {len(stats['episode_rewards'])}")
logging.info(f" - Final Balance: ${stats['balances'][-1]:.2f}")
logging.info(f" - Net PnL: ${stats['net_pnl_after_fees'][-1]:.2f}")
else:
logging.warning(f"No episodes completed for period: {period_name}")
# Save all results
if not all_results.empty:
all_results.to_csv("all_backtest_results.csv", index=False)
logging.info("Saved all backtest results to all_backtest_results.csv")
# Create plot of results
plt.figure(figsize=(12, 8))
# Plot Net PnL by period
all_results.groupby('Period')['Net_PnL'].last().plot(kind='bar')
plt.title('Net PnL by Training Period (Last Episode)')
plt.ylabel('Net PnL ($)')
plt.tight_layout()
plt.savefig("backtest_results.png")
logging.info("Saved backtest results plot to backtest_results.png")
# Close exchange connection
try:
await exchange.close()
except AttributeError:
# Some exchanges don't have a close method
pass
logging.info("7-day backtesting test completed")
def test_seven_days_backtesting(self):
"""Run the async test."""
asyncio.run(self.async_test_seven_days_backtesting())
class TestSingleDayBacktesting(unittest.TestCase):
"""Test backtesting on a single day of historical data."""
async def async_test_single_day_backtesting(self):
"""Test backtesting on a single day."""
logging.info("Testing backtesting on a single day...")
# Initialize exchange
exchange = await initialize_exchange()
self.assertIsNotNone(exchange, "Failed to initialize exchange")
# Create environment with small initial balance for testing
env = TradingEnvironment(
initial_balance=100, # Small balance for testing
leverage=10, # Lower leverage for testing
window_size=50, # Smaller window for faster testing
commission=0.0004 # Standard commission
)
# Create agent
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE)
# Calculate time period for 1 day ago
now = datetime.datetime.now()
end_day = now
start_day = end_day - datetime.timedelta(days=1)
# Convert to milliseconds
since_timestamp = int(start_day.timestamp() * 1000)
until_timestamp = int(end_day.timestamp() * 1000)
# Period name
period_name = "Test-Day-1"
logging.info(f"Testing backtesting for period: {period_name}")
logging.info(f" - From: {start_day.strftime('%Y-%m-%d %H:%M:%S')}")
logging.info(f" - To: {end_day.strftime('%Y-%m-%d %H:%M:%S')}")
# Run backtesting with a small number of episodes for testing
stats = await train_with_backtesting(
agent=agent,
env=env,
symbol="ETH/USDT",
since_timestamp=since_timestamp,
until_timestamp=until_timestamp,
num_episodes=2, # Very small number for quick testing
max_steps_per_episode=100, # Very small number for quick testing
period_name=period_name
)
# Check if stats were returned
self.assertIsNotNone(stats, "No stats returned from backtesting")
# Check if episodes were completed
self.assertGreater(len(stats['episode_rewards']), 0, "No episodes completed")
# Log results
logging.info(f"Completed backtesting for period: {period_name}")
logging.info(f" - Episodes: {len(stats['episode_rewards'])}")
logging.info(f" - Final Balance: ${stats['balances'][-1]:.2f}")
logging.info(f" - Net PnL: ${stats['net_pnl_after_fees'][-1]:.2f}")
# Close exchange connection
try:
await exchange.close()
except AttributeError:
# Some exchanges don't have a close method
pass
logging.info("Single day backtesting test completed")
def test_single_day_backtesting(self):
"""Run the async test."""
asyncio.run(self.async_test_single_day_backtesting())
if __name__ == '__main__':
unittest.main()