better train algo

This commit is contained in:
Dobromir Popov 2025-02-04 22:10:24 +02:00
parent 967363378b
commit 615579d456

View File

@ -18,6 +18,7 @@ from datetime import datetime
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import math import math
from torch.nn import TransformerEncoder, TransformerEncoderLayer from torch.nn import TransformerEncoder, TransformerEncoderLayer
import matplotlib.dates as mdates
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
@ -31,7 +32,10 @@ CACHE_FILE = "candles_cache.json"
# --- Constants --- # --- Constants ---
NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"] NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"]
NUM_INDICATORS = 20 # e.g., 20 technical indicators NUM_INDICATORS = 20 # e.g., 20 technical indicators
FEATURES_PER_CHANNEL = 7 # e.g., [open, high, low, close, volume, sma_close, sma_volume] # Each channel input will have 7 features.
FEATURES_PER_CHANNEL = 7
# We add one extra channel for order information.
ORDER_CHANNELS = 1
# --- Positional Encoding Module --- # --- Positional Encoding Module ---
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
@ -52,7 +56,7 @@ class PositionalEncoding(nn.Module):
class TradingModel(nn.Module): class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128): def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__() super().__init__()
# One branch per channel # Create one branch per channel.
self.channel_branches = nn.ModuleList([ self.channel_branches = nn.ModuleList([
nn.Sequential( nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
@ -61,6 +65,7 @@ class TradingModel(nn.Module):
nn.Dropout(0.1) nn.Dropout(0.1)
) for _ in range(num_channels) ) for _ in range(num_channels)
]) ])
# Embedding for channels 0..num_channels-1.
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim) self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
self.pos_encoder = PositionalEncoding(hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim)
encoder_layers = TransformerEncoderLayer( encoder_layers = TransformerEncoderLayer(
@ -86,8 +91,8 @@ class TradingModel(nn.Module):
for i in range(num_channels): for i in range(num_channels):
channel_out = self.channel_branches[i](x[:, i, :]) channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out) channel_outs.append(channel_out)
stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden] stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden] stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
stacked = stacked + tf_embeds stacked = stacked + tf_embeds
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device) src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
@ -211,15 +216,17 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
# --- Live HTML Chart Update --- # --- Live HTML Chart Update ---
def update_live_html(candles, trade_history, epoch): def update_live_html(candles, trade_history, epoch):
""" """
Generate a chart image with buy/sell markers and dotted lines between entry and exit, Generate a chart image that uses actual timestamps on the x-axis and shows a cumulative epoch PnL.
then embed it in an auto-refreshing HTML page. The chart (with buy/sell markers and dotted lines) is embedded in an HTML page that auto-refreshes.
""" """
from io import BytesIO from io import BytesIO
import base64 import base64
fig, ax = plt.subplots(figsize=(12, 6)) fig, ax = plt.subplots(figsize=(12, 6))
update_live_chart(ax, candles, trade_history) update_live_chart(ax, candles, trade_history)
ax.set_title(f"Live Trading Chart - Epoch {epoch}") # Compute cumulative epoch PnL.
epoch_pnl = sum(trade["pnl"] for trade in trade_history)
ax.set_title(f"Live Trading Chart - Epoch {epoch} | PnL: {epoch_pnl:.2f}")
buf = BytesIO() buf = BytesIO()
fig.savefig(buf, format='png') fig.savefig(buf, format='png')
plt.close(fig) plt.close(fig)
@ -252,7 +259,7 @@ def update_live_html(candles, trade_history, epoch):
</head> </head>
<body> <body>
<div class="chart-container"> <div class="chart-container">
<h2>Live Trading Chart - Epoch {epoch}</h2> <h2>Live Trading Chart - Epoch {epoch} | PnL: {epoch_pnl:.2f}</h2>
<img src="data:image/png;base64,{image_base64}" alt="Live Chart"/> <img src="data:image/png;base64,{image_base64}" alt="Live Chart"/>
</div> </div>
</body> </body>
@ -265,42 +272,51 @@ def update_live_html(candles, trade_history, epoch):
# --- Chart Drawing Helpers --- # --- Chart Drawing Helpers ---
def update_live_chart(ax, candles, trade_history): def update_live_chart(ax, candles, trade_history):
""" """
Draw the price chart with close prices and mark BUY (green) and SELL (red) actions. Plot the price chart using actual timestamps on the x-axis.
Mark BUY (green) and SELL (red) actions, and draw dotted lines between entry and exit.
""" """
ax.clear() ax.clear()
# Convert timestamps to datetime objects.
times = [datetime.fromtimestamp(candle["timestamp"]) for candle in candles]
close_prices = [candle["close"] for candle in candles] close_prices = [candle["close"] for candle in candles]
x = list(range(len(close_prices))) ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
ax.plot(x, close_prices, label="Close Price", color="black", linewidth=1) # Format x-axis date labels.
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
# Calculate epoch PnL.
epoch_pnl = sum(trade["pnl"] for trade in trade_history)
# Plot each trade.
buy_label_added = False buy_label_added = False
sell_label_added = False sell_label_added = False
for trade in trade_history: for trade in trade_history:
in_idx = trade["entry_index"] entry_time = datetime.fromtimestamp(candles[trade["entry_index"]]["timestamp"])
out_idx = trade["exit_index"] exit_time = datetime.fromtimestamp(candles[trade["exit_index"]]["timestamp"])
in_price = trade["entry_price"] in_price = trade["entry_price"]
out_price = trade["exit_price"] out_price = trade["exit_price"]
if not buy_label_added: if not buy_label_added:
ax.plot(in_idx, in_price, marker="^", color="green", markersize=10, label="BUY") ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY")
buy_label_added = True buy_label_added = True
else: else:
ax.plot(in_idx, in_price, marker="^", color="green", markersize=10) ax.plot(entry_time, in_price, marker="^", color="green", markersize=10)
if not sell_label_added: if not sell_label_added:
ax.plot(out_idx, out_price, marker="v", color="red", markersize=10, label="SELL") ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL")
sell_label_added = True sell_label_added = True
else: else:
ax.plot(out_idx, out_price, marker="v", color="red", markersize=10) ax.plot(exit_time, out_price, marker="v", color="red", markersize=10)
ax.plot([in_idx, out_idx], [in_price, out_price], linestyle="dotted", color="blue") ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue")
ax.set_xlabel("Candle Index") ax.set_xlabel("Time")
ax.set_ylabel("Price") ax.set_ylabel("Price")
ax.legend() ax.legend()
ax.grid(True) ax.grid(True)
fig = ax.get_figure()
fig.autofmt_xdate()
# --- Simulation of Trades for Visualization --- # --- Simulation of Trades for Visualization ---
def simulate_trades(model, env, device, args): def simulate_trades(model, env, device, args):
""" """
Run a complete simulation on the current sliding window using a decision rule based on model outputs. Run a simulation on the current sliding window using the model's outputs and a decision rule.
This simulation (which updates env.trade_history) is used only for visualization. This simulation updates env.trade_history and is used for visualization only.
""" """
env.reset() # resets the sliding window and index env.reset() # resets the window and index
while True: while True:
i = env.current_index i = env.current_index
state = env.get_state(i) state = env.get_state(i)
@ -310,7 +326,7 @@ def simulate_trades(model, env, device, args):
pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item() pred_high = pred_high.item()
pred_low = pred_low.item() pred_low = pred_low.item()
# Decision rule: if upward move larger than downward and above threshold, BUY; if downward is larger, SELL; else HOLD. # Simple decision rule based on predicted move.
if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold: if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold:
action = 2 # BUY action = 2 # BUY
elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold: elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold:
@ -321,21 +337,20 @@ def simulate_trades(model, env, device, args):
if done: if done:
break break
# --- Backtest Environment with Sliding Window --- # --- Backtest Environment with Sliding Window and Order Info ---
class BacktestEnvironment: class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes, window_size=None): def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
self.candles_dict = candles_dict # full candles dict for all timeframes self.candles_dict = candles_dict # full candles dict across timeframes
self.base_tf = base_tf self.base_tf = base_tf
self.timeframes = timeframes self.timeframes = timeframes
self.full_candles = candles_dict[base_tf] self.full_candles = candles_dict[base_tf]
if window_size is None: if window_size is None:
window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles) window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles)
self.window_size = window_size self.window_size = window_size
self.hint_penalty = 0.001 # not used in the revised loss below
self.reset() self.reset()
def reset(self): def reset(self):
# Pick a random sliding window from the full dataset. # Randomly select a sliding window from the full dataset.
self.start_index = random.randint(0, len(self.full_candles) - self.window_size) self.start_index = random.randint(0, len(self.full_candles) - self.window_size)
self.candle_window = self.full_candles[self.start_index: self.start_index + self.window_size] self.candle_window = self.full_candles[self.start_index: self.start_index + self.window_size]
self.current_index = 0 self.current_index = 0
@ -346,7 +361,29 @@ class BacktestEnvironment:
def __len__(self): def __len__(self):
return self.window_size return self.window_size
def get_order_features(self, index):
"""
Returns a list of 7 features for the order channel.
If an order is open, the first element is 1.0 and the second is the normalized difference:
(current open - entry_price) / current open.
Otherwise, returns zeros.
"""
candle = self.candle_window[index]
if self.position is None:
return [0.0] * FEATURES_PER_CHANNEL
else:
flag = 1.0
diff = (candle["open"] - self.position["entry_price"]) / candle["open"]
return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2)
def get_state(self, index): def get_state(self, index):
"""
Build state features from:
- For each timeframe: features from the aligned candle.
- One extra channel: current order information.
- NUM_INDICATORS channels of zeros.
Each channel is a vector of length FEATURES_PER_CHANNEL.
"""
state_features = [] state_features = []
base_ts = self.candle_window[index]["timestamp"] base_ts = self.candle_window[index]["timestamp"]
for tf in self.timeframes: for tf in self.timeframes:
@ -357,15 +394,19 @@ class BacktestEnvironment:
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
features = get_features_for_tf(self.candles_dict[tf], aligned_idx) features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
state_features.append(features) state_features.append(features)
# Append order channel.
order_features = self.get_order_features(index)
state_features.append(order_features)
# Append technical indicator channels.
for _ in range(NUM_INDICATORS): for _ in range(NUM_INDICATORS):
state_features.append([0.0] * FEATURES_PER_CHANNEL) state_features.append([0.0] * FEATURES_PER_CHANNEL)
return np.array(state_features, dtype=np.float32) return np.array(state_features, dtype=np.float32)
def step(self, action): def step(self, action):
""" """
Discrete simulation step. Execute one step in the environment:
- Action: 0 (SELL), 1 (HOLD), 2 (BUY). - action: 0 => SELL, 1 => HOLD, 2 => BUY.
- Trades are recorded when a BUY is followed by a SELL. - Trades recorded when a BUY is followed by a SELL.
""" """
base = self.candle_window base = self.candle_window
if self.current_index >= len(base) - 1: if self.current_index >= len(base) - 1:
@ -378,7 +419,6 @@ class BacktestEnvironment:
next_candle = base[next_index] next_candle = base[next_index]
reward = 0.0 reward = 0.0
# Simple trading logic (only one position allowed at a time)
if self.position is None: if self.position is None:
if action == 2: # BUY signal: enter at next open. if action == 2: # BUY signal: enter at next open.
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
@ -404,29 +444,25 @@ class BacktestEnvironment:
# --- Enhanced Training Loop --- # --- Enhanced Training Loop ---
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
# Weighting factor for trade surrogate loss. lambda_trade = args.lambda_trade # Weight for the surrogate profit loss.
lambda_trade = 1.0
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
# Reset sliding window for each epoch. env.reset() # Resets the sliding window.
env.reset()
loss_accum = 0.0 loss_accum = 0.0
steps = len(env) - 1 # we use pairs of consecutive candles steps = len(env) - 1 # We assume steps over consecutive candle pairs.
for i in range(steps): for i in range(steps):
state = env.get_state(i) state = env.get_state(i)
current_open = env.candle_window[i]["open"] current_open = env.candle_window[i]["open"]
# Next candle's actual values serve as targets.
actual_high = env.candle_window[i+1]["high"] actual_high = env.candle_window[i+1]["high"]
actual_low = env.candle_window[i+1]["low"] actual_low = env.candle_window[i+1]["low"]
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high, pred_low = model(state_tensor, timeframe_ids)
# Compute prediction loss (L1) # Prediction loss (L1 error).
L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \ L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \
torch.abs(pred_low - torch.tensor(actual_low, device=device)) torch.abs(pred_low - torch.tensor(actual_low, device=device))
# Compute surrogate profit (differentiable estimate) # Surrogate profit loss:
profit_buy = pred_high - current_open # potential long gain profit_buy = pred_high - current_open # potential long gain
profit_sell = current_open - pred_low # potential short gain profit_sell = current_open - pred_low # potential short gain
# Here we reward a higher potential move by subtracting it.
L_trade = - torch.max(profit_buy, profit_sell) L_trade = - torch.max(profit_buy, profit_sell)
loss = L_pred + lambda_trade * L_trade loss = L_pred + lambda_trade * L_trade
optimizer.zero_grad() optimizer.zero_grad()
@ -438,7 +474,6 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
epoch_loss = loss_accum / steps epoch_loss = loss_accum / steps
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}") print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
save_checkpoint(model, optimizer, epoch, loss_accum) save_checkpoint(model, optimizer, epoch, loss_accum)
# Update the trade simulation (for visualization) using the current model on the same window.
simulate_trades(model, env, device, args) simulate_trades(model, env, device, args)
update_live_html(env.candle_window, env.trade_history, epoch+1) update_live_html(env.candle_window, env.trade_history, epoch+1)
@ -458,10 +493,13 @@ def parse_args():
parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.") parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.")
parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for the trade surrogate loss.") parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for trade surrogate loss.")
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.") parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
return parser.parse_args() return parser.parse_args()
def random_action():
return random.randint(0, 2)
# --- Main Function --- # --- Main Function ---
async def main(): async def main():
args = parse_args() args = parse_args()
@ -469,7 +507,8 @@ async def main():
print("Using device:", device) print("Using device:", device)
timeframes = ["1m", "5m", "15m", "1h", "1d"] timeframes = ["1m", "5m", "15m", "1h", "1d"]
hidden_dim = 128 hidden_dim = 128
total_channels = NUM_TIMEFRAMES + NUM_INDICATORS # Total channels: NUM_TIMEFRAMES + 1 (order info) + NUM_INDICATORS.
total_channels = NUM_TIMEFRAMES + 1 + NUM_INDICATORS
model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device) model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device)
if args.mode == 'train': if args.mode == 'train':
@ -478,7 +517,6 @@ async def main():
print("No historical candle data available for backtesting.") print("No historical candle data available for backtesting.")
return return
base_tf = "1m" base_tf = "1m"
# Use a sliding window of up to 100 candles (if available)
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100) env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
start_epoch = 0 start_epoch = 0
checkpoint = None checkpoint = None
@ -513,7 +551,6 @@ async def main():
preview_thread.start() preview_thread.start()
print("Starting live trading loop. (Using model-based decision rule.)") print("Starting live trading loop. (Using model-based decision rule.)")
while True: while True:
# In live mode, we use the simulation decision rule.
state = env.get_state(env.current_index) state = env.get_state(env.current_index)
current_open = env.candle_window[env.current_index]["open"] current_open = env.candle_window[env.current_index]["open"]
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
@ -535,7 +572,7 @@ async def main():
elif args.mode == 'inference': elif args.mode == 'inference':
load_best_checkpoint(model) load_best_checkpoint(model)
print("Running inference...") print("Running inference...")
# Inference logic can use a similar decision rule as in live mode. # Your inference logic goes here.
else: else:
print("Invalid mode specified.") print("Invalid mode specified.")