more efficient traning

This commit is contained in:
Dobromir Popov 2025-02-04 22:16:39 +02:00
parent d2686b31b7
commit 39ce152391

View File

@ -68,9 +68,10 @@ class TradingModel(nn.Module):
# Embedding for channels 0..num_channels-1. # Embedding for channels 0..num_channels-1.
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim) self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
self.pos_encoder = PositionalEncoding(hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim)
# Set batch_first=True to avoid the nested tensor warning.
encoder_layers = TransformerEncoderLayer( encoder_layers = TransformerEncoderLayer(
d_model=hidden_dim, nhead=4, dim_feedforward=512, d_model=hidden_dim, nhead=4, dim_feedforward=512,
dropout=0.1, activation='gelu', batch_first=False dropout=0.1, activation='gelu', batch_first=True
) )
self.transformer = TransformerEncoder(encoder_layers, num_layers=2) self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
self.attn_pool = nn.Linear(hidden_dim, 1) self.attn_pool = nn.Linear(hidden_dim, 1)
@ -92,13 +93,13 @@ class TradingModel(nn.Module):
channel_out = self.channel_branches[i](x[:, i, :]) channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out) channel_outs.append(channel_out)
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden] stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden] # Notice that with batch_first=True, we want shape [batch, channels, hidden]
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) tf_embeds = self.timeframe_embed(timeframe_ids)
stacked = stacked + tf_embeds stacked = stacked + tf_embeds.unsqueeze(0) # add embedding to each sample in batch
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device) # The Transformer expects input of shape [batch, seq_len, hidden] when batch_first=True.
transformer_out = self.transformer(stacked, mask=src_mask) transformer_out = self.transformer(stacked)
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0) attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1)
aggregated = (transformer_out * attn_weights).sum(dim=0) aggregated = (transformer_out * attn_weights).sum(dim=1)
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze() return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
# --- Technical Indicator Helpers --- # --- Technical Indicator Helpers ---
@ -210,7 +211,21 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
path = os.path.join(best_dir, best_file) path = os.path.join(best_dir, best_file)
print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}") print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}")
checkpoint = torch.load(path) checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"]) old_state = checkpoint["model_state_dict"]
new_state = model.state_dict()
# Fix the size mismatch for timeframe_embed.weight.
if "timeframe_embed.weight" in old_state:
old_embed = old_state["timeframe_embed.weight"]
new_embed = new_state["timeframe_embed.weight"]
if old_embed.shape[0] < new_embed.shape[0]:
# Copy the available rows and keep the remaining as initialized.
new_embed[:old_embed.shape[0]] = old_embed
old_state["timeframe_embed.weight"] = new_embed
# (For channel_branches, if the checkpoint has fewer branches than your new model expects,
# missing branches will be left at their randomly initialized values.)
model.load_state_dict(old_state, strict=False)
return checkpoint return checkpoint
# --- Live HTML Chart Update --- # --- Live HTML Chart Update ---
@ -283,24 +298,13 @@ def update_live_chart(ax, candles, trade_history):
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1) ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
# Format x-axis date labels. # Format x-axis date labels.
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
# Plot each trade.
buy_label_added = False
sell_label_added = False
for trade in trade_history: for trade in trade_history:
entry_time = datetime.fromtimestamp(candles[trade["entry_index"]]["timestamp"]) entry_time = datetime.fromtimestamp(candles[trade["entry_index"]]["timestamp"])
exit_time = datetime.fromtimestamp(candles[trade["exit_index"]]["timestamp"]) exit_time = datetime.fromtimestamp(candles[trade["exit_index"]]["timestamp"])
in_price = trade["entry_price"] in_price = trade["entry_price"]
out_price = trade["exit_price"] out_price = trade["exit_price"]
if not buy_label_added: ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY")
ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL")
buy_label_added = True
else:
ax.plot(entry_time, in_price, marker="^", color="green", markersize=10)
if not sell_label_added:
ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL")
sell_label_added = True
else:
ax.plot(exit_time, out_price, marker="v", color="red", markersize=10)
ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue") ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue")
ax.set_xlabel("Time") ax.set_xlabel("Time")
ax.set_ylabel("Price") ax.set_ylabel("Price")
@ -325,7 +329,6 @@ def simulate_trades(model, env, device, args):
pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item() pred_high = pred_high.item()
pred_low = pred_low.item() pred_low = pred_low.item()
# Simple decision rule based on predicted move.
if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold: if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold:
action = 2 # BUY action = 2 # BUY
elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold: elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold:
@ -349,7 +352,6 @@ class BacktestEnvironment:
self.reset() self.reset()
def reset(self): def reset(self):
# Randomly select a sliding window from the full dataset.
self.start_index = random.randint(0, len(self.full_candles) - self.window_size) self.start_index = random.randint(0, len(self.full_candles) - self.window_size)
self.candle_window = self.full_candles[self.start_index: self.start_index + self.window_size] self.candle_window = self.full_candles[self.start_index: self.start_index + self.window_size]
self.current_index = 0 self.current_index = 0
@ -361,12 +363,6 @@ class BacktestEnvironment:
return self.window_size return self.window_size
def get_order_features(self, index): def get_order_features(self, index):
"""
Returns a list of 7 features for the order channel.
If an order is open, the first element is 1.0 and the second is the normalized difference:
(current open - entry_price) / current open.
Otherwise, returns zeros.
"""
candle = self.candle_window[index] candle = self.candle_window[index]
if self.position is None: if self.position is None:
return [0.0] * FEATURES_PER_CHANNEL return [0.0] * FEATURES_PER_CHANNEL
@ -376,13 +372,6 @@ class BacktestEnvironment:
return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2) return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2)
def get_state(self, index): def get_state(self, index):
"""
Build state features from:
- For each timeframe: features from the aligned candle.
- One extra channel: current order information.
- NUM_INDICATORS channels of zeros.
Each channel is a vector of length FEATURES_PER_CHANNEL.
"""
state_features = [] state_features = []
base_ts = self.candle_window[index]["timestamp"] base_ts = self.candle_window[index]["timestamp"]
for tf in self.timeframes: for tf in self.timeframes:
@ -393,36 +382,27 @@ class BacktestEnvironment:
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
features = get_features_for_tf(self.candles_dict[tf], aligned_idx) features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
state_features.append(features) state_features.append(features)
# Append order channel.
order_features = self.get_order_features(index) order_features = self.get_order_features(index)
state_features.append(order_features) state_features.append(order_features)
# Append technical indicator channels.
for _ in range(NUM_INDICATORS): for _ in range(NUM_INDICATORS):
state_features.append([0.0] * FEATURES_PER_CHANNEL) state_features.append([0.0] * FEATURES_PER_CHANNEL)
return np.array(state_features, dtype=np.float32) return np.array(state_features, dtype=np.float32)
def step(self, action): def step(self, action):
"""
Execute one step in the environment:
- action: 0 => SELL, 1 => HOLD, 2 => BUY.
- Trades are recorded when a BUY is followed by a SELL.
"""
base = self.candle_window base = self.candle_window
if self.current_index >= len(base) - 1: if self.current_index >= len(base) - 1:
current_state = self.get_state(self.current_index) current_state = self.get_state(self.current_index)
return current_state, 0.0, None, True, 0.0, 0.0 return current_state, 0.0, None, True, 0.0, 0.0
current_state = self.get_state(self.current_index) current_state = self.get_state(self.current_index)
next_index = self.current_index + 1 next_index = self.current_index + 1
next_state = self.get_state(next_index) next_state = self.get_state(next_index)
next_candle = base[next_index] next_candle = base[next_index]
reward = 0.0 reward = 0.0
if self.position is None: if self.position is None:
if action == 2: # BUY signal: enter at next open. if action == 2:
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
else: else:
if action == 0: # SELL signal: exit at next open. if action == 0:
exit_price = next_candle["open"] exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"] reward = exit_price - self.position["entry_price"]
trade = { trade = {
@ -434,7 +414,6 @@ class BacktestEnvironment:
} }
self.trade_history.append(trade) self.trade_history.append(trade)
self.position = None self.position = None
self.current_index = next_index self.current_index = next_index
done = (self.current_index >= len(base) - 1) done = (self.current_index >= len(base) - 1)
actual_high = next_candle["high"] actual_high = next_candle["high"]
@ -443,11 +422,11 @@ class BacktestEnvironment:
# --- Enhanced Training Loop --- # --- Enhanced Training Loop ---
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
lambda_trade = args.lambda_trade # Weight for the surrogate profit loss. lambda_trade = args.lambda_trade
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
env.reset() # Resets the sliding window. env.reset()
loss_accum = 0.0 loss_accum = 0.0
steps = len(env) - 1 # We assume steps over consecutive candle pairs. steps = len(env) - 1
for i in range(steps): for i in range(steps):
state = env.get_state(i) state = env.get_state(i)
current_open = env.candle_window[i]["open"] current_open = env.candle_window[i]["open"]
@ -456,14 +435,11 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high, pred_low = model(state_tensor, timeframe_ids)
# Prediction loss (L1 error).
L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \ L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \
torch.abs(pred_low - torch.tensor(actual_low, device=device)) torch.abs(pred_low - torch.tensor(actual_low, device=device))
# Surrogate profit loss. profit_buy = pred_high - current_open
profit_buy = pred_high - current_open # potential long gain profit_sell = current_open - pred_low
profit_sell = current_open - pred_low # potential short gain
L_trade = - torch.max(profit_buy, profit_sell) L_trade = - torch.max(profit_buy, profit_sell)
# Additional penalty if no strong signal is produced.
current_open_tensor = torch.tensor(current_open, device=device) current_open_tensor = torch.tensor(current_open, device=device)
signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low) signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low)
penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0) penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0)
@ -497,7 +473,7 @@ def parse_args():
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.") parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.")
parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.") parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.")
parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty for not taking an action (if predicted move is below threshold).") parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty if no action is taken.")
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.") parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
return parser.parse_args() return parser.parse_args()
@ -511,7 +487,6 @@ async def main():
print("Using device:", device) print("Using device:", device)
timeframes = ["1m", "5m", "15m", "1h", "1d"] timeframes = ["1m", "5m", "15m", "1h", "1d"]
hidden_dim = 128 hidden_dim = 128
# Total channels: NUM_TIMEFRAMES + 1 (order info) + NUM_INDICATORS.
total_channels = NUM_TIMEFRAMES + 1 + NUM_INDICATORS total_channels = NUM_TIMEFRAMES + 1 + NUM_INDICATORS
model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device) model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device)
@ -541,7 +516,7 @@ async def main():
optimizer.load_state_dict(optim_state) optimizer.load_state_dict(optim_state)
print("Loaded optimizer state from checkpoint.") print("Loaded optimizer state from checkpoint.")
else: else:
print("No valid optimizer state found in checkpoint; starting fresh optimizer state.") print("No valid optimizer state found; using fresh optimizer state.")
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler) train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
elif args.mode == 'live': elif args.mode == 'live':
@ -576,7 +551,7 @@ async def main():
elif args.mode == 'inference': elif args.mode == 'inference':
load_best_checkpoint(model) load_best_checkpoint(model)
print("Running inference...") print("Running inference...")
# Your inference logic goes here. # Inference logic goes here.
else: else:
print("Invalid mode specified.") print("Invalid mode specified.")