profitable bot 1
This commit is contained in:
parent
dc5df52292
commit
fc2f834b32
@ -26,12 +26,12 @@ load_dotenv()
|
||||
def convert_timestamp(ts):
|
||||
"""
|
||||
Safely converts a timestamp to a datetime object.
|
||||
If the timestamp is abnormally high (i.e. in milliseconds),
|
||||
If the timestamp is abnormally high (e.g. in milliseconds),
|
||||
it is divided by 1000.
|
||||
"""
|
||||
ts = float(ts)
|
||||
if ts > 1e10: # Likely in milliseconds
|
||||
ts = ts / 1000.0
|
||||
ts /= 1000.0
|
||||
return datetime.fromtimestamp(ts)
|
||||
|
||||
# --- Directories ---
|
||||
@ -42,10 +42,10 @@ os.makedirs(BEST_DIR, exist_ok=True)
|
||||
CACHE_FILE = "candles_cache.json"
|
||||
|
||||
# --- Constants ---
|
||||
NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"]
|
||||
NUM_TIMEFRAMES = 6 # e.g., ["1s", "1m", "5m", "15m", "1h", "1d"]
|
||||
NUM_INDICATORS = 20 # e.g., 20 technical indicators
|
||||
FEATURES_PER_CHANNEL = 7 # Each channel input will have 7 features.
|
||||
ORDER_CHANNELS = 1 # One extra channel for order info.
|
||||
FEATURES_PER_CHANNEL = 7 # Each channel has 7 features.
|
||||
ORDER_CHANNELS = 1 # One extra channel for order information.
|
||||
|
||||
# --- Positional Encoding Module ---
|
||||
class PositionalEncoding(nn.Module):
|
||||
@ -66,6 +66,7 @@ class PositionalEncoding(nn.Module):
|
||||
class TradingModel(nn.Module):
|
||||
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
|
||||
super().__init__()
|
||||
# One branch per channel.
|
||||
self.channel_branches = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
|
||||
@ -78,7 +79,7 @@ class TradingModel(nn.Module):
|
||||
self.pos_encoder = PositionalEncoding(hidden_dim)
|
||||
encoder_layers = TransformerEncoderLayer(
|
||||
d_model=hidden_dim, nhead=4, dim_feedforward=512,
|
||||
dropout=0.1, activation='gelu', batch_first=True # Use batch_first to avoid nested tensor warning.
|
||||
dropout=0.1, activation='gelu', batch_first=True # avoid nested tensor warning
|
||||
)
|
||||
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
|
||||
self.attn_pool = nn.Linear(hidden_dim, 1)
|
||||
@ -100,8 +101,8 @@ class TradingModel(nn.Module):
|
||||
channel_out = self.channel_branches[i](x[:, i, :])
|
||||
channel_outs.append(channel_out)
|
||||
stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
|
||||
tf_embeds = self.timeframe_embed(timeframe_ids) # shape: [num_channels, hidden]
|
||||
stacked = stacked + tf_embeds.unsqueeze(0) # broadcast along batch dimension.
|
||||
tf_embeds = self.timeframe_embed(timeframe_ids) # [num_channels, hidden]
|
||||
stacked = stacked + tf_embeds.unsqueeze(0) # add embeddings (broadcast along batch)
|
||||
transformer_out = self.transformer(stacked)
|
||||
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1)
|
||||
aggregated = (transformer_out * attn_weights).sum(dim=1)
|
||||
@ -129,13 +130,13 @@ def get_aligned_candle_with_index(candles_list, target_ts):
|
||||
|
||||
def get_features_for_tf(candles_list, index, period=10):
|
||||
candle = candles_list[index]
|
||||
f_open = candle["open"]
|
||||
f_high = candle["high"]
|
||||
f_low = candle["low"]
|
||||
f_close = candle["close"]
|
||||
f_volume = candle["volume"]
|
||||
f_open = candle["open"]
|
||||
f_high = candle["high"]
|
||||
f_low = candle["low"]
|
||||
f_close = candle["close"]
|
||||
f_volume = candle["volume"]
|
||||
sma_close = compute_sma(candles_list, index, period)
|
||||
sma_volume = compute_sma_volume(candles_list, index, period)
|
||||
sma_volume= compute_sma_volume(candles_list, index, period)
|
||||
return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume]
|
||||
|
||||
# --- Caching & Checkpoint Functions ---
|
||||
@ -230,10 +231,10 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
|
||||
# --- Function for Manual Trade Override ---
|
||||
def manual_trade(env):
|
||||
"""
|
||||
When no sufficient action is taken by the model, manually decide the trade.
|
||||
Find the maximum high and minimum low in the remaining window.
|
||||
If maximum occurs before minimum, we short; otherwise we long.
|
||||
The trade is closed at the candle where the chosen extreme occurs.
|
||||
When no sufficient action is taken by the model, use a fallback:
|
||||
Scan the remaining window for the global maximum and minimum.
|
||||
If the maximum occurs before the minimum, simulate a short trade;
|
||||
otherwise simulate a long trade. Closes the trade at the candle where the chosen extreme occurs.
|
||||
"""
|
||||
current_index = env.current_index
|
||||
if current_index >= len(env.candle_window) - 1:
|
||||
@ -252,7 +253,6 @@ def manual_trade(env):
|
||||
if low_j < min_val:
|
||||
min_val = low_j
|
||||
i_min = j
|
||||
# If maximum occurs before minimum, we interpret that as short (price will drop).
|
||||
if i_max < i_min:
|
||||
entry_price = env.candle_window[current_index]["open"]
|
||||
exit_price = env.candle_window[i_min]["open"]
|
||||
@ -278,16 +278,97 @@ def manual_trade(env):
|
||||
env.trade_history.append(trade)
|
||||
env.current_index = trade["exit_index"]
|
||||
|
||||
# --- Live HTML Chart Update ---
|
||||
# --- Simulation for 1s Data Using Local Extrema ---
|
||||
def simulate_trades_1s(env):
|
||||
"""
|
||||
When the main timeframe is 1s, scan the entire remaining window to detect local extrema.
|
||||
If at least two extrema are found, pair consecutive extrema as trades.
|
||||
If none (or too few) are found, fallback to manual_trade.
|
||||
"""
|
||||
n = len(env.candle_window)
|
||||
extrema = []
|
||||
for i in range(env.current_index, n):
|
||||
# Add first and last points.
|
||||
if i == env.current_index or i == n-1:
|
||||
extrema.append(i)
|
||||
else:
|
||||
prev = env.candle_window[i-1]["close"]
|
||||
curr = env.candle_window[i]["close"]
|
||||
nex = env.candle_window[i+1]["close"]
|
||||
# A valley or a peak.
|
||||
if curr < prev and curr < nex:
|
||||
extrema.append(i)
|
||||
elif curr > prev and curr > nex:
|
||||
extrema.append(i)
|
||||
if len(extrema) < 2:
|
||||
manual_trade(env)
|
||||
return
|
||||
# Process consecutive extrema into trades.
|
||||
for j in range(len(extrema)-1):
|
||||
entry_idx = extrema[j]
|
||||
exit_idx = extrema[j+1]
|
||||
entry_price = env.candle_window[entry_idx]["open"]
|
||||
exit_price = env.candle_window[exit_idx]["open"]
|
||||
# If the entry candle’s close is lower than exit candle’s close, this is a long trade.
|
||||
if env.candle_window[entry_idx]["close"] < env.candle_window[exit_idx]["close"]:
|
||||
reward = exit_price - entry_price
|
||||
else:
|
||||
reward = entry_price - exit_price
|
||||
trade = {
|
||||
"entry_index": entry_idx,
|
||||
"entry_price": entry_price,
|
||||
"exit_index": exit_idx,
|
||||
"exit_price": exit_price,
|
||||
"pnl": reward
|
||||
}
|
||||
env.trade_history.append(trade)
|
||||
env.current_index = n - 1
|
||||
|
||||
# --- General Simulation of Trades ---
|
||||
def simulate_trades(model, env, device, args):
|
||||
"""
|
||||
Simulate trades over the current sliding window.
|
||||
If the main timeframe is 1s, use local extrema detection.
|
||||
Otherwise, check if the model's predicted potentials exceed the threshold.
|
||||
- If so, execute the model decision.
|
||||
- Otherwise, call the manual_trade override.
|
||||
"""
|
||||
if args.main_tf == "1s":
|
||||
simulate_trades_1s(env)
|
||||
return
|
||||
env.reset()
|
||||
while True:
|
||||
i = env.current_index
|
||||
if i >= len(env.candle_window) - 1:
|
||||
break
|
||||
state = env.get_state(i)
|
||||
current_open = env.candle_window[i]["open"]
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||
pred_high = pred_high.item()
|
||||
pred_low = pred_low.item()
|
||||
if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
|
||||
if (pred_high - current_open) >= (current_open - pred_low):
|
||||
action = 2 # BUY
|
||||
else:
|
||||
action = 0 # SELL
|
||||
_, _, _, done, _, _ = env.step(action)
|
||||
else:
|
||||
manual_trade(env)
|
||||
if env.current_index >= len(env.candle_window) - 1:
|
||||
break
|
||||
|
||||
# --- Live HTML Chart Update (with Volume) ---
|
||||
def update_live_html(candles, trade_history, epoch):
|
||||
"""
|
||||
Generate a chart image with actual timestamps on the x-axis and cumulative epoch PnL.
|
||||
The chart now also plots volume as a bar chart on a secondary y-axis.
|
||||
The HTML page auto-refreshes every 10 seconds.
|
||||
Generate an HTML page with a live chart.
|
||||
The chart displays price (line) and volume (bar chart on a secondary y-axis),
|
||||
and includes buy/sell markers with dotted lines connecting entries and exits.
|
||||
The page auto-refreshes every 10 seconds.
|
||||
"""
|
||||
from io import BytesIO
|
||||
import base64
|
||||
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
update_live_chart(ax, candles, trade_history)
|
||||
epoch_pnl = sum(trade["pnl"] for trade in trade_history)
|
||||
@ -334,12 +415,11 @@ def update_live_html(candles, trade_history, epoch):
|
||||
f.write(html_content)
|
||||
print("Updated live_chart.html.")
|
||||
|
||||
# --- Chart Drawing Helpers ---
|
||||
# --- Chart Drawing Helpers (with Volume) ---
|
||||
def update_live_chart(ax, candles, trade_history):
|
||||
"""
|
||||
Plot the price chart with proper timestamp conversion.
|
||||
Mark BUY (green) and SELL (red) actions (with dotted lines between),
|
||||
and plot volume as a bar chart on a secondary y-axis.
|
||||
Plot the price chart with actual timestamps and volume on a secondary y-axis.
|
||||
Mark BUY (green) and SELL (red) points and connect them with dotted lines.
|
||||
"""
|
||||
ax.clear()
|
||||
times = [convert_timestamp(candle["timestamp"]) for candle in candles]
|
||||
@ -348,11 +428,9 @@ def update_live_chart(ax, candles, trade_history):
|
||||
ax.set_xlabel("Time")
|
||||
ax.set_ylabel("Price")
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
|
||||
|
||||
# Plot volume on secondary axis.
|
||||
# Plot volume on secondary y-axis.
|
||||
ax2 = ax.twinx()
|
||||
volumes = [candle["volume"] for candle in candles]
|
||||
# Compute bar width in days.
|
||||
if len(times) > 1:
|
||||
times_num = mdates.date2num(times)
|
||||
bar_width = (times_num[-1] - times_num[0]) / len(times) * 0.8
|
||||
@ -360,7 +438,15 @@ def update_live_chart(ax, candles, trade_history):
|
||||
bar_width = 0.01
|
||||
ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume")
|
||||
ax2.set_ylabel("Volume")
|
||||
|
||||
# Plot trade markers.
|
||||
for trade in trade_history:
|
||||
entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"])
|
||||
exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"])
|
||||
in_price = trade["entry_price"]
|
||||
out_price = trade["exit_price"]
|
||||
ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY")
|
||||
ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL")
|
||||
ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue")
|
||||
# Combine legends.
|
||||
lines, labels = ax.get_legend_handles_labels()
|
||||
lines2, labels2 = ax2.get_legend_handles_labels()
|
||||
@ -369,42 +455,11 @@ def update_live_chart(ax, candles, trade_history):
|
||||
fig = ax.get_figure()
|
||||
fig.autofmt_xdate()
|
||||
|
||||
# --- Simulation of Trades for Visualization ---
|
||||
def simulate_trades(model, env, device, args):
|
||||
"""
|
||||
Run a simulation on the current sliding window.
|
||||
If the model produces a sufficiently strong signal (based on threshold), use its action.
|
||||
Otherwise, manually compute the trade by scanning for max/min prices.
|
||||
"""
|
||||
env.reset()
|
||||
while True:
|
||||
i = env.current_index
|
||||
if i >= len(env.candle_window) - 1:
|
||||
break
|
||||
state = env.get_state(i)
|
||||
current_open = env.candle_window[i]["open"]
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||
pred_high = pred_high.item()
|
||||
pred_low = pred_low.item()
|
||||
# If either upward potential or downward potential exceeds the threshold, use model decision.
|
||||
if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
|
||||
if (pred_high - current_open) >= (current_open - pred_low):
|
||||
action = 2 # BUY
|
||||
else:
|
||||
action = 0 # SELL
|
||||
_, _, _, done, _, _ = env.step(action)
|
||||
else:
|
||||
# No significant signal; use manual trade computation.
|
||||
manual_trade(env)
|
||||
if env.current_index >= len(env.candle_window) - 1:
|
||||
break
|
||||
|
||||
# --- Backtest Environment with Sliding Window and Order Info ---
|
||||
class BacktestEnvironment:
|
||||
def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
|
||||
self.candles_dict = candles_dict
|
||||
self.candles_dict = candles_dict # full candles dict across timeframes
|
||||
self.base_tf = base_tf
|
||||
self.timeframes = timeframes
|
||||
self.full_candles = candles_dict[base_tf]
|
||||
@ -459,7 +514,7 @@ class BacktestEnvironment:
|
||||
if action == 2: # BUY (open long)
|
||||
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
|
||||
else:
|
||||
if action == 0: # SELL (close long / exit trade)
|
||||
if action == 0: # SELL (close trade)
|
||||
exit_price = next_candle["open"]
|
||||
reward = exit_price - self.position["entry_price"]
|
||||
trade = {
|
||||
@ -513,7 +568,7 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
|
||||
simulate_trades(model, env, device, args)
|
||||
update_live_html(env.candle_window, env.trade_history, epoch+1)
|
||||
|
||||
# --- Live Plotting Functions (For Live Mode) ---
|
||||
# --- Live Plotting (for Live Mode) ---
|
||||
def live_preview_loop(candles, env):
|
||||
plt.ion()
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
@ -529,12 +584,14 @@ def parse_args():
|
||||
parser.add_argument('--epochs', type=int, default=1000)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--threshold', type=float, default=0.005,
|
||||
help="Minimum predicted move to trigger trade (used in loss; model may override with manual trade).")
|
||||
help="Minimum predicted move to trigger trade (used in loss; model may override manual trades).")
|
||||
parser.add_argument('--lambda_trade', type=float, default=1.0,
|
||||
help="Weight for trade surrogate loss.")
|
||||
help="Weight for the trade surrogate loss.")
|
||||
parser.add_argument('--penalty_noaction', type=float, default=10.0,
|
||||
help="Penalty if no action is taken (used in loss).")
|
||||
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
|
||||
parser.add_argument('--main_tf', type=str, default='1m',
|
||||
help="Desired main timeframe to focus on (e.g., '1s' or '1m').")
|
||||
return parser.parse_args()
|
||||
|
||||
def random_action():
|
||||
@ -545,17 +602,24 @@ async def main():
|
||||
args = parse_args()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print("Using device:", device)
|
||||
timeframes = ["1m", "5m", "15m", "1h", "1d"]
|
||||
# Load cached candles.
|
||||
candles_dict = load_candles_cache(CACHE_FILE)
|
||||
if not candles_dict:
|
||||
print("No historical candle data available for backtesting.")
|
||||
return
|
||||
# Define desired timeframes list; if available, include "1s".
|
||||
default_timeframes = ["1s", "1m", "5m", "15m", "1h", "1d"]
|
||||
timeframes = [tf for tf in default_timeframes if tf in candles_dict]
|
||||
if args.main_tf not in timeframes:
|
||||
print(f"Desired main timeframe {args.main_tf} is not available. Available: {timeframes}")
|
||||
return
|
||||
base_tf = args.main_tf # Set the main timeframe as the base for the environment.
|
||||
hidden_dim = 128
|
||||
total_channels = NUM_TIMEFRAMES + 1 + NUM_INDICATORS
|
||||
model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device)
|
||||
# Total channels: number of timeframes + 1 order channel + NUM_INDICATORS.
|
||||
total_channels = len(timeframes) + ORDER_CHANNELS + NUM_INDICATORS
|
||||
model = TradingModel(total_channels, len(timeframes)).to(device)
|
||||
|
||||
if args.mode == 'train':
|
||||
candles_dict = load_candles_cache(CACHE_FILE)
|
||||
if not candles_dict:
|
||||
print("No historical candle data available for backtesting.")
|
||||
return
|
||||
base_tf = "1m"
|
||||
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
|
||||
start_epoch = 0
|
||||
checkpoint = None
|
||||
@ -583,35 +647,34 @@ async def main():
|
||||
for f in os.listdir(chk_dir):
|
||||
os.remove(os.path.join(chk_dir, f))
|
||||
else:
|
||||
print("No valid optimizer state found in checkpoint; using fresh optimizer state.")
|
||||
print("No valid optimizer state found; using fresh optimizer state.")
|
||||
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
|
||||
|
||||
elif args.mode == 'live':
|
||||
load_best_checkpoint(model)
|
||||
candles_dict = load_candles_cache(CACHE_FILE)
|
||||
if not candles_dict:
|
||||
print("No cached candles available for live preview.")
|
||||
return
|
||||
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100)
|
||||
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
|
||||
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
|
||||
preview_thread.start()
|
||||
print("Starting live trading loop. (Using model, with manual override for HOLD actions.)")
|
||||
print("Starting live trading loop. (For main_tf={} using manual override if model signal is weak.)".format(args.main_tf))
|
||||
while True:
|
||||
state = env.get_state(env.current_index)
|
||||
current_open = env.candle_window[env.current_index]["open"]
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||
pred_high = pred_high.item()
|
||||
pred_low = pred_low.item()
|
||||
if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
|
||||
if (pred_high - current_open) >= (current_open - pred_low):
|
||||
action = 2 # BUY
|
||||
else:
|
||||
action = 0 # SELL
|
||||
_, _, _, done, _, _ = env.step(action)
|
||||
if args.main_tf == "1s":
|
||||
simulate_trades_1s(env)
|
||||
else:
|
||||
manual_trade(env)
|
||||
state = env.get_state(env.current_index)
|
||||
current_open = env.candle_window[env.current_index]["open"]
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||
pred_high = pred_high.item()
|
||||
pred_low = pred_low.item()
|
||||
if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
|
||||
if (pred_high - current_open) >= (current_open - pred_low):
|
||||
action = 2
|
||||
else:
|
||||
action = 0
|
||||
_, _, _, done, _, _ = env.step(action)
|
||||
else:
|
||||
manual_trade(env)
|
||||
if env.current_index >= len(env.candle_window)-1:
|
||||
print("Reached end of simulation window; resetting environment.")
|
||||
env.reset()
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user