gogo2/crypto/gogo/main.py
Dobromir Popov c8b0f77d32 suggestions
2025-02-12 01:38:05 +02:00

303 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import torch
import torch.nn as nn
import torch.optim as optim
from data.live_data import LiveDataManager
from model.transformer import Transformer
from training.train import train
from data.data_utils import preprocess_data # Import preprocess_data
import ccxt.async_support as ccxt
import time
import os
import numpy as np
import matplotlib.pyplot as plt
from model.trading_model import TradingModel
from training.rl_agent import ContinuousRLAgent, ReplayBuffer
from training.train_historical import train_on_historical_data, load_best_checkpoint, save_candles_cache, CACHE_FILE, BEST_DIR
from data.data_utils import get_aligned_candle_with_index, get_features_for_tf
import argparse
async def main_training():
symbol = 'BTC/USDT'
data_manager = LiveDataManager(symbol)
# Model parameters (adjust for ~1B parameters)
input_dim = 6 + len([5, 10, 20, 60, 120, 200]) # OHLCV + EMAs
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
dropout = 0.1
model = Transformer(input_dim, d_model, num_heads, num_layers, d_ff, dropout)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
# Define loss functions
criterion_candle = nn.MSELoss()
criterion_volume = nn.MSELoss() # Consider a different loss for volume if needed
criterion_ticks = nn.MSELoss()
# Check for CUDA availability and set device
if torch.cuda.is_available():
device = torch.device('cuda')
print("Using CUDA")
else:
device = torch.device('cpu')
print("Using CPU")
try:
await train(model, data_manager, optimizer, criterion_candle, criterion_volume, criterion_ticks, num_epochs=10, device=device)
except KeyboardInterrupt:
print("Training stopped manually.")
finally:
await data_manager.close()
# -------------------------------------
# Main Asynchronous Function for Training & Charting
# -------------------------------------
async def main_backtest():
symbol = 'BTC/USDT'
# Define timeframes: we'll use 5 different ones.
timeframes = ["1m", "5m", "15m", "1h", "1d"]
now = int(time.time() * 1000)
# Use the base timeframe period of 1500 candles. For 1m, that is 1500 minutes.
period_ms = 1500 * 60 * 1000
since = now - period_ms
end_time = now
# Initialize exchange using MEXC (or your preferred exchange).
mexc_api_key = os.environ.get('MEXC_API_KEY', 'YOUR_API_KEY')
mexc_api_secret = os.environ.get('MEXC_API_SECRET', 'YOUR_SECRET_KEY')
exchange = ccxt.mexc({
'apiKey': mexc_api_key,
'secret': mexc_api_secret,
'enableRateLimit': True,
})
candles_dict = {}
for tf in timeframes:
print(f"Fetching historical data for timeframe {tf}...")
candles = await fetch_historical_data(exchange, symbol, tf, since, end_time, batch_size=500)
candles_dict[tf] = candles
# Optionally, save the multi-timeframe cache.
save_candles_cache(CACHE_FILE, candles_dict)
# Create the backtest environment using multi-timeframe data.
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes)
# Neural Network dimensions: each timeframe produces 7 features.
input_dim = len(timeframes) * 7 # 7 features * 5 timeframes = 35.
hidden_dim = 128
output_dim = 3 # Actions: SELL, HOLD, BUY.
model = TradingModel(input_dim, hidden_dim, output_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
replay_buffer = ReplayBuffer(capacity=10000)
rl_agent = ContinuousRLAgent(model, optimizer, replay_buffer, batch_size=32, gamma=0.99)
# Load best checkpoint if available.
load_best_checkpoint(model, BEST_DIR)
# Train the agent over the historical period.
num_epochs = 10 # Adjust as needed.
train_on_historical_data(env, rl_agent, num_epochs=num_epochs, epsilon=0.1)
# Run a final simulation (without exploration) to record trade history.
state = env.reset(clear_trade_history=True)
done = False
cumulative_reward = 0.0
while not done:
action = rl_agent.act(state, epsilon=0.0)
state, reward, next_state, done = env.step(action)
cumulative_reward += reward
state = next_state
print("Final simulation cumulative profit:", cumulative_reward)
# Evaluate trade performance.
trades = env.trade_history
num_trades = len(trades)
num_wins = sum(1 for trade in trades if trade["pnl"] > 0)
win_rate = (num_wins / num_trades * 100) if num_trades > 0 else 0.0
total_profit = sum(trade["pnl"] for trade in trades)
print(f"Total trades: {num_trades}, Wins: {num_wins}, Win rate: {win_rate:.2f}%, Total Profit: {total_profit:.4f}")
# Plot chart with buy/sell markers on the base timeframe ("1m").
plot_trade_history(candles_dict["1m"], trades)
await exchange.close()
# -------------------------------------
# Historical Data Fetching Function (for a given timeframe)
# -------------------------------------
async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500):
candles = []
since_ms = since
while True:
try:
batch = await exchange.fetch_ohlcv(symbol, timeframe=timeframe, since=since_ms, limit=batch_size)
except Exception as e:
print(f"Error fetching historical data for {timeframe}:", e)
break
if not batch:
break
for c in batch:
candle_dict = {
'timestamp': c[0],
'open': c[1],
'high': c[2],
'low': c[3],
'close': c[4],
'volume': c[5]
}
candles.append(candle_dict)
last_timestamp = batch[-1][0]
if last_timestamp >= end_time:
break
since_ms = last_timestamp + 1
print(f"Fetched {len(candles)} candles for timeframe {timeframe}.")
return candles
# -------------------------------------
# Backtest Environment with Multi-Timeframe State
# -------------------------------------
class BacktestEnvironment:
def __init__(self, candles_dict, base_tf="1m", timeframes=None):
self.candles_dict = candles_dict # dict of timeframe: candles_list
self.base_tf = base_tf
if timeframes is None:
self.timeframes = [base_tf] # fallback to single timeframe
else:
self.timeframes = timeframes
self.trade_history = [] # record of closed trades
self.current_index = 0 # index on base_tf candles
self.position = None # active position record
def reset(self, clear_trade_history=True):
self.current_index = 0
self.position = None
if clear_trade_history:
self.trade_history = []
return self.get_state(self.current_index)
def get_state(self, index):
"""Construct the state as the concatenated features of all timeframes.
For each timeframe, find the aligned candle for the base timeframes timestamp."""
state_features = []
base_candle = self.candles_dict[self.base_tf][index]
base_ts = base_candle["timestamp"]
for tf in self.timeframes:
candles_list = self.candles_dict[tf]
# Get the candle from this timeframe that is closest to (and <=) base_ts.
aligned_index, _ = get_aligned_candle_with_index(candles_list, base_ts)
features = get_features_for_tf(candles_list, aligned_index, period=10)
state_features.extend(features)
return np.array(state_features, dtype=np.float32)
def step(self, action):
"""
Simulate a trading step based on the base timeframe.
- If not in a position and action is BUY (2), record entry at next candle's open.
- If in a position and action is SELL (0), record exit at next candle's open, computing PnL.
Returns: (current_state, reward, next_state, done)
"""
base_candles = self.candles_dict[self.base_tf]
if self.current_index >= len(base_candles) - 1:
return self.get_state(self.current_index), 0.0, None, True
current_state = self.get_state(self.current_index)
next_index = self.current_index + 1
next_state = self.get_state(next_index)
current_candle = base_candles[self.current_index]
next_candle = base_candles[next_index]
reward = 0.0
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
if self.position is None:
if action == 2: # BUY signal: enter position at next candle's open.
entry_price = next_candle["open"]
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
else:
if action == 0: # SELL signal: close position at next candle's open.
exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"]
trade = {
"entry_index": self.position["entry_index"],
"entry_price": self.position["entry_price"],
"exit_index": next_index,
"exit_price": exit_price,
"pnl": reward
}
self.trade_history.append(trade)
self.position = None
self.current_index = next_index
done = (self.current_index >= len(base_candles) - 1)
return current_state, reward, next_state, done
# -------------------------------------
# Chart Plotting: Trade History & PnL
# -------------------------------------
def plot_trade_history(candles, trade_history):
close_prices = [candle["close"] for candle in candles]
x = list(range(len(close_prices)))
plt.figure(figsize=(12, 6))
plt.plot(x, close_prices, label="Close Price", color="black", linewidth=1)
# Use these flags so that the label "BUY" or "SELL" is only shown once in the legend.
buy_label_added = False
sell_label_added = False
for trade in trade_history:
in_idx = trade["entry_index"]
out_idx = trade["exit_index"]
in_price = trade["entry_price"]
out_price = trade["exit_price"]
pnl = trade["pnl"]
# Plot BUY marker ("IN")
if not buy_label_added:
plt.plot(in_idx, in_price, marker="^", color="green", markersize=10, label="BUY (IN)")
buy_label_added = True
else:
plt.plot(in_idx, in_price, marker="^", color="green", markersize=10)
plt.text(in_idx, in_price, " IN", color="green", fontsize=8, verticalalignment="bottom")
# Plot SELL marker ("OUT")
if not sell_label_added:
plt.plot(out_idx, out_price, marker="v", color="red", markersize=10, label="SELL (OUT)")
sell_label_added = True
else:
plt.plot(out_idx, out_price, marker="v", color="red", markersize=10)
plt.text(out_idx, out_price, " OUT", color="red", fontsize=8, verticalalignment="top")
# Annotate the PnL near the SELL marker.
plt.text(out_idx, out_price, f" {pnl:+.2f}", color="blue", fontsize=8, verticalalignment="bottom")
# Choose line color based on profitability.
if pnl > 0:
line_color = "green"
elif pnl < 0:
line_color = "red"
else:
line_color = "gray"
# Draw a dotted line between the buy and sell points.
plt.plot([in_idx, out_idx], [in_price, out_price], linestyle="dotted", color=line_color)
plt.title("Trade History with PnL")
plt.xlabel("Base Candle Index (1m)")
plt.ylabel("Price")
plt.legend()
plt.grid(True)
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Trading Bot Modes')
parser.add_argument('--mode', type=str, default='backtest', choices=['train', 'backtest'], help='Choose mode: train or backtest')
args = parser.parse_args()
if args.mode == 'train':
asyncio.run(main_training())
elif args.mode == 'backtest':
asyncio.run(main_backtest())