better training algo

This commit is contained in:
Dobromir Popov 2025-02-04 22:05:41 +02:00
parent 10ff22eb42
commit 375aebee88

View File

@ -32,7 +32,7 @@ CACHE_FILE = "candles_cache.json"
# --- Constants ---
NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"]
NUM_INDICATORS = 20 # e.g., 20 technical indicators
FEATURES_PER_CHANNEL = 7 # H, L, O, C, Volume, SMA_close, SMA_volume
FEATURES_PER_CHANNEL = 7 # e.g., H, L, O, C, Volume, SMA_close, SMA_volume
# --- Positional Encoding Module ---
class PositionalEncoding(nn.Module):
@ -53,7 +53,7 @@ class PositionalEncoding(nn.Module):
class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__()
# Create branch for each channel
# Create one branch per channel (each channel input has FEATURES_PER_CHANNEL features)
self.channel_branches = nn.ModuleList([
nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
@ -82,14 +82,14 @@ class TradingModel(nn.Module):
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, x, timeframe_ids):
# x: [batch_size, num_channels, FEATURES_PER_CHANNEL]
# x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL]
batch_size, num_channels, _ = x.shape
channel_outs = []
for i in range(num_channels):
channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out)
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
# Add embedding for each channel.
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
stacked = stacked + tf_embeds
@ -182,7 +182,6 @@ def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=B
maintain_checkpoint_directory(last_dir, max_files=10)
best_models = get_best_models(best_dir)
add_to_best = False
# Update best pool if fewer than 10 or if the new loss is lower than the worst saved loss.
if len(best_models) < 10:
add_to_best = True
else:
@ -206,7 +205,6 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
best_models = get_best_models(best_dir)
if not best_models:
return None
# Choose the checkpoint with the lowest loss.
best_loss, best_file = min(best_models, key=lambda x: x[0])
path = os.path.join(best_dir, best_file)
print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}")
@ -217,7 +215,7 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
# --- Live HTML Chart Update ---
def update_live_html(candles, trade_history, epoch):
"""
Generate a chart image with buy/sell markers and a dotted line between open and close,
Generate a chart image with buy/sell markers and a dotted line between open/close positions,
then embed it in a simple HTML page that auto-refreshes every 10 seconds.
"""
from io import BytesIO
@ -271,7 +269,7 @@ def update_live_html(candles, trade_history, epoch):
# --- Chart Drawing Helpers (used by both live preview and HTML update) ---
def update_live_chart(ax, candles, trade_history):
"""
Plot the chart with close price, buy and sell markers, and dotted lines joining buy/sell entry/exit.
Plot the chart with close price, buy/sell markers, and dotted lines joining entry/exit.
"""
ax.clear()
close_prices = [candle["close"] for candle in candles]
@ -300,13 +298,14 @@ def update_live_chart(ax, candles, trade_history):
ax.legend()
ax.grid(True)
# --- Forced Action Helper ---
# --- Forced Action & Optimal Hint Helpers ---
def get_forced_action(env):
"""
Force at least one trade per episode:
- At the very first step, force a BUY (action 2) if no position is open.
- At the penultimate step, if a position is open, force a SELL (action 0).
- Otherwise, default to HOLD (action 1).
When simulating streaming data, we force a trade at strategic moments:
- At the very first step: force BUY.
- At the penultimate step: if a position is open, force SELL.
- Otherwise, default to HOLD.
(The environment will also apply a penalty if the chosen action does not match the optimal hint.)
"""
total = len(env)
if env.current_index == 0:
@ -319,53 +318,98 @@ def get_forced_action(env):
else:
return 1 # HOLD
# --- Backtest Environment ---
# --- Backtest Environment with Sliding Window and Hints ---
class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes):
self.candles_dict = candles_dict # dict: timeframe -> list of candles
def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
self.candles_dict = candles_dict # full dictionary of timeframe candles
self.base_tf = base_tf
self.timeframes = timeframes
self.current_index = 0
self.trade_history = []
self.position = None
# Use maximum allowed candles for the base timeframe.
self.full_candles = candles_dict[base_tf]
# Determine sliding window size:
if window_size is None:
window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles)
self.window_size = window_size
self.hint_penalty = 0.001 # Penalty coefficient (multiplied by open price)
self.reset()
def reset(self):
# Pick a random sliding window from the full dataset.
self.start_index = random.randint(0, len(self.full_candles) - self.window_size)
self.candle_window = self.full_candles[self.start_index:self.start_index+self.window_size]
self.current_index = 0
self.position = None
self.trade_history = []
self.position = None
return self.get_state(self.current_index)
def __len__(self):
return self.window_size
def get_state(self, index):
"""
Build state features by taking the candle at the current index for the base timeframe
(from the sliding window) and aligning candles for other timeframes.
Then append zeros for technical indicators.
"""
state_features = []
base_ts = self.candles_dict[self.base_tf][index]["timestamp"]
base_ts = self.candle_window[index]["timestamp"]
for tf in self.timeframes:
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
if tf == self.base_tf:
# For base timeframe, use the sliding window candle.
candle = self.candle_window[index]
features = get_features_for_tf([candle], 0) # List of one element
else:
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
state_features.append(features)
for _ in range(NUM_INDICATORS):
state_features.append([0.0] * FEATURES_PER_CHANNEL)
return np.array(state_features, dtype=np.float32)
def compute_optimal_hint(self, horizon=10, threshold=0.005):
"""
Using a lookahead window from the sliding window (future candles)
determine an optimal action hint:
2: BUY if price is expected to rise at least by threshold.
0: SELL if expected to drop by threshold.
1: HOLD otherwise.
"""
base = self.candle_window
if self.current_index >= len(base) - 1:
return 1 # Hold
current_candle = base[self.current_index]
open_price = current_candle["open"]
future_slice = base[self.current_index+1: min(self.current_index+1+horizon, len(base))]
if not future_slice:
return 1
max_future = max(candle["high"] for candle in future_slice)
min_future = min(candle["low"] for candle in future_slice)
if (max_future - open_price) / open_price >= threshold:
return 2 # BUY
elif (open_price - min_future) / open_price >= threshold:
return 0 # SELL
else:
return 1 # HOLD
def step(self, action):
base_candles = self.candles_dict[self.base_tf]
# End-of-data: return dummy high/low targets
if self.current_index >= len(base_candles) - 1:
base = self.candle_window
if self.current_index >= len(base) - 1:
current_state = self.get_state(self.current_index)
return current_state, 0.0, None, True, 0.0, 0.0
current_state = self.get_state(self.current_index)
next_index = self.current_index + 1
next_state = self.get_state(next_index)
next_candle = base_candles[next_index]
next_candle = base[next_index]
reward = 0.0
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
# Trade logic (0: SELL, 1: HOLD, 2: BUY)
if self.position is None:
if action == 2: # BUY signal: enter at next candle's open.
if action == 2: # BUY: enter at next candle's open.
entry_price = next_candle["open"]
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
else:
if action == 0: # SELL signal: exit at next candle's open.
if action == 0: # SELL: exit at next candle's open.
exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"]
trade = {
@ -379,27 +423,30 @@ class BacktestEnvironment:
self.position = None
self.current_index = next_index
done = (self.current_index >= len(base_candles) - 1)
done = (self.current_index >= len(base) - 1)
actual_high = next_candle["high"]
actual_low = next_candle["low"]
# Compute optimal action hint and apply a penalty if action deviates.
optimal_hint = self.compute_optimal_hint(horizon=10, threshold=0.005)
if action != optimal_hint:
reward -= self.hint_penalty * next_candle["open"]
return current_state, reward, next_state, done, actual_high, actual_low
def __len__(self):
return len(self.candles_dict[self.base_tf])
# --- Enhanced Training Loop ---
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler, base_candles):
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
for epoch in range(start_epoch, args.epochs):
state = env.reset()
total_loss = 0
total_loss = 0.0
model.train()
while True:
# Use forced action policy to guarantee at least one trade per episode
# Use forced-action policy for trading (guaranteeing at least one trade per episode)
action = get_forced_action(env)
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids)
# Use the forced action in the environment step.
# Use our forced action in the environment step.
_, reward, next_state, done, actual_high, actual_low = env.step(action)
target_high = torch.FloatTensor([actual_high]).to(device)
target_low = torch.FloatTensor([actual_low]).to(device)
@ -418,8 +465,8 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
epoch_loss = total_loss / len(env)
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
save_checkpoint(model, optimizer, epoch, total_loss)
# Update the live HTML file with the current epoch chart.
update_live_html(base_candles, env.trade_history, epoch+1)
# Update live HTML chart to display the current sliding window
update_live_html(env.candle_window, env.trade_history, epoch+1)
# --- Live Plotting Functions (For live mode) ---
def live_preview_loop(candles, env):
@ -461,7 +508,8 @@ async def main():
print("No historical candle data available for backtesting.")
return
base_tf = "1m"
env = BacktestEnvironment(candles_dict, base_tf, timeframes)
# Create the environment with a sliding window (simulate streaming data)
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
start_epoch = 0
checkpoint = None
@ -475,7 +523,6 @@ async def main():
else:
print("Starting training from scratch as requested.")
# Create optimizer and scheduler in main
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
if checkpoint is not None:
@ -485,8 +532,7 @@ async def main():
print("Loaded optimizer state from checkpoint.")
else:
print("No valid optimizer state found in checkpoint; starting fresh optimizer state.")
# Pass the base timeframe candles for the live HTML chart update.
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler, candles_dict[base_tf])
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
elif args.mode == 'live':
load_best_checkpoint(model)
@ -494,22 +540,21 @@ async def main():
if not candles_dict:
print("No cached candles available for live preview.")
return
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes)
preview_thread = threading.Thread(target=live_preview_loop, args=(candles_dict["1m"], env), daemon=True)
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100)
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
preview_thread.start()
print("Starting live trading loop. (Using forced action policy for simulation.)")
# Here we use the forced-action policy as in training.
print("Starting live trading loop. (Using forced-action policy for simulation.)")
while True:
action = get_forced_action(env)
state, reward, next_state, done, _, _ = env.step(action)
if done:
print("Reached end of simulated data, resetting environment.")
print("Reached end of simulation window, resetting environment.")
state = env.reset()
await asyncio.sleep(1)
elif args.mode == 'inference':
load_best_checkpoint(model)
print("Running inference...")
# Here you can apply a similar forced-action policy or use a learned policy.
# Apply a similar (or learned) policy as needed.
else:
print("Invalid mode specified.")