backtseting support

This commit is contained in:
Dobromir Popov 2025-03-17 19:21:43 +02:00
parent 5e9e6360af
commit 2e7a242ac7
5 changed files with 1581 additions and 612 deletions

File diff suppressed because it is too large Load Diff

34
crypto/gogo2/run_demo.py Normal file
View File

@ -0,0 +1,34 @@
#!/usr/bin/env python
import asyncio
import logging
from main import live_trading, setup_logging
# Set up logging
setup_logging()
logger = logging.getLogger(__name__)
async def main():
"""Run a simplified demo trading session with mock data"""
logger.info("Starting simplified demo trading session")
# Run live trading in demo mode with simplified parameters
await live_trading(
symbol="ETH/USDT",
timeframe="1m",
model_path="models/trading_agent_best_pnl.pt",
demo=True,
initial_balance=1000,
update_interval=10, # Update every 10 seconds for faster feedback
max_position_size=0.1,
risk_per_trade=0.02,
stop_loss_pct=0.02,
take_profit_pct=0.04,
)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Demo trading stopped by user")
except Exception as e:
logger.error(f"Error in demo trading: {e}")

View File

@ -1,477 +1,40 @@
import os #!/usr/bin/env python
import sys
import asyncio import asyncio
import logging
import argparse import argparse
import numpy as np import logging
import pandas as pd from main import live_trading, setup_logging
import random
import datetime
import torch
import matplotlib.pyplot as plt
import io
from PIL import Image
from dotenv import load_dotenv
# Configure logging # Set up logging
logging.basicConfig( setup_logging()
level=logging.INFO, logger = logging.getLogger(__name__)
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("live_trading.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger("live_trading")
def generate_mock_data(symbol, timeframe, limit=1000): async def main():
"""Generate mock OHLCV data for demo mode""" parser = argparse.ArgumentParser(description='Run live trading in demo mode')
logger.info(f"Generating mock data for {symbol} ({timeframe})")
# Set seed for reproducibility
np.random.seed(42)
# Generate timestamps
end_time = datetime.datetime.now()
start_time = end_time - datetime.timedelta(minutes=limit)
timestamps = [start_time + datetime.timedelta(minutes=i) for i in range(limit)]
# Convert to milliseconds
timestamps_ms = [int(ts.timestamp() * 1000) for ts in timestamps]
# Generate price data with realistic patterns
base_price = 3000.0 # Starting price
price_data = []
current_price = base_price
for i in range(limit):
# Random walk with momentum and volatility clusters
momentum = np.random.normal(0, 1)
volatility = 0.5 + 0.5 * np.sin(i / 100) # Cyclical volatility
# Price change with momentum and volatility
price_change = momentum * volatility * current_price * 0.005
# Add some trends and patterns
if i % 200 < 100: # Uptrend for 100 candles, then downtrend
price_change += current_price * 0.001
else:
price_change -= current_price * 0.0008
# Update current price
current_price += price_change
# Generate OHLCV data
open_price = current_price
close_price = current_price + np.random.normal(0, 1) * current_price * 0.002
high_price = max(open_price, close_price) + abs(np.random.normal(0, 1)) * current_price * 0.003
low_price = min(open_price, close_price) - abs(np.random.normal(0, 1)) * current_price * 0.003
volume = np.random.gamma(2, 100) * (1 + 0.5 * np.sin(i / 50)) # Cyclical volume
# Store data
price_data.append({
'timestamp': timestamps_ms[i],
'open': open_price,
'high': high_price,
'low': low_price,
'close': close_price,
'volume': volume
})
logger.info(f"Generated {len(price_data)} mock candles")
return price_data
async def generate_mock_live_candles(initial_data, symbol, timeframe):
"""Generate mock live candles based on initial data"""
last_candle = initial_data[-1].copy()
last_timestamp = last_candle['timestamp']
while True:
# Wait for next candle
await asyncio.sleep(5)
# Update timestamp
if timeframe == '1m':
last_timestamp += 60 * 1000 # 1 minute in milliseconds
elif timeframe == '5m':
last_timestamp += 5 * 60 * 1000
elif timeframe == '15m':
last_timestamp += 15 * 60 * 1000
elif timeframe == '1h':
last_timestamp += 60 * 60 * 1000
else:
last_timestamp += 60 * 1000 # Default to 1 minute
# Generate new candle
last_price = last_candle['close']
price_change = np.random.normal(0, 1) * last_price * 0.002
# Add some persistence
if last_candle['close'] > last_candle['open']:
# Previous candle was green, more likely to continue up
price_change += last_price * 0.0005
else:
# Previous candle was red, more likely to continue down
price_change -= last_price * 0.0005
# Generate OHLCV data
open_price = last_price
close_price = last_price + price_change
high_price = max(open_price, close_price) + abs(np.random.normal(0, 1)) * last_price * 0.001
low_price = min(open_price, close_price) - abs(np.random.normal(0, 1)) * last_price * 0.001
volume = np.random.gamma(2, 100)
# Create new candle
new_candle = {
'timestamp': last_timestamp,
'open': open_price,
'high': high_price,
'low': low_price,
'close': close_price,
'volume': volume
}
# Update last candle
last_candle = new_candle.copy()
yield new_candle
class MockExchange:
"""Mock exchange for demo mode"""
def __init__(self):
self.name = "MockExchange"
self.id = "mock"
async def fetch_ohlcv(self, symbol, timeframe, limit=1000):
"""Mock method to fetch OHLCV data"""
# Generate mock data
mock_data = generate_mock_data(symbol, timeframe, limit)
# Convert to CCXT format
ohlcv = []
for candle in mock_data:
ohlcv.append([
candle['timestamp'],
candle['open'],
candle['high'],
candle['low'],
candle['close'],
candle['volume']
])
return ohlcv
async def close(self):
"""Mock method to close exchange connection"""
pass
def get_model_info(model_path):
"""Extract model architecture information from saved model file"""
try:
# Load checkpoint with weights_only=False to get all information
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
# Extract model parameters
state_size = checkpoint['policy_net']['fc1.weight'].shape[1]
action_size = checkpoint['policy_net']['advantage_stream.bias'].shape[0]
hidden_size = checkpoint['policy_net']['fc1.weight'].shape[0]
# Try to extract LSTM layers and attention heads
lstm_layers = 2 # Default
attention_heads = 4 # Default
# Check if these parameters are stored in the checkpoint
if 'lstm_layers' in checkpoint:
lstm_layers = checkpoint['lstm_layers']
if 'attention_heads' in checkpoint:
attention_heads = checkpoint['attention_heads']
logger.info(f"Extracted model architecture: state_size={state_size}, action_size={action_size}, "
f"hidden_size={hidden_size}, lstm_layers={lstm_layers}, attention_heads={attention_heads}")
return state_size, action_size, hidden_size, lstm_layers, attention_heads
except Exception as e:
logger.error(f"Failed to extract model info: {str(e)}")
logger.warning("Using default model architecture")
return 40, 3, 384, 2, 4 # Default values
async def run_live_demo():
"""Run the trading bot in live demo mode with enhanced error handling"""
parser = argparse.ArgumentParser(description='Live Trading Demo')
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol') parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
parser.add_argument('--timeframe', type=str, default='1m', help='Candle timeframe') parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for trading')
parser.add_argument('--model', type=str, default='models/trading_agent_best_pnl.pt', help='Path to model file') parser.add_argument('--model_path', type=str, default='data/best_model.pth', help='Path to the trained model')
parser.add_argument('--mock', action='store_true', help='Use mock data instead of real exchange data') parser.add_argument('--initial_balance', type=float, default=1000, help='Initial balance')
parser.add_argument('--update_interval', type=int, default=30, help='Interval to update data in seconds')
args = parser.parse_args() args = parser.parse_args()
try: logger.info(f"Starting live trading demo with {args.symbol} on {args.timeframe} timeframe")
# Import main module
import main # Run live trading in demo mode
await live_trading(
# Load environment variables symbol=args.symbol,
load_dotenv() timeframe=args.timeframe,
model_path=args.model_path,
# Create directories if they don't exist demo=True, # Always use demo mode in this script
os.makedirs("trade_logs", exist_ok=True) initial_balance=args.initial_balance,
os.makedirs("runs", exist_ok=True) update_interval=args.update_interval,
# Using default values for other parameters
# Check if model file exists )
if not os.path.exists(args.model):
logger.error(f"Model file not found: {args.model}")
return 1
logger.info(f"Starting live trading demo for {args.symbol} on {args.timeframe} timeframe")
logger.info(f"Using model: {args.model}")
# Check API keys
api_key = os.getenv('MEXC_API_KEY')
secret_key = os.getenv('MEXC_SECRET_KEY')
use_mock = args.mock or not api_key or api_key == "your_api_key_here"
if use_mock:
logger.info("Using mock data for demo mode (no API keys required)")
exchange = MockExchange()
else:
# Initialize real exchange
exchange = await main.initialize_exchange()
# Initialize environment
env = main.TradingEnvironment(
initial_balance=float(os.getenv('INITIAL_BALANCE', 1000)),
window_size=30,
demo=True # Always use demo mode in this script
)
# Fetch initial data
if use_mock:
# Use mock data
mock_data = generate_mock_data(args.symbol, args.timeframe, 1000)
env.data = mock_data
env._initialize_features()
success = True
else:
# Fetch real data
success = await env.fetch_initial_data(
exchange,
symbol=args.symbol,
timeframe=args.timeframe,
limit=1000
)
if not success:
logger.error("Failed to fetch initial data. Exiting.")
return 1
# Get model architecture from saved model
state_size, action_size, hidden_size, lstm_layers, attention_heads = get_model_info(args.model)
# Initialize agent with the correct architecture
agent = main.Agent(
state_size=state_size,
action_size=action_size,
hidden_size=hidden_size,
lstm_layers=lstm_layers,
attention_heads=attention_heads
)
# Load model with weights_only=False to handle numpy types
try:
# First try with weights_only=True (safer)
agent.load(args.model)
except Exception as e:
logger.warning(f"Failed to load model with weights_only=True: {str(e)}")
# Try with safe_globals
try:
import torch.serialization
with torch.serialization.safe_globals(['numpy._core.multiarray.scalar', 'numpy.dtype']):
agent.load(args.model)
except Exception as e:
logger.warning(f"Failed with safe_globals: {str(e)}")
# Last resort: try with weights_only=False
try:
# Monkey patch the load method temporarily
original_load = main.Agent.load
def patched_load(self, path):
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
self.policy_net.load_state_dict(checkpoint['policy_net'])
self.target_net.load_state_dict(checkpoint['target_net'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.epsilon = checkpoint.get('epsilon', 0.05)
logger.info(f"Model loaded from {path}")
# Apply the patch
main.Agent.load = patched_load
# Try loading
agent.load(args.model)
# Restore original method
main.Agent.load = original_load
except Exception as e:
logger.error(f"All loading attempts failed: {str(e)}")
return 1
logger.info(f"Model loaded successfully")
# Initialize TensorBoard writer
from torch.utils.tensorboard import SummaryWriter
agent.writer = SummaryWriter(f'runs/live_demo_{args.symbol.replace("/", "_")}_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}')
# Track performance metrics
trades_count = 0
winning_trades = 0
total_profit = 0
max_drawdown = 0
peak_balance = env.balance
step_counter = 0
prev_position = 'flat'
# Create trade log file
os.makedirs('trade_logs', exist_ok=True)
trade_log_path = f'trade_logs/trades_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}.csv'
with open(trade_log_path, 'w') as f:
f.write("timestamp,action,price,position_size,balance,pnl\n")
logger.info(f"Starting live trading simulation...")
try:
# Set up mock live data generator if using mock data
if use_mock:
live_candle_generator = generate_mock_live_candles(env.data, args.symbol, args.timeframe)
# Main trading loop
while True:
try:
# Get latest candle
if use_mock:
# Get mock candle
candle = await anext(live_candle_generator)
else:
# Get real candle
candle = await main.get_latest_candle(exchange, args.symbol)
if candle is None:
logger.warning("Failed to fetch latest candle, retrying in 5 seconds...")
await asyncio.sleep(5)
continue
# Add new data to environment
env.add_data(candle)
# Get current state and select action
state = env.get_state()
action = agent.select_action(state, training=False)
# Update environment with action (simulated)
next_state, reward, done = env.step(action)
# Create info dictionary (missing in the step function)
info = {
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
'price': env.current_price,
'balance': env.balance,
'position': env.position,
'pnl': env.total_pnl
}
# Log trade if position changed
if env.position != prev_position:
trades_count += 1
if env.last_trade_profit > 0:
winning_trades += 1
total_profit += env.last_trade_profit
# Log trade details
with open(trade_log_path, 'a') as f:
f.write(f"{datetime.datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n")
logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}")
# Update performance metrics
if env.balance > peak_balance:
peak_balance = env.balance
current_drawdown = (peak_balance - env.balance) / peak_balance if peak_balance > 0 else 0
if current_drawdown > max_drawdown:
max_drawdown = current_drawdown
# Update TensorBoard metrics
step_counter += 1
agent.writer.add_scalar('Live/Balance', env.balance, step_counter)
agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter)
agent.writer.add_scalar('Live/Drawdown', current_drawdown * 100, step_counter)
# Update chart visualization
if step_counter % 5 == 0 or env.position != prev_position:
agent.add_chart_to_tensorboard(env, step_counter)
# Log performance summary
if trades_count > 0:
win_rate = (winning_trades / trades_count) * 100
agent.writer.add_scalar('Live/WinRate', win_rate, step_counter)
performance_text = f"""
**Live Trading Performance**
Balance: ${env.balance:.2f}
Total PnL: ${env.total_pnl:.2f}
Trades: {trades_count}
Win Rate: {win_rate:.1f}%
Max Drawdown: {max_drawdown*100:.1f}%
"""
agent.writer.add_text('Performance', performance_text, step_counter)
prev_position = env.position
# Wait for next candle
await asyncio.sleep(1) # Faster updates in demo mode
except Exception as e:
logger.error(f"Error in live trading loop: {str(e)}")
import traceback
logger.error(traceback.format_exc())
logger.info("Continuing after error...")
await asyncio.sleep(5)
except KeyboardInterrupt:
logger.info("Live trading stopped by user")
# Final performance report
if trades_count > 0:
win_rate = (winning_trades / trades_count) * 100
logger.info(f"Trading session summary:")
logger.info(f"Total trades: {trades_count}")
logger.info(f"Win rate: {win_rate:.1f}%")
logger.info(f"Final balance: ${env.balance:.2f}")
logger.info(f"Total profit: ${total_profit:.2f}")
logger.info(f"Maximum drawdown: {max_drawdown*100:.1f}%")
logger.info(f"Trade log saved to: {trade_log_path}")
return 0
except KeyboardInterrupt:
logger.info("Live trading stopped by user")
return 0
except Exception as e:
logger.error(f"Error in live trading: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return 1
if __name__ == "__main__": if __name__ == "__main__":
# Set environment variable to indicate we're in demo mode try:
os.environ['DEMO_MODE'] = 'true' asyncio.run(main())
except KeyboardInterrupt:
# Print banner logger.info("Live trading demo stopped by user")
print("\n" + "="*60) except Exception as e:
print("🤖 TRADING BOT - LIVE DEMO MODE 🤖") logger.error(f"Error in live trading demo: {e}")
print("="*60)
print("This is a DEMO mode with simulated trading (no real trades)")
print("Press Ctrl+C to stop the bot at any time")
print("="*60 + "\n")
# Run the async main function
exit_code = asyncio.run(run_live_demo())
sys.exit(exit_code)

77
crypto/gogo2/run_tests.py Normal file
View File

@ -0,0 +1,77 @@
#!/usr/bin/env python
"""
Run unit tests for the trading bot.
This script runs the unit tests defined in tests.py and displays the results.
It can run a single test or all tests.
Usage:
python run_tests.py [test_name]
If test_name is provided, only that test will be run.
Otherwise, all tests will be run.
Example:
python run_tests.py TestPeriodicUpdates
python run_tests.py TestBacktesting
python run_tests.py TestBacktestingLastSevenDays
python run_tests.py TestSingleDayBacktesting
python run_tests.py
"""
import sys
import unittest
import logging
from tests import (
TestPeriodicUpdates,
TestBacktesting,
TestBacktestingLastSevenDays,
TestSingleDayBacktesting
)
if __name__ == "__main__":
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()])
# Get the test name from the command line
test_name = sys.argv[1] if len(sys.argv) > 1 else None
# Run the specified test or all tests
if test_name:
logging.info(f"Running test: {test_name}")
if test_name == "TestPeriodicUpdates":
suite = unittest.TestLoader().loadTestsFromTestCase(TestPeriodicUpdates)
elif test_name == "TestBacktesting":
suite = unittest.TestLoader().loadTestsFromTestCase(TestBacktesting)
elif test_name == "TestBacktestingLastSevenDays":
suite = unittest.TestLoader().loadTestsFromTestCase(TestBacktestingLastSevenDays)
elif test_name == "TestSingleDayBacktesting":
suite = unittest.TestLoader().loadTestsFromTestCase(TestSingleDayBacktesting)
else:
logging.error(f"Unknown test: {test_name}")
logging.info("Available tests: TestPeriodicUpdates, TestBacktesting, TestBacktestingLastSevenDays, TestSingleDayBacktesting")
sys.exit(1)
else:
# Run all tests
logging.info("Running all tests")
suite = unittest.TestSuite()
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestPeriodicUpdates))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestBacktesting))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestBacktestingLastSevenDays))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestSingleDayBacktesting))
# Run the tests
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
# Print summary
print("\nTest Summary:")
print(f" Ran {result.testsRun} tests")
print(f" Errors: {len(result.errors)}")
print(f" Failures: {len(result.failures)}")
print(f" Skipped: {len(result.skipped)}")
# Exit with non-zero status if any tests failed
sys.exit(len(result.errors) + len(result.failures))

337
crypto/gogo2/tests.py Normal file
View File

@ -0,0 +1,337 @@
"""
Unit tests for the trading bot.
This file contains tests for various components of the trading bot, including:
1. Periodic candle updates
2. Backtesting on historical data
3. Training on the last 7 days of data
"""
import unittest
import asyncio
import os
import sys
import logging
import datetime
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()])
# Import functionality from main.py
import main
from main import (
CandleCache, BacktestCandles, initialize_exchange,
TradingEnvironment, Agent, train_with_backtesting,
fetch_multi_timeframe_data, train_agent
)
class TestPeriodicUpdates(unittest.TestCase):
"""Test that candle data is periodically updated during training."""
async def async_test_periodic_updates(self):
"""Test that candle data is periodically updated during training."""
logging.info("Testing periodic candle updates...")
# Initialize exchange
exchange = await initialize_exchange()
self.assertIsNotNone(exchange, "Failed to initialize exchange")
# Create candle cache
candle_cache = CandleCache()
# Initial fetch of candle data
candle_data = await fetch_multi_timeframe_data(exchange, "ETH/USDT", candle_cache)
self.assertIsNotNone(candle_data, "Failed to fetch initial candle data")
self.assertIn('1m', candle_data, "1m candles not found in initial data")
# Check initial data timestamps
initial_1m_candles = candle_data['1m']
self.assertGreater(len(initial_1m_candles), 0, "No 1m candles found in initial data")
initial_timestamp = initial_1m_candles[-1][0]
# Wait for update interval to pass
logging.info("Waiting for update interval to pass (5 seconds for testing)...")
await asyncio.sleep(5) # Short wait for testing
# Force update by setting last_updated to None
candle_cache.last_updated['1m'] = None
# Fetch updated data
updated_data = await fetch_multi_timeframe_data(exchange, "ETH/USDT", candle_cache)
self.assertIsNotNone(updated_data, "Failed to fetch updated candle data")
# Check if data was updated
updated_1m_candles = updated_data['1m']
self.assertGreater(len(updated_1m_candles), 0, "No 1m candles found in updated data")
updated_timestamp = updated_1m_candles[-1][0]
# In a live scenario, this check should pass with real-time updates
# For testing, we just ensure data was fetched
logging.info(f"Initial timestamp: {initial_timestamp}, Updated timestamp: {updated_timestamp}")
self.assertIsNotNone(updated_timestamp, "Updated timestamp is None")
# Close exchange connection
try:
await exchange.close()
except AttributeError:
# Some exchanges don't have a close method
pass
logging.info("Periodic update test completed")
def test_periodic_updates(self):
"""Run the async test."""
asyncio.run(self.async_test_periodic_updates())
class TestBacktesting(unittest.TestCase):
"""Test backtesting on historical data."""
async def async_test_backtesting(self):
"""Test backtesting on a specific time period."""
logging.info("Testing backtesting with historical data...")
# Initialize exchange
exchange = await initialize_exchange()
self.assertIsNotNone(exchange, "Failed to initialize exchange")
# Create a timestamp for 24 hours ago
now = datetime.datetime.now()
yesterday = now - datetime.timedelta(days=1)
since_timestamp = int(yesterday.timestamp() * 1000) # Convert to milliseconds
# Create a backtesting candle cache
backtest_cache = BacktestCandles(since_timestamp=since_timestamp)
backtest_cache.period_name = "1-day-ago"
# Fetch historical data
candle_data = await backtest_cache.fetch_all_timeframes(exchange, "ETH/USDT")
self.assertIsNotNone(candle_data, "Failed to fetch historical candle data")
self.assertIn('1m', candle_data, "1m candles not found in historical data")
# Check historical data timestamps
minute_candles = candle_data['1m']
self.assertGreater(len(minute_candles), 0, "No minute candles found in historical data")
# Check if timestamps are within the requested range
first_timestamp = minute_candles[0][0]
last_timestamp = minute_candles[-1][0]
logging.info(f"Requested since: {since_timestamp}")
logging.info(f"First timestamp in data: {first_timestamp}")
logging.info(f"Last timestamp in data: {last_timestamp}")
# In real tests, this check should compare timestamps precisely
# For this test, we just ensure data was fetched
self.assertLessEqual(first_timestamp, last_timestamp, "First timestamp should be before last timestamp")
# Close exchange connection
try:
await exchange.close()
except AttributeError:
# Some exchanges don't have a close method
pass
logging.info("Backtesting fetch test completed")
def test_backtesting(self):
"""Run the async test."""
asyncio.run(self.async_test_backtesting())
class TestBacktestingLastSevenDays(unittest.TestCase):
"""Test backtesting on the last 7 days of data."""
async def async_test_seven_days_backtesting(self):
"""Test backtesting on the last 7 days."""
logging.info("Testing backtesting on the last 7 days...")
# Initialize exchange
exchange = await initialize_exchange()
self.assertIsNotNone(exchange, "Failed to initialize exchange")
# Create environment with small initial balance for testing
env = TradingEnvironment(
initial_balance=100, # Small balance for testing
leverage=10, # Lower leverage for testing
window_size=50, # Smaller window for faster testing
commission=0.0004 # Standard commission
)
# Create agent
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE)
# Initialize empty results dataframe
all_results = pd.DataFrame()
# Run backtesting for the last 7 days, one day at a time
now = datetime.datetime.now()
for day_offset in range(1, 8):
# Calculate time period
end_day = now - datetime.timedelta(days=day_offset-1)
start_day = end_day - datetime.timedelta(days=1)
# Convert to milliseconds
since_timestamp = int(start_day.timestamp() * 1000)
until_timestamp = int(end_day.timestamp() * 1000)
# Period name
period_name = f"Day-{day_offset}"
logging.info(f"Testing backtesting for period: {period_name}")
logging.info(f" - From: {start_day.strftime('%Y-%m-%d %H:%M:%S')}")
logging.info(f" - To: {end_day.strftime('%Y-%m-%d %H:%M:%S')}")
# Run backtesting with a small number of episodes for testing
stats = await train_with_backtesting(
agent=agent,
env=env,
symbol="ETH/USDT",
since_timestamp=since_timestamp,
until_timestamp=until_timestamp,
num_episodes=3, # Use a small number for testing
max_steps_per_episode=200, # Use a small number for testing
period_name=period_name
)
# Check if stats were returned
if stats is None:
logging.warning(f"No stats returned for period: {period_name}")
continue
# Create a dataframe from stats
if len(stats['episode_rewards']) > 0:
df = pd.DataFrame({
'Period': [period_name] * len(stats['episode_rewards']),
'Episode': list(range(1, len(stats['episode_rewards']) + 1)),
'Reward': stats['episode_rewards'],
'Balance': stats['balances'],
'PnL': stats['episode_pnls'],
'Fees': stats['fees'],
'Net_PnL': stats['net_pnl_after_fees']
})
# Append to all results
all_results = pd.concat([all_results, df], ignore_index=True)
logging.info(f"Completed backtesting for period: {period_name}")
logging.info(f" - Episodes: {len(stats['episode_rewards'])}")
logging.info(f" - Final Balance: ${stats['balances'][-1]:.2f}")
logging.info(f" - Net PnL: ${stats['net_pnl_after_fees'][-1]:.2f}")
else:
logging.warning(f"No episodes completed for period: {period_name}")
# Save all results
if not all_results.empty:
all_results.to_csv("all_backtest_results.csv", index=False)
logging.info("Saved all backtest results to all_backtest_results.csv")
# Create plot of results
plt.figure(figsize=(12, 8))
# Plot Net PnL by period
all_results.groupby('Period')['Net_PnL'].last().plot(kind='bar')
plt.title('Net PnL by Training Period (Last Episode)')
plt.ylabel('Net PnL ($)')
plt.tight_layout()
plt.savefig("backtest_results.png")
logging.info("Saved backtest results plot to backtest_results.png")
# Close exchange connection
try:
await exchange.close()
except AttributeError:
# Some exchanges don't have a close method
pass
logging.info("7-day backtesting test completed")
def test_seven_days_backtesting(self):
"""Run the async test."""
asyncio.run(self.async_test_seven_days_backtesting())
class TestSingleDayBacktesting(unittest.TestCase):
"""Test backtesting on a single day of historical data."""
async def async_test_single_day_backtesting(self):
"""Test backtesting on a single day."""
logging.info("Testing backtesting on a single day...")
# Initialize exchange
exchange = await initialize_exchange()
self.assertIsNotNone(exchange, "Failed to initialize exchange")
# Create environment with small initial balance for testing
env = TradingEnvironment(
initial_balance=100, # Small balance for testing
leverage=10, # Lower leverage for testing
window_size=50, # Smaller window for faster testing
commission=0.0004 # Standard commission
)
# Create agent
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE)
# Calculate time period for 1 day ago
now = datetime.datetime.now()
end_day = now
start_day = end_day - datetime.timedelta(days=1)
# Convert to milliseconds
since_timestamp = int(start_day.timestamp() * 1000)
until_timestamp = int(end_day.timestamp() * 1000)
# Period name
period_name = "Test-Day-1"
logging.info(f"Testing backtesting for period: {period_name}")
logging.info(f" - From: {start_day.strftime('%Y-%m-%d %H:%M:%S')}")
logging.info(f" - To: {end_day.strftime('%Y-%m-%d %H:%M:%S')}")
# Run backtesting with a small number of episodes for testing
stats = await train_with_backtesting(
agent=agent,
env=env,
symbol="ETH/USDT",
since_timestamp=since_timestamp,
until_timestamp=until_timestamp,
num_episodes=2, # Very small number for quick testing
max_steps_per_episode=100, # Very small number for quick testing
period_name=period_name
)
# Check if stats were returned
self.assertIsNotNone(stats, "No stats returned from backtesting")
# Check if episodes were completed
self.assertGreater(len(stats['episode_rewards']), 0, "No episodes completed")
# Log results
logging.info(f"Completed backtesting for period: {period_name}")
logging.info(f" - Episodes: {len(stats['episode_rewards'])}")
logging.info(f" - Final Balance: ${stats['balances'][-1]:.2f}")
logging.info(f" - Net PnL: ${stats['net_pnl_after_fees'][-1]:.2f}")
# Close exchange connection
try:
await exchange.close()
except AttributeError:
# Some exchanges don't have a close method
pass
logging.info("Single day backtesting test completed")
def test_single_day_backtesting(self):
"""Run the async test."""
asyncio.run(self.async_test_single_day_backtesting())
if __name__ == '__main__':
unittest.main()