profitable bot 1

This commit is contained in:
Dobromir Popov 2025-02-05 09:16:16 +02:00
parent dc5df52292
commit fc2f834b32
2 changed files with 163 additions and 100 deletions

View File

@ -26,12 +26,12 @@ load_dotenv()
def convert_timestamp(ts):
"""
Safely converts a timestamp to a datetime object.
If the timestamp is abnormally high (i.e. in milliseconds),
If the timestamp is abnormally high (e.g. in milliseconds),
it is divided by 1000.
"""
ts = float(ts)
if ts > 1e10: # Likely in milliseconds
ts = ts / 1000.0
ts /= 1000.0
return datetime.fromtimestamp(ts)
# --- Directories ---
@ -42,10 +42,10 @@ os.makedirs(BEST_DIR, exist_ok=True)
CACHE_FILE = "candles_cache.json"
# --- Constants ---
NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"]
NUM_TIMEFRAMES = 6 # e.g., ["1s", "1m", "5m", "15m", "1h", "1d"]
NUM_INDICATORS = 20 # e.g., 20 technical indicators
FEATURES_PER_CHANNEL = 7 # Each channel input will have 7 features.
ORDER_CHANNELS = 1 # One extra channel for order info.
FEATURES_PER_CHANNEL = 7 # Each channel has 7 features.
ORDER_CHANNELS = 1 # One extra channel for order information.
# --- Positional Encoding Module ---
class PositionalEncoding(nn.Module):
@ -66,6 +66,7 @@ class PositionalEncoding(nn.Module):
class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__()
# One branch per channel.
self.channel_branches = nn.ModuleList([
nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
@ -78,7 +79,7 @@ class TradingModel(nn.Module):
self.pos_encoder = PositionalEncoding(hidden_dim)
encoder_layers = TransformerEncoderLayer(
d_model=hidden_dim, nhead=4, dim_feedforward=512,
dropout=0.1, activation='gelu', batch_first=True # Use batch_first to avoid nested tensor warning.
dropout=0.1, activation='gelu', batch_first=True # avoid nested tensor warning
)
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
self.attn_pool = nn.Linear(hidden_dim, 1)
@ -100,8 +101,8 @@ class TradingModel(nn.Module):
channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out)
stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
tf_embeds = self.timeframe_embed(timeframe_ids) # shape: [num_channels, hidden]
stacked = stacked + tf_embeds.unsqueeze(0) # broadcast along batch dimension.
tf_embeds = self.timeframe_embed(timeframe_ids) # [num_channels, hidden]
stacked = stacked + tf_embeds.unsqueeze(0) # add embeddings (broadcast along batch)
transformer_out = self.transformer(stacked)
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1)
aggregated = (transformer_out * attn_weights).sum(dim=1)
@ -135,7 +136,7 @@ def get_features_for_tf(candles_list, index, period=10):
f_close = candle["close"]
f_volume = candle["volume"]
sma_close = compute_sma(candles_list, index, period)
sma_volume = compute_sma_volume(candles_list, index, period)
sma_volume= compute_sma_volume(candles_list, index, period)
return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume]
# --- Caching & Checkpoint Functions ---
@ -230,10 +231,10 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
# --- Function for Manual Trade Override ---
def manual_trade(env):
"""
When no sufficient action is taken by the model, manually decide the trade.
Find the maximum high and minimum low in the remaining window.
If maximum occurs before minimum, we short; otherwise we long.
The trade is closed at the candle where the chosen extreme occurs.
When no sufficient action is taken by the model, use a fallback:
Scan the remaining window for the global maximum and minimum.
If the maximum occurs before the minimum, simulate a short trade;
otherwise simulate a long trade. Closes the trade at the candle where the chosen extreme occurs.
"""
current_index = env.current_index
if current_index >= len(env.candle_window) - 1:
@ -252,7 +253,6 @@ def manual_trade(env):
if low_j < min_val:
min_val = low_j
i_min = j
# If maximum occurs before minimum, we interpret that as short (price will drop).
if i_max < i_min:
entry_price = env.candle_window[current_index]["open"]
exit_price = env.candle_window[i_min]["open"]
@ -278,16 +278,97 @@ def manual_trade(env):
env.trade_history.append(trade)
env.current_index = trade["exit_index"]
# --- Live HTML Chart Update ---
# --- Simulation for 1s Data Using Local Extrema ---
def simulate_trades_1s(env):
"""
When the main timeframe is 1s, scan the entire remaining window to detect local extrema.
If at least two extrema are found, pair consecutive extrema as trades.
If none (or too few) are found, fallback to manual_trade.
"""
n = len(env.candle_window)
extrema = []
for i in range(env.current_index, n):
# Add first and last points.
if i == env.current_index or i == n-1:
extrema.append(i)
else:
prev = env.candle_window[i-1]["close"]
curr = env.candle_window[i]["close"]
nex = env.candle_window[i+1]["close"]
# A valley or a peak.
if curr < prev and curr < nex:
extrema.append(i)
elif curr > prev and curr > nex:
extrema.append(i)
if len(extrema) < 2:
manual_trade(env)
return
# Process consecutive extrema into trades.
for j in range(len(extrema)-1):
entry_idx = extrema[j]
exit_idx = extrema[j+1]
entry_price = env.candle_window[entry_idx]["open"]
exit_price = env.candle_window[exit_idx]["open"]
# If the entry candles close is lower than exit candles close, this is a long trade.
if env.candle_window[entry_idx]["close"] < env.candle_window[exit_idx]["close"]:
reward = exit_price - entry_price
else:
reward = entry_price - exit_price
trade = {
"entry_index": entry_idx,
"entry_price": entry_price,
"exit_index": exit_idx,
"exit_price": exit_price,
"pnl": reward
}
env.trade_history.append(trade)
env.current_index = n - 1
# --- General Simulation of Trades ---
def simulate_trades(model, env, device, args):
"""
Simulate trades over the current sliding window.
If the main timeframe is 1s, use local extrema detection.
Otherwise, check if the model's predicted potentials exceed the threshold.
- If so, execute the model decision.
- Otherwise, call the manual_trade override.
"""
if args.main_tf == "1s":
simulate_trades_1s(env)
return
env.reset()
while True:
i = env.current_index
if i >= len(env.candle_window) - 1:
break
state = env.get_state(i)
current_open = env.candle_window[i]["open"]
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item()
pred_low = pred_low.item()
if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
if (pred_high - current_open) >= (current_open - pred_low):
action = 2 # BUY
else:
action = 0 # SELL
_, _, _, done, _, _ = env.step(action)
else:
manual_trade(env)
if env.current_index >= len(env.candle_window) - 1:
break
# --- Live HTML Chart Update (with Volume) ---
def update_live_html(candles, trade_history, epoch):
"""
Generate a chart image with actual timestamps on the x-axis and cumulative epoch PnL.
The chart now also plots volume as a bar chart on a secondary y-axis.
The HTML page auto-refreshes every 10 seconds.
Generate an HTML page with a live chart.
The chart displays price (line) and volume (bar chart on a secondary y-axis),
and includes buy/sell markers with dotted lines connecting entries and exits.
The page auto-refreshes every 10 seconds.
"""
from io import BytesIO
import base64
fig, ax = plt.subplots(figsize=(12, 6))
update_live_chart(ax, candles, trade_history)
epoch_pnl = sum(trade["pnl"] for trade in trade_history)
@ -334,12 +415,11 @@ def update_live_html(candles, trade_history, epoch):
f.write(html_content)
print("Updated live_chart.html.")
# --- Chart Drawing Helpers ---
# --- Chart Drawing Helpers (with Volume) ---
def update_live_chart(ax, candles, trade_history):
"""
Plot the price chart with proper timestamp conversion.
Mark BUY (green) and SELL (red) actions (with dotted lines between),
and plot volume as a bar chart on a secondary y-axis.
Plot the price chart with actual timestamps and volume on a secondary y-axis.
Mark BUY (green) and SELL (red) points and connect them with dotted lines.
"""
ax.clear()
times = [convert_timestamp(candle["timestamp"]) for candle in candles]
@ -348,11 +428,9 @@ def update_live_chart(ax, candles, trade_history):
ax.set_xlabel("Time")
ax.set_ylabel("Price")
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
# Plot volume on secondary axis.
# Plot volume on secondary y-axis.
ax2 = ax.twinx()
volumes = [candle["volume"] for candle in candles]
# Compute bar width in days.
if len(times) > 1:
times_num = mdates.date2num(times)
bar_width = (times_num[-1] - times_num[0]) / len(times) * 0.8
@ -360,7 +438,15 @@ def update_live_chart(ax, candles, trade_history):
bar_width = 0.01
ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume")
ax2.set_ylabel("Volume")
# Plot trade markers.
for trade in trade_history:
entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"])
exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"])
in_price = trade["entry_price"]
out_price = trade["exit_price"]
ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY")
ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL")
ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue")
# Combine legends.
lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
@ -369,42 +455,11 @@ def update_live_chart(ax, candles, trade_history):
fig = ax.get_figure()
fig.autofmt_xdate()
# --- Simulation of Trades for Visualization ---
def simulate_trades(model, env, device, args):
"""
Run a simulation on the current sliding window.
If the model produces a sufficiently strong signal (based on threshold), use its action.
Otherwise, manually compute the trade by scanning for max/min prices.
"""
env.reset()
while True:
i = env.current_index
if i >= len(env.candle_window) - 1:
break
state = env.get_state(i)
current_open = env.candle_window[i]["open"]
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item()
pred_low = pred_low.item()
# If either upward potential or downward potential exceeds the threshold, use model decision.
if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
if (pred_high - current_open) >= (current_open - pred_low):
action = 2 # BUY
else:
action = 0 # SELL
_, _, _, done, _, _ = env.step(action)
else:
# No significant signal; use manual trade computation.
manual_trade(env)
if env.current_index >= len(env.candle_window) - 1:
break
# --- Backtest Environment with Sliding Window and Order Info ---
class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
self.candles_dict = candles_dict
self.candles_dict = candles_dict # full candles dict across timeframes
self.base_tf = base_tf
self.timeframes = timeframes
self.full_candles = candles_dict[base_tf]
@ -459,7 +514,7 @@ class BacktestEnvironment:
if action == 2: # BUY (open long)
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
else:
if action == 0: # SELL (close long / exit trade)
if action == 0: # SELL (close trade)
exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"]
trade = {
@ -513,7 +568,7 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
simulate_trades(model, env, device, args)
update_live_html(env.candle_window, env.trade_history, epoch+1)
# --- Live Plotting Functions (For Live Mode) ---
# --- Live Plotting (for Live Mode) ---
def live_preview_loop(candles, env):
plt.ion()
fig, ax = plt.subplots(figsize=(12, 6))
@ -529,12 +584,14 @@ def parse_args():
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005,
help="Minimum predicted move to trigger trade (used in loss; model may override with manual trade).")
help="Minimum predicted move to trigger trade (used in loss; model may override manual trades).")
parser.add_argument('--lambda_trade', type=float, default=1.0,
help="Weight for trade surrogate loss.")
help="Weight for the trade surrogate loss.")
parser.add_argument('--penalty_noaction', type=float, default=10.0,
help="Penalty if no action is taken (used in loss).")
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
parser.add_argument('--main_tf', type=str, default='1m',
help="Desired main timeframe to focus on (e.g., '1s' or '1m').")
return parser.parse_args()
def random_action():
@ -545,17 +602,24 @@ async def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
timeframes = ["1m", "5m", "15m", "1h", "1d"]
hidden_dim = 128
total_channels = NUM_TIMEFRAMES + 1 + NUM_INDICATORS
model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device)
if args.mode == 'train':
# Load cached candles.
candles_dict = load_candles_cache(CACHE_FILE)
if not candles_dict:
print("No historical candle data available for backtesting.")
return
base_tf = "1m"
# Define desired timeframes list; if available, include "1s".
default_timeframes = ["1s", "1m", "5m", "15m", "1h", "1d"]
timeframes = [tf for tf in default_timeframes if tf in candles_dict]
if args.main_tf not in timeframes:
print(f"Desired main timeframe {args.main_tf} is not available. Available: {timeframes}")
return
base_tf = args.main_tf # Set the main timeframe as the base for the environment.
hidden_dim = 128
# Total channels: number of timeframes + 1 order channel + NUM_INDICATORS.
total_channels = len(timeframes) + ORDER_CHANNELS + NUM_INDICATORS
model = TradingModel(total_channels, len(timeframes)).to(device)
if args.mode == 'train':
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
start_epoch = 0
checkpoint = None
@ -583,20 +647,19 @@ async def main():
for f in os.listdir(chk_dir):
os.remove(os.path.join(chk_dir, f))
else:
print("No valid optimizer state found in checkpoint; using fresh optimizer state.")
print("No valid optimizer state found; using fresh optimizer state.")
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
elif args.mode == 'live':
load_best_checkpoint(model)
candles_dict = load_candles_cache(CACHE_FILE)
if not candles_dict:
print("No cached candles available for live preview.")
return
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100)
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
preview_thread.start()
print("Starting live trading loop. (Using model, with manual override for HOLD actions.)")
print("Starting live trading loop. (For main_tf={} using manual override if model signal is weak.)".format(args.main_tf))
while True:
if args.main_tf == "1s":
simulate_trades_1s(env)
else:
state = env.get_state(env.current_index)
current_open = env.candle_window[env.current_index]["open"]
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
@ -606,9 +669,9 @@ async def main():
pred_low = pred_low.item()
if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
if (pred_high - current_open) >= (current_open - pred_low):
action = 2 # BUY
action = 2
else:
action = 0 # SELL
action = 0
_, _, _, done, _, _ = env.step(action)
else:
manual_trade(env)

File diff suppressed because one or more lines are too long