diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index f910b96..7c04c5d 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -26,12 +26,12 @@ load_dotenv() def convert_timestamp(ts): """ Safely converts a timestamp to a datetime object. - If the timestamp is abnormally high (i.e. in milliseconds), + If the timestamp is abnormally high (e.g. in milliseconds), it is divided by 1000. """ ts = float(ts) if ts > 1e10: # Likely in milliseconds - ts = ts / 1000.0 + ts /= 1000.0 return datetime.fromtimestamp(ts) # --- Directories --- @@ -42,10 +42,10 @@ os.makedirs(BEST_DIR, exist_ok=True) CACHE_FILE = "candles_cache.json" # --- Constants --- -NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"] +NUM_TIMEFRAMES = 6 # e.g., ["1s", "1m", "5m", "15m", "1h", "1d"] NUM_INDICATORS = 20 # e.g., 20 technical indicators -FEATURES_PER_CHANNEL = 7 # Each channel input will have 7 features. -ORDER_CHANNELS = 1 # One extra channel for order info. +FEATURES_PER_CHANNEL = 7 # Each channel has 7 features. +ORDER_CHANNELS = 1 # One extra channel for order information. # --- Positional Encoding Module --- class PositionalEncoding(nn.Module): @@ -66,6 +66,7 @@ class PositionalEncoding(nn.Module): class TradingModel(nn.Module): def __init__(self, num_channels, num_timeframes, hidden_dim=128): super().__init__() + # One branch per channel. self.channel_branches = nn.ModuleList([ nn.Sequential( nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), @@ -78,7 +79,7 @@ class TradingModel(nn.Module): self.pos_encoder = PositionalEncoding(hidden_dim) encoder_layers = TransformerEncoderLayer( d_model=hidden_dim, nhead=4, dim_feedforward=512, - dropout=0.1, activation='gelu', batch_first=True # Use batch_first to avoid nested tensor warning. + dropout=0.1, activation='gelu', batch_first=True # avoid nested tensor warning ) self.transformer = TransformerEncoder(encoder_layers, num_layers=2) self.attn_pool = nn.Linear(hidden_dim, 1) @@ -100,8 +101,8 @@ class TradingModel(nn.Module): channel_out = self.channel_branches[i](x[:, i, :]) channel_outs.append(channel_out) stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden] - tf_embeds = self.timeframe_embed(timeframe_ids) # shape: [num_channels, hidden] - stacked = stacked + tf_embeds.unsqueeze(0) # broadcast along batch dimension. + tf_embeds = self.timeframe_embed(timeframe_ids) # [num_channels, hidden] + stacked = stacked + tf_embeds.unsqueeze(0) # add embeddings (broadcast along batch) transformer_out = self.transformer(stacked) attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1) aggregated = (transformer_out * attn_weights).sum(dim=1) @@ -129,13 +130,13 @@ def get_aligned_candle_with_index(candles_list, target_ts): def get_features_for_tf(candles_list, index, period=10): candle = candles_list[index] - f_open = candle["open"] - f_high = candle["high"] - f_low = candle["low"] - f_close = candle["close"] - f_volume = candle["volume"] + f_open = candle["open"] + f_high = candle["high"] + f_low = candle["low"] + f_close = candle["close"] + f_volume = candle["volume"] sma_close = compute_sma(candles_list, index, period) - sma_volume = compute_sma_volume(candles_list, index, period) + sma_volume= compute_sma_volume(candles_list, index, period) return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume] # --- Caching & Checkpoint Functions --- @@ -230,10 +231,10 @@ def load_best_checkpoint(model, best_dir=BEST_DIR): # --- Function for Manual Trade Override --- def manual_trade(env): """ - When no sufficient action is taken by the model, manually decide the trade. - Find the maximum high and minimum low in the remaining window. - If maximum occurs before minimum, we short; otherwise we long. - The trade is closed at the candle where the chosen extreme occurs. + When no sufficient action is taken by the model, use a fallback: + Scan the remaining window for the global maximum and minimum. + If the maximum occurs before the minimum, simulate a short trade; + otherwise simulate a long trade. Closes the trade at the candle where the chosen extreme occurs. """ current_index = env.current_index if current_index >= len(env.candle_window) - 1: @@ -252,7 +253,6 @@ def manual_trade(env): if low_j < min_val: min_val = low_j i_min = j - # If maximum occurs before minimum, we interpret that as short (price will drop). if i_max < i_min: entry_price = env.candle_window[current_index]["open"] exit_price = env.candle_window[i_min]["open"] @@ -278,16 +278,97 @@ def manual_trade(env): env.trade_history.append(trade) env.current_index = trade["exit_index"] -# --- Live HTML Chart Update --- +# --- Simulation for 1s Data Using Local Extrema --- +def simulate_trades_1s(env): + """ + When the main timeframe is 1s, scan the entire remaining window to detect local extrema. + If at least two extrema are found, pair consecutive extrema as trades. + If none (or too few) are found, fallback to manual_trade. + """ + n = len(env.candle_window) + extrema = [] + for i in range(env.current_index, n): + # Add first and last points. + if i == env.current_index or i == n-1: + extrema.append(i) + else: + prev = env.candle_window[i-1]["close"] + curr = env.candle_window[i]["close"] + nex = env.candle_window[i+1]["close"] + # A valley or a peak. + if curr < prev and curr < nex: + extrema.append(i) + elif curr > prev and curr > nex: + extrema.append(i) + if len(extrema) < 2: + manual_trade(env) + return + # Process consecutive extrema into trades. + for j in range(len(extrema)-1): + entry_idx = extrema[j] + exit_idx = extrema[j+1] + entry_price = env.candle_window[entry_idx]["open"] + exit_price = env.candle_window[exit_idx]["open"] + # If the entry candle’s close is lower than exit candle’s close, this is a long trade. + if env.candle_window[entry_idx]["close"] < env.candle_window[exit_idx]["close"]: + reward = exit_price - entry_price + else: + reward = entry_price - exit_price + trade = { + "entry_index": entry_idx, + "entry_price": entry_price, + "exit_index": exit_idx, + "exit_price": exit_price, + "pnl": reward + } + env.trade_history.append(trade) + env.current_index = n - 1 + +# --- General Simulation of Trades --- +def simulate_trades(model, env, device, args): + """ + Simulate trades over the current sliding window. + If the main timeframe is 1s, use local extrema detection. + Otherwise, check if the model's predicted potentials exceed the threshold. + - If so, execute the model decision. + - Otherwise, call the manual_trade override. + """ + if args.main_tf == "1s": + simulate_trades_1s(env) + return + env.reset() + while True: + i = env.current_index + if i >= len(env.candle_window) - 1: + break + state = env.get_state(i) + current_open = env.candle_window[i]["open"] + state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) + timeframe_ids = torch.arange(state.shape[0]).to(device) + pred_high, pred_low = model(state_tensor, timeframe_ids) + pred_high = pred_high.item() + pred_low = pred_low.item() + if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold: + if (pred_high - current_open) >= (current_open - pred_low): + action = 2 # BUY + else: + action = 0 # SELL + _, _, _, done, _, _ = env.step(action) + else: + manual_trade(env) + if env.current_index >= len(env.candle_window) - 1: + break + +# --- Live HTML Chart Update (with Volume) --- def update_live_html(candles, trade_history, epoch): """ - Generate a chart image with actual timestamps on the x-axis and cumulative epoch PnL. - The chart now also plots volume as a bar chart on a secondary y-axis. - The HTML page auto-refreshes every 10 seconds. + Generate an HTML page with a live chart. + The chart displays price (line) and volume (bar chart on a secondary y-axis), + and includes buy/sell markers with dotted lines connecting entries and exits. + The page auto-refreshes every 10 seconds. """ from io import BytesIO import base64 - fig, ax = plt.subplots(figsize=(12, 6)) update_live_chart(ax, candles, trade_history) epoch_pnl = sum(trade["pnl"] for trade in trade_history) @@ -334,12 +415,11 @@ def update_live_html(candles, trade_history, epoch): f.write(html_content) print("Updated live_chart.html.") -# --- Chart Drawing Helpers --- +# --- Chart Drawing Helpers (with Volume) --- def update_live_chart(ax, candles, trade_history): """ - Plot the price chart with proper timestamp conversion. - Mark BUY (green) and SELL (red) actions (with dotted lines between), - and plot volume as a bar chart on a secondary y-axis. + Plot the price chart with actual timestamps and volume on a secondary y-axis. + Mark BUY (green) and SELL (red) points and connect them with dotted lines. """ ax.clear() times = [convert_timestamp(candle["timestamp"]) for candle in candles] @@ -348,11 +428,9 @@ def update_live_chart(ax, candles, trade_history): ax.set_xlabel("Time") ax.set_ylabel("Price") ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) - - # Plot volume on secondary axis. + # Plot volume on secondary y-axis. ax2 = ax.twinx() volumes = [candle["volume"] for candle in candles] - # Compute bar width in days. if len(times) > 1: times_num = mdates.date2num(times) bar_width = (times_num[-1] - times_num[0]) / len(times) * 0.8 @@ -360,7 +438,15 @@ def update_live_chart(ax, candles, trade_history): bar_width = 0.01 ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume") ax2.set_ylabel("Volume") - + # Plot trade markers. + for trade in trade_history: + entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"]) + exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"]) + in_price = trade["entry_price"] + out_price = trade["exit_price"] + ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") + ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL") + ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue") # Combine legends. lines, labels = ax.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() @@ -369,42 +455,11 @@ def update_live_chart(ax, candles, trade_history): fig = ax.get_figure() fig.autofmt_xdate() -# --- Simulation of Trades for Visualization --- -def simulate_trades(model, env, device, args): - """ - Run a simulation on the current sliding window. - If the model produces a sufficiently strong signal (based on threshold), use its action. - Otherwise, manually compute the trade by scanning for max/min prices. - """ - env.reset() - while True: - i = env.current_index - if i >= len(env.candle_window) - 1: - break - state = env.get_state(i) - current_open = env.candle_window[i]["open"] - state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) - timeframe_ids = torch.arange(state.shape[0]).to(device) - pred_high, pred_low = model(state_tensor, timeframe_ids) - pred_high = pred_high.item() - pred_low = pred_low.item() - # If either upward potential or downward potential exceeds the threshold, use model decision. - if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold: - if (pred_high - current_open) >= (current_open - pred_low): - action = 2 # BUY - else: - action = 0 # SELL - _, _, _, done, _, _ = env.step(action) - else: - # No significant signal; use manual trade computation. - manual_trade(env) - if env.current_index >= len(env.candle_window) - 1: - break # --- Backtest Environment with Sliding Window and Order Info --- class BacktestEnvironment: def __init__(self, candles_dict, base_tf, timeframes, window_size=None): - self.candles_dict = candles_dict + self.candles_dict = candles_dict # full candles dict across timeframes self.base_tf = base_tf self.timeframes = timeframes self.full_candles = candles_dict[base_tf] @@ -459,7 +514,7 @@ class BacktestEnvironment: if action == 2: # BUY (open long) self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} else: - if action == 0: # SELL (close long / exit trade) + if action == 0: # SELL (close trade) exit_price = next_candle["open"] reward = exit_price - self.position["entry_price"] trade = { @@ -513,7 +568,7 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s simulate_trades(model, env, device, args) update_live_html(env.candle_window, env.trade_history, epoch+1) -# --- Live Plotting Functions (For Live Mode) --- +# --- Live Plotting (for Live Mode) --- def live_preview_loop(candles, env): plt.ion() fig, ax = plt.subplots(figsize=(12, 6)) @@ -528,13 +583,15 @@ def parse_args(): parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train') parser.add_argument('--epochs', type=int, default=1000) parser.add_argument('--lr', type=float, default=3e-4) - parser.add_argument('--threshold', type=float, default=0.005, - help="Minimum predicted move to trigger trade (used in loss; model may override with manual trade).") + parser.add_argument('--threshold', type=float, default=0.005, + help="Minimum predicted move to trigger trade (used in loss; model may override manual trades).") parser.add_argument('--lambda_trade', type=float, default=1.0, - help="Weight for trade surrogate loss.") + help="Weight for the trade surrogate loss.") parser.add_argument('--penalty_noaction', type=float, default=10.0, help="Penalty if no action is taken (used in loss).") parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.") + parser.add_argument('--main_tf', type=str, default='1m', + help="Desired main timeframe to focus on (e.g., '1s' or '1m').") return parser.parse_args() def random_action(): @@ -545,17 +602,24 @@ async def main(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) - timeframes = ["1m", "5m", "15m", "1h", "1d"] + # Load cached candles. + candles_dict = load_candles_cache(CACHE_FILE) + if not candles_dict: + print("No historical candle data available for backtesting.") + return + # Define desired timeframes list; if available, include "1s". + default_timeframes = ["1s", "1m", "5m", "15m", "1h", "1d"] + timeframes = [tf for tf in default_timeframes if tf in candles_dict] + if args.main_tf not in timeframes: + print(f"Desired main timeframe {args.main_tf} is not available. Available: {timeframes}") + return + base_tf = args.main_tf # Set the main timeframe as the base for the environment. hidden_dim = 128 - total_channels = NUM_TIMEFRAMES + 1 + NUM_INDICATORS - model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device) - + # Total channels: number of timeframes + 1 order channel + NUM_INDICATORS. + total_channels = len(timeframes) + ORDER_CHANNELS + NUM_INDICATORS + model = TradingModel(total_channels, len(timeframes)).to(device) + if args.mode == 'train': - candles_dict = load_candles_cache(CACHE_FILE) - if not candles_dict: - print("No historical candle data available for backtesting.") - return - base_tf = "1m" env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100) start_epoch = 0 checkpoint = None @@ -583,35 +647,34 @@ async def main(): for f in os.listdir(chk_dir): os.remove(os.path.join(chk_dir, f)) else: - print("No valid optimizer state found in checkpoint; using fresh optimizer state.") + print("No valid optimizer state found; using fresh optimizer state.") train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler) elif args.mode == 'live': load_best_checkpoint(model) - candles_dict = load_candles_cache(CACHE_FILE) - if not candles_dict: - print("No cached candles available for live preview.") - return - env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100) + env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100) preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True) preview_thread.start() - print("Starting live trading loop. (Using model, with manual override for HOLD actions.)") + print("Starting live trading loop. (For main_tf={} using manual override if model signal is weak.)".format(args.main_tf)) while True: - state = env.get_state(env.current_index) - current_open = env.candle_window[env.current_index]["open"] - state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) - timeframe_ids = torch.arange(state.shape[0]).to(device) - pred_high, pred_low = model(state_tensor, timeframe_ids) - pred_high = pred_high.item() - pred_low = pred_low.item() - if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold: - if (pred_high - current_open) >= (current_open - pred_low): - action = 2 # BUY - else: - action = 0 # SELL - _, _, _, done, _, _ = env.step(action) + if args.main_tf == "1s": + simulate_trades_1s(env) else: - manual_trade(env) + state = env.get_state(env.current_index) + current_open = env.candle_window[env.current_index]["open"] + state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) + timeframe_ids = torch.arange(state.shape[0]).to(device) + pred_high, pred_low = model(state_tensor, timeframe_ids) + pred_high = pred_high.item() + pred_low = pred_low.item() + if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold: + if (pred_high - current_open) >= (current_open - pred_low): + action = 2 + else: + action = 0 + _, _, _, done, _, _ = env.step(action) + else: + manual_trade(env) if env.current_index >= len(env.candle_window)-1: print("Reached end of simulation window; resetting environment.") env.reset() diff --git a/crypto/brian/live_chart.html b/crypto/brian/live_chart.html index f9789a2..d9e07f6 100644 --- a/crypto/brian/live_chart.html +++ b/crypto/brian/live_chart.html @@ -26,7 +26,7 @@

Live Trading Chart - Epoch 100 | PnL: 0.00

- Live Chart + Live Chart