diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index 5e60a3f..f910b96 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -44,10 +44,8 @@ CACHE_FILE = "candles_cache.json" # --- Constants --- NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"] NUM_INDICATORS = 20 # e.g., 20 technical indicators -# Each channel input will have 7 features. -FEATURES_PER_CHANNEL = 7 -# We add one extra channel for order information. -ORDER_CHANNELS = 1 +FEATURES_PER_CHANNEL = 7 # Each channel input will have 7 features. +ORDER_CHANNELS = 1 # One extra channel for order info. # --- Positional Encoding Module --- class PositionalEncoding(nn.Module): @@ -68,7 +66,6 @@ class PositionalEncoding(nn.Module): class TradingModel(nn.Module): def __init__(self, num_channels, num_timeframes, hidden_dim=128): super().__init__() - # Create one branch per channel. self.channel_branches = nn.ModuleList([ nn.Sequential( nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), @@ -77,13 +74,11 @@ class TradingModel(nn.Module): nn.Dropout(0.1) ) for _ in range(num_channels) ]) - # Embedding for channels 0..num_channels-1. self.timeframe_embed = nn.Embedding(num_channels, hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim) - # Set batch_first=True to avoid the nested tensor warning. encoder_layers = TransformerEncoderLayer( d_model=hidden_dim, nhead=4, dim_feedforward=512, - dropout=0.1, activation='gelu', batch_first=True + dropout=0.1, activation='gelu', batch_first=True # Use batch_first to avoid nested tensor warning. ) self.transformer = TransformerEncoder(encoder_layers, num_layers=2) self.attn_pool = nn.Linear(hidden_dim, 1) @@ -104,11 +99,9 @@ class TradingModel(nn.Module): for i in range(num_channels): channel_out = self.channel_branches[i](x[:, i, :]) channel_outs.append(channel_out) - stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden] - # With batch_first=True, the expected input is [batch, seq_len, hidden] + stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden] tf_embeds = self.timeframe_embed(timeframe_ids) # shape: [num_channels, hidden] - # Expand tf_embeds to match the batch dimension. - stacked = stacked + tf_embeds.unsqueeze(0) + stacked = stacked + tf_embeds.unsqueeze(0) # broadcast along batch dimension. transformer_out = self.transformer(stacked) attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1) aggregated = (transformer_out * attn_weights).sum(dim=1) @@ -225,32 +218,78 @@ def load_best_checkpoint(model, best_dir=BEST_DIR): checkpoint = torch.load(path) old_state = checkpoint["model_state_dict"] new_state = model.state_dict() - - # Fix the size mismatch for timeframe_embed.weight. if "timeframe_embed.weight" in old_state: old_embed = old_state["timeframe_embed.weight"] new_embed = new_state["timeframe_embed.weight"] if old_embed.shape[0] < new_embed.shape[0]: new_embed[:old_embed.shape[0]] = old_embed old_state["timeframe_embed.weight"] = new_embed - - # For channel_branches, missing keys are handled by strict=False. model.load_state_dict(old_state, strict=False) return checkpoint +# --- Function for Manual Trade Override --- +def manual_trade(env): + """ + When no sufficient action is taken by the model, manually decide the trade. + Find the maximum high and minimum low in the remaining window. + If maximum occurs before minimum, we short; otherwise we long. + The trade is closed at the candle where the chosen extreme occurs. + """ + current_index = env.current_index + if current_index >= len(env.candle_window) - 1: + env.current_index = len(env.candle_window) - 1 + return + max_val = -float('inf') + min_val = float('inf') + i_max = current_index + i_min = current_index + for j in range(current_index + 1, len(env.candle_window)): + high_j = env.candle_window[j]["high"] + low_j = env.candle_window[j]["low"] + if high_j > max_val: + max_val = high_j + i_max = j + if low_j < min_val: + min_val = low_j + i_min = j + # If maximum occurs before minimum, we interpret that as short (price will drop). + if i_max < i_min: + entry_price = env.candle_window[current_index]["open"] + exit_price = env.candle_window[i_min]["open"] + reward = entry_price - exit_price + trade = { + "entry_index": current_index, + "entry_price": entry_price, + "exit_index": i_min, + "exit_price": exit_price, + "pnl": reward + } + else: + entry_price = env.candle_window[current_index]["open"] + exit_price = env.candle_window[i_max]["open"] + reward = exit_price - entry_price + trade = { + "entry_index": current_index, + "entry_price": entry_price, + "exit_index": i_max, + "exit_price": exit_price, + "pnl": reward + } + env.trade_history.append(trade) + env.current_index = trade["exit_index"] + # --- Live HTML Chart Update --- def update_live_html(candles, trade_history, epoch): """ - Generate a chart image that uses actual timestamps on the x-axis - and shows a cumulative epoch PnL. The chart (with buy/sell markers and dotted lines) - is embedded in an HTML page that auto-refreshes every 1 seconds. + Generate a chart image with actual timestamps on the x-axis and cumulative epoch PnL. + The chart now also plots volume as a bar chart on a secondary y-axis. + The HTML page auto-refreshes every 10 seconds. """ from io import BytesIO import base64 fig, ax = plt.subplots(figsize=(12, 6)) update_live_chart(ax, candles, trade_history) - # Compute cumulative epoch PnL. epoch_pnl = sum(trade["pnl"] for trade in trade_history) ax.set_title(f"Live Trading Chart - Epoch {epoch} | PnL: {epoch_pnl:.2f}") buf = BytesIO() @@ -263,7 +302,7 @@ def update_live_html(candles, trade_history, epoch): - + Live Trading Chart - Epoch {epoch}