diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index ab50293..e6f8b70 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -16,9 +16,8 @@ import torch.nn as nn import torch.optim as optim from datetime import datetime import matplotlib.pyplot as plt -import ccxt.async_support as ccxt -from torch.nn import TransformerEncoder, TransformerEncoderLayer import math +from torch.nn import TransformerEncoder, TransformerEncoderLayer from dotenv import load_dotenv load_dotenv() @@ -30,9 +29,9 @@ os.makedirs(BEST_DIR, exist_ok=True) CACHE_FILE = "candles_cache.json" # --- Constants --- -NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"] -NUM_INDICATORS = 20 # e.g., 20 technical indicators -FEATURES_PER_CHANNEL = 7 # e.g., H, L, O, C, Volume, SMA_close, SMA_volume +NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"] +NUM_INDICATORS = 20 # e.g., 20 technical indicators +FEATURES_PER_CHANNEL = 7 # e.g., [open, high, low, close, volume, sma_close, sma_volume] # --- Positional Encoding Module --- class PositionalEncoding(nn.Module): @@ -53,7 +52,7 @@ class PositionalEncoding(nn.Module): class TradingModel(nn.Module): def __init__(self, num_channels, num_timeframes, hidden_dim=128): super().__init__() - # Create one branch per channel (each channel input has FEATURES_PER_CHANNEL features) + # One branch per channel self.channel_branches = nn.ModuleList([ nn.Sequential( nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), @@ -62,7 +61,6 @@ class TradingModel(nn.Module): nn.Dropout(0.1) ) for _ in range(num_channels) ]) - # Embedding for channels 0..num_channels-1. self.timeframe_embed = nn.Embedding(num_channels, hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim) encoder_layers = TransformerEncoderLayer( @@ -82,15 +80,14 @@ class TradingModel(nn.Module): nn.Linear(hidden_dim // 2, 1) ) def forward(self, x, timeframe_ids): - # x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL] + # x: [batch_size, num_channels, FEATURES_PER_CHANNEL] batch_size, num_channels, _ = x.shape channel_outs = [] for i in range(num_channels): channel_out = self.channel_branches[i](x[:, i, :]) channel_outs.append(channel_out) - stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden] - stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden] - # Add embedding for each channel. + stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden] + stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden] tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) stacked = stacked + tf_embeds src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device) @@ -103,12 +100,12 @@ class TradingModel(nn.Module): def compute_sma(candles_list, index, period=10): start = max(0, index - period + 1) values = [candle["close"] for candle in candles_list[start:index+1]] - return sum(values) / len(values) if values else 0.0 + return sum(values)/len(values) if values else 0.0 def compute_sma_volume(candles_list, index, period=10): start = max(0, index - period + 1) values = [candle["volume"] for candle in candles_list[start:index+1]] - return sum(values) / len(values) if values else 0.0 + return sum(values)/len(values) if values else 0.0 def get_aligned_candle_with_index(candles_list, target_ts): best_idx = 0 @@ -123,7 +120,7 @@ def get_features_for_tf(candles_list, index, period=10): candle = candles_list[index] f_open = candle["open"] f_high = candle["high"] - f_low = candle["low"] + f_low = candle["low"] f_close = candle["close"] f_volume = candle["volume"] sma_close = compute_sma(candles_list, index, period) @@ -154,7 +151,7 @@ def maintain_checkpoint_directory(directory, max_files=10): if len(files) > max_files: full_paths = [os.path.join(directory, f) for f in files] full_paths.sort(key=lambda x: os.path.getmtime(x)) - for f in full_paths[: len(files) - max_files]: + for f in full_paths[:len(files)-max_files]: os.remove(f) def get_best_models(directory): @@ -162,7 +159,6 @@ def get_best_models(directory): for file in os.listdir(directory): parts = file.split("_") try: - # parts[1] is the recorded loss loss = float(parts[1]) best_files.append((loss, file)) except Exception: @@ -174,10 +170,10 @@ def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=B last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_path = os.path.join(last_dir, last_filename) torch.save({ - "epoch": epoch, - "loss": loss, - "model_state_dict": model.state_dict(), - "optimizer_state_dict": optimizer.state_dict() + "epoch": epoch, + "loss": loss, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict() }, last_path) maintain_checkpoint_directory(last_dir, max_files=10) best_models = get_best_models(best_dir) @@ -215,8 +211,8 @@ def load_best_checkpoint(model, best_dir=BEST_DIR): # --- Live HTML Chart Update --- def update_live_html(candles, trade_history, epoch): """ - Generate a chart image with buy/sell markers and a dotted line between open/close positions, - then embed it in a simple HTML page that auto-refreshes every 10 seconds. + Generate a chart image with buy/sell markers and dotted lines between entry and exit, + then embed it in an auto-refreshing HTML page. """ from io import BytesIO import base64 @@ -266,10 +262,10 @@ def update_live_html(candles, trade_history, epoch): f.write(html_content) print("Updated live_chart.html.") -# --- Chart Drawing Helpers (used by both live preview and HTML update) --- +# --- Chart Drawing Helpers --- def update_live_chart(ax, candles, trade_history): """ - Plot the chart with close price, buy/sell markers, and dotted lines joining entry/exit. + Draw the price chart with close prices and mark BUY (green) and SELL (red) actions. """ ax.clear() close_prices = [candle["close"] for candle in candles] @@ -298,39 +294,44 @@ def update_live_chart(ax, candles, trade_history): ax.legend() ax.grid(True) -# --- Forced Action & Optimal Hint Helpers --- -def get_forced_action(env): +# --- Simulation of Trades for Visualization --- +def simulate_trades(model, env, device, args): """ - When simulating streaming data, we force a trade at strategic moments: - - At the very first step: force BUY. - - At the penultimate step: if a position is open, force SELL. - - Otherwise, default to HOLD. - (The environment will also apply a penalty if the chosen action does not match the optimal hint.) + Run a complete simulation on the current sliding window using a decision rule based on model outputs. + This simulation (which updates env.trade_history) is used only for visualization. """ - total = len(env) - if env.current_index == 0: - return 2 # BUY - elif env.current_index >= total - 2: - if env.position is not None: - return 0 # SELL + env.reset() # resets the sliding window and index + while True: + i = env.current_index + state = env.get_state(i) + current_open = env.candle_window[i]["open"] + state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) + timeframe_ids = torch.arange(state.shape[0]).to(device) + pred_high, pred_low = model(state_tensor, timeframe_ids) + pred_high = pred_high.item() + pred_low = pred_low.item() + # Decision rule: if upward move larger than downward and above threshold, BUY; if downward is larger, SELL; else HOLD. + if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold: + action = 2 # BUY + elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold: + action = 0 # SELL else: - return 1 # HOLD - else: - return 1 # HOLD + action = 1 # HOLD + _, _, _, done, _, _ = env.step(action) + if done: + break -# --- Backtest Environment with Sliding Window and Hints --- +# --- Backtest Environment with Sliding Window --- class BacktestEnvironment: def __init__(self, candles_dict, base_tf, timeframes, window_size=None): - self.candles_dict = candles_dict # full dictionary of timeframe candles + self.candles_dict = candles_dict # full candles dict for all timeframes self.base_tf = base_tf self.timeframes = timeframes - # Use maximum allowed candles for the base timeframe. self.full_candles = candles_dict[base_tf] - # Determine sliding window size: if window_size is None: window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles) self.window_size = window_size - self.hint_penalty = 0.001 # Penalty coefficient (multiplied by open price) + self.hint_penalty = 0.001 # not used in the revised loss below self.reset() def reset(self): @@ -346,52 +347,26 @@ class BacktestEnvironment: return self.window_size def get_state(self, index): - """ - Build state features by taking the candle at the current index for the base timeframe - (from the sliding window) and aligning candles for other timeframes. - Then append zeros for technical indicators. - """ state_features = [] base_ts = self.candle_window[index]["timestamp"] for tf in self.timeframes: if tf == self.base_tf: - # For base timeframe, use the sliding window candle. candle = self.candle_window[index] - features = get_features_for_tf([candle], 0) # List of one element + features = get_features_for_tf([candle], 0) else: aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) features = get_features_for_tf(self.candles_dict[tf], aligned_idx) state_features.append(features) for _ in range(NUM_INDICATORS): - state_features.append([0.0] * FEATURES_PER_CHANNEL) + state_features.append([0.0]*FEATURES_PER_CHANNEL) return np.array(state_features, dtype=np.float32) - def compute_optimal_hint(self, horizon=10, threshold=0.005): - """ - Using a lookahead window from the sliding window (future candles) - determine an optimal action hint: - 2: BUY if price is expected to rise at least by threshold. - 0: SELL if expected to drop by threshold. - 1: HOLD otherwise. - """ - base = self.candle_window - if self.current_index >= len(base) - 1: - return 1 # Hold - current_candle = base[self.current_index] - open_price = current_candle["open"] - future_slice = base[self.current_index+1: min(self.current_index+1+horizon, len(base))] - if not future_slice: - return 1 - max_future = max(candle["high"] for candle in future_slice) - min_future = min(candle["low"] for candle in future_slice) - if (max_future - open_price) / open_price >= threshold: - return 2 # BUY - elif (open_price - min_future) / open_price >= threshold: - return 0 # SELL - else: - return 1 # HOLD - def step(self, action): + """ + Discrete simulation step. + - Action: 0 (SELL), 1 (HOLD), 2 (BUY). + - Trades are recorded when a BUY is followed by a SELL. + """ base = self.candle_window if self.current_index >= len(base) - 1: current_state = self.get_state(self.current_index) @@ -403,13 +378,12 @@ class BacktestEnvironment: next_candle = base[next_index] reward = 0.0 - # Trade logic (0: SELL, 1: HOLD, 2: BUY) + # Simple trading logic (only one position allowed at a time) if self.position is None: - if action == 2: # BUY: enter at next candle's open. - entry_price = next_candle["open"] - self.position = {"entry_price": entry_price, "entry_index": self.current_index} + if action == 2: # BUY signal: enter at next open. + self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} else: - if action == 0: # SELL: exit at next candle's open. + if action == 0: # SELL signal: exit at next open. exit_price = next_candle["open"] reward = exit_price - self.position["entry_price"] trade = { @@ -426,49 +400,49 @@ class BacktestEnvironment: done = (self.current_index >= len(base) - 1) actual_high = next_candle["high"] actual_low = next_candle["low"] - - # Compute optimal action hint and apply a penalty if action deviates. - optimal_hint = self.compute_optimal_hint(horizon=10, threshold=0.005) - if action != optimal_hint: - reward -= self.hint_penalty * next_candle["open"] - return current_state, reward, next_state, done, actual_high, actual_low # --- Enhanced Training Loop --- def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): + # Weighting factor for trade surrogate loss. + lambda_trade = 1.0 for epoch in range(start_epoch, args.epochs): - state = env.reset() - total_loss = 0.0 - model.train() - while True: - # Use forced-action policy for trading (guaranteeing at least one trade per episode) - action = get_forced_action(env) + # Reset sliding window for each epoch. + env.reset() + loss_accum = 0.0 + steps = len(env) - 1 # we use pairs of consecutive candles + for i in range(steps): + state = env.get_state(i) + current_open = env.candle_window[i]["open"] + # Next candle's actual values serve as targets. + actual_high = env.candle_window[i+1]["high"] + actual_low = env.candle_window[i+1]["low"] state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) - # Use our forced action in the environment step. - _, reward, next_state, done, actual_high, actual_low = env.step(action) - target_high = torch.FloatTensor([actual_high]).to(device) - target_low = torch.FloatTensor([actual_low]).to(device) - high_loss = torch.abs(pred_high - target_high) * 2 - low_loss = torch.abs(pred_low - target_low) * 2 - loss = (high_loss + low_loss).mean() + # Compute prediction loss (L1) + L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \ + torch.abs(pred_low - torch.tensor(actual_low, device=device)) + # Compute surrogate profit (differentiable estimate) + profit_buy = pred_high - current_open # potential long gain + profit_sell = current_open - pred_low # potential short gain + # Here we reward a higher potential move by subtracting it. + L_trade = - torch.max(profit_buy, profit_sell) + loss = L_pred + lambda_trade * L_trade optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() - total_loss += loss.item() - if done: - break - state = next_state + loss_accum += loss.item() scheduler.step() - epoch_loss = total_loss / len(env) + epoch_loss = loss_accum / steps print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}") - save_checkpoint(model, optimizer, epoch, total_loss) - # Update live HTML chart to display the current sliding window + save_checkpoint(model, optimizer, epoch, loss_accum) + # Update the trade simulation (for visualization) using the current model on the same window. + simulate_trades(model, env, device, args) update_live_html(env.candle_window, env.trade_history, epoch+1) -# --- Live Plotting Functions (For live mode) --- +# --- Live Plotting Functions (For Live Mode) --- def live_preview_loop(candles, env): plt.ion() fig, ax = plt.subplots(figsize=(12, 6)) @@ -480,21 +454,17 @@ def live_preview_loop(candles, env): # --- Argument Parsing --- def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--mode', choices=['train','live','inference'], default='train') + parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train') parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--lr', type=float, default=3e-4) - parser.add_argument('--threshold', type=float, default=0.005) - # If set, training starts from scratch (ignoring saved checkpoints) - parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch.') + parser.add_argument('--threshold', type=float, default=0.005, help="Minimum predicted move to trigger trade.") + parser.add_argument('--lambda_trade', type=float, default=1.0, help="Weight for the trade surrogate loss.") + parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.") return parser.parse_args() -def random_action(): - return random.randint(0, 2) - # --- Main Function --- async def main(): args = parse_args() - # Use GPU if available; else CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) timeframes = ["1m", "5m", "15m", "1h", "1d"] @@ -508,9 +478,8 @@ async def main(): print("No historical candle data available for backtesting.") return base_tf = "1m" - # Create the environment with a sliding window (simulate streaming data) + # Use a sliding window of up to 100 candles (if available) env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100) - start_epoch = 0 checkpoint = None if not args.start_fresh: @@ -522,7 +491,6 @@ async def main(): print("No checkpoint found. Starting training from scratch.") else: print("Starting training from scratch as requested.") - optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch) if checkpoint is not None: @@ -543,18 +511,31 @@ async def main(): env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes, window_size=100) preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True) preview_thread.start() - print("Starting live trading loop. (Using forced-action policy for simulation.)") + print("Starting live trading loop. (Using model-based decision rule.)") while True: - action = get_forced_action(env) - state, reward, next_state, done, _, _ = env.step(action) + # In live mode, we use the simulation decision rule. + state = env.get_state(env.current_index) + current_open = env.candle_window[env.current_index]["open"] + state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) + timeframe_ids = torch.arange(state.shape[0]).to(device) + pred_high, pred_low = model(state_tensor, timeframe_ids) + pred_high = pred_high.item() + pred_low = pred_low.item() + if (pred_high - current_open) >= (current_open - pred_low) and (pred_high - current_open) > args.threshold: + action = 2 + elif (current_open - pred_low) > (pred_high - current_open) and (current_open - pred_low) > args.threshold: + action = 0 + else: + action = 1 + _, _, _, done, _, _ = env.step(action) if done: - print("Reached end of simulation window, resetting environment.") - state = env.reset() + print("Reached end of simulation window; resetting environment.") + env.reset() await asyncio.sleep(1) elif args.mode == 'inference': load_best_checkpoint(model) print("Running inference...") - # Apply a similar (or learned) policy as needed. + # Inference logic can use a similar decision rule as in live mode. else: print("Invalid mode specified.") diff --git a/crypto/brian/live_chart.html b/crypto/brian/live_chart.html index 0512b35..da26419 100644 --- a/crypto/brian/live_chart.html +++ b/crypto/brian/live_chart.html @@ -4,7 +4,7 @@ - Live Trading Chart - Epoch 14 + Live Trading Chart - Epoch 100