added volume. better training

This commit is contained in:
Dobromir Popov 2025-02-04 22:33:44 +02:00
parent 75c4d6602a
commit dc5df52292
2 changed files with 119 additions and 68 deletions

View File

@ -44,10 +44,8 @@ CACHE_FILE = "candles_cache.json"
# --- Constants --- # --- Constants ---
NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"] NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"]
NUM_INDICATORS = 20 # e.g., 20 technical indicators NUM_INDICATORS = 20 # e.g., 20 technical indicators
# Each channel input will have 7 features. FEATURES_PER_CHANNEL = 7 # Each channel input will have 7 features.
FEATURES_PER_CHANNEL = 7 ORDER_CHANNELS = 1 # One extra channel for order info.
# We add one extra channel for order information.
ORDER_CHANNELS = 1
# --- Positional Encoding Module --- # --- Positional Encoding Module ---
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
@ -68,7 +66,6 @@ class PositionalEncoding(nn.Module):
class TradingModel(nn.Module): class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128): def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__() super().__init__()
# Create one branch per channel.
self.channel_branches = nn.ModuleList([ self.channel_branches = nn.ModuleList([
nn.Sequential( nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
@ -77,13 +74,11 @@ class TradingModel(nn.Module):
nn.Dropout(0.1) nn.Dropout(0.1)
) for _ in range(num_channels) ) for _ in range(num_channels)
]) ])
# Embedding for channels 0..num_channels-1.
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim) self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
self.pos_encoder = PositionalEncoding(hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim)
# Set batch_first=True to avoid the nested tensor warning.
encoder_layers = TransformerEncoderLayer( encoder_layers = TransformerEncoderLayer(
d_model=hidden_dim, nhead=4, dim_feedforward=512, d_model=hidden_dim, nhead=4, dim_feedforward=512,
dropout=0.1, activation='gelu', batch_first=True dropout=0.1, activation='gelu', batch_first=True # Use batch_first to avoid nested tensor warning.
) )
self.transformer = TransformerEncoder(encoder_layers, num_layers=2) self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
self.attn_pool = nn.Linear(hidden_dim, 1) self.attn_pool = nn.Linear(hidden_dim, 1)
@ -104,11 +99,9 @@ class TradingModel(nn.Module):
for i in range(num_channels): for i in range(num_channels):
channel_out = self.channel_branches[i](x[:, i, :]) channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out) channel_outs.append(channel_out)
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden] stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
# With batch_first=True, the expected input is [batch, seq_len, hidden]
tf_embeds = self.timeframe_embed(timeframe_ids) # shape: [num_channels, hidden] tf_embeds = self.timeframe_embed(timeframe_ids) # shape: [num_channels, hidden]
# Expand tf_embeds to match the batch dimension. stacked = stacked + tf_embeds.unsqueeze(0) # broadcast along batch dimension.
stacked = stacked + tf_embeds.unsqueeze(0)
transformer_out = self.transformer(stacked) transformer_out = self.transformer(stacked)
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1) attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1)
aggregated = (transformer_out * attn_weights).sum(dim=1) aggregated = (transformer_out * attn_weights).sum(dim=1)
@ -225,32 +218,78 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
checkpoint = torch.load(path) checkpoint = torch.load(path)
old_state = checkpoint["model_state_dict"] old_state = checkpoint["model_state_dict"]
new_state = model.state_dict() new_state = model.state_dict()
# Fix the size mismatch for timeframe_embed.weight.
if "timeframe_embed.weight" in old_state: if "timeframe_embed.weight" in old_state:
old_embed = old_state["timeframe_embed.weight"] old_embed = old_state["timeframe_embed.weight"]
new_embed = new_state["timeframe_embed.weight"] new_embed = new_state["timeframe_embed.weight"]
if old_embed.shape[0] < new_embed.shape[0]: if old_embed.shape[0] < new_embed.shape[0]:
new_embed[:old_embed.shape[0]] = old_embed new_embed[:old_embed.shape[0]] = old_embed
old_state["timeframe_embed.weight"] = new_embed old_state["timeframe_embed.weight"] = new_embed
# For channel_branches, missing keys are handled by strict=False.
model.load_state_dict(old_state, strict=False) model.load_state_dict(old_state, strict=False)
return checkpoint return checkpoint
# --- Function for Manual Trade Override ---
def manual_trade(env):
"""
When no sufficient action is taken by the model, manually decide the trade.
Find the maximum high and minimum low in the remaining window.
If maximum occurs before minimum, we short; otherwise we long.
The trade is closed at the candle where the chosen extreme occurs.
"""
current_index = env.current_index
if current_index >= len(env.candle_window) - 1:
env.current_index = len(env.candle_window) - 1
return
max_val = -float('inf')
min_val = float('inf')
i_max = current_index
i_min = current_index
for j in range(current_index + 1, len(env.candle_window)):
high_j = env.candle_window[j]["high"]
low_j = env.candle_window[j]["low"]
if high_j > max_val:
max_val = high_j
i_max = j
if low_j < min_val:
min_val = low_j
i_min = j
# If maximum occurs before minimum, we interpret that as short (price will drop).
if i_max < i_min:
entry_price = env.candle_window[current_index]["open"]
exit_price = env.candle_window[i_min]["open"]
reward = entry_price - exit_price
trade = {
"entry_index": current_index,
"entry_price": entry_price,
"exit_index": i_min,
"exit_price": exit_price,
"pnl": reward
}
else:
entry_price = env.candle_window[current_index]["open"]
exit_price = env.candle_window[i_max]["open"]
reward = exit_price - entry_price
trade = {
"entry_index": current_index,
"entry_price": entry_price,
"exit_index": i_max,
"exit_price": exit_price,
"pnl": reward
}
env.trade_history.append(trade)
env.current_index = trade["exit_index"]
# --- Live HTML Chart Update --- # --- Live HTML Chart Update ---
def update_live_html(candles, trade_history, epoch): def update_live_html(candles, trade_history, epoch):
""" """
Generate a chart image that uses actual timestamps on the x-axis Generate a chart image with actual timestamps on the x-axis and cumulative epoch PnL.
and shows a cumulative epoch PnL. The chart (with buy/sell markers and dotted lines) The chart now also plots volume as a bar chart on a secondary y-axis.
is embedded in an HTML page that auto-refreshes every 1 seconds. The HTML page auto-refreshes every 10 seconds.
""" """
from io import BytesIO from io import BytesIO
import base64 import base64
fig, ax = plt.subplots(figsize=(12, 6)) fig, ax = plt.subplots(figsize=(12, 6))
update_live_chart(ax, candles, trade_history) update_live_chart(ax, candles, trade_history)
# Compute cumulative epoch PnL.
epoch_pnl = sum(trade["pnl"] for trade in trade_history) epoch_pnl = sum(trade["pnl"] for trade in trade_history)
ax.set_title(f"Live Trading Chart - Epoch {epoch} | PnL: {epoch_pnl:.2f}") ax.set_title(f"Live Trading Chart - Epoch {epoch} | PnL: {epoch_pnl:.2f}")
buf = BytesIO() buf = BytesIO()
@ -263,7 +302,7 @@ def update_live_html(candles, trade_history, epoch):
<html> <html>
<head> <head>
<meta charset="utf-8"> <meta charset="utf-8">
<meta http-equiv="refresh" content="1"> <meta http-equiv="refresh" content="10">
<title>Live Trading Chart - Epoch {epoch}</title> <title>Live Trading Chart - Epoch {epoch}</title>
<style> <style>
body {{ body {{
@ -298,27 +337,34 @@ def update_live_html(candles, trade_history, epoch):
# --- Chart Drawing Helpers --- # --- Chart Drawing Helpers ---
def update_live_chart(ax, candles, trade_history): def update_live_chart(ax, candles, trade_history):
""" """
Plot the price chart with actual timestamps on the x-axis. Plot the price chart with proper timestamp conversion.
Mark BUY (green) and SELL (red) actions, and draw dotted lines between entry and exit. Mark BUY (green) and SELL (red) actions (with dotted lines between),
and plot volume as a bar chart on a secondary y-axis.
""" """
ax.clear() ax.clear()
# Use the helper to convert timestamps safely.
times = [convert_timestamp(candle["timestamp"]) for candle in candles] times = [convert_timestamp(candle["timestamp"]) for candle in candles]
close_prices = [candle["close"] for candle in candles] close_prices = [candle["close"] for candle in candles]
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1) ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
# Format x-axis date labels.
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
for trade in trade_history:
entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"])
exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"])
in_price = trade["entry_price"]
out_price = trade["exit_price"]
ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY")
ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL")
ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue")
ax.set_xlabel("Time") ax.set_xlabel("Time")
ax.set_ylabel("Price") ax.set_ylabel("Price")
ax.legend() ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
# Plot volume on secondary axis.
ax2 = ax.twinx()
volumes = [candle["volume"] for candle in candles]
# Compute bar width in days.
if len(times) > 1:
times_num = mdates.date2num(times)
bar_width = (times_num[-1] - times_num[0]) / len(times) * 0.8
else:
bar_width = 0.01
ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume")
ax2.set_ylabel("Volume")
# Combine legends.
lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines + lines2, labels + labels2)
ax.grid(True) ax.grid(True)
fig = ax.get_figure() fig = ax.get_figure()
fig.autofmt_xdate() fig.autofmt_xdate()
@ -326,14 +372,15 @@ def update_live_chart(ax, candles, trade_history):
# --- Simulation of Trades for Visualization --- # --- Simulation of Trades for Visualization ---
def simulate_trades(model, env, device, args): def simulate_trades(model, env, device, args):
""" """
Run a simulation on the current sliding window using the model's outputs and a decision rule. Run a simulation on the current sliding window.
Here we force the simulation to always take an action by comparing the predicted potentials, If the model produces a sufficiently strong signal (based on threshold), use its action.
ensuring that the model is forced to trade (either BUY or SELL) rather than HOLD. Otherwise, manually compute the trade by scanning for max/min prices.
This simulation updates env.trade_history and is used for visualization only.
""" """
env.reset() # resets the window and index env.reset()
while True: while True:
i = env.current_index i = env.current_index
if i >= len(env.candle_window) - 1:
break
state = env.get_state(i) state = env.get_state(i)
current_open = env.candle_window[i]["open"] current_open = env.candle_window[i]["open"]
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
@ -341,19 +388,23 @@ def simulate_trades(model, env, device, args):
pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item() pred_high = pred_high.item()
pred_low = pred_low.item() pred_low = pred_low.item()
# Force a trade: choose BUY if predicted up-move is higher (or equal), else SELL. # If either upward potential or downward potential exceeds the threshold, use model decision.
if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
if (pred_high - current_open) >= (current_open - pred_low): if (pred_high - current_open) >= (current_open - pred_low):
action = 2 # BUY action = 2 # BUY
else: else:
action = 0 # SELL action = 0 # SELL
_, _, _, done, _, _ = env.step(action) _, _, _, done, _, _ = env.step(action)
if done: else:
# No significant signal; use manual trade computation.
manual_trade(env)
if env.current_index >= len(env.candle_window) - 1:
break break
# --- Backtest Environment with Sliding Window and Order Info --- # --- Backtest Environment with Sliding Window and Order Info ---
class BacktestEnvironment: class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes, window_size=None): def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
self.candles_dict = candles_dict # full candles dict across timeframes self.candles_dict = candles_dict
self.base_tf = base_tf self.base_tf = base_tf
self.timeframes = timeframes self.timeframes = timeframes
self.full_candles = candles_dict[base_tf] self.full_candles = candles_dict[base_tf]
@ -361,7 +412,6 @@ class BacktestEnvironment:
window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles) window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles)
self.window_size = window_size self.window_size = window_size
self.reset() self.reset()
def reset(self): def reset(self):
self.start_index = random.randint(0, len(self.full_candles) - self.window_size) self.start_index = random.randint(0, len(self.full_candles) - self.window_size)
self.candle_window = self.full_candles[self.start_index: self.start_index + self.window_size] self.candle_window = self.full_candles[self.start_index: self.start_index + self.window_size]
@ -369,10 +419,8 @@ class BacktestEnvironment:
self.trade_history = [] self.trade_history = []
self.position = None self.position = None
return self.get_state(self.current_index) return self.get_state(self.current_index)
def __len__(self): def __len__(self):
return self.window_size return self.window_size
def get_order_features(self, index): def get_order_features(self, index):
candle = self.candle_window[index] candle = self.candle_window[index]
if self.position is None: if self.position is None:
@ -381,7 +429,6 @@ class BacktestEnvironment:
flag = 1.0 flag = 1.0
diff = (candle["open"] - self.position["entry_price"]) / candle["open"] diff = (candle["open"] - self.position["entry_price"]) / candle["open"]
return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2) return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2)
def get_state(self, index): def get_state(self, index):
state_features = [] state_features = []
base_ts = self.candle_window[index]["timestamp"] base_ts = self.candle_window[index]["timestamp"]
@ -398,7 +445,6 @@ class BacktestEnvironment:
for _ in range(NUM_INDICATORS): for _ in range(NUM_INDICATORS):
state_features.append([0.0] * FEATURES_PER_CHANNEL) state_features.append([0.0] * FEATURES_PER_CHANNEL)
return np.array(state_features, dtype=np.float32) return np.array(state_features, dtype=np.float32)
def step(self, action): def step(self, action):
base = self.candle_window base = self.candle_window
if self.current_index >= len(base) - 1: if self.current_index >= len(base) - 1:
@ -410,10 +456,10 @@ class BacktestEnvironment:
next_candle = base[next_index] next_candle = base[next_index]
reward = 0.0 reward = 0.0
if self.position is None: if self.position is None:
if action == 2: if action == 2: # BUY (open long)
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
else: else:
if action == 0: if action == 0: # SELL (close long / exit trade)
exit_price = next_candle["open"] exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"] reward = exit_price - self.position["entry_price"]
trade = { trade = {
@ -480,11 +526,14 @@ def live_preview_loop(candles, env):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train') parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train')
parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade (used in loss).") parser.add_argument('--threshold', type=float, default=0.005,
parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.") help="Minimum predicted move to trigger trade (used in loss; model may override with manual trade).")
parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty if no action is taken (used in loss).") parser.add_argument('--lambda_trade', type=float, default=1.0,
help="Weight for trade surrogate loss.")
parser.add_argument('--penalty_noaction', type=float, default=10.0,
help="Penalty if no action is taken (used in loss).")
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.") parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
return parser.parse_args() return parser.parse_args()
@ -546,7 +595,7 @@ async def main():
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100) env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100)
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True) preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
preview_thread.start() preview_thread.start()
print("Starting live trading loop. (Forcing trade actions based on highest potential.)") print("Starting live trading loop. (Using model, with manual override for HOLD actions.)")
while True: while True:
state = env.get_state(env.current_index) state = env.get_state(env.current_index)
current_open = env.candle_window[env.current_index]["open"] current_open = env.candle_window[env.current_index]["open"]
@ -555,13 +604,15 @@ async def main():
pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item() pred_high = pred_high.item()
pred_low = pred_low.item() pred_low = pred_low.item()
# Force a trade (choose BUY if upward potential >= downward, else SELL) if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
if (pred_high - current_open) >= (current_open - pred_low): if (pred_high - current_open) >= (current_open - pred_low):
action = 2 action = 2 # BUY
else: else:
action = 0 action = 0 # SELL
_, _, _, done, _, _ = env.step(action) _, _, _, done, _, _ = env.step(action)
if done: else:
manual_trade(env)
if env.current_index >= len(env.candle_window)-1:
print("Reached end of simulation window; resetting environment.") print("Reached end of simulation window; resetting environment.")
env.reset() env.reset()
await asyncio.sleep(1) await asyncio.sleep(1)

File diff suppressed because one or more lines are too long