This commit is contained in:
Dobromir Popov 2025-02-04 19:04:44 +02:00
parent 2ec75e66cb
commit c8043a9dcd

View File

@ -7,6 +7,9 @@ if sys.platform == 'win32':
import os import os
import time import time
import json import json
import argparse
import threading
import random
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -15,20 +18,22 @@ from collections import deque
from datetime import datetime from datetime import datetime
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import ccxt.async_support as ccxt import ccxt.async_support as ccxt
import argparse
from torch.nn import TransformerEncoder, TransformerEncoderLayer from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math import math
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
# --- Directories ---
LAST_DIR = os.path.join("models", "last")
BEST_DIR = os.path.join("models", "best")
os.makedirs(LAST_DIR, exist_ok=True)
os.makedirs(BEST_DIR, exist_ok=True)
CACHE_FILE = "candles_cache.json"
# --- New Constants --- # --- Constants ---
NUM_TIMEFRAMES = 5 # Example: ["1m", "5m", "15m", "1h", "1d"] NUM_TIMEFRAMES = 5 # Example: ["1m", "5m", "15m", "1h", "1d"]
NUM_INDICATORS = 20 # Example: 20 technical indicators NUM_INDICATORS = 20 # Example: 20 technical indicators
FEATURES_PER_CHANNEL = 7 # HLOC + SMA_close + SMA_volume FEATURES_PER_CHANNEL = 7 # e.g. HLOC, SMA_close, SMA_volume
# --- Positional Encoding Module --- # --- Positional Encoding Module ---
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
@ -58,107 +63,207 @@ class TradingModel(nn.Module):
nn.Dropout(0.1) nn.Dropout(0.1)
) for _ in range(num_channels) ) for _ in range(num_channels)
]) ])
self.timeframe_embed = nn.Embedding(num_timeframes, hidden_dim) self.timeframe_embed = nn.Embedding(num_timeframes, hidden_dim)
self.pos_encoder = PositionalEncoding(hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim)
# Transformer
encoder_layers = TransformerEncoderLayer( encoder_layers = TransformerEncoderLayer(
d_model=hidden_dim, nhead=4, dim_feedforward=512, d_model=hidden_dim, nhead=4, dim_feedforward=512,
dropout=0.1, activation='gelu', batch_first=False dropout=0.1, activation='gelu', batch_first=False
) )
self.transformer = TransformerEncoder(encoder_layers, num_layers=2) self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
# Attention Pooling
self.attn_pool = nn.Linear(hidden_dim, 1) self.attn_pool = nn.Linear(hidden_dim, 1)
# Prediction Heads
self.high_pred = nn.Sequential( self.high_pred = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim//2), nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(), nn.GELU(),
nn.Linear(hidden_dim//2, 1) nn.Linear(hidden_dim // 2, 1)
) )
self.low_pred = nn.Sequential( self.low_pred = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim//2), nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(), nn.GELU(),
nn.Linear(hidden_dim//2, 1) nn.Linear(hidden_dim // 2, 1)
) )
def forward(self, x, timeframe_ids): def forward(self, x, timeframe_ids):
# x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL] # x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL]
batch_size, num_channels, _ = x.shape batch_size, num_channels, _ = x.shape
channel_outs = [self.channel_branches[i](x[:, i, :]) for i in range(num_channels)]
# Process each channel through its branch
channel_outs = []
for i in range(num_channels):
channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out)
# Stack and add timeframe embeddings
stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden] stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden] stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden]
# Add timeframe embeddings to each channel
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
stacked = stacked + tf_embeds stacked = stacked + tf_embeds
# Apply Transformer
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device) src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
transformer_out = self.transformer(stacked, src_mask=src_mask) transformer_out = self.transformer(stacked, src_mask=src_mask)
# Attention Pooling over channels
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0) attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0)
aggregated = (transformer_out * attn_weights).sum(dim=0) aggregated = (transformer_out * attn_weights).sum(dim=0)
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze() return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
# --- Enhanced Data Processing --- # --- Technical Indicator Helper Functions ---
# Here you need to have the helper functions get_aligned_candle_with_index and get_features_for_tf def compute_sma(candles_list, index, period=10):
# They must be defined elsewhere in your code. start = max(0, index - period + 1)
values = [candle["close"] for candle in candles_list[start:index+1]]
return sum(values) / len(values) if values else 0.0
def compute_sma_volume(candles_list, index, period=10):
start = max(0, index - period + 1)
values = [candle["volume"] for candle in candles_list[start:index+1]]
return sum(values) / len(values) if values else 0.0
def get_aligned_candle_with_index(candles_list, target_ts):
best_idx = 0
for i, candle in enumerate(candles_list):
if candle["timestamp"] <= target_ts:
best_idx = i
else:
break
return best_idx, candles_list[best_idx]
def get_features_for_tf(candles_list, index, period=10):
candle = candles_list[index]
f_open = candle["open"]
f_high = candle["high"]
f_low = candle["low"]
f_close = candle["close"]
f_volume = candle["volume"]
sma_close = compute_sma(candles_list, index, period)
sma_volume = compute_sma_volume(candles_list, index, period)
return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume]
# --- Caching and Checkpoint Functions ---
def load_candles_cache(filename):
if os.path.exists(filename):
try:
with open(filename, "r") as f:
data = json.load(f)
print(f"Loaded cached data from {filename}.")
return data
except Exception as e:
print("Error reading cache file:", e)
return {}
def save_candles_cache(filename, candles_dict):
try:
with open(filename, "w") as f:
json.dump(candles_dict, f)
except Exception as e:
print("Error saving cache file:", e)
def maintain_checkpoint_directory(directory, max_files=10):
files = os.listdir(directory)
if len(files) > max_files:
full_paths = [os.path.join(directory, f) for f in files]
full_paths.sort(key=lambda x: os.path.getmtime(x))
for f in full_paths[: len(files) - max_files]:
os.remove(f)
def get_best_models(directory):
best_files = []
for file in os.listdir(directory):
parts = file.split("_")
try:
r = float(parts[1])
best_files.append((r, file))
except Exception:
continue
return best_files
def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"model_state_dict": model.state_dict()
}, last_path)
maintain_checkpoint_directory(last_dir, max_files=10)
best_models = get_best_models(best_dir)
add_to_best = False
if len(best_models) < 10:
add_to_best = True
else:
min_reward, min_file = min(best_models, key=lambda x: x[0])
if reward > min_reward:
add_to_best = True
os.remove(os.path.join(best_dir, min_file))
if add_to_best:
best_filename = f"best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"
best_path = os.path.join(best_dir, best_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"model_state_dict": model.state_dict()
}, best_path)
maintain_checkpoint_directory(best_dir, max_files=10)
print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}")
def load_best_checkpoint(model, best_dir=BEST_DIR):
best_models = get_best_models(best_dir)
if not best_models:
return None
best_reward, best_file = max(best_models, key=lambda x: x[0])
path = os.path.join(best_dir, best_file)
print(f"Loading best model from checkpoint: {best_file} with reward {best_reward:.4f}")
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
return checkpoint
# --- Backtest Environment ---
class BacktestEnvironment: class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes): def __init__(self, candles_dict, base_tf, timeframes):
self.candles_dict = candles_dict self.candles_dict = candles_dict # dict of timeframe: candles_list
self.base_tf = base_tf self.base_tf = base_tf
self.timeframes = timeframes self.timeframes = timeframes
self.current_index = 0 # Initialize step pointer self.current_index = 0
self.trade_history = []
self.position = None
def reset(self): def reset(self):
self.current_index = 0 self.current_index = 0
self.position = None
self.trade_history = []
return self.get_state(self.current_index) return self.get_state(self.current_index)
def get_state(self, index): def get_state(self, index):
"""Returns state as an array of shape [num_channels, FEATURES_PER_CHANNEL]."""
state_features = [] state_features = []
base_ts = self.candles_dict[self.base_tf][index]["timestamp"] base_ts = self.candles_dict[self.base_tf][index]["timestamp"]
# Timeframe channels
for tf in self.timeframes: for tf in self.timeframes:
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts) aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
features = get_features_for_tf(self.candles_dict[tf], aligned_idx) features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
state_features.append(features) state_features.append(features)
# Indicator channels (placeholder - implement your indicators)
for _ in range(NUM_INDICATORS): for _ in range(NUM_INDICATORS):
state_features.append([0.0] * FEATURES_PER_CHANNEL) state_features.append([0.0] * FEATURES_PER_CHANNEL)
return np.array(state_features, dtype=np.float32) return np.array(state_features, dtype=np.float32)
def step(self, action): def step(self, action):
""" base_candles = self.candles_dict[self.base_tf]
Advance the environment by one step. if self.current_index >= len(base_candles) - 1:
Since this is for backtesting, action isn't used here. return self.get_state(self.current_index), 0.0, None, True
Returns: current candle info, reward, next state, done, actual high, actual low. current_state = self.get_state(self.current_index)
""" next_index = self.current_index + 1
# Dummy implementation: you would generate targets based on your backtest logic. next_state = self.get_state(next_index)
done = (self.current_index >= len(self.candles_dict[self.base_tf]) - 2) current_candle = base_candles[self.current_index]
current_candle = self.candles_dict[self.base_tf][self.current_index] next_candle = base_candles[next_index]
# For example, take the next candle's high/low as targets reward = 0.0
next_candle = self.candles_dict[self.base_tf][self.current_index + 1] # Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
actual_high = next_candle["high"] if self.position is None:
actual_low = next_candle["low"] if action == 2: # BUY signal
self.current_index += 1 entry_price = next_candle["open"]
next_state = self.get_state(self.current_index) self.position = {"entry_price": entry_price, "entry_index": self.current_index}
return current_candle, 0.0, next_state, done, actual_high, actual_low else:
if action == 0: # SELL signal
exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"]
trade = {
"entry_index": self.position["entry_index"],
"entry_price": self.position["entry_price"],
"exit_index": next_index,
"exit_price": exit_price,
"pnl": reward
}
self.trade_history.append(trade)
self.position = None
self.current_index = next_index
done = (self.current_index >= len(base_candles) - 1)
return current_state, reward, next_state, done
def __len__(self): def __len__(self):
return len(self.candles_dict[self.base_tf]) return len(self.candles_dict[self.base_tf])
@ -167,99 +272,126 @@ class BacktestEnvironment:
def train_on_historical_data(env, model, device, args): def train_on_historical_data(env, model, device, args):
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
for epoch in range(args.epochs): for epoch in range(args.epochs):
state = env.reset() state = env.reset()
total_loss = 0 total_loss = 0
model.train() model.train()
while True: while True:
# Prepare batch (here batch size is 1 for simplicity)
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device)
# Forward pass
pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high, pred_low = model(state_tensor, timeframe_ids)
# Here we use dummy targets extracted from the next candle's high/low
# Get target values from next candle (dummy targets from environment) _, _, next_state, done, actual_high, actual_low = env.step(None)
_, _, next_state, done, actual_high, actual_low = env.step(None) # Dummy action
target_high = torch.FloatTensor([actual_high]).to(device) target_high = torch.FloatTensor([actual_high]).to(device)
target_low = torch.FloatTensor([actual_low]).to(device) target_low = torch.FloatTensor([actual_low]).to(device)
# Custom loss: use absolute error scaled by 2
high_loss = torch.abs(pred_high - target_high) * 2 high_loss = torch.abs(pred_high - target_high) * 2
low_loss = torch.abs(pred_low - target_low) * 2 low_loss = torch.abs(pred_low - target_low) * 2
loss = (high_loss + low_loss).mean() loss = (high_loss + low_loss).mean()
# Backpropagation
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step() optimizer.step()
total_loss += loss.item() total_loss += loss.item()
if done: if done:
break break
state = next_state state = next_state
scheduler.step() scheduler.step()
print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}") print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}")
save_checkpoint(model, epoch, total_loss) save_checkpoint(model, epoch, total_loss)
# --- Mode Handling and Argument Parsing --- # --- Live Plotting Functions ---
def update_live_chart(ax, candles, trade_history):
ax.clear()
close_prices = [candle["close"] for candle in candles]
x = list(range(len(close_prices)))
ax.plot(x, close_prices, label="Close Price", color="black", linewidth=1)
buy_label_added = False
sell_label_added = False
for trade in trade_history:
in_idx = trade["entry_index"]
out_idx = trade["exit_index"]
in_price = trade["entry_price"]
out_price = trade["exit_price"]
if not buy_label_added:
ax.plot(in_idx, in_price, marker="^", color="green", markersize=10, label="BUY")
buy_label_added = True
else:
ax.plot(in_idx, in_price, marker="^", color="green", markersize=10)
if not sell_label_added:
ax.plot(out_idx, out_price, marker="v", color="red", markersize=10, label="SELL")
sell_label_added = True
else:
ax.plot(out_idx, out_price, marker="v", color="red", markersize=10)
ax.plot([in_idx, out_idx], [in_price, out_price], linestyle="dotted", color="blue")
ax.set_title("Live Trading Chart")
ax.set_xlabel("Candle Index")
ax.set_ylabel("Price")
ax.legend()
ax.grid(True)
def live_preview_loop(candles, env):
plt.ion()
fig, ax = plt.subplots(figsize=(12, 6))
while True:
update_live_chart(ax, candles, env.trade_history)
plt.draw()
plt.pause(1) # Update every second
# --- Argument Parsing ---
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train') parser.add_argument('--mode', choices=['train','live','inference'], default='train')
parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005) parser.add_argument('--threshold', type=float, default=0.005)
return parser.parse_args() return parser.parse_args()
def load_best_checkpoint(model, best_dir="models/best"): def random_action():
# Dummy implementation for loading the best checkpoint. return random.randint(0, 2)
# In real usage, check your saved checkpoints.
print("Loading best checkpoint (dummy implementation)")
# torch.load(...) can be invoked here.
return
def save_checkpoint(model, epoch, reward, last_dir="models/last", best_dir="models/best"):
# Dummy implementation for saving checkpoints.
print(f"Saving checkpoint for epoch {epoch}, reward: {reward:.4f}")
# --- Main Function --- # --- Main Function ---
async def main(): async def main():
args = parse_args() args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define timeframes; these must match your data and expected state dimensions.
# Initialize model timeframes = ["1m", "5m", "15m", "1h", "1d"]
model = TradingModel( input_dim = len(timeframes) * 7 # 7 features per timeframe.
num_channels=NUM_TIMEFRAMES + NUM_INDICATORS, hidden_dim = 128
num_timeframes=NUM_TIMEFRAMES output_dim = 3 # Actions: SELL, HOLD, BUY.
).to(device) # For the Transformer model, we set number of channels = NUM_TIMEFRAMES + NUM_INDICATORS.
model = TradingModel(NUM_TIMEFRAMES + NUM_INDICATORS, NUM_TIMEFRAMES).to(device)
if args.mode == 'train': if args.mode == 'train':
# Load historical candle data for backtesting candles_dict = load_candles_cache(CACHE_FILE)
candles_dict = load_candles_cache("candles_cache.json")
if not candles_dict: if not candles_dict:
print("No historical candle data available for backtesting.") print("No historical candle data available for backtesting.")
return return
base_tf = "1m" # Base timeframe base_tf = "1m"
timeframes = ["1m", "5m", "15m", "1h", "1d"]
env = BacktestEnvironment(candles_dict, base_tf, timeframes) env = BacktestEnvironment(candles_dict, base_tf, timeframes)
train_on_historical_data(env, model, device, args) train_on_historical_data(env, model, device, args)
elif args.mode == 'live': elif args.mode == 'live':
# Load model and connect to live data
load_best_checkpoint(model) load_best_checkpoint(model)
candles_dict = load_candles_cache(CACHE_FILE)
if not candles_dict:
print("No cached candles available for live preview.")
return
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes)
# Start the live preview in a separate daemon thread.
preview_thread = threading.Thread(target=live_preview_loop, args=(candles_dict["1m"], env), daemon=True)
preview_thread.start()
print("Starting live trading loop. (Using random actions for simulation.)")
while True: while True:
# Process live data: fetch live candles, make predictions, execute trades state, reward, next_state, done = env.step(random_action())
print("Processing live data...") if done:
await asyncio.sleep(1) print("Reached end of simulated data, resetting environment.")
state = env.reset(clear_trade_history=False)
await asyncio.sleep(1) # Simulate one candle per second.
elif args.mode == 'inference': elif args.mode == 'inference':
# Load model and run inference
load_best_checkpoint(model) load_best_checkpoint(model)
print("Running inference...") print("Running inference...")
# Add your inference logic here # Implement your inference loop here.
else:
print("Invalid mode specified.")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())