gogo2/crypto/brian/index-deep-new.py
Dobromir Popov 967363378b even better
2025-02-04 22:09:13 +02:00

543 lines
22 KiB
Python

#!/usr/bin/env python3
import sys
import asyncio
if sys.platform == 'win32':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
import os
import time
import json
import argparse
import threading
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from datetime import datetime
import matplotlib.pyplot as plt
import math
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from dotenv import load_dotenv
load_dotenv()
# --- Directories ---
LAST_DIR = os.path.join("models", "last")
BEST_DIR = os.path.join("models", "best")
os.makedirs(LAST_DIR, exist_ok=True)
os.makedirs(BEST_DIR, exist_ok=True)
CACHE_FILE = "candles_cache.json"
# --- Constants ---
NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"]
NUM_INDICATORS = 20 # e.g., 20 technical indicators
FEATURES_PER_CHANNEL = 7 # e.g., [open, high, low, close, volume, sma_close, sma_volume]
# --- Positional Encoding Module ---
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0)]
return self.dropout(x)
# --- Enhanced Transformer Model ---
class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__()
# One branch per channel
self.channel_branches = nn.ModuleList([
nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(0.1)
) for _ in range(num_channels)
])
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
self.pos_encoder = PositionalEncoding(hidden_dim)
encoder_layers = TransformerEncoderLayer(
d_model=hidden_dim, nhead=4, dim_feedforward=512,
dropout=0.1, activation='gelu', batch_first=False
)
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
self.attn_pool = nn.Linear(hidden_dim, 1)
self.high_pred = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 1)
)
self.low_pred = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, x, timeframe_ids):
# x: [batch_size, num_channels, FEATURES_PER_CHANNEL]
batch_size, num_channels, _ = x.shape
channel_outs = []
for i in range(num_channels):
channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out)
stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden]
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
stacked = stacked + tf_embeds
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
transformer_out = self.transformer(stacked, mask=src_mask)
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0)
aggregated = (transformer_out * attn_weights).sum(dim=0)
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
# --- Technical Indicator Helpers ---
def compute_sma(candles_list, index, period=10):
start = max(0, index - period + 1)
values = [candle["close"] for candle in candles_list[start:index+1]]
return sum(values)/len(values) if values else 0.0
def compute_sma_volume(candles_list, index, period=10):
start = max(0, index - period + 1)
values = [candle["volume"] for candle in candles_list[start:index+1]]
return sum(values)/len(values) if values else 0.0
def get_aligned_candle_with_index(candles_list, target_ts):
best_idx = 0
for i, candle in enumerate(candles_list):
if candle["timestamp"] <= target_ts:
best_idx = i
else:
break
return best_idx, candles_list[best_idx]
def get_features_for_tf(candles_list, index, period=10):
candle = candles_list[index]
f_open = candle["open"]
f_high = candle["high"]
f_low = candle["low"]
f_close = candle["close"]
f_volume = candle["volume"]
sma_close = compute_sma(candles_list, index, period)
sma_volume = compute_sma_volume(candles_list, index, period)
return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume]
# --- Caching & Checkpoint Functions ---
def load_candles_cache(filename):
if os.path.exists(filename):
try:
with open(filename, "r") as f:
data = json.load(f)
print(f"Loaded cached data from {filename}.")
return data
except Exception as e:
print("Error reading cache file:", e)
return {}
def save_candles_cache(filename, candles_dict):
try:
with open(filename, "w") as f:
json.dump(candles_dict, f)
except Exception as e:
print("Error saving cache file:", e)
def maintain_checkpoint_directory(directory, max_files=10):
files = os.listdir(directory)
if len(files) > max_files:
full_paths = [os.path.join(directory, f) for f in files]
full_paths.sort(key=lambda x: os.path.getmtime(x))
for f in full_paths[:len(files)-max_files]:
os.remove(f)
def get_best_models(directory):
best_files = []
for file in os.listdir(directory):
parts = file.split("_")
try:
loss = float(parts[1])
best_files.append((loss, file))
except Exception:
continue
return best_files
def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=BEST_DIR):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename)
torch.save({
"epoch": epoch,
"loss": loss,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}, last_path)
maintain_checkpoint_directory(last_dir, max_files=10)
best_models = get_best_models(best_dir)
add_to_best = False
if len(best_models) < 10:
add_to_best = True
else:
worst_loss, worst_file = max(best_models, key=lambda x: x[0])
if loss < worst_loss:
add_to_best = True
os.remove(os.path.join(best_dir, worst_file))
if add_to_best:
best_filename = f"best_{loss:.4f}_epoch_{epoch}_{timestamp}.pt"
best_path = os.path.join(best_dir, best_filename)
torch.save({
"epoch": epoch,
"loss": loss,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}, best_path)
maintain_checkpoint_directory(best_dir, max_files=10)
print(f"Saved checkpoint for epoch {epoch} with loss {loss:.4f}")
def load_best_checkpoint(model, best_dir=BEST_DIR):
best_models = get_best_models(best_dir)
if not best_models:
return None
best_loss, best_file = min(best_models, key=lambda x: x[0])
path = os.path.join(best_dir, best_file)
print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}")
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
return checkpoint
# --- Live HTML Chart Update ---
def update_live_html(candles, trade_history, epoch):
"""
Generate a chart image with buy/sell markers and dotted lines between entry and exit,
then embed it in an auto-refreshing HTML page.
"""
from io import BytesIO
import base64
fig, ax = plt.subplots(figsize=(12, 6))
update_live_chart(ax, candles, trade_history)
ax.set_title(f"Live Trading Chart - Epoch {epoch}")
buf = BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta http-equiv="refresh" content="10">
<title>Live Trading Chart - Epoch {epoch}</title>
<style>
body {{
margin: 0;
padding: 0;
display: flex;
justify-content: center;
align-items: center;
background-color: #f4f4f4;
}}
.chart-container {{
text-align: center;
}}
img {{
max-width: 100%;
height: auto;
}}
</style>
</head>
<body>
<div class="chart-container">
<h2>Live Trading Chart - Epoch {epoch}</h2>
<img src="data:image/png;base64,{image_base64}" alt="Live Chart"/>
</div>
</body>
</html>
"""
with open("live_chart.html", "w") as f:
f.write(html_content)
print("Updated live_chart.html.")
# --- Chart Drawing Helpers ---
def update_live_chart(ax, candles, trade_history):
"""
Draw the price chart with close prices and mark BUY (green) and SELL (red) actions.
"""
ax.clear()
close_prices = [candle["close"] for candle in candles]
x = list(range(len(close_prices)))
ax.plot(x, close_prices, label="Close Price", color="black", linewidth=1)
buy_label_added = False
sell_label_added = False
for trade in trade_history:
in_idx = trade["entry_index"]
out_idx = trade["exit_index"]
in_price = trade["entry_price"]
out_price = trade["exit_price"]
if not buy_label_added:
ax.plot(in_idx, in_price, marker="^", color="green", markersize=10, label="BUY")
buy_label_added = True
else:
ax.plot(in_idx, in_price, marker="^", color="green", markersize=10)
if not sell_label_added:
ax.plot(out_idx, out_price, marker="v", color="red", markersize=10, label="SELL")
sell_label_added = True
else:
ax.plot(out_idx, out_price, marker="v", color="red", markersize=10)
ax.plot([in_idx, out_idx], [in_price, out_price], linestyle="dotted", color="blue")
ax.set_xlabel("Candle Index")
ax.set_ylabel("Price")
ax.legend()
ax.grid(True)
# --- Simulation of Trades for Visualization ---
def simulate_trades(model, env, device, args):
"""
Run a complete simulation on the current sliding window using a decision rule based on model outputs.
This simulation (which updates env.trade_history) is used only for visualization.
"""
env.reset() # resets the sliding window and index
while True:
i = env.current_index
state = env.get_state(i)
current_open = env.candle_window[i]["open"]
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item()
pred_low = pred_low.item()
# Decision rule: if upward move larger than downward and above threshold, BUY; if downward is larger, SELL; else HOLD.
if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold:
action = 2 # BUY
elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold:
action = 0 # SELL
else:
action = 1 # HOLD
_, _, _, done, _, _ = env.step(action)
if done:
break
# --- Backtest Environment with Sliding Window ---
class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
self.candles_dict = candles_dict # full candles dict for all timeframes
self.base_tf = base_tf
self.timeframes = timeframes
self.full_candles = candles_dict[base_tf]
if window_size is None:
window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles)
self.window_size = window_size
self.hint_penalty = 0.001 # not used in the revised loss below
self.reset()
def reset(self):
# Pick a random sliding window from the full dataset.
self.start_index = random.randint(0, len(self.full_candles) - self.window_size)
self.candle_window = self.full_candles[self.start_index:self.start_index+self.window_size]
self.current_index = 0
self.trade_history = []
self.position = None
return self.get_state(self.current_index)
def __len__(self):
return self.window_size
def get_state(self, index):
state_features = []
base_ts = self.candle_window[index]["timestamp"]
for tf in self.timeframes:
if tf == self.base_tf:
candle = self.candle_window[index]
features = get_features_for_tf([candle], 0)
else:
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
state_features.append(features)
for _ in range(NUM_INDICATORS):
state_features.append([0.0]*FEATURES_PER_CHANNEL)
return np.array(state_features, dtype=np.float32)
def step(self, action):
"""
Discrete simulation step.
- Action: 0 (SELL), 1 (HOLD), 2 (BUY).
- Trades are recorded when a BUY is followed by a SELL.
"""
base = self.candle_window
if self.current_index >= len(base) - 1:
current_state = self.get_state(self.current_index)
return current_state, 0.0, None, True, 0.0, 0.0
current_state = self.get_state(self.current_index)
next_index = self.current_index + 1
next_state = self.get_state(next_index)
next_candle = base[next_index]
reward = 0.0
# Simple trading logic (only one position allowed at a time)
if self.position is None:
if action == 2: # BUY signal: enter at next open.
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
else:
if action == 0: # SELL signal: exit at next open.
exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"]
trade = {
"entry_index": self.position["entry_index"],
"entry_price": self.position["entry_price"],
"exit_index": next_index,
"exit_price": exit_price,
"pnl": reward
}
self.trade_history.append(trade)
self.position = None
self.current_index = next_index
done = (self.current_index >= len(base) - 1)
actual_high = next_candle["high"]
actual_low = next_candle["low"]
return current_state, reward, next_state, done, actual_high, actual_low
# --- Enhanced Training Loop ---
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
# Weighting factor for trade surrogate loss.
lambda_trade = 1.0
for epoch in range(start_epoch, args.epochs):
# Reset sliding window for each epoch.
env.reset()
loss_accum = 0.0
steps = len(env) - 1 # we use pairs of consecutive candles
for i in range(steps):
state = env.get_state(i)
current_open = env.candle_window[i]["open"]
# Next candle's actual values serve as targets.
actual_high = env.candle_window[i+1]["high"]
actual_low = env.candle_window[i+1]["low"]
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids)
# Compute prediction loss (L1)
L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \
torch.abs(pred_low - torch.tensor(actual_low, device=device))
# Compute surrogate profit (differentiable estimate)
profit_buy = pred_high - current_open # potential long gain
profit_sell = current_open - pred_low # potential short gain
# Here we reward a higher potential move by subtracting it.
L_trade = - torch.max(profit_buy, profit_sell)
loss = L_pred + lambda_trade * L_trade
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
loss_accum += loss.item()
scheduler.step()
epoch_loss = loss_accum / steps
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
save_checkpoint(model, optimizer, epoch, loss_accum)
# Update the trade simulation (for visualization) using the current model on the same window.
simulate_trades(model, env, device, args)
update_live_html(env.candle_window, env.trade_history, epoch+1)
# --- Live Plotting Functions (For Live Mode) ---
def live_preview_loop(candles, env):
plt.ion()
fig, ax = plt.subplots(figsize=(12, 6))
while True:
update_live_chart(ax, candles, env.trade_history)
plt.draw()
plt.pause(1)
# --- Argument Parsing ---
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train')
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.")
parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for the trade surrogate loss.")
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
return parser.parse_args()
# --- Main Function ---
async def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
timeframes = ["1m", "5m", "15m", "1h", "1d"]
hidden_dim = 128
total_channels = NUM_TIMEFRAMES + NUM_INDICATORS
model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device)
if args.mode == 'train':
candles_dict = load_candles_cache(CACHE_FILE)
if not candles_dict:
print("No historical candle data available for backtesting.")
return
base_tf = "1m"
# Use a sliding window of up to 100 candles (if available)
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
start_epoch = 0
checkpoint = None
if not args.start_fresh:
checkpoint = load_best_checkpoint(model)
if checkpoint is not None:
start_epoch = checkpoint.get("epoch", 0) + 1
print(f"Resuming training from epoch {start_epoch}.")
else:
print("No checkpoint found. Starting training from scratch.")
else:
print("Starting training from scratch as requested.")
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
if checkpoint is not None:
optim_state = checkpoint.get("optimizer_state_dict", None)
if optim_state is not None and "param_groups" in optim_state:
optimizer.load_state_dict(optim_state)
print("Loaded optimizer state from checkpoint.")
else:
print("No valid optimizer state found in checkpoint; starting fresh optimizer state.")
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
elif args.mode == 'live':
load_best_checkpoint(model)
candles_dict = load_candles_cache(CACHE_FILE)
if not candles_dict:
print("No cached candles available for live preview.")
return
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100)
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
preview_thread.start()
print("Starting live trading loop. (Using model-based decision rule.)")
while True:
# In live mode, we use the simulation decision rule.
state = env.get_state(env.current_index)
current_open = env.candle_window[env.current_index]["open"]
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item()
pred_low = pred_low.item()
if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold:
action = 2
elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold:
action = 0
else:
action = 1
_, _, _, done, _, _ = env.step(action)
if done:
print("Reached end of simulation window; resetting environment.")
env.reset()
await asyncio.sleep(1)
elif args.mode == 'inference':
load_best_checkpoint(model)
print("Running inference...")
# Inference logic can use a similar decision rule as in live mode.
else:
print("Invalid mode specified.")
if __name__ == "__main__":
asyncio.run(main())