#!/usr/bin/env python3 import sys import asyncio if sys.platform == 'win32': asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) import os import time import json import numpy as np import torch import torch.nn as nn import torch.optim as optim from collections import deque from datetime import datetime import matplotlib.pyplot as plt import ccxt.async_support as ccxt import argparse from torch.nn import TransformerEncoder, TransformerEncoderLayer import math from dotenv import load_dotenv load_dotenv() # --- New Constants --- NUM_TIMEFRAMES = 5 # Example: ["1m", "5m", "15m", "1h", "1d"] NUM_INDICATORS = 20 # Example: 20 technical indicators FEATURES_PER_CHANNEL = 7 # HLOC + SMA_close + SMA_volume # --- Positional Encoding Module --- class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:x.size(0)] return self.dropout(x) # --- Enhanced Transformer Model --- class TradingModel(nn.Module): def __init__(self, num_channels, num_timeframes, hidden_dim=128): super().__init__() self.channel_branches = nn.ModuleList([ nn.Sequential( nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(0.1) ) for _ in range(num_channels) ]) self.timeframe_embed = nn.Embedding(num_timeframes, hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim) # Transformer encoder_layers = TransformerEncoderLayer( d_model=hidden_dim, nhead=4, dim_feedforward=512, dropout=0.1, activation='gelu', batch_first=False ) self.transformer = TransformerEncoder(encoder_layers, num_layers=2) # Attention Pooling self.attn_pool = nn.Linear(hidden_dim, 1) # Prediction Heads self.high_pred = nn.Sequential( nn.Linear(hidden_dim, hidden_dim//2), nn.GELU(), nn.Linear(hidden_dim//2, 1) ) self.low_pred = nn.Sequential( nn.Linear(hidden_dim, hidden_dim//2), nn.GELU(), nn.Linear(hidden_dim//2, 1) ) def forward(self, x, timeframe_ids): # x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL] batch_size, num_channels, _ = x.shape # Process each channel through its branch channel_outs = [] for i in range(num_channels): channel_out = self.channel_branches[i](x[:, i, :]) channel_outs.append(channel_out) # Stack and add timeframe embeddings stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden] stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden] # Add timeframe embeddings to each channel tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) stacked = stacked + tf_embeds # Apply Transformer src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device) transformer_out = self.transformer(stacked, src_mask=src_mask) # Attention Pooling over channels attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0) aggregated = (transformer_out * attn_weights).sum(dim=0) return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze() # --- Enhanced Data Processing --- # Here you need to have the helper functions get_aligned_candle_with_index and get_features_for_tf # They must be defined elsewhere in your code. class BacktestEnvironment: def __init__(self, candles_dict, base_tf, timeframes): self.candles_dict = candles_dict self.base_tf = base_tf self.timeframes = timeframes self.current_index = 0 # Initialize step pointer def reset(self): self.current_index = 0 return self.get_state(self.current_index) def get_state(self, index): """Returns state as an array of shape [num_channels, FEATURES_PER_CHANNEL].""" state_features = [] base_ts = self.candles_dict[self.base_tf][index]["timestamp"] # Timeframe channels for tf in self.timeframes: aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) features = get_features_for_tf(self.candles_dict[tf], aligned_idx) state_features.append(features) # Indicator channels (placeholder - implement your indicators) for _ in range(NUM_INDICATORS): state_features.append([0.0] * FEATURES_PER_CHANNEL) return np.array(state_features, dtype=np.float32) def step(self, action): """ Advance the environment by one step. Since this is for backtesting, action isn't used here. Returns: current candle info, reward, next state, done, actual high, actual low. """ # Dummy implementation: you would generate targets based on your backtest logic. done = (self.current_index >= len(self.candles_dict[self.base_tf]) - 2) current_candle = self.candles_dict[self.base_tf][self.current_index] # For example, take the next candle's high/low as targets next_candle = self.candles_dict[self.base_tf][self.current_index + 1] actual_high = next_candle["high"] actual_low = next_candle["low"] self.current_index += 1 next_state = self.get_state(self.current_index) return current_candle, 0.0, next_state, done, actual_high, actual_low def __len__(self): return len(self.candles_dict[self.base_tf]) # --- Enhanced Training Loop --- def train_on_historical_data(env, model, device, args): optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) for epoch in range(args.epochs): state = env.reset() total_loss = 0 model.train() while True: # Prepare batch (here batch size is 1 for simplicity) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device) # Forward pass pred_high, pred_low = model(state_tensor, timeframe_ids) # Get target values from next candle (dummy targets from environment) _, _, next_state, done, actual_high, actual_low = env.step(None) # Dummy action target_high = torch.FloatTensor([actual_high]).to(device) target_low = torch.FloatTensor([actual_low]).to(device) # Custom loss: use absolute error scaled by 2 high_loss = torch.abs(pred_high - target_high) * 2 low_loss = torch.abs(pred_low - target_low) * 2 loss = (high_loss + low_loss).mean() # Backpropagation optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total_loss += loss.item() if done: break state = next_state scheduler.step() print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}") save_checkpoint(model, epoch, total_loss) # --- Mode Handling and Argument Parsing --- def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train') parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--threshold', type=float, default=0.005) return parser.parse_args() def load_best_checkpoint(model, best_dir="models/best"): # Dummy implementation for loading the best checkpoint. # In real usage, check your saved checkpoints. print("Loading best checkpoint (dummy implementation)") # torch.load(...) can be invoked here. return def save_checkpoint(model, epoch, reward, last_dir="models/last", best_dir="models/best"): # Dummy implementation for saving checkpoints. print(f"Saving checkpoint for epoch {epoch}, reward: {reward:.4f}") # --- Main Function --- async def main(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize model model = TradingModel( num_channels=NUM_TIMEFRAMES + NUM_INDICATORS, num_timeframes=NUM_TIMEFRAMES ).to(device) if args.mode == 'train': # Load historical candle data for backtesting candles_dict = load_candles_cache("candles_cache.json") if not candles_dict: print("No historical candle data available for backtesting.") return base_tf = "1m" # Base timeframe timeframes = ["1m", "5m", "15m", "1h", "1d"] env = BacktestEnvironment(candles_dict, base_tf, timeframes) train_on_historical_data(env, model, device, args) elif args.mode == 'live': # Load model and connect to live data load_best_checkpoint(model) while True: # Process live data: fetch live candles, make predictions, execute trades print("Processing live data...") await asyncio.sleep(1) elif args.mode == 'inference': # Load model and run inference load_best_checkpoint(model) print("Running inference...") # Add your inference logic here if __name__ == "__main__": asyncio.run(main())