gogo2/crypto/brian/index-deep-new.py
2025-02-04 20:59:24 +02:00

407 lines
16 KiB
Python

#!/usr/bin/env python3
import sys
import asyncio
if sys.platform == 'win32':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
import os
import time
import json
import argparse
import threading
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
from datetime import datetime
import matplotlib.pyplot as plt
import ccxt.async_support as ccxt
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math
from dotenv import load_dotenv
load_dotenv()
# --- Directories ---
LAST_DIR = os.path.join("models", "last")
BEST_DIR = os.path.join("models", "best")
os.makedirs(LAST_DIR, exist_ok=True)
os.makedirs(BEST_DIR, exist_ok=True)
CACHE_FILE = "candles_cache.json"
# --- Constants ---
NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"]
NUM_INDICATORS = 20 # e.g., 20 technical indicators
FEATURES_PER_CHANNEL = 7 # H, L, O, C, Volume, SMA_close, SMA_volume
# --- Positional Encoding Module ---
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0)]
return self.dropout(x)
# --- Enhanced Transformer Model ---
class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__()
# Create a branch for each channel (each channel input has FEATURES_PER_CHANNEL features)
self.channel_branches = nn.ModuleList([
nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(0.1)
) for _ in range(num_channels)
])
# IMPORTANT FIX: Use num_channels (total channels) instead of num_timeframes.
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
self.pos_encoder = PositionalEncoding(hidden_dim)
encoder_layers = TransformerEncoderLayer(
d_model=hidden_dim, nhead=4, dim_feedforward=512,
dropout=0.1, activation='gelu', batch_first=False
)
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
self.attn_pool = nn.Linear(hidden_dim, 1)
self.high_pred = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 1)
)
self.low_pred = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, x, timeframe_ids):
# x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL]
batch_size, num_channels, _ = x.shape
channel_outs = []
for i in range(num_channels):
channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out)
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
# Use embedding for each channel (indices 0 to num_channels-1)
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
stacked = stacked + tf_embeds
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
transformer_out = self.transformer(stacked, mask=src_mask)
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0)
aggregated = (transformer_out * attn_weights).sum(dim=0)
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
# --- Technical Indicator Helpers ---
def compute_sma(candles_list, index, period=10):
start = max(0, index - period + 1)
values = [candle["close"] for candle in candles_list[start:index+1]]
return sum(values) / len(values) if values else 0.0
def compute_sma_volume(candles_list, index, period=10):
start = max(0, index - period + 1)
values = [candle["volume"] for candle in candles_list[start:index+1]]
return sum(values) / len(values) if values else 0.0
def get_aligned_candle_with_index(candles_list, target_ts):
best_idx = 0
for i, candle in enumerate(candles_list):
if candle["timestamp"] <= target_ts:
best_idx = i
else:
break
return best_idx, candles_list[best_idx]
def get_features_for_tf(candles_list, index, period=10):
candle = candles_list[index]
f_open = candle["open"]
f_high = candle["high"]
f_low = candle["low"]
f_close = candle["close"]
f_volume = candle["volume"]
sma_close = compute_sma(candles_list, index, period)
sma_volume = compute_sma_volume(candles_list, index, period)
return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume]
# --- Caching & Checkpoint Functions ---
def load_candles_cache(filename):
if os.path.exists(filename):
try:
with open(filename, "r") as f:
data = json.load(f)
print(f"Loaded cached data from {filename}.")
return data
except Exception as e:
print("Error reading cache file:", e)
return {}
def save_candles_cache(filename, candles_dict):
try:
with open(filename, "w") as f:
json.dump(candles_dict, f)
except Exception as e:
print("Error saving cache file:", e)
def maintain_checkpoint_directory(directory, max_files=10):
files = os.listdir(directory)
if len(files) > max_files:
full_paths = [os.path.join(directory, f) for f in files]
full_paths.sort(key=lambda x: os.path.getmtime(x))
for f in full_paths[: len(files) - max_files]:
os.remove(f)
def get_best_models(directory):
best_files = []
for file in os.listdir(directory):
parts = file.split("_")
try:
r = float(parts[1])
best_files.append((r, file))
except Exception:
continue
return best_files
def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"model_state_dict": model.state_dict()
}, last_path)
maintain_checkpoint_directory(last_dir, max_files=10)
best_models = get_best_models(best_dir)
add_to_best = False
if len(best_models) < 10:
add_to_best = True
else:
min_reward, min_file = min(best_models, key=lambda x: x[0])
if reward > min_reward:
add_to_best = True
os.remove(os.path.join(best_dir, min_file))
if add_to_best:
best_filename = f"best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"
best_path = os.path.join(best_dir, best_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"model_state_dict": model.state_dict()
}, best_path)
maintain_checkpoint_directory(best_dir, max_files=10)
print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}")
def load_best_checkpoint(model, best_dir=BEST_DIR):
best_models = get_best_models(best_dir)
if not best_models:
return None
best_reward, best_file = max(best_models, key=lambda x: x[0])
path = os.path.join(best_dir, best_file)
print(f"Loading best model from checkpoint: {best_file} with reward {best_reward:.4f}")
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
return checkpoint
# --- Backtest Environment ---
class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes):
self.candles_dict = candles_dict # dict: timeframe -> list of candles
self.base_tf = base_tf
self.timeframes = timeframes
self.current_index = 0
self.trade_history = []
self.position = None
def reset(self):
self.current_index = 0
self.position = None
self.trade_history = []
return self.get_state(self.current_index)
def get_state(self, index):
state_features = []
base_ts = self.candles_dict[self.base_tf][index]["timestamp"]
for tf in self.timeframes:
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
state_features.append(features)
for _ in range(NUM_INDICATORS):
state_features.append([0.0] * FEATURES_PER_CHANNEL)
return np.array(state_features, dtype=np.float32)
def step(self, action):
base_candles = self.candles_dict[self.base_tf]
# Handle end-of-data scenario: return dummy high/low values if needed.
if self.current_index >= len(base_candles) - 1:
current_state = self.get_state(self.current_index)
return current_state, 0.0, None, True, 0.0, 0.0
current_state = self.get_state(self.current_index)
next_index = self.current_index + 1
next_state = self.get_state(next_index)
current_candle = base_candles[self.current_index]
next_candle = base_candles[next_index]
reward = 0.0
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
if self.position is None:
if action == 2: # BUY signal: enter at next candle's open.
entry_price = next_candle["open"]
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
else:
if action == 0: # SELL signal: exit at next candle's open.
exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"]
trade = {
"entry_index": self.position["entry_index"],
"entry_price": self.position["entry_price"],
"exit_index": next_index,
"exit_price": exit_price,
"pnl": reward
}
self.trade_history.append(trade)
self.position = None
self.current_index = next_index
done = (self.current_index >= len(base_candles) - 1)
# Return the high and low of the next candle as the targets.
actual_high = next_candle["high"]
actual_low = next_candle["low"]
return current_state, reward, next_state, done, actual_high, actual_low
def __len__(self):
return len(self.candles_dict[self.base_tf])
# --- Enhanced Training Loop ---
def train_on_historical_data(env, model, device, args):
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
for epoch in range(args.epochs):
state = env.reset()
total_loss = 0
model.train()
while True:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device) # Expect shape[0]==num_channels
pred_high, pred_low = model(state_tensor, timeframe_ids)
# Dummy targets from next candle's high/low
_, _, next_state, done, actual_high, actual_low = env.step(None)
target_high = torch.FloatTensor([actual_high]).to(device)
target_low = torch.FloatTensor([actual_low]).to(device)
high_loss = torch.abs(pred_high - target_high) * 2
low_loss = torch.abs(pred_low - target_low) * 2
loss = (high_loss + low_loss).mean()
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
if done:
break
state = next_state
scheduler.step()
print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}")
save_checkpoint(model, epoch, total_loss)
# --- Live Plotting Functions ---
def update_live_chart(ax, candles, trade_history):
ax.clear()
close_prices = [candle["close"] for candle in candles]
x = list(range(len(close_prices)))
ax.plot(x, close_prices, label="Close Price", color="black", linewidth=1)
buy_label_added = False
sell_label_added = False
for trade in trade_history:
in_idx = trade["entry_index"]
out_idx = trade["exit_index"]
in_price = trade["entry_price"]
out_price = trade["exit_price"]
if not buy_label_added:
ax.plot(in_idx, in_price, marker="^", color="green", markersize=10, label="BUY")
buy_label_added = True
else:
ax.plot(in_idx, in_price, marker="^", color="green", markersize=10)
if not sell_label_added:
ax.plot(out_idx, out_price, marker="v", color="red", markersize=10, label="SELL")
sell_label_added = True
else:
ax.plot(out_idx, out_price, marker="v", color="red", markersize=10)
ax.plot([in_idx, out_idx], [in_price, out_price], linestyle="dotted", color="blue")
ax.set_title("Live Trading Chart")
ax.set_xlabel("Candle Index")
ax.set_ylabel("Price")
ax.legend()
ax.grid(True)
def live_preview_loop(candles, env):
plt.ion()
fig, ax = plt.subplots(figsize=(12, 6))
while True:
update_live_chart(ax, candles, env.trade_history)
plt.draw()
plt.pause(1)
# --- Argument Parsing ---
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train','live','inference'], default='train')
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005)
return parser.parse_args()
def random_action():
return random.randint(0, 2)
# --- Main Function ---
async def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timeframes = ["1m", "5m", "15m", "1h", "1d"]
input_dim = len(timeframes) * 7 # 7 features per timeframe
hidden_dim = 128
output_dim = 3 # SELL, HOLD, BUY
# Set total number of channels = NUM_TIMEFRAMES + NUM_INDICATORS.
total_channels = NUM_TIMEFRAMES + NUM_INDICATORS
model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device)
if args.mode == 'train':
candles_dict = load_candles_cache(CACHE_FILE)
if not candles_dict:
print("No historical candle data available for backtesting.")
return
base_tf = "1m"
env = BacktestEnvironment(candles_dict, base_tf, timeframes)
train_on_historical_data(env, model, device, args)
elif args.mode == 'live':
load_best_checkpoint(model)
candles_dict = load_candles_cache(CACHE_FILE)
if not candles_dict:
print("No cached candles available for live preview.")
return
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes)
preview_thread = threading.Thread(target=live_preview_loop, args=(candles_dict["1m"], env), daemon=True)
preview_thread.start()
print("Starting live trading loop. (Using random actions for simulation.)")
while True:
state, reward, next_state, done = env.step(random_action())
if done:
print("Reached end of simulated data, resetting environment.")
state = env.reset(clear_trade_history=False)
await asyncio.sleep(1)
elif args.mode == 'inference':
load_best_checkpoint(model)
print("Running inference...")
# Place your inference logic here.
else:
print("Invalid mode specified.")
if __name__ == "__main__":
asyncio.run(main())