more efficient traning
This commit is contained in:
parent
d2686b31b7
commit
39ce152391
@ -68,9 +68,10 @@ class TradingModel(nn.Module):
|
||||
# Embedding for channels 0..num_channels-1.
|
||||
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
|
||||
self.pos_encoder = PositionalEncoding(hidden_dim)
|
||||
# Set batch_first=True to avoid the nested tensor warning.
|
||||
encoder_layers = TransformerEncoderLayer(
|
||||
d_model=hidden_dim, nhead=4, dim_feedforward=512,
|
||||
dropout=0.1, activation='gelu', batch_first=False
|
||||
dropout=0.1, activation='gelu', batch_first=True
|
||||
)
|
||||
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
|
||||
self.attn_pool = nn.Linear(hidden_dim, 1)
|
||||
@ -92,13 +93,13 @@ class TradingModel(nn.Module):
|
||||
channel_out = self.channel_branches[i](x[:, i, :])
|
||||
channel_outs.append(channel_out)
|
||||
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
|
||||
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
|
||||
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
|
||||
stacked = stacked + tf_embeds
|
||||
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
|
||||
transformer_out = self.transformer(stacked, mask=src_mask)
|
||||
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0)
|
||||
aggregated = (transformer_out * attn_weights).sum(dim=0)
|
||||
# Notice that with batch_first=True, we want shape [batch, channels, hidden]
|
||||
tf_embeds = self.timeframe_embed(timeframe_ids)
|
||||
stacked = stacked + tf_embeds.unsqueeze(0) # add embedding to each sample in batch
|
||||
# The Transformer expects input of shape [batch, seq_len, hidden] when batch_first=True.
|
||||
transformer_out = self.transformer(stacked)
|
||||
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1)
|
||||
aggregated = (transformer_out * attn_weights).sum(dim=1)
|
||||
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
|
||||
|
||||
# --- Technical Indicator Helpers ---
|
||||
@ -210,7 +211,21 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
|
||||
path = os.path.join(best_dir, best_file)
|
||||
print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}")
|
||||
checkpoint = torch.load(path)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
old_state = checkpoint["model_state_dict"]
|
||||
new_state = model.state_dict()
|
||||
|
||||
# Fix the size mismatch for timeframe_embed.weight.
|
||||
if "timeframe_embed.weight" in old_state:
|
||||
old_embed = old_state["timeframe_embed.weight"]
|
||||
new_embed = new_state["timeframe_embed.weight"]
|
||||
if old_embed.shape[0] < new_embed.shape[0]:
|
||||
# Copy the available rows and keep the remaining as initialized.
|
||||
new_embed[:old_embed.shape[0]] = old_embed
|
||||
old_state["timeframe_embed.weight"] = new_embed
|
||||
|
||||
# (For channel_branches, if the checkpoint has fewer branches than your new model expects,
|
||||
# missing branches will be left at their randomly initialized values.)
|
||||
model.load_state_dict(old_state, strict=False)
|
||||
return checkpoint
|
||||
|
||||
# --- Live HTML Chart Update ---
|
||||
@ -283,24 +298,13 @@ def update_live_chart(ax, candles, trade_history):
|
||||
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
|
||||
# Format x-axis date labels.
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
|
||||
# Plot each trade.
|
||||
buy_label_added = False
|
||||
sell_label_added = False
|
||||
for trade in trade_history:
|
||||
entry_time = datetime.fromtimestamp(candles[trade["entry_index"]]["timestamp"])
|
||||
exit_time = datetime.fromtimestamp(candles[trade["exit_index"]]["timestamp"])
|
||||
in_price = trade["entry_price"]
|
||||
out_price = trade["exit_price"]
|
||||
if not buy_label_added:
|
||||
ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY")
|
||||
buy_label_added = True
|
||||
else:
|
||||
ax.plot(entry_time, in_price, marker="^", color="green", markersize=10)
|
||||
if not sell_label_added:
|
||||
ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL")
|
||||
sell_label_added = True
|
||||
else:
|
||||
ax.plot(exit_time, out_price, marker="v", color="red", markersize=10)
|
||||
ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue")
|
||||
ax.set_xlabel("Time")
|
||||
ax.set_ylabel("Price")
|
||||
@ -325,7 +329,6 @@ def simulate_trades(model, env, device, args):
|
||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||
pred_high = pred_high.item()
|
||||
pred_low = pred_low.item()
|
||||
# Simple decision rule based on predicted move.
|
||||
if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold:
|
||||
action = 2 # BUY
|
||||
elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold:
|
||||
@ -349,7 +352,6 @@ class BacktestEnvironment:
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
# Randomly select a sliding window from the full dataset.
|
||||
self.start_index = random.randint(0, len(self.full_candles) - self.window_size)
|
||||
self.candle_window = self.full_candles[self.start_index: self.start_index + self.window_size]
|
||||
self.current_index = 0
|
||||
@ -361,12 +363,6 @@ class BacktestEnvironment:
|
||||
return self.window_size
|
||||
|
||||
def get_order_features(self, index):
|
||||
"""
|
||||
Returns a list of 7 features for the order channel.
|
||||
If an order is open, the first element is 1.0 and the second is the normalized difference:
|
||||
(current open - entry_price) / current open.
|
||||
Otherwise, returns zeros.
|
||||
"""
|
||||
candle = self.candle_window[index]
|
||||
if self.position is None:
|
||||
return [0.0] * FEATURES_PER_CHANNEL
|
||||
@ -376,13 +372,6 @@ class BacktestEnvironment:
|
||||
return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2)
|
||||
|
||||
def get_state(self, index):
|
||||
"""
|
||||
Build state features from:
|
||||
- For each timeframe: features from the aligned candle.
|
||||
- One extra channel: current order information.
|
||||
- NUM_INDICATORS channels of zeros.
|
||||
Each channel is a vector of length FEATURES_PER_CHANNEL.
|
||||
"""
|
||||
state_features = []
|
||||
base_ts = self.candle_window[index]["timestamp"]
|
||||
for tf in self.timeframes:
|
||||
@ -393,36 +382,27 @@ class BacktestEnvironment:
|
||||
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
|
||||
features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
|
||||
state_features.append(features)
|
||||
# Append order channel.
|
||||
order_features = self.get_order_features(index)
|
||||
state_features.append(order_features)
|
||||
# Append technical indicator channels.
|
||||
for _ in range(NUM_INDICATORS):
|
||||
state_features.append([0.0] * FEATURES_PER_CHANNEL)
|
||||
return np.array(state_features, dtype=np.float32)
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Execute one step in the environment:
|
||||
- action: 0 => SELL, 1 => HOLD, 2 => BUY.
|
||||
- Trades are recorded when a BUY is followed by a SELL.
|
||||
"""
|
||||
base = self.candle_window
|
||||
if self.current_index >= len(base) - 1:
|
||||
current_state = self.get_state(self.current_index)
|
||||
return current_state, 0.0, None, True, 0.0, 0.0
|
||||
|
||||
current_state = self.get_state(self.current_index)
|
||||
next_index = self.current_index + 1
|
||||
next_state = self.get_state(next_index)
|
||||
next_candle = base[next_index]
|
||||
reward = 0.0
|
||||
|
||||
if self.position is None:
|
||||
if action == 2: # BUY signal: enter at next open.
|
||||
if action == 2:
|
||||
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
|
||||
else:
|
||||
if action == 0: # SELL signal: exit at next open.
|
||||
if action == 0:
|
||||
exit_price = next_candle["open"]
|
||||
reward = exit_price - self.position["entry_price"]
|
||||
trade = {
|
||||
@ -434,7 +414,6 @@ class BacktestEnvironment:
|
||||
}
|
||||
self.trade_history.append(trade)
|
||||
self.position = None
|
||||
|
||||
self.current_index = next_index
|
||||
done = (self.current_index >= len(base) - 1)
|
||||
actual_high = next_candle["high"]
|
||||
@ -443,11 +422,11 @@ class BacktestEnvironment:
|
||||
|
||||
# --- Enhanced Training Loop ---
|
||||
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
|
||||
lambda_trade = args.lambda_trade # Weight for the surrogate profit loss.
|
||||
lambda_trade = args.lambda_trade
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
env.reset() # Resets the sliding window.
|
||||
env.reset()
|
||||
loss_accum = 0.0
|
||||
steps = len(env) - 1 # We assume steps over consecutive candle pairs.
|
||||
steps = len(env) - 1
|
||||
for i in range(steps):
|
||||
state = env.get_state(i)
|
||||
current_open = env.candle_window[i]["open"]
|
||||
@ -456,14 +435,11 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||
# Prediction loss (L1 error).
|
||||
L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \
|
||||
torch.abs(pred_low - torch.tensor(actual_low, device=device))
|
||||
# Surrogate profit loss.
|
||||
profit_buy = pred_high - current_open # potential long gain
|
||||
profit_sell = current_open - pred_low # potential short gain
|
||||
profit_buy = pred_high - current_open
|
||||
profit_sell = current_open - pred_low
|
||||
L_trade = - torch.max(profit_buy, profit_sell)
|
||||
# Additional penalty if no strong signal is produced.
|
||||
current_open_tensor = torch.tensor(current_open, device=device)
|
||||
signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low)
|
||||
penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0)
|
||||
@ -497,7 +473,7 @@ def parse_args():
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.")
|
||||
parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.")
|
||||
parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty for not taking an action (if predicted move is below threshold).")
|
||||
parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty if no action is taken.")
|
||||
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
|
||||
return parser.parse_args()
|
||||
|
||||
@ -511,7 +487,6 @@ async def main():
|
||||
print("Using device:", device)
|
||||
timeframes = ["1m", "5m", "15m", "1h", "1d"]
|
||||
hidden_dim = 128
|
||||
# Total channels: NUM_TIMEFRAMES + 1 (order info) + NUM_INDICATORS.
|
||||
total_channels = NUM_TIMEFRAMES + 1 + NUM_INDICATORS
|
||||
model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device)
|
||||
|
||||
@ -541,7 +516,7 @@ async def main():
|
||||
optimizer.load_state_dict(optim_state)
|
||||
print("Loaded optimizer state from checkpoint.")
|
||||
else:
|
||||
print("No valid optimizer state found in checkpoint; starting fresh optimizer state.")
|
||||
print("No valid optimizer state found; using fresh optimizer state.")
|
||||
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
|
||||
|
||||
elif args.mode == 'live':
|
||||
@ -576,7 +551,7 @@ async def main():
|
||||
elif args.mode == 'inference':
|
||||
load_best_checkpoint(model)
|
||||
print("Running inference...")
|
||||
# Your inference logic goes here.
|
||||
# Inference logic goes here.
|
||||
else:
|
||||
print("Invalid mode specified.")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user