better training algo
This commit is contained in:
parent
10ff22eb42
commit
375aebee88
@ -32,7 +32,7 @@ CACHE_FILE = "candles_cache.json"
|
||||
# --- Constants ---
|
||||
NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"]
|
||||
NUM_INDICATORS = 20 # e.g., 20 technical indicators
|
||||
FEATURES_PER_CHANNEL = 7 # H, L, O, C, Volume, SMA_close, SMA_volume
|
||||
FEATURES_PER_CHANNEL = 7 # e.g., H, L, O, C, Volume, SMA_close, SMA_volume
|
||||
|
||||
# --- Positional Encoding Module ---
|
||||
class PositionalEncoding(nn.Module):
|
||||
@ -53,7 +53,7 @@ class PositionalEncoding(nn.Module):
|
||||
class TradingModel(nn.Module):
|
||||
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
|
||||
super().__init__()
|
||||
# Create branch for each channel
|
||||
# Create one branch per channel (each channel input has FEATURES_PER_CHANNEL features)
|
||||
self.channel_branches = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
|
||||
@ -82,14 +82,14 @@ class TradingModel(nn.Module):
|
||||
nn.Linear(hidden_dim // 2, 1)
|
||||
)
|
||||
def forward(self, x, timeframe_ids):
|
||||
# x: [batch_size, num_channels, FEATURES_PER_CHANNEL]
|
||||
# x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL]
|
||||
batch_size, num_channels, _ = x.shape
|
||||
channel_outs = []
|
||||
for i in range(num_channels):
|
||||
channel_out = self.channel_branches[i](x[:, i, :])
|
||||
channel_outs.append(channel_out)
|
||||
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
|
||||
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
|
||||
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
|
||||
# Add embedding for each channel.
|
||||
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
|
||||
stacked = stacked + tf_embeds
|
||||
@ -182,7 +182,6 @@ def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=B
|
||||
maintain_checkpoint_directory(last_dir, max_files=10)
|
||||
best_models = get_best_models(best_dir)
|
||||
add_to_best = False
|
||||
# Update best pool if fewer than 10 or if the new loss is lower than the worst saved loss.
|
||||
if len(best_models) < 10:
|
||||
add_to_best = True
|
||||
else:
|
||||
@ -206,7 +205,6 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
|
||||
best_models = get_best_models(best_dir)
|
||||
if not best_models:
|
||||
return None
|
||||
# Choose the checkpoint with the lowest loss.
|
||||
best_loss, best_file = min(best_models, key=lambda x: x[0])
|
||||
path = os.path.join(best_dir, best_file)
|
||||
print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}")
|
||||
@ -217,7 +215,7 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
|
||||
# --- Live HTML Chart Update ---
|
||||
def update_live_html(candles, trade_history, epoch):
|
||||
"""
|
||||
Generate a chart image with buy/sell markers and a dotted line between open and close,
|
||||
Generate a chart image with buy/sell markers and a dotted line between open/close positions,
|
||||
then embed it in a simple HTML page that auto-refreshes every 10 seconds.
|
||||
"""
|
||||
from io import BytesIO
|
||||
@ -271,7 +269,7 @@ def update_live_html(candles, trade_history, epoch):
|
||||
# --- Chart Drawing Helpers (used by both live preview and HTML update) ---
|
||||
def update_live_chart(ax, candles, trade_history):
|
||||
"""
|
||||
Plot the chart with close price, buy and sell markers, and dotted lines joining buy/sell entry/exit.
|
||||
Plot the chart with close price, buy/sell markers, and dotted lines joining entry/exit.
|
||||
"""
|
||||
ax.clear()
|
||||
close_prices = [candle["close"] for candle in candles]
|
||||
@ -300,13 +298,14 @@ def update_live_chart(ax, candles, trade_history):
|
||||
ax.legend()
|
||||
ax.grid(True)
|
||||
|
||||
# --- Forced Action Helper ---
|
||||
# --- Forced Action & Optimal Hint Helpers ---
|
||||
def get_forced_action(env):
|
||||
"""
|
||||
Force at least one trade per episode:
|
||||
- At the very first step, force a BUY (action 2) if no position is open.
|
||||
- At the penultimate step, if a position is open, force a SELL (action 0).
|
||||
- Otherwise, default to HOLD (action 1).
|
||||
When simulating streaming data, we force a trade at strategic moments:
|
||||
- At the very first step: force BUY.
|
||||
- At the penultimate step: if a position is open, force SELL.
|
||||
- Otherwise, default to HOLD.
|
||||
(The environment will also apply a penalty if the chosen action does not match the optimal hint.)
|
||||
"""
|
||||
total = len(env)
|
||||
if env.current_index == 0:
|
||||
@ -319,53 +318,98 @@ def get_forced_action(env):
|
||||
else:
|
||||
return 1 # HOLD
|
||||
|
||||
# --- Backtest Environment ---
|
||||
# --- Backtest Environment with Sliding Window and Hints ---
|
||||
class BacktestEnvironment:
|
||||
def __init__(self, candles_dict, base_tf, timeframes):
|
||||
self.candles_dict = candles_dict # dict: timeframe -> list of candles
|
||||
def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
|
||||
self.candles_dict = candles_dict # full dictionary of timeframe candles
|
||||
self.base_tf = base_tf
|
||||
self.timeframes = timeframes
|
||||
self.current_index = 0
|
||||
self.trade_history = []
|
||||
self.position = None
|
||||
# Use maximum allowed candles for the base timeframe.
|
||||
self.full_candles = candles_dict[base_tf]
|
||||
# Determine sliding window size:
|
||||
if window_size is None:
|
||||
window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles)
|
||||
self.window_size = window_size
|
||||
self.hint_penalty = 0.001 # Penalty coefficient (multiplied by open price)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
# Pick a random sliding window from the full dataset.
|
||||
self.start_index = random.randint(0, len(self.full_candles) - self.window_size)
|
||||
self.candle_window = self.full_candles[self.start_index:self.start_index+self.window_size]
|
||||
self.current_index = 0
|
||||
self.position = None
|
||||
self.trade_history = []
|
||||
self.position = None
|
||||
return self.get_state(self.current_index)
|
||||
|
||||
def __len__(self):
|
||||
return self.window_size
|
||||
|
||||
def get_state(self, index):
|
||||
"""
|
||||
Build state features by taking the candle at the current index for the base timeframe
|
||||
(from the sliding window) and aligning candles for other timeframes.
|
||||
Then append zeros for technical indicators.
|
||||
"""
|
||||
state_features = []
|
||||
base_ts = self.candles_dict[self.base_tf][index]["timestamp"]
|
||||
base_ts = self.candle_window[index]["timestamp"]
|
||||
for tf in self.timeframes:
|
||||
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
|
||||
features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
|
||||
if tf == self.base_tf:
|
||||
# For base timeframe, use the sliding window candle.
|
||||
candle = self.candle_window[index]
|
||||
features = get_features_for_tf([candle], 0) # List of one element
|
||||
else:
|
||||
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
|
||||
features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
|
||||
state_features.append(features)
|
||||
for _ in range(NUM_INDICATORS):
|
||||
state_features.append([0.0] * FEATURES_PER_CHANNEL)
|
||||
return np.array(state_features, dtype=np.float32)
|
||||
|
||||
|
||||
def compute_optimal_hint(self, horizon=10, threshold=0.005):
|
||||
"""
|
||||
Using a lookahead window from the sliding window (future candles)
|
||||
determine an optimal action hint:
|
||||
2: BUY if price is expected to rise at least by threshold.
|
||||
0: SELL if expected to drop by threshold.
|
||||
1: HOLD otherwise.
|
||||
"""
|
||||
base = self.candle_window
|
||||
if self.current_index >= len(base) - 1:
|
||||
return 1 # Hold
|
||||
current_candle = base[self.current_index]
|
||||
open_price = current_candle["open"]
|
||||
future_slice = base[self.current_index+1: min(self.current_index+1+horizon, len(base))]
|
||||
if not future_slice:
|
||||
return 1
|
||||
max_future = max(candle["high"] for candle in future_slice)
|
||||
min_future = min(candle["low"] for candle in future_slice)
|
||||
if (max_future - open_price) / open_price >= threshold:
|
||||
return 2 # BUY
|
||||
elif (open_price - min_future) / open_price >= threshold:
|
||||
return 0 # SELL
|
||||
else:
|
||||
return 1 # HOLD
|
||||
|
||||
def step(self, action):
|
||||
base_candles = self.candles_dict[self.base_tf]
|
||||
# End-of-data: return dummy high/low targets
|
||||
if self.current_index >= len(base_candles) - 1:
|
||||
base = self.candle_window
|
||||
if self.current_index >= len(base) - 1:
|
||||
current_state = self.get_state(self.current_index)
|
||||
return current_state, 0.0, None, True, 0.0, 0.0
|
||||
|
||||
current_state = self.get_state(self.current_index)
|
||||
next_index = self.current_index + 1
|
||||
next_state = self.get_state(next_index)
|
||||
next_candle = base_candles[next_index]
|
||||
next_candle = base[next_index]
|
||||
reward = 0.0
|
||||
|
||||
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
|
||||
# Trade logic (0: SELL, 1: HOLD, 2: BUY)
|
||||
if self.position is None:
|
||||
if action == 2: # BUY signal: enter at next candle's open.
|
||||
if action == 2: # BUY: enter at next candle's open.
|
||||
entry_price = next_candle["open"]
|
||||
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
|
||||
else:
|
||||
if action == 0: # SELL signal: exit at next candle's open.
|
||||
if action == 0: # SELL: exit at next candle's open.
|
||||
exit_price = next_candle["open"]
|
||||
reward = exit_price - self.position["entry_price"]
|
||||
trade = {
|
||||
@ -379,27 +423,30 @@ class BacktestEnvironment:
|
||||
self.position = None
|
||||
|
||||
self.current_index = next_index
|
||||
done = (self.current_index >= len(base_candles) - 1)
|
||||
done = (self.current_index >= len(base) - 1)
|
||||
actual_high = next_candle["high"]
|
||||
actual_low = next_candle["low"]
|
||||
|
||||
# Compute optimal action hint and apply a penalty if action deviates.
|
||||
optimal_hint = self.compute_optimal_hint(horizon=10, threshold=0.005)
|
||||
if action != optimal_hint:
|
||||
reward -= self.hint_penalty * next_candle["open"]
|
||||
|
||||
return current_state, reward, next_state, done, actual_high, actual_low
|
||||
|
||||
def __len__(self):
|
||||
return len(self.candles_dict[self.base_tf])
|
||||
|
||||
# --- Enhanced Training Loop ---
|
||||
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler, base_candles):
|
||||
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
state = env.reset()
|
||||
total_loss = 0
|
||||
total_loss = 0.0
|
||||
model.train()
|
||||
while True:
|
||||
# Use forced action policy to guarantee at least one trade per episode
|
||||
# Use forced-action policy for trading (guaranteeing at least one trade per episode)
|
||||
action = get_forced_action(env)
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||
# Use the forced action in the environment step.
|
||||
# Use our forced action in the environment step.
|
||||
_, reward, next_state, done, actual_high, actual_low = env.step(action)
|
||||
target_high = torch.FloatTensor([actual_high]).to(device)
|
||||
target_low = torch.FloatTensor([actual_low]).to(device)
|
||||
@ -418,8 +465,8 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
|
||||
epoch_loss = total_loss / len(env)
|
||||
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
|
||||
save_checkpoint(model, optimizer, epoch, total_loss)
|
||||
# Update the live HTML file with the current epoch chart.
|
||||
update_live_html(base_candles, env.trade_history, epoch+1)
|
||||
# Update live HTML chart to display the current sliding window
|
||||
update_live_html(env.candle_window, env.trade_history, epoch+1)
|
||||
|
||||
# --- Live Plotting Functions (For live mode) ---
|
||||
def live_preview_loop(candles, env):
|
||||
@ -461,7 +508,8 @@ async def main():
|
||||
print("No historical candle data available for backtesting.")
|
||||
return
|
||||
base_tf = "1m"
|
||||
env = BacktestEnvironment(candles_dict, base_tf, timeframes)
|
||||
# Create the environment with a sliding window (simulate streaming data)
|
||||
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
|
||||
|
||||
start_epoch = 0
|
||||
checkpoint = None
|
||||
@ -475,7 +523,6 @@ async def main():
|
||||
else:
|
||||
print("Starting training from scratch as requested.")
|
||||
|
||||
# Create optimizer and scheduler in main
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
|
||||
if checkpoint is not None:
|
||||
@ -485,8 +532,7 @@ async def main():
|
||||
print("Loaded optimizer state from checkpoint.")
|
||||
else:
|
||||
print("No valid optimizer state found in checkpoint; starting fresh optimizer state.")
|
||||
# Pass the base timeframe candles for the live HTML chart update.
|
||||
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler, candles_dict[base_tf])
|
||||
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
|
||||
|
||||
elif args.mode == 'live':
|
||||
load_best_checkpoint(model)
|
||||
@ -494,22 +540,21 @@ async def main():
|
||||
if not candles_dict:
|
||||
print("No cached candles available for live preview.")
|
||||
return
|
||||
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes)
|
||||
preview_thread = threading.Thread(target=live_preview_loop, args=(candles_dict["1m"], env), daemon=True)
|
||||
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100)
|
||||
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
|
||||
preview_thread.start()
|
||||
print("Starting live trading loop. (Using forced action policy for simulation.)")
|
||||
# Here we use the forced-action policy as in training.
|
||||
print("Starting live trading loop. (Using forced-action policy for simulation.)")
|
||||
while True:
|
||||
action = get_forced_action(env)
|
||||
state, reward, next_state, done, _, _ = env.step(action)
|
||||
if done:
|
||||
print("Reached end of simulated data, resetting environment.")
|
||||
print("Reached end of simulation window, resetting environment.")
|
||||
state = env.reset()
|
||||
await asyncio.sleep(1)
|
||||
elif args.mode == 'inference':
|
||||
load_best_checkpoint(model)
|
||||
print("Running inference...")
|
||||
# Here you can apply a similar forced-action policy or use a learned policy.
|
||||
# Apply a similar (or learned) policy as needed.
|
||||
else:
|
||||
print("Invalid mode specified.")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user