gogo2/crypto/brian/index-deep-new.py
Dobromir Popov 2ec75e66cb compile fix
2025-02-04 18:00:58 +02:00

265 lines
10 KiB
Python

#!/usr/bin/env python3
import sys
import asyncio
if sys.platform == 'win32':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
import os
import time
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
from datetime import datetime
import matplotlib.pyplot as plt
import ccxt.async_support as ccxt
import argparse
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math
from dotenv import load_dotenv
load_dotenv()
# --- New Constants ---
NUM_TIMEFRAMES = 5 # Example: ["1m", "5m", "15m", "1h", "1d"]
NUM_INDICATORS = 20 # Example: 20 technical indicators
FEATURES_PER_CHANNEL = 7 # HLOC + SMA_close + SMA_volume
# --- Positional Encoding Module ---
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0)]
return self.dropout(x)
# --- Enhanced Transformer Model ---
class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__()
self.channel_branches = nn.ModuleList([
nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(0.1)
) for _ in range(num_channels)
])
self.timeframe_embed = nn.Embedding(num_timeframes, hidden_dim)
self.pos_encoder = PositionalEncoding(hidden_dim)
# Transformer
encoder_layers = TransformerEncoderLayer(
d_model=hidden_dim, nhead=4, dim_feedforward=512,
dropout=0.1, activation='gelu', batch_first=False
)
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
# Attention Pooling
self.attn_pool = nn.Linear(hidden_dim, 1)
# Prediction Heads
self.high_pred = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim//2),
nn.GELU(),
nn.Linear(hidden_dim//2, 1)
)
self.low_pred = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim//2),
nn.GELU(),
nn.Linear(hidden_dim//2, 1)
)
def forward(self, x, timeframe_ids):
# x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL]
batch_size, num_channels, _ = x.shape
# Process each channel through its branch
channel_outs = []
for i in range(num_channels):
channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out)
# Stack and add timeframe embeddings
stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden]
# Add timeframe embeddings to each channel
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
stacked = stacked + tf_embeds
# Apply Transformer
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
transformer_out = self.transformer(stacked, src_mask=src_mask)
# Attention Pooling over channels
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0)
aggregated = (transformer_out * attn_weights).sum(dim=0)
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
# --- Enhanced Data Processing ---
# Here you need to have the helper functions get_aligned_candle_with_index and get_features_for_tf
# They must be defined elsewhere in your code.
class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes):
self.candles_dict = candles_dict
self.base_tf = base_tf
self.timeframes = timeframes
self.current_index = 0 # Initialize step pointer
def reset(self):
self.current_index = 0
return self.get_state(self.current_index)
def get_state(self, index):
"""Returns state as an array of shape [num_channels, FEATURES_PER_CHANNEL]."""
state_features = []
base_ts = self.candles_dict[self.base_tf][index]["timestamp"]
# Timeframe channels
for tf in self.timeframes:
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
state_features.append(features)
# Indicator channels (placeholder - implement your indicators)
for _ in range(NUM_INDICATORS):
state_features.append([0.0] * FEATURES_PER_CHANNEL)
return np.array(state_features, dtype=np.float32)
def step(self, action):
"""
Advance the environment by one step.
Since this is for backtesting, action isn't used here.
Returns: current candle info, reward, next state, done, actual high, actual low.
"""
# Dummy implementation: you would generate targets based on your backtest logic.
done = (self.current_index >= len(self.candles_dict[self.base_tf]) - 2)
current_candle = self.candles_dict[self.base_tf][self.current_index]
# For example, take the next candle's high/low as targets
next_candle = self.candles_dict[self.base_tf][self.current_index + 1]
actual_high = next_candle["high"]
actual_low = next_candle["low"]
self.current_index += 1
next_state = self.get_state(self.current_index)
return current_candle, 0.0, next_state, done, actual_high, actual_low
def __len__(self):
return len(self.candles_dict[self.base_tf])
# --- Enhanced Training Loop ---
def train_on_historical_data(env, model, device, args):
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
for epoch in range(args.epochs):
state = env.reset()
total_loss = 0
model.train()
while True:
# Prepare batch (here batch size is 1 for simplicity)
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device)
# Forward pass
pred_high, pred_low = model(state_tensor, timeframe_ids)
# Get target values from next candle (dummy targets from environment)
_, _, next_state, done, actual_high, actual_low = env.step(None) # Dummy action
target_high = torch.FloatTensor([actual_high]).to(device)
target_low = torch.FloatTensor([actual_low]).to(device)
# Custom loss: use absolute error scaled by 2
high_loss = torch.abs(pred_high - target_high) * 2
low_loss = torch.abs(pred_low - target_low) * 2
loss = (high_loss + low_loss).mean()
# Backpropagation
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
if done:
break
state = next_state
scheduler.step()
print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}")
save_checkpoint(model, epoch, total_loss)
# --- Mode Handling and Argument Parsing ---
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train')
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005)
return parser.parse_args()
def load_best_checkpoint(model, best_dir="models/best"):
# Dummy implementation for loading the best checkpoint.
# In real usage, check your saved checkpoints.
print("Loading best checkpoint (dummy implementation)")
# torch.load(...) can be invoked here.
return
def save_checkpoint(model, epoch, reward, last_dir="models/last", best_dir="models/best"):
# Dummy implementation for saving checkpoints.
print(f"Saving checkpoint for epoch {epoch}, reward: {reward:.4f}")
# --- Main Function ---
async def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model
model = TradingModel(
num_channels=NUM_TIMEFRAMES + NUM_INDICATORS,
num_timeframes=NUM_TIMEFRAMES
).to(device)
if args.mode == 'train':
# Load historical candle data for backtesting
candles_dict = load_candles_cache("candles_cache.json")
if not candles_dict:
print("No historical candle data available for backtesting.")
return
base_tf = "1m" # Base timeframe
timeframes = ["1m", "5m", "15m", "1h", "1d"]
env = BacktestEnvironment(candles_dict, base_tf, timeframes)
train_on_historical_data(env, model, device, args)
elif args.mode == 'live':
# Load model and connect to live data
load_best_checkpoint(model)
while True:
# Process live data: fetch live candles, make predictions, execute trades
print("Processing live data...")
await asyncio.sleep(1)
elif args.mode == 'inference':
# Load model and run inference
load_best_checkpoint(model)
print("Running inference...")
# Add your inference logic here
if __name__ == "__main__":
asyncio.run(main())