#!/usr/bin/env python3
import sys
import asyncio
if sys.platform == 'win32':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
import os
import time
import json
import argparse
import threading
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from datetime import datetime
import matplotlib.pyplot as plt
import math
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import matplotlib.dates as mdates
from dotenv import load_dotenv
load_dotenv()
# --- Helper Function for Timestamp Conversion ---
def convert_timestamp(ts):
"""
Safely converts a timestamp to a datetime object.
If the timestamp is abnormally high (i.e. in milliseconds),
it is divided by 1000.
"""
ts = float(ts)
if ts > 1e10: # Likely in milliseconds
ts = ts / 1000.0
return datetime.fromtimestamp(ts)
# --- Directories ---
LAST_DIR = os.path.join("models", "last")
BEST_DIR = os.path.join("models", "best")
os.makedirs(LAST_DIR, exist_ok=True)
os.makedirs(BEST_DIR, exist_ok=True)
CACHE_FILE = "candles_cache.json"
# --- Constants ---
NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"]
NUM_INDICATORS = 20 # e.g., 20 technical indicators
FEATURES_PER_CHANNEL = 7 # Each channel input will have 7 features.
ORDER_CHANNELS = 1 # One extra channel for order info.
# --- Positional Encoding Module ---
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0)]
return self.dropout(x)
# --- Enhanced Transformer Model ---
class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__()
self.channel_branches = nn.ModuleList([
nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(0.1)
) for _ in range(num_channels)
])
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
self.pos_encoder = PositionalEncoding(hidden_dim)
encoder_layers = TransformerEncoderLayer(
d_model=hidden_dim, nhead=4, dim_feedforward=512,
dropout=0.1, activation='gelu', batch_first=True # Use batch_first to avoid nested tensor warning.
)
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
self.attn_pool = nn.Linear(hidden_dim, 1)
self.high_pred = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 1)
)
self.low_pred = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, x, timeframe_ids):
# x: [batch_size, num_channels, FEATURES_PER_CHANNEL]
batch_size, num_channels, _ = x.shape
channel_outs = []
for i in range(num_channels):
channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out)
stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
tf_embeds = self.timeframe_embed(timeframe_ids) # shape: [num_channels, hidden]
stacked = stacked + tf_embeds.unsqueeze(0) # broadcast along batch dimension.
transformer_out = self.transformer(stacked)
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1)
aggregated = (transformer_out * attn_weights).sum(dim=1)
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
# --- Technical Indicator Helpers ---
def compute_sma(candles_list, index, period=10):
start = max(0, index - period + 1)
values = [candle["close"] for candle in candles_list[start:index+1]]
return sum(values)/len(values) if values else 0.0
def compute_sma_volume(candles_list, index, period=10):
start = max(0, index - period + 1)
values = [candle["volume"] for candle in candles_list[start:index+1]]
return sum(values)/len(values) if values else 0.0
def get_aligned_candle_with_index(candles_list, target_ts):
best_idx = 0
for i, candle in enumerate(candles_list):
if candle["timestamp"] <= target_ts:
best_idx = i
else:
break
return best_idx, candles_list[best_idx]
def get_features_for_tf(candles_list, index, period=10):
candle = candles_list[index]
f_open = candle["open"]
f_high = candle["high"]
f_low = candle["low"]
f_close = candle["close"]
f_volume = candle["volume"]
sma_close = compute_sma(candles_list, index, period)
sma_volume = compute_sma_volume(candles_list, index, period)
return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume]
# --- Caching & Checkpoint Functions ---
def load_candles_cache(filename):
if os.path.exists(filename):
try:
with open(filename, "r") as f:
data = json.load(f)
print(f"Loaded cached data from {filename}.")
return data
except Exception as e:
print("Error reading cache file:", e)
return {}
def save_candles_cache(filename, candles_dict):
try:
with open(filename, "w") as f:
json.dump(candles_dict, f)
except Exception as e:
print("Error saving cache file:", e)
def maintain_checkpoint_directory(directory, max_files=10):
files = os.listdir(directory)
if len(files) > max_files:
full_paths = [os.path.join(directory, f) for f in files]
full_paths.sort(key=lambda x: os.path.getmtime(x))
for f in full_paths[:len(files) - max_files]:
os.remove(f)
def get_best_models(directory):
best_files = []
for file in os.listdir(directory):
parts = file.split("_")
try:
loss = float(parts[1])
best_files.append((loss, file))
except Exception:
continue
return best_files
def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=BEST_DIR):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename)
torch.save({
"epoch": epoch,
"loss": loss,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}, last_path)
maintain_checkpoint_directory(last_dir, max_files=10)
best_models = get_best_models(best_dir)
add_to_best = False
if len(best_models) < 10:
add_to_best = True
else:
worst_loss, worst_file = max(best_models, key=lambda x: x[0])
if loss < worst_loss:
add_to_best = True
os.remove(os.path.join(best_dir, worst_file))
if add_to_best:
best_filename = f"best_{loss:.4f}_epoch_{epoch}_{timestamp}.pt"
best_path = os.path.join(best_dir, best_filename)
torch.save({
"epoch": epoch,
"loss": loss,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}, best_path)
maintain_checkpoint_directory(best_dir, max_files=10)
print(f"Saved checkpoint for epoch {epoch} with loss {loss:.4f}")
def load_best_checkpoint(model, best_dir=BEST_DIR):
best_models = get_best_models(best_dir)
if not best_models:
return None
best_loss, best_file = min(best_models, key=lambda x: x[0])
path = os.path.join(best_dir, best_file)
print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}")
checkpoint = torch.load(path)
old_state = checkpoint["model_state_dict"]
new_state = model.state_dict()
if "timeframe_embed.weight" in old_state:
old_embed = old_state["timeframe_embed.weight"]
new_embed = new_state["timeframe_embed.weight"]
if old_embed.shape[0] < new_embed.shape[0]:
new_embed[:old_embed.shape[0]] = old_embed
old_state["timeframe_embed.weight"] = new_embed
model.load_state_dict(old_state, strict=False)
return checkpoint
# --- Function for Manual Trade Override ---
def manual_trade(env):
"""
When no sufficient action is taken by the model, manually decide the trade.
Find the maximum high and minimum low in the remaining window.
If maximum occurs before minimum, we short; otherwise we long.
The trade is closed at the candle where the chosen extreme occurs.
"""
current_index = env.current_index
if current_index >= len(env.candle_window) - 1:
env.current_index = len(env.candle_window) - 1
return
max_val = -float('inf')
min_val = float('inf')
i_max = current_index
i_min = current_index
for j in range(current_index + 1, len(env.candle_window)):
high_j = env.candle_window[j]["high"]
low_j = env.candle_window[j]["low"]
if high_j > max_val:
max_val = high_j
i_max = j
if low_j < min_val:
min_val = low_j
i_min = j
# If maximum occurs before minimum, we interpret that as short (price will drop).
if i_max < i_min:
entry_price = env.candle_window[current_index]["open"]
exit_price = env.candle_window[i_min]["open"]
reward = entry_price - exit_price
trade = {
"entry_index": current_index,
"entry_price": entry_price,
"exit_index": i_min,
"exit_price": exit_price,
"pnl": reward
}
else:
entry_price = env.candle_window[current_index]["open"]
exit_price = env.candle_window[i_max]["open"]
reward = exit_price - entry_price
trade = {
"entry_index": current_index,
"entry_price": entry_price,
"exit_index": i_max,
"exit_price": exit_price,
"pnl": reward
}
env.trade_history.append(trade)
env.current_index = trade["exit_index"]
# --- Live HTML Chart Update ---
def update_live_html(candles, trade_history, epoch):
"""
Generate a chart image with actual timestamps on the x-axis and cumulative epoch PnL.
The chart now also plots volume as a bar chart on a secondary y-axis.
The HTML page auto-refreshes every 10 seconds.
"""
from io import BytesIO
import base64
fig, ax = plt.subplots(figsize=(12, 6))
update_live_chart(ax, candles, trade_history)
epoch_pnl = sum(trade["pnl"] for trade in trade_history)
ax.set_title(f"Live Trading Chart - Epoch {epoch} | PnL: {epoch_pnl:.2f}")
buf = BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
html_content = f"""
Live Trading Chart - Epoch {epoch}
Live Trading Chart - Epoch {epoch} | PnL: {epoch_pnl:.2f}
"""
with open("live_chart.html", "w") as f:
f.write(html_content)
print("Updated live_chart.html.")
# --- Chart Drawing Helpers ---
def update_live_chart(ax, candles, trade_history):
"""
Plot the price chart with proper timestamp conversion.
Mark BUY (green) and SELL (red) actions (with dotted lines between),
and plot volume as a bar chart on a secondary y-axis.
"""
ax.clear()
times = [convert_timestamp(candle["timestamp"]) for candle in candles]
close_prices = [candle["close"] for candle in candles]
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
ax.set_xlabel("Time")
ax.set_ylabel("Price")
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
# Plot volume on secondary axis.
ax2 = ax.twinx()
volumes = [candle["volume"] for candle in candles]
# Compute bar width in days.
if len(times) > 1:
times_num = mdates.date2num(times)
bar_width = (times_num[-1] - times_num[0]) / len(times) * 0.8
else:
bar_width = 0.01
ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume")
ax2.set_ylabel("Volume")
# Combine legends.
lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines + lines2, labels + labels2)
ax.grid(True)
fig = ax.get_figure()
fig.autofmt_xdate()
# --- Simulation of Trades for Visualization ---
def simulate_trades(model, env, device, args):
"""
Run a simulation on the current sliding window.
If the model produces a sufficiently strong signal (based on threshold), use its action.
Otherwise, manually compute the trade by scanning for max/min prices.
"""
env.reset()
while True:
i = env.current_index
if i >= len(env.candle_window) - 1:
break
state = env.get_state(i)
current_open = env.candle_window[i]["open"]
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item()
pred_low = pred_low.item()
# If either upward potential or downward potential exceeds the threshold, use model decision.
if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
if (pred_high - current_open) >= (current_open - pred_low):
action = 2 # BUY
else:
action = 0 # SELL
_, _, _, done, _, _ = env.step(action)
else:
# No significant signal; use manual trade computation.
manual_trade(env)
if env.current_index >= len(env.candle_window) - 1:
break
# --- Backtest Environment with Sliding Window and Order Info ---
class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
self.candles_dict = candles_dict
self.base_tf = base_tf
self.timeframes = timeframes
self.full_candles = candles_dict[base_tf]
if window_size is None:
window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles)
self.window_size = window_size
self.reset()
def reset(self):
self.start_index = random.randint(0, len(self.full_candles) - self.window_size)
self.candle_window = self.full_candles[self.start_index: self.start_index + self.window_size]
self.current_index = 0
self.trade_history = []
self.position = None
return self.get_state(self.current_index)
def __len__(self):
return self.window_size
def get_order_features(self, index):
candle = self.candle_window[index]
if self.position is None:
return [0.0] * FEATURES_PER_CHANNEL
else:
flag = 1.0
diff = (candle["open"] - self.position["entry_price"]) / candle["open"]
return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2)
def get_state(self, index):
state_features = []
base_ts = self.candle_window[index]["timestamp"]
for tf in self.timeframes:
if tf == self.base_tf:
candle = self.candle_window[index]
features = get_features_for_tf([candle], 0)
else:
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
state_features.append(features)
order_features = self.get_order_features(index)
state_features.append(order_features)
for _ in range(NUM_INDICATORS):
state_features.append([0.0] * FEATURES_PER_CHANNEL)
return np.array(state_features, dtype=np.float32)
def step(self, action):
base = self.candle_window
if self.current_index >= len(base) - 1:
current_state = self.get_state(self.current_index)
return current_state, 0.0, None, True, 0.0, 0.0
current_state = self.get_state(self.current_index)
next_index = self.current_index + 1
next_state = self.get_state(next_index)
next_candle = base[next_index]
reward = 0.0
if self.position is None:
if action == 2: # BUY (open long)
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
else:
if action == 0: # SELL (close long / exit trade)
exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"]
trade = {
"entry_index": self.position["entry_index"],
"entry_price": self.position["entry_price"],
"exit_index": next_index,
"exit_price": exit_price,
"pnl": reward
}
self.trade_history.append(trade)
self.position = None
self.current_index = next_index
done = (self.current_index >= len(base) - 1)
actual_high = next_candle["high"]
actual_low = next_candle["low"]
return current_state, reward, next_state, done, actual_high, actual_low
# --- Enhanced Training Loop ---
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
lambda_trade = args.lambda_trade
for epoch in range(start_epoch, args.epochs):
env.reset()
loss_accum = 0.0
steps = len(env) - 1
for i in range(steps):
state = env.get_state(i)
current_open = env.candle_window[i]["open"]
actual_high = env.candle_window[i+1]["high"]
actual_low = env.candle_window[i+1]["low"]
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids)
L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \
torch.abs(pred_low - torch.tensor(actual_low, device=device))
profit_buy = pred_high - current_open
profit_sell = current_open - pred_low
L_trade = - torch.max(profit_buy, profit_sell)
current_open_tensor = torch.tensor(current_open, device=device)
signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low)
penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0)
loss = L_pred + lambda_trade * L_trade + penalty_term
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
loss_accum += loss.item()
scheduler.step()
epoch_loss = loss_accum / steps
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
save_checkpoint(model, optimizer, epoch, loss_accum)
simulate_trades(model, env, device, args)
update_live_html(env.candle_window, env.trade_history, epoch+1)
# --- Live Plotting Functions (For Live Mode) ---
def live_preview_loop(candles, env):
plt.ion()
fig, ax = plt.subplots(figsize=(12, 6))
while True:
update_live_chart(ax, candles, env.trade_history)
plt.draw()
plt.pause(1)
# --- Argument Parsing ---
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train')
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005,
help="Minimum predicted move to trigger trade (used in loss; model may override with manual trade).")
parser.add_argument('--lambda_trade', type=float, default=1.0,
help="Weight for trade surrogate loss.")
parser.add_argument('--penalty_noaction', type=float, default=10.0,
help="Penalty if no action is taken (used in loss).")
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
return parser.parse_args()
def random_action():
return random.randint(0, 2)
# --- Main Function ---
async def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
timeframes = ["1m", "5m", "15m", "1h", "1d"]
hidden_dim = 128
total_channels = NUM_TIMEFRAMES + 1 + NUM_INDICATORS
model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device)
if args.mode == 'train':
candles_dict = load_candles_cache(CACHE_FILE)
if not candles_dict:
print("No historical candle data available for backtesting.")
return
base_tf = "1m"
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
start_epoch = 0
checkpoint = None
if not args.start_fresh:
checkpoint = load_best_checkpoint(model)
if checkpoint is not None:
start_epoch = checkpoint.get("epoch", 0) + 1
print(f"Resuming training from epoch {start_epoch}.")
else:
print("No checkpoint found. Starting training from scratch.")
else:
print("Starting training from scratch as requested.")
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
if checkpoint is not None:
optim_state = checkpoint.get("optimizer_state_dict", None)
if optim_state is not None and "param_groups" in optim_state:
try:
optimizer.load_state_dict(optim_state)
print("Loaded optimizer state from checkpoint.")
except Exception as e:
print("Failed to load optimizer state due to:", e)
print("Deleting all checkpoints and starting fresh.")
for chk_dir in [LAST_DIR, BEST_DIR]:
for f in os.listdir(chk_dir):
os.remove(os.path.join(chk_dir, f))
else:
print("No valid optimizer state found in checkpoint; using fresh optimizer state.")
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
elif args.mode == 'live':
load_best_checkpoint(model)
candles_dict = load_candles_cache(CACHE_FILE)
if not candles_dict:
print("No cached candles available for live preview.")
return
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100)
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
preview_thread.start()
print("Starting live trading loop. (Using model, with manual override for HOLD actions.)")
while True:
state = env.get_state(env.current_index)
current_open = env.candle_window[env.current_index]["open"]
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item()
pred_low = pred_low.item()
if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
if (pred_high - current_open) >= (current_open - pred_low):
action = 2 # BUY
else:
action = 0 # SELL
_, _, _, done, _, _ = env.step(action)
else:
manual_trade(env)
if env.current_index >= len(env.candle_window)-1:
print("Reached end of simulation window; resetting environment.")
env.reset()
await asyncio.sleep(1)
elif args.mode == 'inference':
load_best_checkpoint(model)
print("Running inference...")
# Inference logic goes here.
else:
print("Invalid mode specified.")
if __name__ == "__main__":
asyncio.run(main())