even better
This commit is contained in:
parent
375aebee88
commit
967363378b
@ -16,9 +16,8 @@ import torch.nn as nn
|
|||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import ccxt.async_support as ccxt
|
|
||||||
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
|
||||||
import math
|
import math
|
||||||
|
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
@ -30,9 +29,9 @@ os.makedirs(BEST_DIR, exist_ok=True)
|
|||||||
CACHE_FILE = "candles_cache.json"
|
CACHE_FILE = "candles_cache.json"
|
||||||
|
|
||||||
# --- Constants ---
|
# --- Constants ---
|
||||||
NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"]
|
NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"]
|
||||||
NUM_INDICATORS = 20 # e.g., 20 technical indicators
|
NUM_INDICATORS = 20 # e.g., 20 technical indicators
|
||||||
FEATURES_PER_CHANNEL = 7 # e.g., H, L, O, C, Volume, SMA_close, SMA_volume
|
FEATURES_PER_CHANNEL = 7 # e.g., [open, high, low, close, volume, sma_close, sma_volume]
|
||||||
|
|
||||||
# --- Positional Encoding Module ---
|
# --- Positional Encoding Module ---
|
||||||
class PositionalEncoding(nn.Module):
|
class PositionalEncoding(nn.Module):
|
||||||
@ -53,7 +52,7 @@ class PositionalEncoding(nn.Module):
|
|||||||
class TradingModel(nn.Module):
|
class TradingModel(nn.Module):
|
||||||
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
|
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Create one branch per channel (each channel input has FEATURES_PER_CHANNEL features)
|
# One branch per channel
|
||||||
self.channel_branches = nn.ModuleList([
|
self.channel_branches = nn.ModuleList([
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
|
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
|
||||||
@ -62,7 +61,6 @@ class TradingModel(nn.Module):
|
|||||||
nn.Dropout(0.1)
|
nn.Dropout(0.1)
|
||||||
) for _ in range(num_channels)
|
) for _ in range(num_channels)
|
||||||
])
|
])
|
||||||
# Embedding for channels 0..num_channels-1.
|
|
||||||
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
|
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
|
||||||
self.pos_encoder = PositionalEncoding(hidden_dim)
|
self.pos_encoder = PositionalEncoding(hidden_dim)
|
||||||
encoder_layers = TransformerEncoderLayer(
|
encoder_layers = TransformerEncoderLayer(
|
||||||
@ -82,15 +80,14 @@ class TradingModel(nn.Module):
|
|||||||
nn.Linear(hidden_dim // 2, 1)
|
nn.Linear(hidden_dim // 2, 1)
|
||||||
)
|
)
|
||||||
def forward(self, x, timeframe_ids):
|
def forward(self, x, timeframe_ids):
|
||||||
# x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL]
|
# x: [batch_size, num_channels, FEATURES_PER_CHANNEL]
|
||||||
batch_size, num_channels, _ = x.shape
|
batch_size, num_channels, _ = x.shape
|
||||||
channel_outs = []
|
channel_outs = []
|
||||||
for i in range(num_channels):
|
for i in range(num_channels):
|
||||||
channel_out = self.channel_branches[i](x[:, i, :])
|
channel_out = self.channel_branches[i](x[:, i, :])
|
||||||
channel_outs.append(channel_out)
|
channel_outs.append(channel_out)
|
||||||
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
|
stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
|
||||||
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
|
stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden]
|
||||||
# Add embedding for each channel.
|
|
||||||
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
|
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
|
||||||
stacked = stacked + tf_embeds
|
stacked = stacked + tf_embeds
|
||||||
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
|
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
|
||||||
@ -103,12 +100,12 @@ class TradingModel(nn.Module):
|
|||||||
def compute_sma(candles_list, index, period=10):
|
def compute_sma(candles_list, index, period=10):
|
||||||
start = max(0, index - period + 1)
|
start = max(0, index - period + 1)
|
||||||
values = [candle["close"] for candle in candles_list[start:index+1]]
|
values = [candle["close"] for candle in candles_list[start:index+1]]
|
||||||
return sum(values) / len(values) if values else 0.0
|
return sum(values)/len(values) if values else 0.0
|
||||||
|
|
||||||
def compute_sma_volume(candles_list, index, period=10):
|
def compute_sma_volume(candles_list, index, period=10):
|
||||||
start = max(0, index - period + 1)
|
start = max(0, index - period + 1)
|
||||||
values = [candle["volume"] for candle in candles_list[start:index+1]]
|
values = [candle["volume"] for candle in candles_list[start:index+1]]
|
||||||
return sum(values) / len(values) if values else 0.0
|
return sum(values)/len(values) if values else 0.0
|
||||||
|
|
||||||
def get_aligned_candle_with_index(candles_list, target_ts):
|
def get_aligned_candle_with_index(candles_list, target_ts):
|
||||||
best_idx = 0
|
best_idx = 0
|
||||||
@ -123,7 +120,7 @@ def get_features_for_tf(candles_list, index, period=10):
|
|||||||
candle = candles_list[index]
|
candle = candles_list[index]
|
||||||
f_open = candle["open"]
|
f_open = candle["open"]
|
||||||
f_high = candle["high"]
|
f_high = candle["high"]
|
||||||
f_low = candle["low"]
|
f_low = candle["low"]
|
||||||
f_close = candle["close"]
|
f_close = candle["close"]
|
||||||
f_volume = candle["volume"]
|
f_volume = candle["volume"]
|
||||||
sma_close = compute_sma(candles_list, index, period)
|
sma_close = compute_sma(candles_list, index, period)
|
||||||
@ -154,7 +151,7 @@ def maintain_checkpoint_directory(directory, max_files=10):
|
|||||||
if len(files) > max_files:
|
if len(files) > max_files:
|
||||||
full_paths = [os.path.join(directory, f) for f in files]
|
full_paths = [os.path.join(directory, f) for f in files]
|
||||||
full_paths.sort(key=lambda x: os.path.getmtime(x))
|
full_paths.sort(key=lambda x: os.path.getmtime(x))
|
||||||
for f in full_paths[: len(files) - max_files]:
|
for f in full_paths[:len(files)-max_files]:
|
||||||
os.remove(f)
|
os.remove(f)
|
||||||
|
|
||||||
def get_best_models(directory):
|
def get_best_models(directory):
|
||||||
@ -162,7 +159,6 @@ def get_best_models(directory):
|
|||||||
for file in os.listdir(directory):
|
for file in os.listdir(directory):
|
||||||
parts = file.split("_")
|
parts = file.split("_")
|
||||||
try:
|
try:
|
||||||
# parts[1] is the recorded loss
|
|
||||||
loss = float(parts[1])
|
loss = float(parts[1])
|
||||||
best_files.append((loss, file))
|
best_files.append((loss, file))
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -174,10 +170,10 @@ def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=B
|
|||||||
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
|
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
|
||||||
last_path = os.path.join(last_dir, last_filename)
|
last_path = os.path.join(last_dir, last_filename)
|
||||||
torch.save({
|
torch.save({
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
"loss": loss,
|
"loss": loss,
|
||||||
"model_state_dict": model.state_dict(),
|
"model_state_dict": model.state_dict(),
|
||||||
"optimizer_state_dict": optimizer.state_dict()
|
"optimizer_state_dict": optimizer.state_dict()
|
||||||
}, last_path)
|
}, last_path)
|
||||||
maintain_checkpoint_directory(last_dir, max_files=10)
|
maintain_checkpoint_directory(last_dir, max_files=10)
|
||||||
best_models = get_best_models(best_dir)
|
best_models = get_best_models(best_dir)
|
||||||
@ -215,8 +211,8 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
|
|||||||
# --- Live HTML Chart Update ---
|
# --- Live HTML Chart Update ---
|
||||||
def update_live_html(candles, trade_history, epoch):
|
def update_live_html(candles, trade_history, epoch):
|
||||||
"""
|
"""
|
||||||
Generate a chart image with buy/sell markers and a dotted line between open/close positions,
|
Generate a chart image with buy/sell markers and dotted lines between entry and exit,
|
||||||
then embed it in a simple HTML page that auto-refreshes every 10 seconds.
|
then embed it in an auto-refreshing HTML page.
|
||||||
"""
|
"""
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import base64
|
import base64
|
||||||
@ -266,10 +262,10 @@ def update_live_html(candles, trade_history, epoch):
|
|||||||
f.write(html_content)
|
f.write(html_content)
|
||||||
print("Updated live_chart.html.")
|
print("Updated live_chart.html.")
|
||||||
|
|
||||||
# --- Chart Drawing Helpers (used by both live preview and HTML update) ---
|
# --- Chart Drawing Helpers ---
|
||||||
def update_live_chart(ax, candles, trade_history):
|
def update_live_chart(ax, candles, trade_history):
|
||||||
"""
|
"""
|
||||||
Plot the chart with close price, buy/sell markers, and dotted lines joining entry/exit.
|
Draw the price chart with close prices and mark BUY (green) and SELL (red) actions.
|
||||||
"""
|
"""
|
||||||
ax.clear()
|
ax.clear()
|
||||||
close_prices = [candle["close"] for candle in candles]
|
close_prices = [candle["close"] for candle in candles]
|
||||||
@ -298,39 +294,44 @@ def update_live_chart(ax, candles, trade_history):
|
|||||||
ax.legend()
|
ax.legend()
|
||||||
ax.grid(True)
|
ax.grid(True)
|
||||||
|
|
||||||
# --- Forced Action & Optimal Hint Helpers ---
|
# --- Simulation of Trades for Visualization ---
|
||||||
def get_forced_action(env):
|
def simulate_trades(model, env, device, args):
|
||||||
"""
|
"""
|
||||||
When simulating streaming data, we force a trade at strategic moments:
|
Run a complete simulation on the current sliding window using a decision rule based on model outputs.
|
||||||
- At the very first step: force BUY.
|
This simulation (which updates env.trade_history) is used only for visualization.
|
||||||
- At the penultimate step: if a position is open, force SELL.
|
|
||||||
- Otherwise, default to HOLD.
|
|
||||||
(The environment will also apply a penalty if the chosen action does not match the optimal hint.)
|
|
||||||
"""
|
"""
|
||||||
total = len(env)
|
env.reset() # resets the sliding window and index
|
||||||
if env.current_index == 0:
|
while True:
|
||||||
return 2 # BUY
|
i = env.current_index
|
||||||
elif env.current_index >= total - 2:
|
state = env.get_state(i)
|
||||||
if env.position is not None:
|
current_open = env.candle_window[i]["open"]
|
||||||
return 0 # SELL
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||||
|
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
||||||
|
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||||
|
pred_high = pred_high.item()
|
||||||
|
pred_low = pred_low.item()
|
||||||
|
# Decision rule: if upward move larger than downward and above threshold, BUY; if downward is larger, SELL; else HOLD.
|
||||||
|
if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold:
|
||||||
|
action = 2 # BUY
|
||||||
|
elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold:
|
||||||
|
action = 0 # SELL
|
||||||
else:
|
else:
|
||||||
return 1 # HOLD
|
action = 1 # HOLD
|
||||||
else:
|
_, _, _, done, _, _ = env.step(action)
|
||||||
return 1 # HOLD
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
# --- Backtest Environment with Sliding Window and Hints ---
|
# --- Backtest Environment with Sliding Window ---
|
||||||
class BacktestEnvironment:
|
class BacktestEnvironment:
|
||||||
def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
|
def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
|
||||||
self.candles_dict = candles_dict # full dictionary of timeframe candles
|
self.candles_dict = candles_dict # full candles dict for all timeframes
|
||||||
self.base_tf = base_tf
|
self.base_tf = base_tf
|
||||||
self.timeframes = timeframes
|
self.timeframes = timeframes
|
||||||
# Use maximum allowed candles for the base timeframe.
|
|
||||||
self.full_candles = candles_dict[base_tf]
|
self.full_candles = candles_dict[base_tf]
|
||||||
# Determine sliding window size:
|
|
||||||
if window_size is None:
|
if window_size is None:
|
||||||
window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles)
|
window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles)
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.hint_penalty = 0.001 # Penalty coefficient (multiplied by open price)
|
self.hint_penalty = 0.001 # not used in the revised loss below
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@ -346,52 +347,26 @@ class BacktestEnvironment:
|
|||||||
return self.window_size
|
return self.window_size
|
||||||
|
|
||||||
def get_state(self, index):
|
def get_state(self, index):
|
||||||
"""
|
|
||||||
Build state features by taking the candle at the current index for the base timeframe
|
|
||||||
(from the sliding window) and aligning candles for other timeframes.
|
|
||||||
Then append zeros for technical indicators.
|
|
||||||
"""
|
|
||||||
state_features = []
|
state_features = []
|
||||||
base_ts = self.candle_window[index]["timestamp"]
|
base_ts = self.candle_window[index]["timestamp"]
|
||||||
for tf in self.timeframes:
|
for tf in self.timeframes:
|
||||||
if tf == self.base_tf:
|
if tf == self.base_tf:
|
||||||
# For base timeframe, use the sliding window candle.
|
|
||||||
candle = self.candle_window[index]
|
candle = self.candle_window[index]
|
||||||
features = get_features_for_tf([candle], 0) # List of one element
|
features = get_features_for_tf([candle], 0)
|
||||||
else:
|
else:
|
||||||
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
|
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
|
||||||
features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
|
features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
|
||||||
state_features.append(features)
|
state_features.append(features)
|
||||||
for _ in range(NUM_INDICATORS):
|
for _ in range(NUM_INDICATORS):
|
||||||
state_features.append([0.0] * FEATURES_PER_CHANNEL)
|
state_features.append([0.0]*FEATURES_PER_CHANNEL)
|
||||||
return np.array(state_features, dtype=np.float32)
|
return np.array(state_features, dtype=np.float32)
|
||||||
|
|
||||||
def compute_optimal_hint(self, horizon=10, threshold=0.005):
|
|
||||||
"""
|
|
||||||
Using a lookahead window from the sliding window (future candles)
|
|
||||||
determine an optimal action hint:
|
|
||||||
2: BUY if price is expected to rise at least by threshold.
|
|
||||||
0: SELL if expected to drop by threshold.
|
|
||||||
1: HOLD otherwise.
|
|
||||||
"""
|
|
||||||
base = self.candle_window
|
|
||||||
if self.current_index >= len(base) - 1:
|
|
||||||
return 1 # Hold
|
|
||||||
current_candle = base[self.current_index]
|
|
||||||
open_price = current_candle["open"]
|
|
||||||
future_slice = base[self.current_index+1: min(self.current_index+1+horizon, len(base))]
|
|
||||||
if not future_slice:
|
|
||||||
return 1
|
|
||||||
max_future = max(candle["high"] for candle in future_slice)
|
|
||||||
min_future = min(candle["low"] for candle in future_slice)
|
|
||||||
if (max_future - open_price) / open_price >= threshold:
|
|
||||||
return 2 # BUY
|
|
||||||
elif (open_price - min_future) / open_price >= threshold:
|
|
||||||
return 0 # SELL
|
|
||||||
else:
|
|
||||||
return 1 # HOLD
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
"""
|
||||||
|
Discrete simulation step.
|
||||||
|
- Action: 0 (SELL), 1 (HOLD), 2 (BUY).
|
||||||
|
- Trades are recorded when a BUY is followed by a SELL.
|
||||||
|
"""
|
||||||
base = self.candle_window
|
base = self.candle_window
|
||||||
if self.current_index >= len(base) - 1:
|
if self.current_index >= len(base) - 1:
|
||||||
current_state = self.get_state(self.current_index)
|
current_state = self.get_state(self.current_index)
|
||||||
@ -403,13 +378,12 @@ class BacktestEnvironment:
|
|||||||
next_candle = base[next_index]
|
next_candle = base[next_index]
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
|
|
||||||
# Trade logic (0: SELL, 1: HOLD, 2: BUY)
|
# Simple trading logic (only one position allowed at a time)
|
||||||
if self.position is None:
|
if self.position is None:
|
||||||
if action == 2: # BUY: enter at next candle's open.
|
if action == 2: # BUY signal: enter at next open.
|
||||||
entry_price = next_candle["open"]
|
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
|
||||||
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
|
|
||||||
else:
|
else:
|
||||||
if action == 0: # SELL: exit at next candle's open.
|
if action == 0: # SELL signal: exit at next open.
|
||||||
exit_price = next_candle["open"]
|
exit_price = next_candle["open"]
|
||||||
reward = exit_price - self.position["entry_price"]
|
reward = exit_price - self.position["entry_price"]
|
||||||
trade = {
|
trade = {
|
||||||
@ -426,49 +400,49 @@ class BacktestEnvironment:
|
|||||||
done = (self.current_index >= len(base) - 1)
|
done = (self.current_index >= len(base) - 1)
|
||||||
actual_high = next_candle["high"]
|
actual_high = next_candle["high"]
|
||||||
actual_low = next_candle["low"]
|
actual_low = next_candle["low"]
|
||||||
|
|
||||||
# Compute optimal action hint and apply a penalty if action deviates.
|
|
||||||
optimal_hint = self.compute_optimal_hint(horizon=10, threshold=0.005)
|
|
||||||
if action != optimal_hint:
|
|
||||||
reward -= self.hint_penalty * next_candle["open"]
|
|
||||||
|
|
||||||
return current_state, reward, next_state, done, actual_high, actual_low
|
return current_state, reward, next_state, done, actual_high, actual_low
|
||||||
|
|
||||||
# --- Enhanced Training Loop ---
|
# --- Enhanced Training Loop ---
|
||||||
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
|
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
|
||||||
|
# Weighting factor for trade surrogate loss.
|
||||||
|
lambda_trade = 1.0
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
state = env.reset()
|
# Reset sliding window for each epoch.
|
||||||
total_loss = 0.0
|
env.reset()
|
||||||
model.train()
|
loss_accum = 0.0
|
||||||
while True:
|
steps = len(env) - 1 # we use pairs of consecutive candles
|
||||||
# Use forced-action policy for trading (guaranteeing at least one trade per episode)
|
for i in range(steps):
|
||||||
action = get_forced_action(env)
|
state = env.get_state(i)
|
||||||
|
current_open = env.candle_window[i]["open"]
|
||||||
|
# Next candle's actual values serve as targets.
|
||||||
|
actual_high = env.candle_window[i+1]["high"]
|
||||||
|
actual_low = env.candle_window[i+1]["low"]
|
||||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||||
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
||||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||||
# Use our forced action in the environment step.
|
# Compute prediction loss (L1)
|
||||||
_, reward, next_state, done, actual_high, actual_low = env.step(action)
|
L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \
|
||||||
target_high = torch.FloatTensor([actual_high]).to(device)
|
torch.abs(pred_low - torch.tensor(actual_low, device=device))
|
||||||
target_low = torch.FloatTensor([actual_low]).to(device)
|
# Compute surrogate profit (differentiable estimate)
|
||||||
high_loss = torch.abs(pred_high - target_high) * 2
|
profit_buy = pred_high - current_open # potential long gain
|
||||||
low_loss = torch.abs(pred_low - target_low) * 2
|
profit_sell = current_open - pred_low # potential short gain
|
||||||
loss = (high_loss + low_loss).mean()
|
# Here we reward a higher potential move by subtracting it.
|
||||||
|
L_trade = - torch.max(profit_buy, profit_sell)
|
||||||
|
loss = L_pred + lambda_trade * L_trade
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
total_loss += loss.item()
|
loss_accum += loss.item()
|
||||||
if done:
|
|
||||||
break
|
|
||||||
state = next_state
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
epoch_loss = total_loss / len(env)
|
epoch_loss = loss_accum / steps
|
||||||
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
|
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
|
||||||
save_checkpoint(model, optimizer, epoch, total_loss)
|
save_checkpoint(model, optimizer, epoch, loss_accum)
|
||||||
# Update live HTML chart to display the current sliding window
|
# Update the trade simulation (for visualization) using the current model on the same window.
|
||||||
|
simulate_trades(model, env, device, args)
|
||||||
update_live_html(env.candle_window, env.trade_history, epoch+1)
|
update_live_html(env.candle_window, env.trade_history, epoch+1)
|
||||||
|
|
||||||
# --- Live Plotting Functions (For live mode) ---
|
# --- Live Plotting Functions (For Live Mode) ---
|
||||||
def live_preview_loop(candles, env):
|
def live_preview_loop(candles, env):
|
||||||
plt.ion()
|
plt.ion()
|
||||||
fig, ax = plt.subplots(figsize=(12, 6))
|
fig, ax = plt.subplots(figsize=(12, 6))
|
||||||
@ -480,21 +454,17 @@ def live_preview_loop(candles, env):
|
|||||||
# --- Argument Parsing ---
|
# --- Argument Parsing ---
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--mode', choices=['train','live','inference'], default='train')
|
parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train')
|
||||||
parser.add_argument('--epochs', type=int, default=100)
|
parser.add_argument('--epochs', type=int, default=100)
|
||||||
parser.add_argument('--lr', type=float, default=3e-4)
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--threshold', type=float, default=0.005)
|
parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.")
|
||||||
# If set, training starts from scratch (ignoring saved checkpoints)
|
parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for the trade surrogate loss.")
|
||||||
parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch.')
|
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
def random_action():
|
|
||||||
return random.randint(0, 2)
|
|
||||||
|
|
||||||
# --- Main Function ---
|
# --- Main Function ---
|
||||||
async def main():
|
async def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
# Use GPU if available; else CPU
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
print("Using device:", device)
|
print("Using device:", device)
|
||||||
timeframes = ["1m", "5m", "15m", "1h", "1d"]
|
timeframes = ["1m", "5m", "15m", "1h", "1d"]
|
||||||
@ -508,9 +478,8 @@ async def main():
|
|||||||
print("No historical candle data available for backtesting.")
|
print("No historical candle data available for backtesting.")
|
||||||
return
|
return
|
||||||
base_tf = "1m"
|
base_tf = "1m"
|
||||||
# Create the environment with a sliding window (simulate streaming data)
|
# Use a sliding window of up to 100 candles (if available)
|
||||||
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
|
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
|
||||||
|
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
if not args.start_fresh:
|
if not args.start_fresh:
|
||||||
@ -522,7 +491,6 @@ async def main():
|
|||||||
print("No checkpoint found. Starting training from scratch.")
|
print("No checkpoint found. Starting training from scratch.")
|
||||||
else:
|
else:
|
||||||
print("Starting training from scratch as requested.")
|
print("Starting training from scratch as requested.")
|
||||||
|
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
|
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
@ -543,18 +511,31 @@ async def main():
|
|||||||
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100)
|
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100)
|
||||||
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
|
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
|
||||||
preview_thread.start()
|
preview_thread.start()
|
||||||
print("Starting live trading loop. (Using forced-action policy for simulation.)")
|
print("Starting live trading loop. (Using model-based decision rule.)")
|
||||||
while True:
|
while True:
|
||||||
action = get_forced_action(env)
|
# In live mode, we use the simulation decision rule.
|
||||||
state, reward, next_state, done, _, _ = env.step(action)
|
state = env.get_state(env.current_index)
|
||||||
|
current_open = env.candle_window[env.current_index]["open"]
|
||||||
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||||
|
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
||||||
|
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||||
|
pred_high = pred_high.item()
|
||||||
|
pred_low = pred_low.item()
|
||||||
|
if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold:
|
||||||
|
action = 2
|
||||||
|
elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold:
|
||||||
|
action = 0
|
||||||
|
else:
|
||||||
|
action = 1
|
||||||
|
_, _, _, done, _, _ = env.step(action)
|
||||||
if done:
|
if done:
|
||||||
print("Reached end of simulation window, resetting environment.")
|
print("Reached end of simulation window; resetting environment.")
|
||||||
state = env.reset()
|
env.reset()
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
elif args.mode == 'inference':
|
elif args.mode == 'inference':
|
||||||
load_best_checkpoint(model)
|
load_best_checkpoint(model)
|
||||||
print("Running inference...")
|
print("Running inference...")
|
||||||
# Apply a similar (or learned) policy as needed.
|
# Inference logic can use a similar decision rule as in live mode.
|
||||||
else:
|
else:
|
||||||
print("Invalid mode specified.")
|
print("Invalid mode specified.")
|
||||||
|
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user