backtseting support
This commit is contained in:
parent
5e9e6360af
commit
2e7a242ac7
1246
crypto/gogo2/main.py
1246
crypto/gogo2/main.py
File diff suppressed because it is too large
Load Diff
34
crypto/gogo2/run_demo.py
Normal file
34
crypto/gogo2/run_demo.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from main import live_trading, setup_logging
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
setup_logging()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Run a simplified demo trading session with mock data"""
|
||||||
|
logger.info("Starting simplified demo trading session")
|
||||||
|
|
||||||
|
# Run live trading in demo mode with simplified parameters
|
||||||
|
await live_trading(
|
||||||
|
symbol="ETH/USDT",
|
||||||
|
timeframe="1m",
|
||||||
|
model_path="models/trading_agent_best_pnl.pt",
|
||||||
|
demo=True,
|
||||||
|
initial_balance=1000,
|
||||||
|
update_interval=10, # Update every 10 seconds for faster feedback
|
||||||
|
max_position_size=0.1,
|
||||||
|
risk_per_trade=0.02,
|
||||||
|
stop_loss_pct=0.02,
|
||||||
|
take_profit_pct=0.04,
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Demo trading stopped by user")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in demo trading: {e}")
|
@ -1,477 +1,40 @@
|
|||||||
import os
|
#!/usr/bin/env python
|
||||||
import sys
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import logging
|
||||||
import pandas as pd
|
from main import live_trading, setup_logging
|
||||||
import random
|
|
||||||
import datetime
|
|
||||||
import torch
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import io
|
|
||||||
from PIL import Image
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
# Configure logging
|
# Set up logging
|
||||||
logging.basicConfig(
|
setup_logging()
|
||||||
level=logging.INFO,
|
logger = logging.getLogger(__name__)
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
||||||
handlers=[
|
|
||||||
logging.FileHandler("live_trading.log"),
|
|
||||||
logging.StreamHandler()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
logger = logging.getLogger("live_trading")
|
|
||||||
|
|
||||||
def generate_mock_data(symbol, timeframe, limit=1000):
|
async def main():
|
||||||
"""Generate mock OHLCV data for demo mode"""
|
parser = argparse.ArgumentParser(description='Run live trading in demo mode')
|
||||||
logger.info(f"Generating mock data for {symbol} ({timeframe})")
|
|
||||||
|
|
||||||
# Set seed for reproducibility
|
|
||||||
np.random.seed(42)
|
|
||||||
|
|
||||||
# Generate timestamps
|
|
||||||
end_time = datetime.datetime.now()
|
|
||||||
start_time = end_time - datetime.timedelta(minutes=limit)
|
|
||||||
timestamps = [start_time + datetime.timedelta(minutes=i) for i in range(limit)]
|
|
||||||
|
|
||||||
# Convert to milliseconds
|
|
||||||
timestamps_ms = [int(ts.timestamp() * 1000) for ts in timestamps]
|
|
||||||
|
|
||||||
# Generate price data with realistic patterns
|
|
||||||
base_price = 3000.0 # Starting price
|
|
||||||
price_data = []
|
|
||||||
current_price = base_price
|
|
||||||
|
|
||||||
for i in range(limit):
|
|
||||||
# Random walk with momentum and volatility clusters
|
|
||||||
momentum = np.random.normal(0, 1)
|
|
||||||
volatility = 0.5 + 0.5 * np.sin(i / 100) # Cyclical volatility
|
|
||||||
|
|
||||||
# Price change with momentum and volatility
|
|
||||||
price_change = momentum * volatility * current_price * 0.005
|
|
||||||
|
|
||||||
# Add some trends and patterns
|
|
||||||
if i % 200 < 100: # Uptrend for 100 candles, then downtrend
|
|
||||||
price_change += current_price * 0.001
|
|
||||||
else:
|
|
||||||
price_change -= current_price * 0.0008
|
|
||||||
|
|
||||||
# Update current price
|
|
||||||
current_price += price_change
|
|
||||||
|
|
||||||
# Generate OHLCV data
|
|
||||||
open_price = current_price
|
|
||||||
close_price = current_price + np.random.normal(0, 1) * current_price * 0.002
|
|
||||||
high_price = max(open_price, close_price) + abs(np.random.normal(0, 1)) * current_price * 0.003
|
|
||||||
low_price = min(open_price, close_price) - abs(np.random.normal(0, 1)) * current_price * 0.003
|
|
||||||
volume = np.random.gamma(2, 100) * (1 + 0.5 * np.sin(i / 50)) # Cyclical volume
|
|
||||||
|
|
||||||
# Store data
|
|
||||||
price_data.append({
|
|
||||||
'timestamp': timestamps_ms[i],
|
|
||||||
'open': open_price,
|
|
||||||
'high': high_price,
|
|
||||||
'low': low_price,
|
|
||||||
'close': close_price,
|
|
||||||
'volume': volume
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info(f"Generated {len(price_data)} mock candles")
|
|
||||||
return price_data
|
|
||||||
|
|
||||||
async def generate_mock_live_candles(initial_data, symbol, timeframe):
|
|
||||||
"""Generate mock live candles based on initial data"""
|
|
||||||
last_candle = initial_data[-1].copy()
|
|
||||||
last_timestamp = last_candle['timestamp']
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# Wait for next candle
|
|
||||||
await asyncio.sleep(5)
|
|
||||||
|
|
||||||
# Update timestamp
|
|
||||||
if timeframe == '1m':
|
|
||||||
last_timestamp += 60 * 1000 # 1 minute in milliseconds
|
|
||||||
elif timeframe == '5m':
|
|
||||||
last_timestamp += 5 * 60 * 1000
|
|
||||||
elif timeframe == '15m':
|
|
||||||
last_timestamp += 15 * 60 * 1000
|
|
||||||
elif timeframe == '1h':
|
|
||||||
last_timestamp += 60 * 60 * 1000
|
|
||||||
else:
|
|
||||||
last_timestamp += 60 * 1000 # Default to 1 minute
|
|
||||||
|
|
||||||
# Generate new candle
|
|
||||||
last_price = last_candle['close']
|
|
||||||
price_change = np.random.normal(0, 1) * last_price * 0.002
|
|
||||||
|
|
||||||
# Add some persistence
|
|
||||||
if last_candle['close'] > last_candle['open']:
|
|
||||||
# Previous candle was green, more likely to continue up
|
|
||||||
price_change += last_price * 0.0005
|
|
||||||
else:
|
|
||||||
# Previous candle was red, more likely to continue down
|
|
||||||
price_change -= last_price * 0.0005
|
|
||||||
|
|
||||||
# Generate OHLCV data
|
|
||||||
open_price = last_price
|
|
||||||
close_price = last_price + price_change
|
|
||||||
high_price = max(open_price, close_price) + abs(np.random.normal(0, 1)) * last_price * 0.001
|
|
||||||
low_price = min(open_price, close_price) - abs(np.random.normal(0, 1)) * last_price * 0.001
|
|
||||||
volume = np.random.gamma(2, 100)
|
|
||||||
|
|
||||||
# Create new candle
|
|
||||||
new_candle = {
|
|
||||||
'timestamp': last_timestamp,
|
|
||||||
'open': open_price,
|
|
||||||
'high': high_price,
|
|
||||||
'low': low_price,
|
|
||||||
'close': close_price,
|
|
||||||
'volume': volume
|
|
||||||
}
|
|
||||||
|
|
||||||
# Update last candle
|
|
||||||
last_candle = new_candle.copy()
|
|
||||||
|
|
||||||
yield new_candle
|
|
||||||
|
|
||||||
class MockExchange:
|
|
||||||
"""Mock exchange for demo mode"""
|
|
||||||
def __init__(self):
|
|
||||||
self.name = "MockExchange"
|
|
||||||
self.id = "mock"
|
|
||||||
|
|
||||||
async def fetch_ohlcv(self, symbol, timeframe, limit=1000):
|
|
||||||
"""Mock method to fetch OHLCV data"""
|
|
||||||
# Generate mock data
|
|
||||||
mock_data = generate_mock_data(symbol, timeframe, limit)
|
|
||||||
|
|
||||||
# Convert to CCXT format
|
|
||||||
ohlcv = []
|
|
||||||
for candle in mock_data:
|
|
||||||
ohlcv.append([
|
|
||||||
candle['timestamp'],
|
|
||||||
candle['open'],
|
|
||||||
candle['high'],
|
|
||||||
candle['low'],
|
|
||||||
candle['close'],
|
|
||||||
candle['volume']
|
|
||||||
])
|
|
||||||
|
|
||||||
return ohlcv
|
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
"""Mock method to close exchange connection"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_model_info(model_path):
|
|
||||||
"""Extract model architecture information from saved model file"""
|
|
||||||
try:
|
|
||||||
# Load checkpoint with weights_only=False to get all information
|
|
||||||
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
|
||||||
|
|
||||||
# Extract model parameters
|
|
||||||
state_size = checkpoint['policy_net']['fc1.weight'].shape[1]
|
|
||||||
action_size = checkpoint['policy_net']['advantage_stream.bias'].shape[0]
|
|
||||||
hidden_size = checkpoint['policy_net']['fc1.weight'].shape[0]
|
|
||||||
|
|
||||||
# Try to extract LSTM layers and attention heads
|
|
||||||
lstm_layers = 2 # Default
|
|
||||||
attention_heads = 4 # Default
|
|
||||||
|
|
||||||
# Check if these parameters are stored in the checkpoint
|
|
||||||
if 'lstm_layers' in checkpoint:
|
|
||||||
lstm_layers = checkpoint['lstm_layers']
|
|
||||||
if 'attention_heads' in checkpoint:
|
|
||||||
attention_heads = checkpoint['attention_heads']
|
|
||||||
|
|
||||||
logger.info(f"Extracted model architecture: state_size={state_size}, action_size={action_size}, "
|
|
||||||
f"hidden_size={hidden_size}, lstm_layers={lstm_layers}, attention_heads={attention_heads}")
|
|
||||||
|
|
||||||
return state_size, action_size, hidden_size, lstm_layers, attention_heads
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to extract model info: {str(e)}")
|
|
||||||
logger.warning("Using default model architecture")
|
|
||||||
return 40, 3, 384, 2, 4 # Default values
|
|
||||||
|
|
||||||
async def run_live_demo():
|
|
||||||
"""Run the trading bot in live demo mode with enhanced error handling"""
|
|
||||||
parser = argparse.ArgumentParser(description='Live Trading Demo')
|
|
||||||
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
|
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
|
||||||
parser.add_argument('--timeframe', type=str, default='1m', help='Candle timeframe')
|
parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for trading')
|
||||||
parser.add_argument('--model', type=str, default='models/trading_agent_best_pnl.pt', help='Path to model file')
|
parser.add_argument('--model_path', type=str, default='data/best_model.pth', help='Path to the trained model')
|
||||||
parser.add_argument('--mock', action='store_true', help='Use mock data instead of real exchange data')
|
parser.add_argument('--initial_balance', type=float, default=1000, help='Initial balance')
|
||||||
|
parser.add_argument('--update_interval', type=int, default=30, help='Interval to update data in seconds')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
logger.info(f"Starting live trading demo with {args.symbol} on {args.timeframe} timeframe")
|
||||||
# Import main module
|
|
||||||
import main
|
# Run live trading in demo mode
|
||||||
|
await live_trading(
|
||||||
# Load environment variables
|
symbol=args.symbol,
|
||||||
load_dotenv()
|
timeframe=args.timeframe,
|
||||||
|
model_path=args.model_path,
|
||||||
# Create directories if they don't exist
|
demo=True, # Always use demo mode in this script
|
||||||
os.makedirs("trade_logs", exist_ok=True)
|
initial_balance=args.initial_balance,
|
||||||
os.makedirs("runs", exist_ok=True)
|
update_interval=args.update_interval,
|
||||||
|
# Using default values for other parameters
|
||||||
# Check if model file exists
|
)
|
||||||
if not os.path.exists(args.model):
|
|
||||||
logger.error(f"Model file not found: {args.model}")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
logger.info(f"Starting live trading demo for {args.symbol} on {args.timeframe} timeframe")
|
|
||||||
logger.info(f"Using model: {args.model}")
|
|
||||||
|
|
||||||
# Check API keys
|
|
||||||
api_key = os.getenv('MEXC_API_KEY')
|
|
||||||
secret_key = os.getenv('MEXC_SECRET_KEY')
|
|
||||||
use_mock = args.mock or not api_key or api_key == "your_api_key_here"
|
|
||||||
|
|
||||||
if use_mock:
|
|
||||||
logger.info("Using mock data for demo mode (no API keys required)")
|
|
||||||
exchange = MockExchange()
|
|
||||||
else:
|
|
||||||
# Initialize real exchange
|
|
||||||
exchange = await main.initialize_exchange()
|
|
||||||
|
|
||||||
# Initialize environment
|
|
||||||
env = main.TradingEnvironment(
|
|
||||||
initial_balance=float(os.getenv('INITIAL_BALANCE', 1000)),
|
|
||||||
window_size=30,
|
|
||||||
demo=True # Always use demo mode in this script
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fetch initial data
|
|
||||||
if use_mock:
|
|
||||||
# Use mock data
|
|
||||||
mock_data = generate_mock_data(args.symbol, args.timeframe, 1000)
|
|
||||||
env.data = mock_data
|
|
||||||
env._initialize_features()
|
|
||||||
success = True
|
|
||||||
else:
|
|
||||||
# Fetch real data
|
|
||||||
success = await env.fetch_initial_data(
|
|
||||||
exchange,
|
|
||||||
symbol=args.symbol,
|
|
||||||
timeframe=args.timeframe,
|
|
||||||
limit=1000
|
|
||||||
)
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
logger.error("Failed to fetch initial data. Exiting.")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
# Get model architecture from saved model
|
|
||||||
state_size, action_size, hidden_size, lstm_layers, attention_heads = get_model_info(args.model)
|
|
||||||
|
|
||||||
# Initialize agent with the correct architecture
|
|
||||||
agent = main.Agent(
|
|
||||||
state_size=state_size,
|
|
||||||
action_size=action_size,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
lstm_layers=lstm_layers,
|
|
||||||
attention_heads=attention_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load model with weights_only=False to handle numpy types
|
|
||||||
try:
|
|
||||||
# First try with weights_only=True (safer)
|
|
||||||
agent.load(args.model)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to load model with weights_only=True: {str(e)}")
|
|
||||||
|
|
||||||
# Try with safe_globals
|
|
||||||
try:
|
|
||||||
import torch.serialization
|
|
||||||
with torch.serialization.safe_globals(['numpy._core.multiarray.scalar', 'numpy.dtype']):
|
|
||||||
agent.load(args.model)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed with safe_globals: {str(e)}")
|
|
||||||
|
|
||||||
# Last resort: try with weights_only=False
|
|
||||||
try:
|
|
||||||
# Monkey patch the load method temporarily
|
|
||||||
original_load = main.Agent.load
|
|
||||||
|
|
||||||
def patched_load(self, path):
|
|
||||||
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
|
||||||
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
|
||||||
self.target_net.load_state_dict(checkpoint['target_net'])
|
|
||||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
|
||||||
self.epsilon = checkpoint.get('epsilon', 0.05)
|
|
||||||
logger.info(f"Model loaded from {path}")
|
|
||||||
|
|
||||||
# Apply the patch
|
|
||||||
main.Agent.load = patched_load
|
|
||||||
|
|
||||||
# Try loading
|
|
||||||
agent.load(args.model)
|
|
||||||
|
|
||||||
# Restore original method
|
|
||||||
main.Agent.load = original_load
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"All loading attempts failed: {str(e)}")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
logger.info(f"Model loaded successfully")
|
|
||||||
|
|
||||||
# Initialize TensorBoard writer
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
agent.writer = SummaryWriter(f'runs/live_demo_{args.symbol.replace("/", "_")}_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}')
|
|
||||||
|
|
||||||
# Track performance metrics
|
|
||||||
trades_count = 0
|
|
||||||
winning_trades = 0
|
|
||||||
total_profit = 0
|
|
||||||
max_drawdown = 0
|
|
||||||
peak_balance = env.balance
|
|
||||||
step_counter = 0
|
|
||||||
prev_position = 'flat'
|
|
||||||
|
|
||||||
# Create trade log file
|
|
||||||
os.makedirs('trade_logs', exist_ok=True)
|
|
||||||
trade_log_path = f'trade_logs/trades_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}.csv'
|
|
||||||
with open(trade_log_path, 'w') as f:
|
|
||||||
f.write("timestamp,action,price,position_size,balance,pnl\n")
|
|
||||||
|
|
||||||
logger.info(f"Starting live trading simulation...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Set up mock live data generator if using mock data
|
|
||||||
if use_mock:
|
|
||||||
live_candle_generator = generate_mock_live_candles(env.data, args.symbol, args.timeframe)
|
|
||||||
|
|
||||||
# Main trading loop
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
# Get latest candle
|
|
||||||
if use_mock:
|
|
||||||
# Get mock candle
|
|
||||||
candle = await anext(live_candle_generator)
|
|
||||||
else:
|
|
||||||
# Get real candle
|
|
||||||
candle = await main.get_latest_candle(exchange, args.symbol)
|
|
||||||
|
|
||||||
if candle is None:
|
|
||||||
logger.warning("Failed to fetch latest candle, retrying in 5 seconds...")
|
|
||||||
await asyncio.sleep(5)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Add new data to environment
|
|
||||||
env.add_data(candle)
|
|
||||||
|
|
||||||
# Get current state and select action
|
|
||||||
state = env.get_state()
|
|
||||||
action = agent.select_action(state, training=False)
|
|
||||||
|
|
||||||
# Update environment with action (simulated)
|
|
||||||
next_state, reward, done = env.step(action)
|
|
||||||
|
|
||||||
# Create info dictionary (missing in the step function)
|
|
||||||
info = {
|
|
||||||
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
|
|
||||||
'price': env.current_price,
|
|
||||||
'balance': env.balance,
|
|
||||||
'position': env.position,
|
|
||||||
'pnl': env.total_pnl
|
|
||||||
}
|
|
||||||
|
|
||||||
# Log trade if position changed
|
|
||||||
if env.position != prev_position:
|
|
||||||
trades_count += 1
|
|
||||||
if env.last_trade_profit > 0:
|
|
||||||
winning_trades += 1
|
|
||||||
total_profit += env.last_trade_profit
|
|
||||||
|
|
||||||
# Log trade details
|
|
||||||
with open(trade_log_path, 'a') as f:
|
|
||||||
f.write(f"{datetime.datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n")
|
|
||||||
|
|
||||||
logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}")
|
|
||||||
|
|
||||||
# Update performance metrics
|
|
||||||
if env.balance > peak_balance:
|
|
||||||
peak_balance = env.balance
|
|
||||||
current_drawdown = (peak_balance - env.balance) / peak_balance if peak_balance > 0 else 0
|
|
||||||
if current_drawdown > max_drawdown:
|
|
||||||
max_drawdown = current_drawdown
|
|
||||||
|
|
||||||
# Update TensorBoard metrics
|
|
||||||
step_counter += 1
|
|
||||||
agent.writer.add_scalar('Live/Balance', env.balance, step_counter)
|
|
||||||
agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter)
|
|
||||||
agent.writer.add_scalar('Live/Drawdown', current_drawdown * 100, step_counter)
|
|
||||||
|
|
||||||
# Update chart visualization
|
|
||||||
if step_counter % 5 == 0 or env.position != prev_position:
|
|
||||||
agent.add_chart_to_tensorboard(env, step_counter)
|
|
||||||
|
|
||||||
# Log performance summary
|
|
||||||
if trades_count > 0:
|
|
||||||
win_rate = (winning_trades / trades_count) * 100
|
|
||||||
agent.writer.add_scalar('Live/WinRate', win_rate, step_counter)
|
|
||||||
|
|
||||||
performance_text = f"""
|
|
||||||
**Live Trading Performance**
|
|
||||||
Balance: ${env.balance:.2f}
|
|
||||||
Total PnL: ${env.total_pnl:.2f}
|
|
||||||
Trades: {trades_count}
|
|
||||||
Win Rate: {win_rate:.1f}%
|
|
||||||
Max Drawdown: {max_drawdown*100:.1f}%
|
|
||||||
"""
|
|
||||||
agent.writer.add_text('Performance', performance_text, step_counter)
|
|
||||||
|
|
||||||
prev_position = env.position
|
|
||||||
|
|
||||||
# Wait for next candle
|
|
||||||
await asyncio.sleep(1) # Faster updates in demo mode
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in live trading loop: {str(e)}")
|
|
||||||
import traceback
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
logger.info("Continuing after error...")
|
|
||||||
await asyncio.sleep(5)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("Live trading stopped by user")
|
|
||||||
|
|
||||||
# Final performance report
|
|
||||||
if trades_count > 0:
|
|
||||||
win_rate = (winning_trades / trades_count) * 100
|
|
||||||
logger.info(f"Trading session summary:")
|
|
||||||
logger.info(f"Total trades: {trades_count}")
|
|
||||||
logger.info(f"Win rate: {win_rate:.1f}%")
|
|
||||||
logger.info(f"Final balance: ${env.balance:.2f}")
|
|
||||||
logger.info(f"Total profit: ${total_profit:.2f}")
|
|
||||||
logger.info(f"Maximum drawdown: {max_drawdown*100:.1f}%")
|
|
||||||
logger.info(f"Trade log saved to: {trade_log_path}")
|
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("Live trading stopped by user")
|
|
||||||
return 0
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in live trading: {str(e)}")
|
|
||||||
import traceback
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return 1
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Set environment variable to indicate we're in demo mode
|
try:
|
||||||
os.environ['DEMO_MODE'] = 'true'
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
# Print banner
|
logger.info("Live trading demo stopped by user")
|
||||||
print("\n" + "="*60)
|
except Exception as e:
|
||||||
print("🤖 TRADING BOT - LIVE DEMO MODE 🤖")
|
logger.error(f"Error in live trading demo: {e}")
|
||||||
print("="*60)
|
|
||||||
print("This is a DEMO mode with simulated trading (no real trades)")
|
|
||||||
print("Press Ctrl+C to stop the bot at any time")
|
|
||||||
print("="*60 + "\n")
|
|
||||||
|
|
||||||
# Run the async main function
|
|
||||||
exit_code = asyncio.run(run_live_demo())
|
|
||||||
sys.exit(exit_code)
|
|
77
crypto/gogo2/run_tests.py
Normal file
77
crypto/gogo2/run_tests.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Run unit tests for the trading bot.
|
||||||
|
|
||||||
|
This script runs the unit tests defined in tests.py and displays the results.
|
||||||
|
It can run a single test or all tests.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python run_tests.py [test_name]
|
||||||
|
|
||||||
|
If test_name is provided, only that test will be run.
|
||||||
|
Otherwise, all tests will be run.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
python run_tests.py TestPeriodicUpdates
|
||||||
|
python run_tests.py TestBacktesting
|
||||||
|
python run_tests.py TestBacktestingLastSevenDays
|
||||||
|
python run_tests.py TestSingleDayBacktesting
|
||||||
|
python run_tests.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
import logging
|
||||||
|
from tests import (
|
||||||
|
TestPeriodicUpdates,
|
||||||
|
TestBacktesting,
|
||||||
|
TestBacktestingLastSevenDays,
|
||||||
|
TestSingleDayBacktesting
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[logging.StreamHandler()])
|
||||||
|
|
||||||
|
# Get the test name from the command line
|
||||||
|
test_name = sys.argv[1] if len(sys.argv) > 1 else None
|
||||||
|
|
||||||
|
# Run the specified test or all tests
|
||||||
|
if test_name:
|
||||||
|
logging.info(f"Running test: {test_name}")
|
||||||
|
if test_name == "TestPeriodicUpdates":
|
||||||
|
suite = unittest.TestLoader().loadTestsFromTestCase(TestPeriodicUpdates)
|
||||||
|
elif test_name == "TestBacktesting":
|
||||||
|
suite = unittest.TestLoader().loadTestsFromTestCase(TestBacktesting)
|
||||||
|
elif test_name == "TestBacktestingLastSevenDays":
|
||||||
|
suite = unittest.TestLoader().loadTestsFromTestCase(TestBacktestingLastSevenDays)
|
||||||
|
elif test_name == "TestSingleDayBacktesting":
|
||||||
|
suite = unittest.TestLoader().loadTestsFromTestCase(TestSingleDayBacktesting)
|
||||||
|
else:
|
||||||
|
logging.error(f"Unknown test: {test_name}")
|
||||||
|
logging.info("Available tests: TestPeriodicUpdates, TestBacktesting, TestBacktestingLastSevenDays, TestSingleDayBacktesting")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
# Run all tests
|
||||||
|
logging.info("Running all tests")
|
||||||
|
suite = unittest.TestSuite()
|
||||||
|
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestPeriodicUpdates))
|
||||||
|
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestBacktesting))
|
||||||
|
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestBacktestingLastSevenDays))
|
||||||
|
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestSingleDayBacktesting))
|
||||||
|
|
||||||
|
# Run the tests
|
||||||
|
runner = unittest.TextTestRunner(verbosity=2)
|
||||||
|
result = runner.run(suite)
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
print("\nTest Summary:")
|
||||||
|
print(f" Ran {result.testsRun} tests")
|
||||||
|
print(f" Errors: {len(result.errors)}")
|
||||||
|
print(f" Failures: {len(result.failures)}")
|
||||||
|
print(f" Skipped: {len(result.skipped)}")
|
||||||
|
|
||||||
|
# Exit with non-zero status if any tests failed
|
||||||
|
sys.exit(len(result.errors) + len(result.failures))
|
337
crypto/gogo2/tests.py
Normal file
337
crypto/gogo2/tests.py
Normal file
@ -0,0 +1,337 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the trading bot.
|
||||||
|
This file contains tests for various components of the trading bot, including:
|
||||||
|
1. Periodic candle updates
|
||||||
|
2. Backtesting on historical data
|
||||||
|
3. Training on the last 7 days of data
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import datetime
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[logging.StreamHandler()])
|
||||||
|
|
||||||
|
# Import functionality from main.py
|
||||||
|
import main
|
||||||
|
from main import (
|
||||||
|
CandleCache, BacktestCandles, initialize_exchange,
|
||||||
|
TradingEnvironment, Agent, train_with_backtesting,
|
||||||
|
fetch_multi_timeframe_data, train_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
class TestPeriodicUpdates(unittest.TestCase):
|
||||||
|
"""Test that candle data is periodically updated during training."""
|
||||||
|
|
||||||
|
async def async_test_periodic_updates(self):
|
||||||
|
"""Test that candle data is periodically updated during training."""
|
||||||
|
logging.info("Testing periodic candle updates...")
|
||||||
|
|
||||||
|
# Initialize exchange
|
||||||
|
exchange = await initialize_exchange()
|
||||||
|
self.assertIsNotNone(exchange, "Failed to initialize exchange")
|
||||||
|
|
||||||
|
# Create candle cache
|
||||||
|
candle_cache = CandleCache()
|
||||||
|
|
||||||
|
# Initial fetch of candle data
|
||||||
|
candle_data = await fetch_multi_timeframe_data(exchange, "ETH/USDT", candle_cache)
|
||||||
|
self.assertIsNotNone(candle_data, "Failed to fetch initial candle data")
|
||||||
|
self.assertIn('1m', candle_data, "1m candles not found in initial data")
|
||||||
|
|
||||||
|
# Check initial data timestamps
|
||||||
|
initial_1m_candles = candle_data['1m']
|
||||||
|
self.assertGreater(len(initial_1m_candles), 0, "No 1m candles found in initial data")
|
||||||
|
initial_timestamp = initial_1m_candles[-1][0]
|
||||||
|
|
||||||
|
# Wait for update interval to pass
|
||||||
|
logging.info("Waiting for update interval to pass (5 seconds for testing)...")
|
||||||
|
await asyncio.sleep(5) # Short wait for testing
|
||||||
|
|
||||||
|
# Force update by setting last_updated to None
|
||||||
|
candle_cache.last_updated['1m'] = None
|
||||||
|
|
||||||
|
# Fetch updated data
|
||||||
|
updated_data = await fetch_multi_timeframe_data(exchange, "ETH/USDT", candle_cache)
|
||||||
|
self.assertIsNotNone(updated_data, "Failed to fetch updated candle data")
|
||||||
|
|
||||||
|
# Check if data was updated
|
||||||
|
updated_1m_candles = updated_data['1m']
|
||||||
|
self.assertGreater(len(updated_1m_candles), 0, "No 1m candles found in updated data")
|
||||||
|
updated_timestamp = updated_1m_candles[-1][0]
|
||||||
|
|
||||||
|
# In a live scenario, this check should pass with real-time updates
|
||||||
|
# For testing, we just ensure data was fetched
|
||||||
|
logging.info(f"Initial timestamp: {initial_timestamp}, Updated timestamp: {updated_timestamp}")
|
||||||
|
self.assertIsNotNone(updated_timestamp, "Updated timestamp is None")
|
||||||
|
|
||||||
|
# Close exchange connection
|
||||||
|
try:
|
||||||
|
await exchange.close()
|
||||||
|
except AttributeError:
|
||||||
|
# Some exchanges don't have a close method
|
||||||
|
pass
|
||||||
|
logging.info("Periodic update test completed")
|
||||||
|
|
||||||
|
def test_periodic_updates(self):
|
||||||
|
"""Run the async test."""
|
||||||
|
asyncio.run(self.async_test_periodic_updates())
|
||||||
|
|
||||||
|
|
||||||
|
class TestBacktesting(unittest.TestCase):
|
||||||
|
"""Test backtesting on historical data."""
|
||||||
|
|
||||||
|
async def async_test_backtesting(self):
|
||||||
|
"""Test backtesting on a specific time period."""
|
||||||
|
logging.info("Testing backtesting with historical data...")
|
||||||
|
|
||||||
|
# Initialize exchange
|
||||||
|
exchange = await initialize_exchange()
|
||||||
|
self.assertIsNotNone(exchange, "Failed to initialize exchange")
|
||||||
|
|
||||||
|
# Create a timestamp for 24 hours ago
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
yesterday = now - datetime.timedelta(days=1)
|
||||||
|
since_timestamp = int(yesterday.timestamp() * 1000) # Convert to milliseconds
|
||||||
|
|
||||||
|
# Create a backtesting candle cache
|
||||||
|
backtest_cache = BacktestCandles(since_timestamp=since_timestamp)
|
||||||
|
backtest_cache.period_name = "1-day-ago"
|
||||||
|
|
||||||
|
# Fetch historical data
|
||||||
|
candle_data = await backtest_cache.fetch_all_timeframes(exchange, "ETH/USDT")
|
||||||
|
self.assertIsNotNone(candle_data, "Failed to fetch historical candle data")
|
||||||
|
self.assertIn('1m', candle_data, "1m candles not found in historical data")
|
||||||
|
|
||||||
|
# Check historical data timestamps
|
||||||
|
minute_candles = candle_data['1m']
|
||||||
|
self.assertGreater(len(minute_candles), 0, "No minute candles found in historical data")
|
||||||
|
|
||||||
|
# Check if timestamps are within the requested range
|
||||||
|
first_timestamp = minute_candles[0][0]
|
||||||
|
last_timestamp = minute_candles[-1][0]
|
||||||
|
|
||||||
|
logging.info(f"Requested since: {since_timestamp}")
|
||||||
|
logging.info(f"First timestamp in data: {first_timestamp}")
|
||||||
|
logging.info(f"Last timestamp in data: {last_timestamp}")
|
||||||
|
|
||||||
|
# In real tests, this check should compare timestamps precisely
|
||||||
|
# For this test, we just ensure data was fetched
|
||||||
|
self.assertLessEqual(first_timestamp, last_timestamp, "First timestamp should be before last timestamp")
|
||||||
|
|
||||||
|
# Close exchange connection
|
||||||
|
try:
|
||||||
|
await exchange.close()
|
||||||
|
except AttributeError:
|
||||||
|
# Some exchanges don't have a close method
|
||||||
|
pass
|
||||||
|
logging.info("Backtesting fetch test completed")
|
||||||
|
|
||||||
|
def test_backtesting(self):
|
||||||
|
"""Run the async test."""
|
||||||
|
asyncio.run(self.async_test_backtesting())
|
||||||
|
|
||||||
|
|
||||||
|
class TestBacktestingLastSevenDays(unittest.TestCase):
|
||||||
|
"""Test backtesting on the last 7 days of data."""
|
||||||
|
|
||||||
|
async def async_test_seven_days_backtesting(self):
|
||||||
|
"""Test backtesting on the last 7 days."""
|
||||||
|
logging.info("Testing backtesting on the last 7 days...")
|
||||||
|
|
||||||
|
# Initialize exchange
|
||||||
|
exchange = await initialize_exchange()
|
||||||
|
self.assertIsNotNone(exchange, "Failed to initialize exchange")
|
||||||
|
|
||||||
|
# Create environment with small initial balance for testing
|
||||||
|
env = TradingEnvironment(
|
||||||
|
initial_balance=100, # Small balance for testing
|
||||||
|
leverage=10, # Lower leverage for testing
|
||||||
|
window_size=50, # Smaller window for faster testing
|
||||||
|
commission=0.0004 # Standard commission
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
|
||||||
|
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
|
||||||
|
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE)
|
||||||
|
|
||||||
|
# Initialize empty results dataframe
|
||||||
|
all_results = pd.DataFrame()
|
||||||
|
|
||||||
|
# Run backtesting for the last 7 days, one day at a time
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
|
||||||
|
for day_offset in range(1, 8):
|
||||||
|
# Calculate time period
|
||||||
|
end_day = now - datetime.timedelta(days=day_offset-1)
|
||||||
|
start_day = end_day - datetime.timedelta(days=1)
|
||||||
|
|
||||||
|
# Convert to milliseconds
|
||||||
|
since_timestamp = int(start_day.timestamp() * 1000)
|
||||||
|
until_timestamp = int(end_day.timestamp() * 1000)
|
||||||
|
|
||||||
|
# Period name
|
||||||
|
period_name = f"Day-{day_offset}"
|
||||||
|
|
||||||
|
logging.info(f"Testing backtesting for period: {period_name}")
|
||||||
|
logging.info(f" - From: {start_day.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
logging.info(f" - To: {end_day.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
|
||||||
|
# Run backtesting with a small number of episodes for testing
|
||||||
|
stats = await train_with_backtesting(
|
||||||
|
agent=agent,
|
||||||
|
env=env,
|
||||||
|
symbol="ETH/USDT",
|
||||||
|
since_timestamp=since_timestamp,
|
||||||
|
until_timestamp=until_timestamp,
|
||||||
|
num_episodes=3, # Use a small number for testing
|
||||||
|
max_steps_per_episode=200, # Use a small number for testing
|
||||||
|
period_name=period_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if stats were returned
|
||||||
|
if stats is None:
|
||||||
|
logging.warning(f"No stats returned for period: {period_name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create a dataframe from stats
|
||||||
|
if len(stats['episode_rewards']) > 0:
|
||||||
|
df = pd.DataFrame({
|
||||||
|
'Period': [period_name] * len(stats['episode_rewards']),
|
||||||
|
'Episode': list(range(1, len(stats['episode_rewards']) + 1)),
|
||||||
|
'Reward': stats['episode_rewards'],
|
||||||
|
'Balance': stats['balances'],
|
||||||
|
'PnL': stats['episode_pnls'],
|
||||||
|
'Fees': stats['fees'],
|
||||||
|
'Net_PnL': stats['net_pnl_after_fees']
|
||||||
|
})
|
||||||
|
|
||||||
|
# Append to all results
|
||||||
|
all_results = pd.concat([all_results, df], ignore_index=True)
|
||||||
|
|
||||||
|
logging.info(f"Completed backtesting for period: {period_name}")
|
||||||
|
logging.info(f" - Episodes: {len(stats['episode_rewards'])}")
|
||||||
|
logging.info(f" - Final Balance: ${stats['balances'][-1]:.2f}")
|
||||||
|
logging.info(f" - Net PnL: ${stats['net_pnl_after_fees'][-1]:.2f}")
|
||||||
|
else:
|
||||||
|
logging.warning(f"No episodes completed for period: {period_name}")
|
||||||
|
|
||||||
|
# Save all results
|
||||||
|
if not all_results.empty:
|
||||||
|
all_results.to_csv("all_backtest_results.csv", index=False)
|
||||||
|
logging.info("Saved all backtest results to all_backtest_results.csv")
|
||||||
|
|
||||||
|
# Create plot of results
|
||||||
|
plt.figure(figsize=(12, 8))
|
||||||
|
|
||||||
|
# Plot Net PnL by period
|
||||||
|
all_results.groupby('Period')['Net_PnL'].last().plot(kind='bar')
|
||||||
|
plt.title('Net PnL by Training Period (Last Episode)')
|
||||||
|
plt.ylabel('Net PnL ($)')
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig("backtest_results.png")
|
||||||
|
logging.info("Saved backtest results plot to backtest_results.png")
|
||||||
|
|
||||||
|
# Close exchange connection
|
||||||
|
try:
|
||||||
|
await exchange.close()
|
||||||
|
except AttributeError:
|
||||||
|
# Some exchanges don't have a close method
|
||||||
|
pass
|
||||||
|
logging.info("7-day backtesting test completed")
|
||||||
|
|
||||||
|
def test_seven_days_backtesting(self):
|
||||||
|
"""Run the async test."""
|
||||||
|
asyncio.run(self.async_test_seven_days_backtesting())
|
||||||
|
|
||||||
|
|
||||||
|
class TestSingleDayBacktesting(unittest.TestCase):
|
||||||
|
"""Test backtesting on a single day of historical data."""
|
||||||
|
|
||||||
|
async def async_test_single_day_backtesting(self):
|
||||||
|
"""Test backtesting on a single day."""
|
||||||
|
logging.info("Testing backtesting on a single day...")
|
||||||
|
|
||||||
|
# Initialize exchange
|
||||||
|
exchange = await initialize_exchange()
|
||||||
|
self.assertIsNotNone(exchange, "Failed to initialize exchange")
|
||||||
|
|
||||||
|
# Create environment with small initial balance for testing
|
||||||
|
env = TradingEnvironment(
|
||||||
|
initial_balance=100, # Small balance for testing
|
||||||
|
leverage=10, # Lower leverage for testing
|
||||||
|
window_size=50, # Smaller window for faster testing
|
||||||
|
commission=0.0004 # Standard commission
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
|
||||||
|
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
|
||||||
|
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE)
|
||||||
|
|
||||||
|
# Calculate time period for 1 day ago
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
end_day = now
|
||||||
|
start_day = end_day - datetime.timedelta(days=1)
|
||||||
|
|
||||||
|
# Convert to milliseconds
|
||||||
|
since_timestamp = int(start_day.timestamp() * 1000)
|
||||||
|
until_timestamp = int(end_day.timestamp() * 1000)
|
||||||
|
|
||||||
|
# Period name
|
||||||
|
period_name = "Test-Day-1"
|
||||||
|
|
||||||
|
logging.info(f"Testing backtesting for period: {period_name}")
|
||||||
|
logging.info(f" - From: {start_day.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
logging.info(f" - To: {end_day.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
|
||||||
|
# Run backtesting with a small number of episodes for testing
|
||||||
|
stats = await train_with_backtesting(
|
||||||
|
agent=agent,
|
||||||
|
env=env,
|
||||||
|
symbol="ETH/USDT",
|
||||||
|
since_timestamp=since_timestamp,
|
||||||
|
until_timestamp=until_timestamp,
|
||||||
|
num_episodes=2, # Very small number for quick testing
|
||||||
|
max_steps_per_episode=100, # Very small number for quick testing
|
||||||
|
period_name=period_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if stats were returned
|
||||||
|
self.assertIsNotNone(stats, "No stats returned from backtesting")
|
||||||
|
|
||||||
|
# Check if episodes were completed
|
||||||
|
self.assertGreater(len(stats['episode_rewards']), 0, "No episodes completed")
|
||||||
|
|
||||||
|
# Log results
|
||||||
|
logging.info(f"Completed backtesting for period: {period_name}")
|
||||||
|
logging.info(f" - Episodes: {len(stats['episode_rewards'])}")
|
||||||
|
logging.info(f" - Final Balance: ${stats['balances'][-1]:.2f}")
|
||||||
|
logging.info(f" - Net PnL: ${stats['net_pnl_after_fees'][-1]:.2f}")
|
||||||
|
|
||||||
|
# Close exchange connection
|
||||||
|
try:
|
||||||
|
await exchange.close()
|
||||||
|
except AttributeError:
|
||||||
|
# Some exchanges don't have a close method
|
||||||
|
pass
|
||||||
|
logging.info("Single day backtesting test completed")
|
||||||
|
|
||||||
|
def test_single_day_backtesting(self):
|
||||||
|
"""Run the async test."""
|
||||||
|
asyncio.run(self.async_test_single_day_backtesting())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user