even better

This commit is contained in:
Dobromir Popov 2025-02-04 22:09:13 +02:00
parent 375aebee88
commit 967363378b
2 changed files with 111 additions and 130 deletions

View File

@ -16,9 +16,8 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from datetime import datetime from datetime import datetime
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import ccxt.async_support as ccxt
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math import math
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
@ -30,9 +29,9 @@ os.makedirs(BEST_DIR, exist_ok=True)
CACHE_FILE = "candles_cache.json" CACHE_FILE = "candles_cache.json"
# --- Constants --- # --- Constants ---
NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"] NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"]
NUM_INDICATORS = 20 # e.g., 20 technical indicators NUM_INDICATORS = 20 # e.g., 20 technical indicators
FEATURES_PER_CHANNEL = 7 # e.g., H, L, O, C, Volume, SMA_close, SMA_volume FEATURES_PER_CHANNEL = 7 # e.g., [open, high, low, close, volume, sma_close, sma_volume]
# --- Positional Encoding Module --- # --- Positional Encoding Module ---
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
@ -53,7 +52,7 @@ class PositionalEncoding(nn.Module):
class TradingModel(nn.Module): class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128): def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__() super().__init__()
# Create one branch per channel (each channel input has FEATURES_PER_CHANNEL features) # One branch per channel
self.channel_branches = nn.ModuleList([ self.channel_branches = nn.ModuleList([
nn.Sequential( nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
@ -62,7 +61,6 @@ class TradingModel(nn.Module):
nn.Dropout(0.1) nn.Dropout(0.1)
) for _ in range(num_channels) ) for _ in range(num_channels)
]) ])
# Embedding for channels 0..num_channels-1.
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim) self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
self.pos_encoder = PositionalEncoding(hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim)
encoder_layers = TransformerEncoderLayer( encoder_layers = TransformerEncoderLayer(
@ -82,15 +80,14 @@ class TradingModel(nn.Module):
nn.Linear(hidden_dim // 2, 1) nn.Linear(hidden_dim // 2, 1)
) )
def forward(self, x, timeframe_ids): def forward(self, x, timeframe_ids):
# x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL] # x: [batch_size, num_channels, FEATURES_PER_CHANNEL]
batch_size, num_channels, _ = x.shape batch_size, num_channels, _ = x.shape
channel_outs = [] channel_outs = []
for i in range(num_channels): for i in range(num_channels):
channel_out = self.channel_branches[i](x[:, i, :]) channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out) channel_outs.append(channel_out)
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden] stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden] stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden]
# Add embedding for each channel.
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
stacked = stacked + tf_embeds stacked = stacked + tf_embeds
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device) src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
@ -103,12 +100,12 @@ class TradingModel(nn.Module):
def compute_sma(candles_list, index, period=10): def compute_sma(candles_list, index, period=10):
start = max(0, index - period + 1) start = max(0, index - period + 1)
values = [candle["close"] for candle in candles_list[start:index+1]] values = [candle["close"] for candle in candles_list[start:index+1]]
return sum(values) / len(values) if values else 0.0 return sum(values)/len(values) if values else 0.0
def compute_sma_volume(candles_list, index, period=10): def compute_sma_volume(candles_list, index, period=10):
start = max(0, index - period + 1) start = max(0, index - period + 1)
values = [candle["volume"] for candle in candles_list[start:index+1]] values = [candle["volume"] for candle in candles_list[start:index+1]]
return sum(values) / len(values) if values else 0.0 return sum(values)/len(values) if values else 0.0
def get_aligned_candle_with_index(candles_list, target_ts): def get_aligned_candle_with_index(candles_list, target_ts):
best_idx = 0 best_idx = 0
@ -123,7 +120,7 @@ def get_features_for_tf(candles_list, index, period=10):
candle = candles_list[index] candle = candles_list[index]
f_open = candle["open"] f_open = candle["open"]
f_high = candle["high"] f_high = candle["high"]
f_low = candle["low"] f_low = candle["low"]
f_close = candle["close"] f_close = candle["close"]
f_volume = candle["volume"] f_volume = candle["volume"]
sma_close = compute_sma(candles_list, index, period) sma_close = compute_sma(candles_list, index, period)
@ -154,7 +151,7 @@ def maintain_checkpoint_directory(directory, max_files=10):
if len(files) > max_files: if len(files) > max_files:
full_paths = [os.path.join(directory, f) for f in files] full_paths = [os.path.join(directory, f) for f in files]
full_paths.sort(key=lambda x: os.path.getmtime(x)) full_paths.sort(key=lambda x: os.path.getmtime(x))
for f in full_paths[: len(files) - max_files]: for f in full_paths[:len(files)-max_files]:
os.remove(f) os.remove(f)
def get_best_models(directory): def get_best_models(directory):
@ -162,7 +159,6 @@ def get_best_models(directory):
for file in os.listdir(directory): for file in os.listdir(directory):
parts = file.split("_") parts = file.split("_")
try: try:
# parts[1] is the recorded loss
loss = float(parts[1]) loss = float(parts[1])
best_files.append((loss, file)) best_files.append((loss, file))
except Exception: except Exception:
@ -174,10 +170,10 @@ def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=B
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename) last_path = os.path.join(last_dir, last_filename)
torch.save({ torch.save({
"epoch": epoch, "epoch": epoch,
"loss": loss, "loss": loss,
"model_state_dict": model.state_dict(), "model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict() "optimizer_state_dict": optimizer.state_dict()
}, last_path) }, last_path)
maintain_checkpoint_directory(last_dir, max_files=10) maintain_checkpoint_directory(last_dir, max_files=10)
best_models = get_best_models(best_dir) best_models = get_best_models(best_dir)
@ -215,8 +211,8 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
# --- Live HTML Chart Update --- # --- Live HTML Chart Update ---
def update_live_html(candles, trade_history, epoch): def update_live_html(candles, trade_history, epoch):
""" """
Generate a chart image with buy/sell markers and a dotted line between open/close positions, Generate a chart image with buy/sell markers and dotted lines between entry and exit,
then embed it in a simple HTML page that auto-refreshes every 10 seconds. then embed it in an auto-refreshing HTML page.
""" """
from io import BytesIO from io import BytesIO
import base64 import base64
@ -266,10 +262,10 @@ def update_live_html(candles, trade_history, epoch):
f.write(html_content) f.write(html_content)
print("Updated live_chart.html.") print("Updated live_chart.html.")
# --- Chart Drawing Helpers (used by both live preview and HTML update) --- # --- Chart Drawing Helpers ---
def update_live_chart(ax, candles, trade_history): def update_live_chart(ax, candles, trade_history):
""" """
Plot the chart with close price, buy/sell markers, and dotted lines joining entry/exit. Draw the price chart with close prices and mark BUY (green) and SELL (red) actions.
""" """
ax.clear() ax.clear()
close_prices = [candle["close"] for candle in candles] close_prices = [candle["close"] for candle in candles]
@ -298,39 +294,44 @@ def update_live_chart(ax, candles, trade_history):
ax.legend() ax.legend()
ax.grid(True) ax.grid(True)
# --- Forced Action & Optimal Hint Helpers --- # --- Simulation of Trades for Visualization ---
def get_forced_action(env): def simulate_trades(model, env, device, args):
""" """
When simulating streaming data, we force a trade at strategic moments: Run a complete simulation on the current sliding window using a decision rule based on model outputs.
- At the very first step: force BUY. This simulation (which updates env.trade_history) is used only for visualization.
- At the penultimate step: if a position is open, force SELL.
- Otherwise, default to HOLD.
(The environment will also apply a penalty if the chosen action does not match the optimal hint.)
""" """
total = len(env) env.reset() # resets the sliding window and index
if env.current_index == 0: while True:
return 2 # BUY i = env.current_index
elif env.current_index >= total - 2: state = env.get_state(i)
if env.position is not None: current_open = env.candle_window[i]["open"]
return 0 # SELL state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item()
pred_low = pred_low.item()
# Decision rule: if upward move larger than downward and above threshold, BUY; if downward is larger, SELL; else HOLD.
if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold:
action = 2 # BUY
elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold:
action = 0 # SELL
else: else:
return 1 # HOLD action = 1 # HOLD
else: _, _, _, done, _, _ = env.step(action)
return 1 # HOLD if done:
break
# --- Backtest Environment with Sliding Window and Hints --- # --- Backtest Environment with Sliding Window ---
class BacktestEnvironment: class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes, window_size=None): def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
self.candles_dict = candles_dict # full dictionary of timeframe candles self.candles_dict = candles_dict # full candles dict for all timeframes
self.base_tf = base_tf self.base_tf = base_tf
self.timeframes = timeframes self.timeframes = timeframes
# Use maximum allowed candles for the base timeframe.
self.full_candles = candles_dict[base_tf] self.full_candles = candles_dict[base_tf]
# Determine sliding window size:
if window_size is None: if window_size is None:
window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles) window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles)
self.window_size = window_size self.window_size = window_size
self.hint_penalty = 0.001 # Penalty coefficient (multiplied by open price) self.hint_penalty = 0.001 # not used in the revised loss below
self.reset() self.reset()
def reset(self): def reset(self):
@ -346,52 +347,26 @@ class BacktestEnvironment:
return self.window_size return self.window_size
def get_state(self, index): def get_state(self, index):
"""
Build state features by taking the candle at the current index for the base timeframe
(from the sliding window) and aligning candles for other timeframes.
Then append zeros for technical indicators.
"""
state_features = [] state_features = []
base_ts = self.candle_window[index]["timestamp"] base_ts = self.candle_window[index]["timestamp"]
for tf in self.timeframes: for tf in self.timeframes:
if tf == self.base_tf: if tf == self.base_tf:
# For base timeframe, use the sliding window candle.
candle = self.candle_window[index] candle = self.candle_window[index]
features = get_features_for_tf([candle], 0) # List of one element features = get_features_for_tf([candle], 0)
else: else:
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
features = get_features_for_tf(self.candles_dict[tf], aligned_idx) features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
state_features.append(features) state_features.append(features)
for _ in range(NUM_INDICATORS): for _ in range(NUM_INDICATORS):
state_features.append([0.0] * FEATURES_PER_CHANNEL) state_features.append([0.0]*FEATURES_PER_CHANNEL)
return np.array(state_features, dtype=np.float32) return np.array(state_features, dtype=np.float32)
def compute_optimal_hint(self, horizon=10, threshold=0.005):
"""
Using a lookahead window from the sliding window (future candles)
determine an optimal action hint:
2: BUY if price is expected to rise at least by threshold.
0: SELL if expected to drop by threshold.
1: HOLD otherwise.
"""
base = self.candle_window
if self.current_index >= len(base) - 1:
return 1 # Hold
current_candle = base[self.current_index]
open_price = current_candle["open"]
future_slice = base[self.current_index+1: min(self.current_index+1+horizon, len(base))]
if not future_slice:
return 1
max_future = max(candle["high"] for candle in future_slice)
min_future = min(candle["low"] for candle in future_slice)
if (max_future - open_price) / open_price >= threshold:
return 2 # BUY
elif (open_price - min_future) / open_price >= threshold:
return 0 # SELL
else:
return 1 # HOLD
def step(self, action): def step(self, action):
"""
Discrete simulation step.
- Action: 0 (SELL), 1 (HOLD), 2 (BUY).
- Trades are recorded when a BUY is followed by a SELL.
"""
base = self.candle_window base = self.candle_window
if self.current_index >= len(base) - 1: if self.current_index >= len(base) - 1:
current_state = self.get_state(self.current_index) current_state = self.get_state(self.current_index)
@ -403,13 +378,12 @@ class BacktestEnvironment:
next_candle = base[next_index] next_candle = base[next_index]
reward = 0.0 reward = 0.0
# Trade logic (0: SELL, 1: HOLD, 2: BUY) # Simple trading logic (only one position allowed at a time)
if self.position is None: if self.position is None:
if action == 2: # BUY: enter at next candle's open. if action == 2: # BUY signal: enter at next open.
entry_price = next_candle["open"] self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
else: else:
if action == 0: # SELL: exit at next candle's open. if action == 0: # SELL signal: exit at next open.
exit_price = next_candle["open"] exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"] reward = exit_price - self.position["entry_price"]
trade = { trade = {
@ -426,49 +400,49 @@ class BacktestEnvironment:
done = (self.current_index >= len(base) - 1) done = (self.current_index >= len(base) - 1)
actual_high = next_candle["high"] actual_high = next_candle["high"]
actual_low = next_candle["low"] actual_low = next_candle["low"]
# Compute optimal action hint and apply a penalty if action deviates.
optimal_hint = self.compute_optimal_hint(horizon=10, threshold=0.005)
if action != optimal_hint:
reward -= self.hint_penalty * next_candle["open"]
return current_state, reward, next_state, done, actual_high, actual_low return current_state, reward, next_state, done, actual_high, actual_low
# --- Enhanced Training Loop --- # --- Enhanced Training Loop ---
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
# Weighting factor for trade surrogate loss.
lambda_trade = 1.0
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
state = env.reset() # Reset sliding window for each epoch.
total_loss = 0.0 env.reset()
model.train() loss_accum = 0.0
while True: steps = len(env) - 1 # we use pairs of consecutive candles
# Use forced-action policy for trading (guaranteeing at least one trade per episode) for i in range(steps):
action = get_forced_action(env) state = env.get_state(i)
current_open = env.candle_window[i]["open"]
# Next candle's actual values serve as targets.
actual_high = env.candle_window[i+1]["high"]
actual_low = env.candle_window[i+1]["low"]
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high, pred_low = model(state_tensor, timeframe_ids)
# Use our forced action in the environment step. # Compute prediction loss (L1)
_, reward, next_state, done, actual_high, actual_low = env.step(action) L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \
target_high = torch.FloatTensor([actual_high]).to(device) torch.abs(pred_low - torch.tensor(actual_low, device=device))
target_low = torch.FloatTensor([actual_low]).to(device) # Compute surrogate profit (differentiable estimate)
high_loss = torch.abs(pred_high - target_high) * 2 profit_buy = pred_high - current_open # potential long gain
low_loss = torch.abs(pred_low - target_low) * 2 profit_sell = current_open - pred_low # potential short gain
loss = (high_loss + low_loss).mean() # Here we reward a higher potential move by subtracting it.
L_trade = - torch.max(profit_buy, profit_sell)
loss = L_pred + lambda_trade * L_trade
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step() optimizer.step()
total_loss += loss.item() loss_accum += loss.item()
if done:
break
state = next_state
scheduler.step() scheduler.step()
epoch_loss = total_loss / len(env) epoch_loss = loss_accum / steps
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}") print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
save_checkpoint(model, optimizer, epoch, total_loss) save_checkpoint(model, optimizer, epoch, loss_accum)
# Update live HTML chart to display the current sliding window # Update the trade simulation (for visualization) using the current model on the same window.
simulate_trades(model, env, device, args)
update_live_html(env.candle_window, env.trade_history, epoch+1) update_live_html(env.candle_window, env.trade_history, epoch+1)
# --- Live Plotting Functions (For live mode) --- # --- Live Plotting Functions (For Live Mode) ---
def live_preview_loop(candles, env): def live_preview_loop(candles, env):
plt.ion() plt.ion()
fig, ax = plt.subplots(figsize=(12, 6)) fig, ax = plt.subplots(figsize=(12, 6))
@ -480,21 +454,17 @@ def live_preview_loop(candles, env):
# --- Argument Parsing --- # --- Argument Parsing ---
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train','live','inference'], default='train') parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train')
parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005) parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.")
# If set, training starts from scratch (ignoring saved checkpoints) parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for the trade surrogate loss.")
parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch.') parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
return parser.parse_args() return parser.parse_args()
def random_action():
return random.randint(0, 2)
# --- Main Function --- # --- Main Function ---
async def main(): async def main():
args = parse_args() args = parse_args()
# Use GPU if available; else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device) print("Using device:", device)
timeframes = ["1m", "5m", "15m", "1h", "1d"] timeframes = ["1m", "5m", "15m", "1h", "1d"]
@ -508,9 +478,8 @@ async def main():
print("No historical candle data available for backtesting.") print("No historical candle data available for backtesting.")
return return
base_tf = "1m" base_tf = "1m"
# Create the environment with a sliding window (simulate streaming data) # Use a sliding window of up to 100 candles (if available)
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100) env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
start_epoch = 0 start_epoch = 0
checkpoint = None checkpoint = None
if not args.start_fresh: if not args.start_fresh:
@ -522,7 +491,6 @@ async def main():
print("No checkpoint found. Starting training from scratch.") print("No checkpoint found. Starting training from scratch.")
else: else:
print("Starting training from scratch as requested.") print("Starting training from scratch as requested.")
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
if checkpoint is not None: if checkpoint is not None:
@ -543,18 +511,31 @@ async def main():
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100) env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100)
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True) preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
preview_thread.start() preview_thread.start()
print("Starting live trading loop. (Using forced-action policy for simulation.)") print("Starting live trading loop. (Using model-based decision rule.)")
while True: while True:
action = get_forced_action(env) # In live mode, we use the simulation decision rule.
state, reward, next_state, done, _, _ = env.step(action) state = env.get_state(env.current_index)
current_open = env.candle_window[env.current_index]["open"]
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item()
pred_low = pred_low.item()
if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold:
action = 2
elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold:
action = 0
else:
action = 1
_, _, _, done, _, _ = env.step(action)
if done: if done:
print("Reached end of simulation window, resetting environment.") print("Reached end of simulation window; resetting environment.")
state = env.reset() env.reset()
await asyncio.sleep(1) await asyncio.sleep(1)
elif args.mode == 'inference': elif args.mode == 'inference':
load_best_checkpoint(model) load_best_checkpoint(model)
print("Running inference...") print("Running inference...")
# Apply a similar (or learned) policy as needed. # Inference logic can use a similar decision rule as in live mode.
else: else:
print("Invalid mode specified.") print("Invalid mode specified.")

File diff suppressed because one or more lines are too long