backtseting support
This commit is contained in:
parent
5e9e6360af
commit
2e7a242ac7
1222
crypto/gogo2/main.py
1222
crypto/gogo2/main.py
File diff suppressed because it is too large
Load Diff
34
crypto/gogo2/run_demo.py
Normal file
34
crypto/gogo2/run_demo.py
Normal file
@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python
|
||||
import asyncio
|
||||
import logging
|
||||
from main import live_trading, setup_logging
|
||||
|
||||
# Set up logging
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def main():
|
||||
"""Run a simplified demo trading session with mock data"""
|
||||
logger.info("Starting simplified demo trading session")
|
||||
|
||||
# Run live trading in demo mode with simplified parameters
|
||||
await live_trading(
|
||||
symbol="ETH/USDT",
|
||||
timeframe="1m",
|
||||
model_path="models/trading_agent_best_pnl.pt",
|
||||
demo=True,
|
||||
initial_balance=1000,
|
||||
update_interval=10, # Update every 10 seconds for faster feedback
|
||||
max_position_size=0.1,
|
||||
risk_per_trade=0.02,
|
||||
stop_loss_pct=0.02,
|
||||
take_profit_pct=0.04,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Demo trading stopped by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in demo trading: {e}")
|
@ -1,477 +1,40 @@
|
||||
import os
|
||||
import sys
|
||||
#!/usr/bin/env python
|
||||
import asyncio
|
||||
import logging
|
||||
import argparse
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import random
|
||||
import datetime
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import io
|
||||
from PIL import Image
|
||||
from dotenv import load_dotenv
|
||||
import logging
|
||||
from main import live_trading, setup_logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler("live_trading.log"),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger("live_trading")
|
||||
# Set up logging
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def generate_mock_data(symbol, timeframe, limit=1000):
|
||||
"""Generate mock OHLCV data for demo mode"""
|
||||
logger.info(f"Generating mock data for {symbol} ({timeframe})")
|
||||
|
||||
# Set seed for reproducibility
|
||||
np.random.seed(42)
|
||||
|
||||
# Generate timestamps
|
||||
end_time = datetime.datetime.now()
|
||||
start_time = end_time - datetime.timedelta(minutes=limit)
|
||||
timestamps = [start_time + datetime.timedelta(minutes=i) for i in range(limit)]
|
||||
|
||||
# Convert to milliseconds
|
||||
timestamps_ms = [int(ts.timestamp() * 1000) for ts in timestamps]
|
||||
|
||||
# Generate price data with realistic patterns
|
||||
base_price = 3000.0 # Starting price
|
||||
price_data = []
|
||||
current_price = base_price
|
||||
|
||||
for i in range(limit):
|
||||
# Random walk with momentum and volatility clusters
|
||||
momentum = np.random.normal(0, 1)
|
||||
volatility = 0.5 + 0.5 * np.sin(i / 100) # Cyclical volatility
|
||||
|
||||
# Price change with momentum and volatility
|
||||
price_change = momentum * volatility * current_price * 0.005
|
||||
|
||||
# Add some trends and patterns
|
||||
if i % 200 < 100: # Uptrend for 100 candles, then downtrend
|
||||
price_change += current_price * 0.001
|
||||
else:
|
||||
price_change -= current_price * 0.0008
|
||||
|
||||
# Update current price
|
||||
current_price += price_change
|
||||
|
||||
# Generate OHLCV data
|
||||
open_price = current_price
|
||||
close_price = current_price + np.random.normal(0, 1) * current_price * 0.002
|
||||
high_price = max(open_price, close_price) + abs(np.random.normal(0, 1)) * current_price * 0.003
|
||||
low_price = min(open_price, close_price) - abs(np.random.normal(0, 1)) * current_price * 0.003
|
||||
volume = np.random.gamma(2, 100) * (1 + 0.5 * np.sin(i / 50)) # Cyclical volume
|
||||
|
||||
# Store data
|
||||
price_data.append({
|
||||
'timestamp': timestamps_ms[i],
|
||||
'open': open_price,
|
||||
'high': high_price,
|
||||
'low': low_price,
|
||||
'close': close_price,
|
||||
'volume': volume
|
||||
})
|
||||
|
||||
logger.info(f"Generated {len(price_data)} mock candles")
|
||||
return price_data
|
||||
|
||||
async def generate_mock_live_candles(initial_data, symbol, timeframe):
|
||||
"""Generate mock live candles based on initial data"""
|
||||
last_candle = initial_data[-1].copy()
|
||||
last_timestamp = last_candle['timestamp']
|
||||
|
||||
while True:
|
||||
# Wait for next candle
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Update timestamp
|
||||
if timeframe == '1m':
|
||||
last_timestamp += 60 * 1000 # 1 minute in milliseconds
|
||||
elif timeframe == '5m':
|
||||
last_timestamp += 5 * 60 * 1000
|
||||
elif timeframe == '15m':
|
||||
last_timestamp += 15 * 60 * 1000
|
||||
elif timeframe == '1h':
|
||||
last_timestamp += 60 * 60 * 1000
|
||||
else:
|
||||
last_timestamp += 60 * 1000 # Default to 1 minute
|
||||
|
||||
# Generate new candle
|
||||
last_price = last_candle['close']
|
||||
price_change = np.random.normal(0, 1) * last_price * 0.002
|
||||
|
||||
# Add some persistence
|
||||
if last_candle['close'] > last_candle['open']:
|
||||
# Previous candle was green, more likely to continue up
|
||||
price_change += last_price * 0.0005
|
||||
else:
|
||||
# Previous candle was red, more likely to continue down
|
||||
price_change -= last_price * 0.0005
|
||||
|
||||
# Generate OHLCV data
|
||||
open_price = last_price
|
||||
close_price = last_price + price_change
|
||||
high_price = max(open_price, close_price) + abs(np.random.normal(0, 1)) * last_price * 0.001
|
||||
low_price = min(open_price, close_price) - abs(np.random.normal(0, 1)) * last_price * 0.001
|
||||
volume = np.random.gamma(2, 100)
|
||||
|
||||
# Create new candle
|
||||
new_candle = {
|
||||
'timestamp': last_timestamp,
|
||||
'open': open_price,
|
||||
'high': high_price,
|
||||
'low': low_price,
|
||||
'close': close_price,
|
||||
'volume': volume
|
||||
}
|
||||
|
||||
# Update last candle
|
||||
last_candle = new_candle.copy()
|
||||
|
||||
yield new_candle
|
||||
|
||||
class MockExchange:
|
||||
"""Mock exchange for demo mode"""
|
||||
def __init__(self):
|
||||
self.name = "MockExchange"
|
||||
self.id = "mock"
|
||||
|
||||
async def fetch_ohlcv(self, symbol, timeframe, limit=1000):
|
||||
"""Mock method to fetch OHLCV data"""
|
||||
# Generate mock data
|
||||
mock_data = generate_mock_data(symbol, timeframe, limit)
|
||||
|
||||
# Convert to CCXT format
|
||||
ohlcv = []
|
||||
for candle in mock_data:
|
||||
ohlcv.append([
|
||||
candle['timestamp'],
|
||||
candle['open'],
|
||||
candle['high'],
|
||||
candle['low'],
|
||||
candle['close'],
|
||||
candle['volume']
|
||||
])
|
||||
|
||||
return ohlcv
|
||||
|
||||
async def close(self):
|
||||
"""Mock method to close exchange connection"""
|
||||
pass
|
||||
|
||||
def get_model_info(model_path):
|
||||
"""Extract model architecture information from saved model file"""
|
||||
try:
|
||||
# Load checkpoint with weights_only=False to get all information
|
||||
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
||||
|
||||
# Extract model parameters
|
||||
state_size = checkpoint['policy_net']['fc1.weight'].shape[1]
|
||||
action_size = checkpoint['policy_net']['advantage_stream.bias'].shape[0]
|
||||
hidden_size = checkpoint['policy_net']['fc1.weight'].shape[0]
|
||||
|
||||
# Try to extract LSTM layers and attention heads
|
||||
lstm_layers = 2 # Default
|
||||
attention_heads = 4 # Default
|
||||
|
||||
# Check if these parameters are stored in the checkpoint
|
||||
if 'lstm_layers' in checkpoint:
|
||||
lstm_layers = checkpoint['lstm_layers']
|
||||
if 'attention_heads' in checkpoint:
|
||||
attention_heads = checkpoint['attention_heads']
|
||||
|
||||
logger.info(f"Extracted model architecture: state_size={state_size}, action_size={action_size}, "
|
||||
f"hidden_size={hidden_size}, lstm_layers={lstm_layers}, attention_heads={attention_heads}")
|
||||
|
||||
return state_size, action_size, hidden_size, lstm_layers, attention_heads
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract model info: {str(e)}")
|
||||
logger.warning("Using default model architecture")
|
||||
return 40, 3, 384, 2, 4 # Default values
|
||||
|
||||
async def run_live_demo():
|
||||
"""Run the trading bot in live demo mode with enhanced error handling"""
|
||||
parser = argparse.ArgumentParser(description='Live Trading Demo')
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description='Run live trading in demo mode')
|
||||
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
|
||||
parser.add_argument('--timeframe', type=str, default='1m', help='Candle timeframe')
|
||||
parser.add_argument('--model', type=str, default='models/trading_agent_best_pnl.pt', help='Path to model file')
|
||||
parser.add_argument('--mock', action='store_true', help='Use mock data instead of real exchange data')
|
||||
parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for trading')
|
||||
parser.add_argument('--model_path', type=str, default='data/best_model.pth', help='Path to the trained model')
|
||||
parser.add_argument('--initial_balance', type=float, default=1000, help='Initial balance')
|
||||
parser.add_argument('--update_interval', type=int, default=30, help='Interval to update data in seconds')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Import main module
|
||||
import main
|
||||
logger.info(f"Starting live trading demo with {args.symbol} on {args.timeframe} timeframe")
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Create directories if they don't exist
|
||||
os.makedirs("trade_logs", exist_ok=True)
|
||||
os.makedirs("runs", exist_ok=True)
|
||||
|
||||
# Check if model file exists
|
||||
if not os.path.exists(args.model):
|
||||
logger.error(f"Model file not found: {args.model}")
|
||||
return 1
|
||||
|
||||
logger.info(f"Starting live trading demo for {args.symbol} on {args.timeframe} timeframe")
|
||||
logger.info(f"Using model: {args.model}")
|
||||
|
||||
# Check API keys
|
||||
api_key = os.getenv('MEXC_API_KEY')
|
||||
secret_key = os.getenv('MEXC_SECRET_KEY')
|
||||
use_mock = args.mock or not api_key or api_key == "your_api_key_here"
|
||||
|
||||
if use_mock:
|
||||
logger.info("Using mock data for demo mode (no API keys required)")
|
||||
exchange = MockExchange()
|
||||
else:
|
||||
# Initialize real exchange
|
||||
exchange = await main.initialize_exchange()
|
||||
|
||||
# Initialize environment
|
||||
env = main.TradingEnvironment(
|
||||
initial_balance=float(os.getenv('INITIAL_BALANCE', 1000)),
|
||||
window_size=30,
|
||||
demo=True # Always use demo mode in this script
|
||||
)
|
||||
|
||||
# Fetch initial data
|
||||
if use_mock:
|
||||
# Use mock data
|
||||
mock_data = generate_mock_data(args.symbol, args.timeframe, 1000)
|
||||
env.data = mock_data
|
||||
env._initialize_features()
|
||||
success = True
|
||||
else:
|
||||
# Fetch real data
|
||||
success = await env.fetch_initial_data(
|
||||
exchange,
|
||||
# Run live trading in demo mode
|
||||
await live_trading(
|
||||
symbol=args.symbol,
|
||||
timeframe=args.timeframe,
|
||||
limit=1000
|
||||
model_path=args.model_path,
|
||||
demo=True, # Always use demo mode in this script
|
||||
initial_balance=args.initial_balance,
|
||||
update_interval=args.update_interval,
|
||||
# Using default values for other parameters
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error("Failed to fetch initial data. Exiting.")
|
||||
return 1
|
||||
|
||||
# Get model architecture from saved model
|
||||
state_size, action_size, hidden_size, lstm_layers, attention_heads = get_model_info(args.model)
|
||||
|
||||
# Initialize agent with the correct architecture
|
||||
agent = main.Agent(
|
||||
state_size=state_size,
|
||||
action_size=action_size,
|
||||
hidden_size=hidden_size,
|
||||
lstm_layers=lstm_layers,
|
||||
attention_heads=attention_heads
|
||||
)
|
||||
|
||||
# Load model with weights_only=False to handle numpy types
|
||||
try:
|
||||
# First try with weights_only=True (safer)
|
||||
agent.load(args.model)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load model with weights_only=True: {str(e)}")
|
||||
|
||||
# Try with safe_globals
|
||||
try:
|
||||
import torch.serialization
|
||||
with torch.serialization.safe_globals(['numpy._core.multiarray.scalar', 'numpy.dtype']):
|
||||
agent.load(args.model)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed with safe_globals: {str(e)}")
|
||||
|
||||
# Last resort: try with weights_only=False
|
||||
try:
|
||||
# Monkey patch the load method temporarily
|
||||
original_load = main.Agent.load
|
||||
|
||||
def patched_load(self, path):
|
||||
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
||||
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
||||
self.target_net.load_state_dict(checkpoint['target_net'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
self.epsilon = checkpoint.get('epsilon', 0.05)
|
||||
logger.info(f"Model loaded from {path}")
|
||||
|
||||
# Apply the patch
|
||||
main.Agent.load = patched_load
|
||||
|
||||
# Try loading
|
||||
agent.load(args.model)
|
||||
|
||||
# Restore original method
|
||||
main.Agent.load = original_load
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"All loading attempts failed: {str(e)}")
|
||||
return 1
|
||||
|
||||
logger.info(f"Model loaded successfully")
|
||||
|
||||
# Initialize TensorBoard writer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
agent.writer = SummaryWriter(f'runs/live_demo_{args.symbol.replace("/", "_")}_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}')
|
||||
|
||||
# Track performance metrics
|
||||
trades_count = 0
|
||||
winning_trades = 0
|
||||
total_profit = 0
|
||||
max_drawdown = 0
|
||||
peak_balance = env.balance
|
||||
step_counter = 0
|
||||
prev_position = 'flat'
|
||||
|
||||
# Create trade log file
|
||||
os.makedirs('trade_logs', exist_ok=True)
|
||||
trade_log_path = f'trade_logs/trades_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}.csv'
|
||||
with open(trade_log_path, 'w') as f:
|
||||
f.write("timestamp,action,price,position_size,balance,pnl\n")
|
||||
|
||||
logger.info(f"Starting live trading simulation...")
|
||||
|
||||
try:
|
||||
# Set up mock live data generator if using mock data
|
||||
if use_mock:
|
||||
live_candle_generator = generate_mock_live_candles(env.data, args.symbol, args.timeframe)
|
||||
|
||||
# Main trading loop
|
||||
while True:
|
||||
try:
|
||||
# Get latest candle
|
||||
if use_mock:
|
||||
# Get mock candle
|
||||
candle = await anext(live_candle_generator)
|
||||
else:
|
||||
# Get real candle
|
||||
candle = await main.get_latest_candle(exchange, args.symbol)
|
||||
|
||||
if candle is None:
|
||||
logger.warning("Failed to fetch latest candle, retrying in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
|
||||
# Add new data to environment
|
||||
env.add_data(candle)
|
||||
|
||||
# Get current state and select action
|
||||
state = env.get_state()
|
||||
action = agent.select_action(state, training=False)
|
||||
|
||||
# Update environment with action (simulated)
|
||||
next_state, reward, done = env.step(action)
|
||||
|
||||
# Create info dictionary (missing in the step function)
|
||||
info = {
|
||||
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
|
||||
'price': env.current_price,
|
||||
'balance': env.balance,
|
||||
'position': env.position,
|
||||
'pnl': env.total_pnl
|
||||
}
|
||||
|
||||
# Log trade if position changed
|
||||
if env.position != prev_position:
|
||||
trades_count += 1
|
||||
if env.last_trade_profit > 0:
|
||||
winning_trades += 1
|
||||
total_profit += env.last_trade_profit
|
||||
|
||||
# Log trade details
|
||||
with open(trade_log_path, 'a') as f:
|
||||
f.write(f"{datetime.datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n")
|
||||
|
||||
logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}")
|
||||
|
||||
# Update performance metrics
|
||||
if env.balance > peak_balance:
|
||||
peak_balance = env.balance
|
||||
current_drawdown = (peak_balance - env.balance) / peak_balance if peak_balance > 0 else 0
|
||||
if current_drawdown > max_drawdown:
|
||||
max_drawdown = current_drawdown
|
||||
|
||||
# Update TensorBoard metrics
|
||||
step_counter += 1
|
||||
agent.writer.add_scalar('Live/Balance', env.balance, step_counter)
|
||||
agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter)
|
||||
agent.writer.add_scalar('Live/Drawdown', current_drawdown * 100, step_counter)
|
||||
|
||||
# Update chart visualization
|
||||
if step_counter % 5 == 0 or env.position != prev_position:
|
||||
agent.add_chart_to_tensorboard(env, step_counter)
|
||||
|
||||
# Log performance summary
|
||||
if trades_count > 0:
|
||||
win_rate = (winning_trades / trades_count) * 100
|
||||
agent.writer.add_scalar('Live/WinRate', win_rate, step_counter)
|
||||
|
||||
performance_text = f"""
|
||||
**Live Trading Performance**
|
||||
Balance: ${env.balance:.2f}
|
||||
Total PnL: ${env.total_pnl:.2f}
|
||||
Trades: {trades_count}
|
||||
Win Rate: {win_rate:.1f}%
|
||||
Max Drawdown: {max_drawdown*100:.1f}%
|
||||
"""
|
||||
agent.writer.add_text('Performance', performance_text, step_counter)
|
||||
|
||||
prev_position = env.position
|
||||
|
||||
# Wait for next candle
|
||||
await asyncio.sleep(1) # Faster updates in demo mode
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in live trading loop: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
logger.info("Continuing after error...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Live trading stopped by user")
|
||||
|
||||
# Final performance report
|
||||
if trades_count > 0:
|
||||
win_rate = (winning_trades / trades_count) * 100
|
||||
logger.info(f"Trading session summary:")
|
||||
logger.info(f"Total trades: {trades_count}")
|
||||
logger.info(f"Win rate: {win_rate:.1f}%")
|
||||
logger.info(f"Final balance: ${env.balance:.2f}")
|
||||
logger.info(f"Total profit: ${total_profit:.2f}")
|
||||
logger.info(f"Maximum drawdown: {max_drawdown*100:.1f}%")
|
||||
logger.info(f"Trade log saved to: {trade_log_path}")
|
||||
|
||||
return 0
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Live trading stopped by user")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error in live trading: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set environment variable to indicate we're in demo mode
|
||||
os.environ['DEMO_MODE'] = 'true'
|
||||
|
||||
# Print banner
|
||||
print("\n" + "="*60)
|
||||
print("🤖 TRADING BOT - LIVE DEMO MODE 🤖")
|
||||
print("="*60)
|
||||
print("This is a DEMO mode with simulated trading (no real trades)")
|
||||
print("Press Ctrl+C to stop the bot at any time")
|
||||
print("="*60 + "\n")
|
||||
|
||||
# Run the async main function
|
||||
exit_code = asyncio.run(run_live_demo())
|
||||
sys.exit(exit_code)
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Live trading demo stopped by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in live trading demo: {e}")
|
77
crypto/gogo2/run_tests.py
Normal file
77
crypto/gogo2/run_tests.py
Normal file
@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Run unit tests for the trading bot.
|
||||
|
||||
This script runs the unit tests defined in tests.py and displays the results.
|
||||
It can run a single test or all tests.
|
||||
|
||||
Usage:
|
||||
python run_tests.py [test_name]
|
||||
|
||||
If test_name is provided, only that test will be run.
|
||||
Otherwise, all tests will be run.
|
||||
|
||||
Example:
|
||||
python run_tests.py TestPeriodicUpdates
|
||||
python run_tests.py TestBacktesting
|
||||
python run_tests.py TestBacktestingLastSevenDays
|
||||
python run_tests.py TestSingleDayBacktesting
|
||||
python run_tests.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import logging
|
||||
from tests import (
|
||||
TestPeriodicUpdates,
|
||||
TestBacktesting,
|
||||
TestBacktestingLastSevenDays,
|
||||
TestSingleDayBacktesting
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.StreamHandler()])
|
||||
|
||||
# Get the test name from the command line
|
||||
test_name = sys.argv[1] if len(sys.argv) > 1 else None
|
||||
|
||||
# Run the specified test or all tests
|
||||
if test_name:
|
||||
logging.info(f"Running test: {test_name}")
|
||||
if test_name == "TestPeriodicUpdates":
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(TestPeriodicUpdates)
|
||||
elif test_name == "TestBacktesting":
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(TestBacktesting)
|
||||
elif test_name == "TestBacktestingLastSevenDays":
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(TestBacktestingLastSevenDays)
|
||||
elif test_name == "TestSingleDayBacktesting":
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(TestSingleDayBacktesting)
|
||||
else:
|
||||
logging.error(f"Unknown test: {test_name}")
|
||||
logging.info("Available tests: TestPeriodicUpdates, TestBacktesting, TestBacktestingLastSevenDays, TestSingleDayBacktesting")
|
||||
sys.exit(1)
|
||||
else:
|
||||
# Run all tests
|
||||
logging.info("Running all tests")
|
||||
suite = unittest.TestSuite()
|
||||
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestPeriodicUpdates))
|
||||
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestBacktesting))
|
||||
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestBacktestingLastSevenDays))
|
||||
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestSingleDayBacktesting))
|
||||
|
||||
# Run the tests
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(suite)
|
||||
|
||||
# Print summary
|
||||
print("\nTest Summary:")
|
||||
print(f" Ran {result.testsRun} tests")
|
||||
print(f" Errors: {len(result.errors)}")
|
||||
print(f" Failures: {len(result.failures)}")
|
||||
print(f" Skipped: {len(result.skipped)}")
|
||||
|
||||
# Exit with non-zero status if any tests failed
|
||||
sys.exit(len(result.errors) + len(result.failures))
|
337
crypto/gogo2/tests.py
Normal file
337
crypto/gogo2/tests.py
Normal file
@ -0,0 +1,337 @@
|
||||
"""
|
||||
Unit tests for the trading bot.
|
||||
This file contains tests for various components of the trading bot, including:
|
||||
1. Periodic candle updates
|
||||
2. Backtesting on historical data
|
||||
3. Training on the last 7 days of data
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import datetime
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.StreamHandler()])
|
||||
|
||||
# Import functionality from main.py
|
||||
import main
|
||||
from main import (
|
||||
CandleCache, BacktestCandles, initialize_exchange,
|
||||
TradingEnvironment, Agent, train_with_backtesting,
|
||||
fetch_multi_timeframe_data, train_agent
|
||||
)
|
||||
|
||||
class TestPeriodicUpdates(unittest.TestCase):
|
||||
"""Test that candle data is periodically updated during training."""
|
||||
|
||||
async def async_test_periodic_updates(self):
|
||||
"""Test that candle data is periodically updated during training."""
|
||||
logging.info("Testing periodic candle updates...")
|
||||
|
||||
# Initialize exchange
|
||||
exchange = await initialize_exchange()
|
||||
self.assertIsNotNone(exchange, "Failed to initialize exchange")
|
||||
|
||||
# Create candle cache
|
||||
candle_cache = CandleCache()
|
||||
|
||||
# Initial fetch of candle data
|
||||
candle_data = await fetch_multi_timeframe_data(exchange, "ETH/USDT", candle_cache)
|
||||
self.assertIsNotNone(candle_data, "Failed to fetch initial candle data")
|
||||
self.assertIn('1m', candle_data, "1m candles not found in initial data")
|
||||
|
||||
# Check initial data timestamps
|
||||
initial_1m_candles = candle_data['1m']
|
||||
self.assertGreater(len(initial_1m_candles), 0, "No 1m candles found in initial data")
|
||||
initial_timestamp = initial_1m_candles[-1][0]
|
||||
|
||||
# Wait for update interval to pass
|
||||
logging.info("Waiting for update interval to pass (5 seconds for testing)...")
|
||||
await asyncio.sleep(5) # Short wait for testing
|
||||
|
||||
# Force update by setting last_updated to None
|
||||
candle_cache.last_updated['1m'] = None
|
||||
|
||||
# Fetch updated data
|
||||
updated_data = await fetch_multi_timeframe_data(exchange, "ETH/USDT", candle_cache)
|
||||
self.assertIsNotNone(updated_data, "Failed to fetch updated candle data")
|
||||
|
||||
# Check if data was updated
|
||||
updated_1m_candles = updated_data['1m']
|
||||
self.assertGreater(len(updated_1m_candles), 0, "No 1m candles found in updated data")
|
||||
updated_timestamp = updated_1m_candles[-1][0]
|
||||
|
||||
# In a live scenario, this check should pass with real-time updates
|
||||
# For testing, we just ensure data was fetched
|
||||
logging.info(f"Initial timestamp: {initial_timestamp}, Updated timestamp: {updated_timestamp}")
|
||||
self.assertIsNotNone(updated_timestamp, "Updated timestamp is None")
|
||||
|
||||
# Close exchange connection
|
||||
try:
|
||||
await exchange.close()
|
||||
except AttributeError:
|
||||
# Some exchanges don't have a close method
|
||||
pass
|
||||
logging.info("Periodic update test completed")
|
||||
|
||||
def test_periodic_updates(self):
|
||||
"""Run the async test."""
|
||||
asyncio.run(self.async_test_periodic_updates())
|
||||
|
||||
|
||||
class TestBacktesting(unittest.TestCase):
|
||||
"""Test backtesting on historical data."""
|
||||
|
||||
async def async_test_backtesting(self):
|
||||
"""Test backtesting on a specific time period."""
|
||||
logging.info("Testing backtesting with historical data...")
|
||||
|
||||
# Initialize exchange
|
||||
exchange = await initialize_exchange()
|
||||
self.assertIsNotNone(exchange, "Failed to initialize exchange")
|
||||
|
||||
# Create a timestamp for 24 hours ago
|
||||
now = datetime.datetime.now()
|
||||
yesterday = now - datetime.timedelta(days=1)
|
||||
since_timestamp = int(yesterday.timestamp() * 1000) # Convert to milliseconds
|
||||
|
||||
# Create a backtesting candle cache
|
||||
backtest_cache = BacktestCandles(since_timestamp=since_timestamp)
|
||||
backtest_cache.period_name = "1-day-ago"
|
||||
|
||||
# Fetch historical data
|
||||
candle_data = await backtest_cache.fetch_all_timeframes(exchange, "ETH/USDT")
|
||||
self.assertIsNotNone(candle_data, "Failed to fetch historical candle data")
|
||||
self.assertIn('1m', candle_data, "1m candles not found in historical data")
|
||||
|
||||
# Check historical data timestamps
|
||||
minute_candles = candle_data['1m']
|
||||
self.assertGreater(len(minute_candles), 0, "No minute candles found in historical data")
|
||||
|
||||
# Check if timestamps are within the requested range
|
||||
first_timestamp = minute_candles[0][0]
|
||||
last_timestamp = minute_candles[-1][0]
|
||||
|
||||
logging.info(f"Requested since: {since_timestamp}")
|
||||
logging.info(f"First timestamp in data: {first_timestamp}")
|
||||
logging.info(f"Last timestamp in data: {last_timestamp}")
|
||||
|
||||
# In real tests, this check should compare timestamps precisely
|
||||
# For this test, we just ensure data was fetched
|
||||
self.assertLessEqual(first_timestamp, last_timestamp, "First timestamp should be before last timestamp")
|
||||
|
||||
# Close exchange connection
|
||||
try:
|
||||
await exchange.close()
|
||||
except AttributeError:
|
||||
# Some exchanges don't have a close method
|
||||
pass
|
||||
logging.info("Backtesting fetch test completed")
|
||||
|
||||
def test_backtesting(self):
|
||||
"""Run the async test."""
|
||||
asyncio.run(self.async_test_backtesting())
|
||||
|
||||
|
||||
class TestBacktestingLastSevenDays(unittest.TestCase):
|
||||
"""Test backtesting on the last 7 days of data."""
|
||||
|
||||
async def async_test_seven_days_backtesting(self):
|
||||
"""Test backtesting on the last 7 days."""
|
||||
logging.info("Testing backtesting on the last 7 days...")
|
||||
|
||||
# Initialize exchange
|
||||
exchange = await initialize_exchange()
|
||||
self.assertIsNotNone(exchange, "Failed to initialize exchange")
|
||||
|
||||
# Create environment with small initial balance for testing
|
||||
env = TradingEnvironment(
|
||||
initial_balance=100, # Small balance for testing
|
||||
leverage=10, # Lower leverage for testing
|
||||
window_size=50, # Smaller window for faster testing
|
||||
commission=0.0004 # Standard commission
|
||||
)
|
||||
|
||||
# Create agent
|
||||
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
|
||||
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
|
||||
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE)
|
||||
|
||||
# Initialize empty results dataframe
|
||||
all_results = pd.DataFrame()
|
||||
|
||||
# Run backtesting for the last 7 days, one day at a time
|
||||
now = datetime.datetime.now()
|
||||
|
||||
for day_offset in range(1, 8):
|
||||
# Calculate time period
|
||||
end_day = now - datetime.timedelta(days=day_offset-1)
|
||||
start_day = end_day - datetime.timedelta(days=1)
|
||||
|
||||
# Convert to milliseconds
|
||||
since_timestamp = int(start_day.timestamp() * 1000)
|
||||
until_timestamp = int(end_day.timestamp() * 1000)
|
||||
|
||||
# Period name
|
||||
period_name = f"Day-{day_offset}"
|
||||
|
||||
logging.info(f"Testing backtesting for period: {period_name}")
|
||||
logging.info(f" - From: {start_day.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
logging.info(f" - To: {end_day.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
# Run backtesting with a small number of episodes for testing
|
||||
stats = await train_with_backtesting(
|
||||
agent=agent,
|
||||
env=env,
|
||||
symbol="ETH/USDT",
|
||||
since_timestamp=since_timestamp,
|
||||
until_timestamp=until_timestamp,
|
||||
num_episodes=3, # Use a small number for testing
|
||||
max_steps_per_episode=200, # Use a small number for testing
|
||||
period_name=period_name
|
||||
)
|
||||
|
||||
# Check if stats were returned
|
||||
if stats is None:
|
||||
logging.warning(f"No stats returned for period: {period_name}")
|
||||
continue
|
||||
|
||||
# Create a dataframe from stats
|
||||
if len(stats['episode_rewards']) > 0:
|
||||
df = pd.DataFrame({
|
||||
'Period': [period_name] * len(stats['episode_rewards']),
|
||||
'Episode': list(range(1, len(stats['episode_rewards']) + 1)),
|
||||
'Reward': stats['episode_rewards'],
|
||||
'Balance': stats['balances'],
|
||||
'PnL': stats['episode_pnls'],
|
||||
'Fees': stats['fees'],
|
||||
'Net_PnL': stats['net_pnl_after_fees']
|
||||
})
|
||||
|
||||
# Append to all results
|
||||
all_results = pd.concat([all_results, df], ignore_index=True)
|
||||
|
||||
logging.info(f"Completed backtesting for period: {period_name}")
|
||||
logging.info(f" - Episodes: {len(stats['episode_rewards'])}")
|
||||
logging.info(f" - Final Balance: ${stats['balances'][-1]:.2f}")
|
||||
logging.info(f" - Net PnL: ${stats['net_pnl_after_fees'][-1]:.2f}")
|
||||
else:
|
||||
logging.warning(f"No episodes completed for period: {period_name}")
|
||||
|
||||
# Save all results
|
||||
if not all_results.empty:
|
||||
all_results.to_csv("all_backtest_results.csv", index=False)
|
||||
logging.info("Saved all backtest results to all_backtest_results.csv")
|
||||
|
||||
# Create plot of results
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
# Plot Net PnL by period
|
||||
all_results.groupby('Period')['Net_PnL'].last().plot(kind='bar')
|
||||
plt.title('Net PnL by Training Period (Last Episode)')
|
||||
plt.ylabel('Net PnL ($)')
|
||||
plt.tight_layout()
|
||||
plt.savefig("backtest_results.png")
|
||||
logging.info("Saved backtest results plot to backtest_results.png")
|
||||
|
||||
# Close exchange connection
|
||||
try:
|
||||
await exchange.close()
|
||||
except AttributeError:
|
||||
# Some exchanges don't have a close method
|
||||
pass
|
||||
logging.info("7-day backtesting test completed")
|
||||
|
||||
def test_seven_days_backtesting(self):
|
||||
"""Run the async test."""
|
||||
asyncio.run(self.async_test_seven_days_backtesting())
|
||||
|
||||
|
||||
class TestSingleDayBacktesting(unittest.TestCase):
|
||||
"""Test backtesting on a single day of historical data."""
|
||||
|
||||
async def async_test_single_day_backtesting(self):
|
||||
"""Test backtesting on a single day."""
|
||||
logging.info("Testing backtesting on a single day...")
|
||||
|
||||
# Initialize exchange
|
||||
exchange = await initialize_exchange()
|
||||
self.assertIsNotNone(exchange, "Failed to initialize exchange")
|
||||
|
||||
# Create environment with small initial balance for testing
|
||||
env = TradingEnvironment(
|
||||
initial_balance=100, # Small balance for testing
|
||||
leverage=10, # Lower leverage for testing
|
||||
window_size=50, # Smaller window for faster testing
|
||||
commission=0.0004 # Standard commission
|
||||
)
|
||||
|
||||
# Create agent
|
||||
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
|
||||
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
|
||||
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE)
|
||||
|
||||
# Calculate time period for 1 day ago
|
||||
now = datetime.datetime.now()
|
||||
end_day = now
|
||||
start_day = end_day - datetime.timedelta(days=1)
|
||||
|
||||
# Convert to milliseconds
|
||||
since_timestamp = int(start_day.timestamp() * 1000)
|
||||
until_timestamp = int(end_day.timestamp() * 1000)
|
||||
|
||||
# Period name
|
||||
period_name = "Test-Day-1"
|
||||
|
||||
logging.info(f"Testing backtesting for period: {period_name}")
|
||||
logging.info(f" - From: {start_day.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
logging.info(f" - To: {end_day.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
# Run backtesting with a small number of episodes for testing
|
||||
stats = await train_with_backtesting(
|
||||
agent=agent,
|
||||
env=env,
|
||||
symbol="ETH/USDT",
|
||||
since_timestamp=since_timestamp,
|
||||
until_timestamp=until_timestamp,
|
||||
num_episodes=2, # Very small number for quick testing
|
||||
max_steps_per_episode=100, # Very small number for quick testing
|
||||
period_name=period_name
|
||||
)
|
||||
|
||||
# Check if stats were returned
|
||||
self.assertIsNotNone(stats, "No stats returned from backtesting")
|
||||
|
||||
# Check if episodes were completed
|
||||
self.assertGreater(len(stats['episode_rewards']), 0, "No episodes completed")
|
||||
|
||||
# Log results
|
||||
logging.info(f"Completed backtesting for period: {period_name}")
|
||||
logging.info(f" - Episodes: {len(stats['episode_rewards'])}")
|
||||
logging.info(f" - Final Balance: ${stats['balances'][-1]:.2f}")
|
||||
logging.info(f" - Net PnL: ${stats['net_pnl_after_fees'][-1]:.2f}")
|
||||
|
||||
# Close exchange connection
|
||||
try:
|
||||
await exchange.close()
|
||||
except AttributeError:
|
||||
# Some exchanges don't have a close method
|
||||
pass
|
||||
logging.info("Single day backtesting test completed")
|
||||
|
||||
def test_single_day_backtesting(self):
|
||||
"""Run the async test."""
|
||||
asyncio.run(self.async_test_single_day_backtesting())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user