gogo2/crypto/brian/index.py
Dobromir Popov 6afd370023 misc
2025-02-04 13:42:10 +02:00

488 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import sys
import asyncio
if sys.platform == 'win32':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
from dotenv import load_dotenv
import os
import time
import json
import ccxt.async_support as ccxt
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
from datetime import datetime
import matplotlib.pyplot as plt
# --- Directories for saving models ---
LAST_DIR = os.path.join("models", "last")
BEST_DIR = os.path.join("models", "best")
os.makedirs(LAST_DIR, exist_ok=True)
os.makedirs(BEST_DIR, exist_ok=True)
CACHE_FILE = "candles_cache.json"
# -------------------------------------
# Utility functions for caching candles to file
# -------------------------------------
def load_candles_cache(filename):
if os.path.exists(filename):
try:
with open(filename, "r") as f:
data = json.load(f)
print(f"Loaded cached data from {filename}.")
return data
except Exception as e:
print("Error reading cache file:", e)
return {}
def save_candles_cache(filename, candles_dict):
try:
with open(filename, "w") as f:
json.dump(candles_dict, f)
except Exception as e:
print("Error saving cache file:", e)
# -------------------------------------
# Checkpoint Functions (same as before)
# -------------------------------------
def maintain_checkpoint_directory(directory, max_files=10):
files = os.listdir(directory)
if len(files) > max_files:
full_paths = [os.path.join(directory, f) for f in files]
full_paths.sort(key=lambda x: os.path.getmtime(x))
for f in full_paths[: len(files) - max_files]:
os.remove(f)
def get_best_models(directory):
best_files = []
for file in os.listdir(directory):
parts = file.split("_")
try:
r = float(parts[1])
best_files.append((r, file))
except Exception:
continue
return best_files
def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"model_state_dict": model.state_dict()
}, last_path)
maintain_checkpoint_directory(last_dir, max_files=10)
best_models = get_best_models(best_dir)
add_to_best = False
if len(best_models) < 10:
add_to_best = True
else:
min_reward, min_file = min(best_models, key=lambda x: x[0])
if reward > min_reward:
add_to_best = True
os.remove(os.path.join(best_dir, min_file))
if add_to_best:
best_filename = f"best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"
best_path = os.path.join(best_dir, best_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"model_state_dict": model.state_dict()
}, best_path)
maintain_checkpoint_directory(best_dir, max_files=10)
print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}")
def load_best_checkpoint(model, best_dir=BEST_DIR):
best_models = get_best_models(best_dir)
if not best_models:
return None
best_reward, best_file = max(best_models, key=lambda x: x[0])
path = os.path.join(best_dir, best_file)
print(f"Loading best model from checkpoint: {best_file} with reward {best_reward:.4f}")
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
return checkpoint
# -------------------------------------
# Technical Indicator Helper Functions
# -------------------------------------
def compute_sma(candles_list, index, period=10):
start = max(0, index - period + 1)
values = [candle["close"] for candle in candles_list[start:index+1]]
return sum(values) / len(values) if values else 0.0
def compute_sma_volume(candles_list, index, period=10):
start = max(0, index - period + 1)
values = [candle["volume"] for candle in candles_list[start:index+1]]
return sum(values) / len(values) if values else 0.0
def get_aligned_candle_with_index(candles_list, target_ts):
"""Find the candle in the list whose timestamp is the largest that is <= target_ts."""
best_idx = 0
for i, candle in enumerate(candles_list):
if candle["timestamp"] <= target_ts:
best_idx = i
else:
break
return best_idx, candles_list[best_idx]
def get_features_for_tf(candles_list, index, period=10):
"""Return a vector of 7 features: open, high, low, close, volume, sma_close, sma_volume."""
candle = candles_list[index]
f_open = candle["open"]
f_high = candle["high"]
f_low = candle["low"]
f_close = candle["close"]
f_volume = candle["volume"]
sma_close = compute_sma(candles_list, index, period)
sma_volume = compute_sma_volume(candles_list, index, period)
return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume]
# -------------------------------------
# Neural Network Architecture Definition
# -------------------------------------
class TradingModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(TradingModel, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.net(x)
# -------------------------------------
# Replay Buffer for Experience Storage
# -------------------------------------
class ReplayBuffer:
def __init__(self, capacity=10000):
self.buffer = deque(maxlen=capacity)
def add(self, experience):
self.buffer.append(experience)
def sample(self, batch_size):
indices = np.random.choice(len(self.buffer), size=batch_size, replace=False)
return [self.buffer[i] for i in indices]
def __len__(self):
return len(self.buffer)
# -------------------------------------
# RL Agent with Q-Learning and Epsilon-Greedy Exploration
# -------------------------------------
class ContinuousRLAgent:
def __init__(self, model, optimizer, replay_buffer, batch_size=32, gamma=0.99):
self.model = model
self.optimizer = optimizer
self.replay_buffer = replay_buffer
self.batch_size = batch_size
self.loss_fn = nn.MSELoss()
self.gamma = gamma
def act(self, state, epsilon=0.1):
if np.random.rand() < epsilon:
return np.random.randint(0, 3)
state_tensor = torch.from_numpy(np.array(state, dtype=np.float32)).unsqueeze(0)
with torch.no_grad():
output = self.model(state_tensor)
return torch.argmax(output, dim=1).item()
def train_step(self):
if len(self.replay_buffer) < self.batch_size:
return
batch = self.replay_buffer.sample(self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
states_tensor = torch.from_numpy(np.array(states, dtype=np.float32))
actions_tensor = torch.tensor(actions, dtype=torch.int64)
rewards_tensor = torch.from_numpy(np.array(rewards, dtype=np.float32)).unsqueeze(1)
next_states_tensor = torch.from_numpy(np.array(next_states, dtype=np.float32))
dones_tensor = torch.tensor(dones, dtype=torch.float32).unsqueeze(1)
Q_values = self.model(states_tensor)
current_Q = Q_values.gather(1, actions_tensor.unsqueeze(1))
with torch.no_grad():
next_Q_values = self.model(next_states_tensor)
max_next_Q = next_Q_values.max(1)[0].unsqueeze(1)
target = rewards_tensor + self.gamma * max_next_Q * (1.0 - dones_tensor)
loss = self.loss_fn(current_Q, target)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# -------------------------------------
# Historical Data Fetching Function (for a given timeframe)
# -------------------------------------
async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500):
candles = []
since_ms = since
while True:
try:
batch = await exchange.fetch_ohlcv(symbol, timeframe=timeframe, since=since_ms, limit=batch_size)
except Exception as e:
print(f"Error fetching historical data for {timeframe}:", e)
break
if not batch:
break
for c in batch:
candle_dict = {
'timestamp': c[0],
'open': c[1],
'high': c[2],
'low': c[3],
'close': c[4],
'volume': c[5]
}
candles.append(candle_dict)
last_timestamp = batch[-1][0]
if last_timestamp >= end_time:
break
since_ms = last_timestamp + 1
print(f"Fetched {len(candles)} candles for timeframe {timeframe}.")
return candles
# -------------------------------------
# Backtest Environment with Multi-Timeframe State
# -------------------------------------
class BacktestEnvironment:
def __init__(self, candles_dict, base_tf="1m", timeframes=None):
self.candles_dict = candles_dict # dict of timeframe: candles_list
self.base_tf = base_tf
if timeframes is None:
self.timeframes = [base_tf] # fallback to single timeframe
else:
self.timeframes = timeframes
self.trade_history = [] # record of closed trades
self.current_index = 0 # index on base_tf candles
self.position = None # active position record
def reset(self, clear_trade_history=True):
self.current_index = 0
self.position = None
if clear_trade_history:
self.trade_history = []
return self.get_state(self.current_index)
def get_state(self, index):
"""Construct the state as the concatenated features of all timeframes.
For each timeframe, find the aligned candle for the base timeframes timestamp."""
state_features = []
base_candle = self.candles_dict[self.base_tf][index]
base_ts = base_candle["timestamp"]
for tf in self.timeframes:
candles_list = self.candles_dict[tf]
# Get the candle from this timeframe that is closest to (and <=) base_ts.
aligned_index, _ = get_aligned_candle_with_index(candles_list, base_ts)
features = get_features_for_tf(candles_list, aligned_index, period=10)
state_features.extend(features)
return np.array(state_features, dtype=np.float32)
def step(self, action):
"""
Simulate a trading step based on the base timeframe.
- If not in a position and action is BUY (2), record entry at next candle's open.
- If in a position and action is SELL (0), record exit at next candle's open, computing PnL.
Returns: (current_state, reward, next_state, done)
"""
base_candles = self.candles_dict[self.base_tf]
if self.current_index >= len(base_candles) - 1:
return self.get_state(self.current_index), 0.0, None, True
current_state = self.get_state(self.current_index)
next_index = self.current_index + 1
next_state = self.get_state(next_index)
current_candle = base_candles[self.current_index]
next_candle = base_candles[next_index]
reward = 0.0
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
if self.position is None:
if action == 2: # BUY signal: enter position at next candle's open.
entry_price = next_candle["open"]
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
else:
if action == 0: # SELL signal: close position at next candle's open.
exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"]
trade = {
"entry_index": self.position["entry_index"],
"entry_price": self.position["entry_price"],
"exit_index": next_index,
"exit_price": exit_price,
"pnl": reward
}
self.trade_history.append(trade)
self.position = None
self.current_index = next_index
done = (self.current_index >= len(base_candles) - 1)
return current_state, reward, next_state, done
# -------------------------------------
# Chart Plotting: Trade History & PnL
# -------------------------------------
def plot_trade_history(candles, trade_history):
close_prices = [candle["close"] for candle in candles]
x = list(range(len(close_prices)))
plt.figure(figsize=(12, 6))
plt.plot(x, close_prices, label="Close Price", color="black", linewidth=1)
# Use these flags so that the label "BUY" or "SELL" is only shown once in the legend.
buy_label_added = False
sell_label_added = False
for trade in trade_history:
in_idx = trade["entry_index"]
out_idx = trade["exit_index"]
in_price = trade["entry_price"]
out_price = trade["exit_price"]
pnl = trade["pnl"]
# Plot BUY marker ("IN")
if not buy_label_added:
plt.plot(in_idx, in_price, marker="^", color="green", markersize=10, label="BUY (IN)")
buy_label_added = True
else:
plt.plot(in_idx, in_price, marker="^", color="green", markersize=10)
plt.text(in_idx, in_price, " IN", color="green", fontsize=8, verticalalignment="bottom")
# Plot SELL marker ("OUT")
if not sell_label_added:
plt.plot(out_idx, out_price, marker="v", color="red", markersize=10, label="SELL (OUT)")
sell_label_added = True
else:
plt.plot(out_idx, out_price, marker="v", color="red", markersize=10)
plt.text(out_idx, out_price, " OUT", color="red", fontsize=8, verticalalignment="top")
# Annotate the PnL near the SELL marker.
plt.text(out_idx, out_price, f" {pnl:+.2f}", color="blue", fontsize=8, verticalalignment="bottom")
# Choose line color based on profitability.
if pnl > 0:
line_color = "green"
elif pnl < 0:
line_color = "red"
else:
line_color = "gray"
# Draw a dotted line between the buy and sell points.
plt.plot([in_idx, out_idx], [in_price, out_price], linestyle="dotted", color=line_color)
plt.title("Trade History with PnL")
plt.xlabel("Base Candle Index (1m)")
plt.ylabel("Price")
plt.legend()
plt.grid(True)
plt.show()
# -------------------------------------
# Training Loop: Backtesting Trading Episodes
# -------------------------------------
def train_on_historical_data(env, rl_agent, num_epochs=10, epsilon=0.1):
for epoch in range(1, num_epochs + 1):
state = env.reset() # clear trade history each epoch
done = False
total_reward = 0.0
steps = 0
while not done:
action = rl_agent.act(state, epsilon=epsilon)
prev_state = state
state, reward, next_state, done = env.step(action)
if next_state is None:
next_state = np.zeros_like(prev_state)
rl_agent.replay_buffer.add((prev_state, action, reward, next_state, done))
rl_agent.train_step()
total_reward += reward
steps += 1
print(f"Epoch {epoch}/{num_epochs} completed, total reward: {total_reward:.4f} over {steps} steps.")
save_checkpoint(rl_agent.model, epoch, total_reward, LAST_DIR, BEST_DIR)
# -------------------------------------
# Main Asynchronous Function for Training & Charting
# -------------------------------------
async def main_backtest():
symbol = 'BTC/USDT'
# Define timeframes: we'll use 5 different ones.
timeframes = ["1m", "5m", "15m", "1h", "1d"]
now = int(time.time() * 1000)
# Use the base timeframe period of 1500 candles. For 1m, that is 1500 minutes.
period_ms = 1500 * 60 * 1000
since = now - period_ms
end_time = now
# Initialize exchange using MEXC (or your preferred exchange).
mexc_api_key = os.environ.get('MEXC_API_KEY', 'YOUR_API_KEY')
mexc_api_secret = os.environ.get('MEXC_API_SECRET', 'YOUR_SECRET_KEY')
exchange = ccxt.mexc({
'apiKey': mexc_api_key,
'secret': mexc_api_secret,
'enableRateLimit': True,
})
candles_dict = {}
for tf in timeframes:
print(f"Fetching historical data for timeframe {tf}...")
candles = await fetch_historical_data(exchange, symbol, tf, since, end_time, batch_size=500)
candles_dict[tf] = candles
# Optionally, save the multi-timeframe cache.
save_candles_cache(CACHE_FILE, candles_dict)
# Create the backtest environment using multi-timeframe data.
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes)
# Neural Network dimensions: each timeframe produces 7 features.
input_dim = len(timeframes) * 7 # 7 features * 5 timeframes = 35.
hidden_dim = 128
output_dim = 3 # Actions: SELL, HOLD, BUY.
model = TradingModel(input_dim, hidden_dim, output_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
replay_buffer = ReplayBuffer(capacity=10000)
rl_agent = ContinuousRLAgent(model, optimizer, replay_buffer, batch_size=32, gamma=0.99)
# Load best checkpoint if available.
load_best_checkpoint(model, BEST_DIR)
# Train the agent over the historical period.
num_epochs = 10 # Adjust as needed.
train_on_historical_data(env, rl_agent, num_epochs=num_epochs, epsilon=0.1)
# Run a final simulation (without exploration) to record trade history.
state = env.reset(clear_trade_history=True)
done = False
cumulative_reward = 0.0
while not done:
action = rl_agent.act(state, epsilon=0.0)
state, reward, next_state, done = env.step(action)
cumulative_reward += reward
state = next_state
print("Final simulation cumulative profit:", cumulative_reward)
# Evaluate trade performance.
trades = env.trade_history
num_trades = len(trades)
num_wins = sum(1 for trade in trades if trade["pnl"] > 0)
win_rate = (num_wins / num_trades * 100) if num_trades > 0 else 0.0
total_profit = sum(trade["pnl"] for trade in trades)
print(f"Total trades: {num_trades}, Wins: {num_wins}, Win rate: {win_rate:.2f}%, Total Profit: {total_profit:.4f}")
# Plot chart with buy/sell markers on the base timeframe ("1m").
plot_trade_history(candles_dict["1m"], trades)
await exchange.close()
if __name__ == "__main__":
load_dotenv()
asyncio.run(main_backtest())