added volume. better training

This commit is contained in:
Dobromir Popov 2025-02-04 22:33:44 +02:00
parent 75c4d6602a
commit dc5df52292
2 changed files with 119 additions and 68 deletions

View File

@ -44,10 +44,8 @@ CACHE_FILE = "candles_cache.json"
# --- Constants ---
NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"]
NUM_INDICATORS = 20 # e.g., 20 technical indicators
# Each channel input will have 7 features.
FEATURES_PER_CHANNEL = 7
# We add one extra channel for order information.
ORDER_CHANNELS = 1
FEATURES_PER_CHANNEL = 7 # Each channel input will have 7 features.
ORDER_CHANNELS = 1 # One extra channel for order info.
# --- Positional Encoding Module ---
class PositionalEncoding(nn.Module):
@ -68,7 +66,6 @@ class PositionalEncoding(nn.Module):
class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__()
# Create one branch per channel.
self.channel_branches = nn.ModuleList([
nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
@ -77,13 +74,11 @@ class TradingModel(nn.Module):
nn.Dropout(0.1)
) for _ in range(num_channels)
])
# Embedding for channels 0..num_channels-1.
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
self.pos_encoder = PositionalEncoding(hidden_dim)
# Set batch_first=True to avoid the nested tensor warning.
encoder_layers = TransformerEncoderLayer(
d_model=hidden_dim, nhead=4, dim_feedforward=512,
dropout=0.1, activation='gelu', batch_first=True
dropout=0.1, activation='gelu', batch_first=True # Use batch_first to avoid nested tensor warning.
)
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
self.attn_pool = nn.Linear(hidden_dim, 1)
@ -104,11 +99,9 @@ class TradingModel(nn.Module):
for i in range(num_channels):
channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out)
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
# With batch_first=True, the expected input is [batch, seq_len, hidden]
stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
tf_embeds = self.timeframe_embed(timeframe_ids) # shape: [num_channels, hidden]
# Expand tf_embeds to match the batch dimension.
stacked = stacked + tf_embeds.unsqueeze(0)
stacked = stacked + tf_embeds.unsqueeze(0) # broadcast along batch dimension.
transformer_out = self.transformer(stacked)
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1)
aggregated = (transformer_out * attn_weights).sum(dim=1)
@ -225,32 +218,78 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
checkpoint = torch.load(path)
old_state = checkpoint["model_state_dict"]
new_state = model.state_dict()
# Fix the size mismatch for timeframe_embed.weight.
if "timeframe_embed.weight" in old_state:
old_embed = old_state["timeframe_embed.weight"]
new_embed = new_state["timeframe_embed.weight"]
if old_embed.shape[0] < new_embed.shape[0]:
new_embed[:old_embed.shape[0]] = old_embed
old_state["timeframe_embed.weight"] = new_embed
# For channel_branches, missing keys are handled by strict=False.
model.load_state_dict(old_state, strict=False)
return checkpoint
# --- Function for Manual Trade Override ---
def manual_trade(env):
"""
When no sufficient action is taken by the model, manually decide the trade.
Find the maximum high and minimum low in the remaining window.
If maximum occurs before minimum, we short; otherwise we long.
The trade is closed at the candle where the chosen extreme occurs.
"""
current_index = env.current_index
if current_index >= len(env.candle_window) - 1:
env.current_index = len(env.candle_window) - 1
return
max_val = -float('inf')
min_val = float('inf')
i_max = current_index
i_min = current_index
for j in range(current_index + 1, len(env.candle_window)):
high_j = env.candle_window[j]["high"]
low_j = env.candle_window[j]["low"]
if high_j > max_val:
max_val = high_j
i_max = j
if low_j < min_val:
min_val = low_j
i_min = j
# If maximum occurs before minimum, we interpret that as short (price will drop).
if i_max < i_min:
entry_price = env.candle_window[current_index]["open"]
exit_price = env.candle_window[i_min]["open"]
reward = entry_price - exit_price
trade = {
"entry_index": current_index,
"entry_price": entry_price,
"exit_index": i_min,
"exit_price": exit_price,
"pnl": reward
}
else:
entry_price = env.candle_window[current_index]["open"]
exit_price = env.candle_window[i_max]["open"]
reward = exit_price - entry_price
trade = {
"entry_index": current_index,
"entry_price": entry_price,
"exit_index": i_max,
"exit_price": exit_price,
"pnl": reward
}
env.trade_history.append(trade)
env.current_index = trade["exit_index"]
# --- Live HTML Chart Update ---
def update_live_html(candles, trade_history, epoch):
"""
Generate a chart image that uses actual timestamps on the x-axis
and shows a cumulative epoch PnL. The chart (with buy/sell markers and dotted lines)
is embedded in an HTML page that auto-refreshes every 1 seconds.
Generate a chart image with actual timestamps on the x-axis and cumulative epoch PnL.
The chart now also plots volume as a bar chart on a secondary y-axis.
The HTML page auto-refreshes every 10 seconds.
"""
from io import BytesIO
import base64
fig, ax = plt.subplots(figsize=(12, 6))
update_live_chart(ax, candles, trade_history)
# Compute cumulative epoch PnL.
epoch_pnl = sum(trade["pnl"] for trade in trade_history)
ax.set_title(f"Live Trading Chart - Epoch {epoch} | PnL: {epoch_pnl:.2f}")
buf = BytesIO()
@ -263,7 +302,7 @@ def update_live_html(candles, trade_history, epoch):
<html>
<head>
<meta charset="utf-8">
<meta http-equiv="refresh" content="1">
<meta http-equiv="refresh" content="10">
<title>Live Trading Chart - Epoch {epoch}</title>
<style>
body {{
@ -292,33 +331,40 @@ def update_live_html(candles, trade_history, epoch):
</html>
"""
with open("live_chart.html", "w") as f:
f.write(html_content)
f.write(html_content)
print("Updated live_chart.html.")
# --- Chart Drawing Helpers ---
def update_live_chart(ax, candles, trade_history):
"""
Plot the price chart with actual timestamps on the x-axis.
Mark BUY (green) and SELL (red) actions, and draw dotted lines between entry and exit.
Plot the price chart with proper timestamp conversion.
Mark BUY (green) and SELL (red) actions (with dotted lines between),
and plot volume as a bar chart on a secondary y-axis.
"""
ax.clear()
# Use the helper to convert timestamps safely.
times = [convert_timestamp(candle["timestamp"]) for candle in candles]
close_prices = [candle["close"] for candle in candles]
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
# Format x-axis date labels.
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
for trade in trade_history:
entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"])
exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"])
in_price = trade["entry_price"]
out_price = trade["exit_price"]
ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY")
ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL")
ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue")
ax.set_xlabel("Time")
ax.set_ylabel("Price")
ax.legend()
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
# Plot volume on secondary axis.
ax2 = ax.twinx()
volumes = [candle["volume"] for candle in candles]
# Compute bar width in days.
if len(times) > 1:
times_num = mdates.date2num(times)
bar_width = (times_num[-1] - times_num[0]) / len(times) * 0.8
else:
bar_width = 0.01
ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume")
ax2.set_ylabel("Volume")
# Combine legends.
lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines + lines2, labels + labels2)
ax.grid(True)
fig = ax.get_figure()
fig.autofmt_xdate()
@ -326,14 +372,15 @@ def update_live_chart(ax, candles, trade_history):
# --- Simulation of Trades for Visualization ---
def simulate_trades(model, env, device, args):
"""
Run a simulation on the current sliding window using the model's outputs and a decision rule.
Here we force the simulation to always take an action by comparing the predicted potentials,
ensuring that the model is forced to trade (either BUY or SELL) rather than HOLD.
This simulation updates env.trade_history and is used for visualization only.
Run a simulation on the current sliding window.
If the model produces a sufficiently strong signal (based on threshold), use its action.
Otherwise, manually compute the trade by scanning for max/min prices.
"""
env.reset() # resets the window and index
env.reset()
while True:
i = env.current_index
if i >= len(env.candle_window) - 1:
break
state = env.get_state(i)
current_open = env.candle_window[i]["open"]
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
@ -341,19 +388,23 @@ def simulate_trades(model, env, device, args):
pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item()
pred_low = pred_low.item()
# Force a trade: choose BUY if predicted up-move is higher (or equal), else SELL.
if (pred_high - current_open) >= (current_open - pred_low):
action = 2 # BUY
# If either upward potential or downward potential exceeds the threshold, use model decision.
if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
if (pred_high - current_open) >= (current_open - pred_low):
action = 2 # BUY
else:
action = 0 # SELL
_, _, _, done, _, _ = env.step(action)
else:
action = 0 # SELL
_, _, _, done, _, _ = env.step(action)
if done:
# No significant signal; use manual trade computation.
manual_trade(env)
if env.current_index >= len(env.candle_window) - 1:
break
# --- Backtest Environment with Sliding Window and Order Info ---
class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
self.candles_dict = candles_dict # full candles dict across timeframes
self.candles_dict = candles_dict
self.base_tf = base_tf
self.timeframes = timeframes
self.full_candles = candles_dict[base_tf]
@ -361,7 +412,6 @@ class BacktestEnvironment:
window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles)
self.window_size = window_size
self.reset()
def reset(self):
self.start_index = random.randint(0, len(self.full_candles) - self.window_size)
self.candle_window = self.full_candles[self.start_index: self.start_index + self.window_size]
@ -369,10 +419,8 @@ class BacktestEnvironment:
self.trade_history = []
self.position = None
return self.get_state(self.current_index)
def __len__(self):
return self.window_size
def get_order_features(self, index):
candle = self.candle_window[index]
if self.position is None:
@ -381,7 +429,6 @@ class BacktestEnvironment:
flag = 1.0
diff = (candle["open"] - self.position["entry_price"]) / candle["open"]
return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2)
def get_state(self, index):
state_features = []
base_ts = self.candle_window[index]["timestamp"]
@ -398,7 +445,6 @@ class BacktestEnvironment:
for _ in range(NUM_INDICATORS):
state_features.append([0.0] * FEATURES_PER_CHANNEL)
return np.array(state_features, dtype=np.float32)
def step(self, action):
base = self.candle_window
if self.current_index >= len(base) - 1:
@ -410,10 +456,10 @@ class BacktestEnvironment:
next_candle = base[next_index]
reward = 0.0
if self.position is None:
if action == 2:
if action == 2: # BUY (open long)
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
else:
if action == 0:
if action == 0: # SELL (close long / exit trade)
exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"]
trade = {
@ -480,11 +526,14 @@ def live_preview_loop(candles, env):
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train')
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade (used in loss).")
parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.")
parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty if no action is taken (used in loss).")
parser.add_argument('--threshold', type=float, default=0.005,
help="Minimum predicted move to trigger trade (used in loss; model may override with manual trade).")
parser.add_argument('--lambda_trade', type=float, default=1.0,
help="Weight for trade surrogate loss.")
parser.add_argument('--penalty_noaction', type=float, default=10.0,
help="Penalty if no action is taken (used in loss).")
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
return parser.parse_args()
@ -546,7 +595,7 @@ async def main():
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100)
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
preview_thread.start()
print("Starting live trading loop. (Forcing trade actions based on highest potential.)")
print("Starting live trading loop. (Using model, with manual override for HOLD actions.)")
while True:
state = env.get_state(env.current_index)
current_open = env.candle_window[env.current_index]["open"]
@ -555,13 +604,15 @@ async def main():
pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item()
pred_low = pred_low.item()
# Force a trade (choose BUY if upward potential >= downward, else SELL)
if (pred_high - current_open) >= (current_open - pred_low):
action = 2
if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
if (pred_high - current_open) >= (current_open - pred_low):
action = 2 # BUY
else:
action = 0 # SELL
_, _, _, done, _, _ = env.step(action)
else:
action = 0
_, _, _, done, _, _ = env.step(action)
if done:
manual_trade(env)
if env.current_index >= len(env.candle_window)-1:
print("Reached end of simulation window; resetting environment.")
env.reset()
await asyncio.sleep(1)

File diff suppressed because one or more lines are too long