gogo2/crypto/brian/index-deep-new.py
2025-02-04 21:52:11 +02:00

517 lines
20 KiB
Python

#!/usr/bin/env python3
import sys
import asyncio
if sys.platform == 'win32':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
import os
import time
import json
import argparse
import threading
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from datetime import datetime
import matplotlib.pyplot as plt
import ccxt.async_support as ccxt
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math
from dotenv import load_dotenv
load_dotenv()
# --- Directories ---
LAST_DIR = os.path.join("models", "last")
BEST_DIR = os.path.join("models", "best")
os.makedirs(LAST_DIR, exist_ok=True)
os.makedirs(BEST_DIR, exist_ok=True)
CACHE_FILE = "candles_cache.json"
# --- Constants ---
NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"]
NUM_INDICATORS = 20 # e.g., 20 technical indicators
FEATURES_PER_CHANNEL = 7 # H, L, O, C, Volume, SMA_close, SMA_volume
# --- Positional Encoding Module ---
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0)]
return self.dropout(x)
# --- Enhanced Transformer Model ---
class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__()
# Create branch for each channel
self.channel_branches = nn.ModuleList([
nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(0.1)
) for _ in range(num_channels)
])
# Embedding for channels 0..num_channels-1.
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
self.pos_encoder = PositionalEncoding(hidden_dim)
encoder_layers = TransformerEncoderLayer(
d_model=hidden_dim, nhead=4, dim_feedforward=512,
dropout=0.1, activation='gelu', batch_first=False
)
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
self.attn_pool = nn.Linear(hidden_dim, 1)
self.high_pred = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 1)
)
self.low_pred = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, x, timeframe_ids):
# x: [batch_size, num_channels, FEATURES_PER_CHANNEL]
batch_size, num_channels, _ = x.shape
channel_outs = []
for i in range(num_channels):
channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out)
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
# Add embedding for each channel.
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
stacked = stacked + tf_embeds
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
transformer_out = self.transformer(stacked, mask=src_mask)
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0)
aggregated = (transformer_out * attn_weights).sum(dim=0)
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
# --- Technical Indicator Helpers ---
def compute_sma(candles_list, index, period=10):
start = max(0, index - period + 1)
values = [candle["close"] for candle in candles_list[start:index+1]]
return sum(values) / len(values) if values else 0.0
def compute_sma_volume(candles_list, index, period=10):
start = max(0, index - period + 1)
values = [candle["volume"] for candle in candles_list[start:index+1]]
return sum(values) / len(values) if values else 0.0
def get_aligned_candle_with_index(candles_list, target_ts):
best_idx = 0
for i, candle in enumerate(candles_list):
if candle["timestamp"] <= target_ts:
best_idx = i
else:
break
return best_idx, candles_list[best_idx]
def get_features_for_tf(candles_list, index, period=10):
candle = candles_list[index]
f_open = candle["open"]
f_high = candle["high"]
f_low = candle["low"]
f_close = candle["close"]
f_volume = candle["volume"]
sma_close = compute_sma(candles_list, index, period)
sma_volume = compute_sma_volume(candles_list, index, period)
return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume]
# --- Caching & Checkpoint Functions ---
def load_candles_cache(filename):
if os.path.exists(filename):
try:
with open(filename, "r") as f:
data = json.load(f)
print(f"Loaded cached data from {filename}.")
return data
except Exception as e:
print("Error reading cache file:", e)
return {}
def save_candles_cache(filename, candles_dict):
try:
with open(filename, "w") as f:
json.dump(candles_dict, f)
except Exception as e:
print("Error saving cache file:", e)
def maintain_checkpoint_directory(directory, max_files=10):
files = os.listdir(directory)
if len(files) > max_files:
full_paths = [os.path.join(directory, f) for f in files]
full_paths.sort(key=lambda x: os.path.getmtime(x))
for f in full_paths[: len(files) - max_files]:
os.remove(f)
def get_best_models(directory):
best_files = []
for file in os.listdir(directory):
parts = file.split("_")
try:
# parts[1] is the recorded loss
loss = float(parts[1])
best_files.append((loss, file))
except Exception:
continue
return best_files
def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=BEST_DIR):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename)
torch.save({
"epoch": epoch,
"loss": loss,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}, last_path)
maintain_checkpoint_directory(last_dir, max_files=10)
best_models = get_best_models(best_dir)
add_to_best = False
# Update best pool if fewer than 10 or if the new loss is lower than the worst saved loss.
if len(best_models) < 10:
add_to_best = True
else:
worst_loss, worst_file = max(best_models, key=lambda x: x[0])
if loss < worst_loss:
add_to_best = True
os.remove(os.path.join(best_dir, worst_file))
if add_to_best:
best_filename = f"best_{loss:.4f}_epoch_{epoch}_{timestamp}.pt"
best_path = os.path.join(best_dir, best_filename)
torch.save({
"epoch": epoch,
"loss": loss,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}, best_path)
maintain_checkpoint_directory(best_dir, max_files=10)
print(f"Saved checkpoint for epoch {epoch} with loss {loss:.4f}")
def load_best_checkpoint(model, best_dir=BEST_DIR):
best_models = get_best_models(best_dir)
if not best_models:
return None
# Choose the checkpoint with the lowest loss.
best_loss, best_file = min(best_models, key=lambda x: x[0])
path = os.path.join(best_dir, best_file)
print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}")
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
return checkpoint
# --- Live HTML Chart Update ---
def update_live_html(candles, trade_history, epoch):
"""
Generate a chart image with buy/sell markers and a dotted line between open and close,
then embed it in a simple HTML page that auto-refreshes every 10 seconds.
"""
from io import BytesIO
import base64
fig, ax = plt.subplots(figsize=(12, 6))
update_live_chart(ax, candles, trade_history)
ax.set_title(f"Live Trading Chart - Epoch {epoch}")
buf = BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta http-equiv="refresh" content="10">
<title>Live Trading Chart - Epoch {epoch}</title>
<style>
body {{
margin: 0;
padding: 0;
display: flex;
justify-content: center;
align-items: center;
background-color: #f4f4f4;
}}
.chart-container {{
text-align: center;
}}
img {{
max-width: 100%;
height: auto;
}}
</style>
</head>
<body>
<div class="chart-container">
<h2>Live Trading Chart - Epoch {epoch}</h2>
<img src="data:image/png;base64,{image_base64}" alt="Live Chart"/>
</div>
</body>
</html>
"""
with open("live_chart.html", "w") as f:
f.write(html_content)
print("Updated live_chart.html.")
# --- Chart Drawing Helpers (used by both live preview and HTML update) ---
def update_live_chart(ax, candles, trade_history):
"""
Plot the chart with close price, buy and sell markers, and dotted lines joining buy/sell entry/exit.
"""
ax.clear()
close_prices = [candle["close"] for candle in candles]
x = list(range(len(close_prices)))
ax.plot(x, close_prices, label="Close Price", color="black", linewidth=1)
buy_label_added = False
sell_label_added = False
for trade in trade_history:
in_idx = trade["entry_index"]
out_idx = trade["exit_index"]
in_price = trade["entry_price"]
out_price = trade["exit_price"]
if not buy_label_added:
ax.plot(in_idx, in_price, marker="^", color="green", markersize=10, label="BUY")
buy_label_added = True
else:
ax.plot(in_idx, in_price, marker="^", color="green", markersize=10)
if not sell_label_added:
ax.plot(out_idx, out_price, marker="v", color="red", markersize=10, label="SELL")
sell_label_added = True
else:
ax.plot(out_idx, out_price, marker="v", color="red", markersize=10)
ax.plot([in_idx, out_idx], [in_price, out_price], linestyle="dotted", color="blue")
ax.set_xlabel("Candle Index")
ax.set_ylabel("Price")
ax.legend()
ax.grid(True)
# --- Forced Action Helper ---
def get_forced_action(env):
"""
Force at least one trade per episode:
- At the very first step, force a BUY (action 2) if no position is open.
- At the penultimate step, if a position is open, force a SELL (action 0).
- Otherwise, default to HOLD (action 1).
"""
total = len(env)
if env.current_index == 0:
return 2 # BUY
elif env.current_index >= total - 2:
if env.position is not None:
return 0 # SELL
else:
return 1 # HOLD
else:
return 1 # HOLD
# --- Backtest Environment ---
class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes):
self.candles_dict = candles_dict # dict: timeframe -> list of candles
self.base_tf = base_tf
self.timeframes = timeframes
self.current_index = 0
self.trade_history = []
self.position = None
def reset(self):
self.current_index = 0
self.position = None
self.trade_history = []
return self.get_state(self.current_index)
def get_state(self, index):
state_features = []
base_ts = self.candles_dict[self.base_tf][index]["timestamp"]
for tf in self.timeframes:
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
state_features.append(features)
for _ in range(NUM_INDICATORS):
state_features.append([0.0] * FEATURES_PER_CHANNEL)
return np.array(state_features, dtype=np.float32)
def step(self, action):
base_candles = self.candles_dict[self.base_tf]
# End-of-data: return dummy high/low targets
if self.current_index >= len(base_candles) - 1:
current_state = self.get_state(self.current_index)
return current_state, 0.0, None, True, 0.0, 0.0
current_state = self.get_state(self.current_index)
next_index = self.current_index + 1
next_state = self.get_state(next_index)
next_candle = base_candles[next_index]
reward = 0.0
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
if self.position is None:
if action == 2: # BUY signal: enter at next candle's open.
entry_price = next_candle["open"]
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
else:
if action == 0: # SELL signal: exit at next candle's open.
exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"]
trade = {
"entry_index": self.position["entry_index"],
"entry_price": self.position["entry_price"],
"exit_index": next_index,
"exit_price": exit_price,
"pnl": reward
}
self.trade_history.append(trade)
self.position = None
self.current_index = next_index
done = (self.current_index >= len(base_candles) - 1)
actual_high = next_candle["high"]
actual_low = next_candle["low"]
return current_state, reward, next_state, done, actual_high, actual_low
def __len__(self):
return len(self.candles_dict[self.base_tf])
# --- Enhanced Training Loop ---
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler, base_candles):
for epoch in range(start_epoch, args.epochs):
state = env.reset()
total_loss = 0
model.train()
while True:
# Use forced action policy to guarantee at least one trade per episode
action = get_forced_action(env)
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids)
# Use the forced action in the environment step.
_, reward, next_state, done, actual_high, actual_low = env.step(action)
target_high = torch.FloatTensor([actual_high]).to(device)
target_low = torch.FloatTensor([actual_low]).to(device)
high_loss = torch.abs(pred_high - target_high) * 2
low_loss = torch.abs(pred_low - target_low) * 2
loss = (high_loss + low_loss).mean()
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
if done:
break
state = next_state
scheduler.step()
epoch_loss = total_loss / len(env)
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
save_checkpoint(model, optimizer, epoch, total_loss)
# Update the live HTML file with the current epoch chart.
update_live_html(base_candles, env.trade_history, epoch+1)
# --- Live Plotting Functions (For live mode) ---
def live_preview_loop(candles, env):
plt.ion()
fig, ax = plt.subplots(figsize=(12, 6))
while True:
update_live_chart(ax, candles, env.trade_history)
plt.draw()
plt.pause(1)
# --- Argument Parsing ---
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train','live','inference'], default='train')
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005)
# If set, training starts from scratch (ignoring saved checkpoints)
parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch.')
return parser.parse_args()
def random_action():
return random.randint(0, 2)
# --- Main Function ---
async def main():
args = parse_args()
# Use GPU if available; else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
timeframes = ["1m", "5m", "15m", "1h", "1d"]
hidden_dim = 128
total_channels = NUM_TIMEFRAMES + NUM_INDICATORS
model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device)
if args.mode == 'train':
candles_dict = load_candles_cache(CACHE_FILE)
if not candles_dict:
print("No historical candle data available for backtesting.")
return
base_tf = "1m"
env = BacktestEnvironment(candles_dict, base_tf, timeframes)
start_epoch = 0
checkpoint = None
if not args.start_fresh:
checkpoint = load_best_checkpoint(model)
if checkpoint is not None:
start_epoch = checkpoint.get("epoch", 0) + 1
print(f"Resuming training from epoch {start_epoch}.")
else:
print("No checkpoint found. Starting training from scratch.")
else:
print("Starting training from scratch as requested.")
# Create optimizer and scheduler in main
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
if checkpoint is not None:
optim_state = checkpoint.get("optimizer_state_dict", None)
if optim_state is not None and "param_groups" in optim_state:
optimizer.load_state_dict(optim_state)
print("Loaded optimizer state from checkpoint.")
else:
print("No valid optimizer state found in checkpoint; starting fresh optimizer state.")
# Pass the base timeframe candles for the live HTML chart update.
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler, candles_dict[base_tf])
elif args.mode == 'live':
load_best_checkpoint(model)
candles_dict = load_candles_cache(CACHE_FILE)
if not candles_dict:
print("No cached candles available for live preview.")
return
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes)
preview_thread = threading.Thread(target=live_preview_loop, args=(candles_dict["1m"], env), daemon=True)
preview_thread.start()
print("Starting live trading loop. (Using forced action policy for simulation.)")
# Here we use the forced-action policy as in training.
while True:
action = get_forced_action(env)
state, reward, next_state, done, _, _ = env.step(action)
if done:
print("Reached end of simulated data, resetting environment.")
state = env.reset()
await asyncio.sleep(1)
elif args.mode == 'inference':
load_best_checkpoint(model)
print("Running inference...")
# Here you can apply a similar forced-action policy or use a learned policy.
else:
print("Invalid mode specified.")
if __name__ == "__main__":
asyncio.run(main())