gogo2/crypto/brian/index-gem.py
Dobromir Popov 6dfeee18bf requirements
2025-02-08 21:48:57 +02:00

910 lines
36 KiB
Python

#!/usr/bin/env python3
import sys
import asyncio
if sys.platform == 'win32':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
import os
import time
import json
import argparse
import threading
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from datetime import datetime
import matplotlib.pyplot as plt
import math
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import matplotlib.dates as mdates
from dotenv import load_dotenv
load_dotenv()
# Define global constants FIRST.
CACHE_FILE = "candles_cache.json"
TRAINING_CACHE_FILE = "training_cache.json"
# --- Helper Function for Timestamp Conversion ---
def convert_timestamp(ts):
ts = float(ts)
if ts > 1e10: # Handle milliseconds
ts /= 1000.0
return datetime.fromtimestamp(ts)
# -------------------------------
# Historical Data Fetching Functions (Using CCXT)
# -------------------------------
async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500):
candles = []
since_ms = since
while True:
try:
batch = await exchange.fetch_ohlcv(symbol, timeframe=timeframe, since=since_ms, limit=batch_size)
except Exception as e:
print("Error fetching historical data:", e)
break
if not batch:
break
for c in batch:
candle_dict = {
'timestamp': c[0],
'open': c[1],
'high': c[2],
'low': c[3],
'close': c[4],
'volume': c[5]
}
candles.append(candle_dict)
last_timestamp = batch[-1][0]
if last_timestamp >= end_time:
break
since_ms = last_timestamp + 1
print(f"Fetched {len(candles)} candles for timeframe {timeframe}.")
return candles
async def get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time, cache_file=CACHE_FILE, batch_size=500):
cached_candles = load_candles_cache(cache_file)
if cached_candles and timeframe in cached_candles:
last_ts = cached_candles[timeframe][-1]['timestamp']
if last_ts < end_time:
print("Fetching new candles to update cache...")
new_candles = await fetch_historical_data(exchange, symbol, timeframe, last_ts + 1, end_time, batch_size)
cached_candles[timeframe].extend(new_candles)
else:
print("Cache covers the requested period.")
return cached_candles[timeframe]
else:
candles = await fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size)
return candles
# -------------------------------
# Cache and Training Cache Helpers
# -------------------------------
def load_candles_cache(filename):
if os.path.exists(filename):
try:
with open(filename, "r") as f:
data = json.load(f)
print(f"Loaded cached data from {filename}.")
return data
except Exception as e:
print("Error reading cache file:", e)
return {} # Return empty dict if no cache
def save_candles_cache(filename, candles_dict):
try:
with open(filename, "w") as f:
json.dump(candles_dict, f)
except Exception as e:
print("Error saving cache file:", e)
def load_training_cache(filename):
if os.path.exists(filename):
try:
with open(filename, "r") as f:
cache = json.load(f)
print(f"Loaded training cache from {filename}.")
return cache
except Exception as e:
print("Error loading training cache:", e)
return {"total_pnl": 0.0} # Initialize if not found
def save_training_cache(filename, cache):
try:
with open(filename, "w") as f:
json.dump(cache, f)
except Exception as e:
print("Error saving training cache:", e)
# -------------------------------
# Checkpoint Functions
# -------------------------------
LAST_DIR = os.path.join("models", "last")
BEST_DIR = os.path.join("models", "best")
os.makedirs(LAST_DIR, exist_ok=True)
os.makedirs(BEST_DIR, exist_ok=True)
def maintain_checkpoint_directory(directory, max_files=10):
files = os.listdir(directory)
if len(files) > max_files:
full_paths = [os.path.join(directory, f) for f in files]
full_paths.sort(key=lambda x: os.path.getmtime(x))
for f in full_paths[:len(files) - max_files]:
os.remove(f)
def get_best_models(directory):
best_files = []
for file in os.listdir(directory):
parts = file.split("_")
try:
loss = float(parts[1]) # Get loss from filename
best_files.append((loss, file))
except Exception:
continue
return best_files
def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=BEST_DIR):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename)
torch.save({
"epoch": epoch,
"loss": loss,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}, last_path)
maintain_checkpoint_directory(last_dir, max_files=10)
best_models = get_best_models(best_dir)
add_to_best = False
if len(best_models) < 10:
add_to_best = True
else:
worst_loss, worst_file = max(best_models, key=lambda x: x[0])
if loss < worst_loss: # Save if better than worst
add_to_best = True
os.remove(os.path.join(best_dir, worst_file)) # Remove worst
if add_to_best:
best_filename = f"best_{loss:.4f}_epoch_{epoch}_{timestamp}.pt" # Include loss in name
best_path = os.path.join(best_dir, best_filename)
torch.save({
"epoch": epoch,
"loss": loss,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}, best_path)
maintain_checkpoint_directory(best_dir, max_files=10)
print(f"Saved checkpoint for epoch {epoch} with loss {loss:.4f}")
def load_best_checkpoint(model, best_dir=BEST_DIR):
best_models = get_best_models(best_dir)
if not best_models:
return None
best_loss, best_file = min(best_models, key=lambda x: x[0]) # Load best (lowest loss)
path = os.path.join(best_dir, best_file)
print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}")
checkpoint = torch.load(path)
# Handle potential embedding size mismatch
old_state = checkpoint["model_state_dict"]
new_state = model.state_dict()
if "timeframe_embed.weight" in old_state:
old_embed = old_state["timeframe_embed.weight"]
new_embed = new_state["timeframe_embed.weight"]
if old_embed.shape[0] < new_embed.shape[0]:
# Copy old embeddings to the new embedding, handling size increase
new_embed[:old_embed.shape[0]] = old_embed
old_state["timeframe_embed.weight"] = new_embed
model.load_state_dict(old_state, strict=False) # Allow for size differences
return checkpoint
# -------------------------------
# Positional Encoding and Transformer-Based Model
# -------------------------------
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: Tensor, shape [seq_len, batch_size, embedding_dim]
"""
x = x + self.pe[:x.size(0)]
return self.dropout(x)
class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__()
self.channel_branches = nn.ModuleList([
nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(0.1)
) for _ in range(num_channels)
])
# Embedding for each timeframe
self.timeframe_embed = nn.Embedding(num_timeframes, hidden_dim)
self.pos_encoder = PositionalEncoding(hidden_dim)
# Increased number of layers and heads for larger model
encoder_layers = TransformerEncoderLayer(
d_model=hidden_dim, nhead=8, dim_feedforward=2048, # Increased nhead and dim_feedforward
dropout=0.1, activation='gelu', batch_first=True
)
self.transformer = TransformerEncoder(encoder_layers, num_layers=6) # More layers
# Attention pooling to aggregate channel outputs
self.attn_pool = nn.Linear(hidden_dim, 1)
# Separate prediction heads for high and low
self.high_pred = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 1)
)
self.low_pred = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, x, timeframe_ids):
# x shape: (batch, num_channels, features_per_channel)
batch_size, num_channels, _ = x.shape
channel_outs = []
# Process each channel through its branch
for i in range(num_channels):
channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out)
# Stack channel outputs
stacked = torch.stack(channel_outs, dim=1) # (batch, num_channels, hidden_dim)
# Add timeframe embeddings
tf_embeds = self.timeframe_embed(timeframe_ids) # (num_timeframes, hidden_dim)
stacked = stacked + tf_embeds.unsqueeze(0) # Add to each item in batch
# Transformer
transformer_out = self.transformer(stacked) # (batch, num_channels, hidden_dim)
# Attention pooling
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1) # (batch, num_channels, 1)
aggregated = (transformer_out * attn_weights).sum(dim=1) # (batch, hidden_dim)
# Predict high and low
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
# -------------------------------
# Technical Indicator Helpers
# -------------------------------
def compute_sma(candles_list, index, period=10):
start = max(0, index - period + 1)
values = [candle["close"] for candle in candles_list[start:index + 1]]
return sum(values) / len(values) if values else 0.0
def compute_sma_volume(candles_list, index, period=10):
start = max(0, index - period + 1)
values = [candle["volume"] for candle in candles_list[start:index + 1]]
return sum(values) / len(values) if values else 0.0
def get_aligned_candle_with_index(candles_list, target_ts):
"""Find the candle in the list whose timestamp is the largest that is <= target_ts."""
best_idx = 0
for i, candle in enumerate(candles_list):
if candle["timestamp"] <= target_ts:
best_idx = i
else:
break # Stop once we go past the target
return best_idx, candles_list[best_idx]
def get_features_for_tf(candles_list, index, period=10):
"""Return a vector of 7 features: open, high, low, close, volume, sma_close, sma_volume."""
candle = candles_list[index]
f_open = candle["open"]
f_high = candle["high"]
f_low = candle["low"]
f_close = candle["close"]
f_volume = candle["volume"]
sma_close = compute_sma(candles_list, index, period)
sma_volume = compute_sma_volume(candles_list, index, period)
return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume]
# -------------------------------
# Backtest Environment Class
# -------------------------------
class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
self.candles_dict = candles_dict
self.base_tf = base_tf
self.timeframes = timeframes
self.full_candles = candles_dict[base_tf] # All candles for base timeframe
# Define window size (or use a reasonable default if not provided)
if window_size is None:
window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles) # Use 100 or total length
self.window_size = window_size
self.reset() # Initialize
def reset(self):
# Randomly select a starting point for the window
self.start_index = random.randint(0, len(self.full_candles) - self.window_size)
self.candle_window = self.full_candles[self.start_index:self.start_index + self.window_size]
self.current_index = 0
self.trade_history = []
self.position = None # Track if we're in a trade: None, or {"entry_price": ..., "entry_index": ...}
return self.get_state(self.current_index) # Return initial state
def __len__(self):
return self.window_size # Length of the environment is the window size
def get_order_features(self, index):
"""Get features related to the current order (if any)."""
candle = self.candle_window[index]
if self.position is None:
# No position: all zeros
return [0.0] * FEATURES_PER_CHANNEL # 7 zeros
else:
# In a position: [1.0, price_diff, 0, 0, 0, 0, 0]
flag = 1.0
diff = (candle["open"] - self.position["entry_price"]) / candle["open"] # Relative difference
return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2)
def get_state(self, index):
"""Construct state for the given index."""
state_features = []
base_ts = self.candle_window[index]["timestamp"]
# Get features for each timeframe
for tf in self.timeframes:
if tf == self.base_tf:
# For the base timeframe, use the candle directly from the window
candle = self.candle_window[index]
features = get_features_for_tf([candle], 0) # Pass as a list with single candle
else:
# For other timeframes, align with the base timestamp
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
state_features.append(features)
# Add order features
order_features = self.get_order_features(index)
state_features.append(order_features)
# Add placeholder channels for additional indicators (if needed)
for _ in range(NUM_INDICATORS):
state_features.append([0.0] * FEATURES_PER_CHANNEL)
return np.array(state_features, dtype=np.float32)
def step(self, action):
"""Take a step, given an action."""
base = self.candle_window # Shorter name
if self.current_index >= len(base) - 1:
current_state = self.get_state(self.current_index)
return current_state, 0.0, None, True, 0.0, 0.0 # No reward at very end, and done
current_state = self.get_state(self.current_index)
next_index = self.current_index + 1
next_state = self.get_state(next_index)
next_candle = base[next_index] # Next candle for open, high, low
reward = 0.0
# Handle actions (simplified for clarity)
if self.position is None:
if action == 2: # BUY
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
else:
if action == 0: # SELL
exit_price = next_candle["close"]
reward = exit_price - self.position["entry_price"] # PnL is reward
trade = {
"entry_index": self.position["entry_index"],
"entry_price": self.position["entry_price"],
"exit_index": next_index,
"exit_price": exit_price,
"pnl": reward
}
self.trade_history.append(trade)
self.position = None # Exit position
self.current_index = next_index
done = (self.current_index >= len(base) - 1) # Done if at end of window
actual_high = next_candle["high"]
actual_low = next_candle["low"]
return current_state, reward, next_state, done, actual_high, actual_low # Return next high/low
# -------------------------------
# Enhanced Training Loop
# -------------------------------
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
lambda_trade = args.lambda_trade # Weight for trade loss
# Load training cache (for total PnL tracking)
training_cache = load_training_cache(TRAINING_CACHE_FILE)
total_pnl = training_cache.get("total_pnl", 0.0)
for epoch in range(start_epoch, args.epochs):
env.reset() # Reset environment for each epoch
loss_accum = 0.0
steps = len(env) - 1 # Number of steps in the episode
for i in range(steps): # Iterate through the episode
state = env.get_state(i)
current_open = env.candle_window[i]["open"] # Current candle's open
actual_high = env.candle_window[i + 1]["high"] # Next candle's high
actual_low = env.candle_window[i + 1]["low"] # Next candle's low
# Forward pass
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) # Add batch dimension
timeframe_ids = torch.arange(state.shape[0]).to(device) # Create timeframe IDs
pred_high, pred_low = model(state_tensor, timeframe_ids)
# Calculate prediction loss (L_pred)
L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \
torch.abs(pred_low - torch.tensor(actual_low, device=device))
# Calculate trade surrogate loss (L_trade)
profit_buy = pred_high - current_open # Potential profit if buying
profit_sell = current_open - pred_low # Potential profit if selling
L_trade = - torch.max(profit_buy, profit_sell) # Minimize negative profit
# Calculate no-action penalty (encourage taking action when profitable)
current_open_tensor = torch.tensor(current_open, device=device)
signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low)
penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0)
# Total loss
loss = L_pred + lambda_trade * L_trade + penalty_term
# Backpropagation
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Gradient clipping
optimizer.step()
scheduler.step()
loss_accum += loss.item()
epoch_loss = loss_accum / steps # Average loss per step
if len(env.trade_history) == 0:
epoch_loss *= 3
epoch_pnl = sum(trade["pnl"] for trade in env.trade_history) # PnL for the epoch
total_pnl += epoch_pnl # Update total PnL
print(f"Epoch {epoch + 1} Loss: {epoch_loss:.4f} | Epoch PnL: {epoch_pnl:.2f} | Total PnL: {total_pnl:.2f}")
save_checkpoint(model, optimizer, epoch, loss_accum) # Save with accumulated loss
simulate_trades(model, env, device, args) # Simulate trades after each epoch
update_live_html(env.candle_window, env.trade_history, epoch + 1, epoch_loss, total_pnl) # Update HTML visualization
# Save training cache (for total PnL tracking)
training_cache["total_pnl"] = total_pnl
save_training_cache(TRAINING_CACHE_FILE, training_cache)
# -------------------------------
# Live Plotting (for Live Mode)
# -------------------------------
def live_preview_loop(candles, env):
plt.ion() # Interactive mode
fig, ax = plt.subplots(figsize=(12, 6))
while True:
update_live_chart(ax, candles, env.trade_history)
plt.draw()
plt.pause(1) # Update every second
# -------------------------------
# Live HTML Chart Update (with Volume and Loss)
# -------------------------------
def update_live_html(candles, trade_history, epoch, loss, total_pnl):
from io import BytesIO
import base64
# Create a new figure and axes for each update
fig, ax = plt.subplots(figsize=(12, 6))
# Draw the chart
update_live_chart(ax, candles, trade_history)
epoch_pnl = sum(trade["pnl"] for trade in trade_history) # PnL for this window
ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}")
# Save the plot to a BytesIO buffer
buf = BytesIO()
fig.savefig(buf, format='png')
plt.close(fig) # Close the figure to free memory
buf.seek(0)
image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
# Generate HTML content
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta http-equiv="refresh" content="1"> <!-- Refresh every second -->
<title>Live Trading Chart - Epoch {epoch}</title>
<style>
body {{
margin: 0;
padding: 0;
display: flex;
justify-content: center;
align-items: center;
background-color: #f4f4f4;
}}
.chart-container {{
text-align: center;
}}
img {{
max-width: 100%;
height: auto;
}}
</style>
</head>
<body>
<div class="chart-container">
<h2>Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}</h2>
<img src="data:image/png;base64,{image_base64}" alt="Live Chart"/>
</div>
</body>
</html>
"""
with open("live_chart.html", "w") as f:
f.write(html_content)
print("Updated live_chart.html.")
# -------------------------------
# Chart Drawing Helpers (with Volume and Date+Time)
# -------------------------------
def update_live_chart(ax, candles, trade_history):
ax.clear() # Clear previous data
# Extract data for plotting
times = [convert_timestamp(candle["timestamp"]) for candle in candles]
close_prices = [candle["close"] for candle in candles]
# Plot close prices
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
ax.set_xlabel("Time")
ax.set_ylabel("Price")
# Format x-axis to show dates and times
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
# Create a second y-axis for volume
ax2 = ax.twinx()
volumes = [candle["volume"] for candle in candles]
if len(times) > 1:
times_num = mdates.date2num(times)
bar_width = (times_num[-1] - times_num[0]) / len(times) * 0.8 # Relative width
else:
bar_width = 0.01
ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume")
ax2.set_ylabel("Volume")
# Plot trade markers (buy/sell)
for trade in trade_history:
entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"])
exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"])
in_price = trade["entry_price"]
out_price = trade["exit_price"]
ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") # Buy marker
ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL") # Sell marker
ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue") # Dotted line
# Combine legends from both axes
lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines + lines2, labels + labels2)
ax.grid(True)
# Auto-format x-axis labels for better readability
fig = ax.get_figure()
fig.autofmt_xdate()
# -------------------------------
# Global Constants for Features
# -------------------------------
NUM_INDICATORS = 20 # Number of additional indicator channels
FEATURES_PER_CHANNEL = 7
ORDER_CHANNELS = 1
# -------------------------------
# General Simulation of Trades Function
# -------------------------------
def simulate_trades(model, env, device, args):
if args.main_tf == "1s":
simulate_trades_1s(env)
return
env.reset() # Reset to a random starting point
while True:
i = env.current_index
if i >= len(env.candle_window) - 1:
break # Exit if at the end of the window
state = env.get_state(i)
current_open = env.candle_window[i]["open"] # Get current open
# Make predictions
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) # Add batch dimension
timeframe_ids = torch.arange(state.shape[0]).to(device) # IDs for timeframes
pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item() # Convert to Python number
pred_low = pred_low.item()
# Decide on action (simplified for clarity)
if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
if (pred_high - current_open) >= (current_open - pred_low):
action = 2 # BUY
else:
action = 0 # SELL
_, _, _, done, _, _ = env.step(action) # Take the step
else:
manual_trade(env)
if env.current_index >= len(env.candle_window) - 1:
break
def simulate_trades_1s(env):
# Ensure main_tf is 1s
if env.base_tf != "1s":
raise ValueError("simulate_trades_1s can only be used with base_tf='1s'")
env.reset()
current_second = env.candle_window[env.current_index]['timestamp'] // 1000
# Simulate trading for the entire window
while True:
if env.current_index >= len(env.candle_window) - 1:
break # Break if end of the window is reached.
# Check if a new second has started
next_second = env.candle_window[env.current_index]['timestamp'] // 1000
if next_second != current_second: # A new second has started. Take random action (buy/sell/hold).
action = random_action()
_, _, _, done, _, _ = env.step(action) # Take the step
current_second = next_second # Update the current second
else:
# Still the same second, hold position
manual_trade(env)
if env.current_index >= len(env.candle_window) - 1:
break
def manual_trade(env):
#If we are in a position, hold it. Otherwise, do nothing
if env.position is not None:
#In position, take 'HOLD' action implicitly by doing nothing
pass
# -------------------------------
# Argument Parsing
# -------------------------------
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train',
help="Operating mode: 'train', 'live', or 'inference'.")
parser.add_argument('--epochs', type=int, default=1000,
help="Number of training epochs.")
parser.add_argument('--lr', type=float, default=3e-4,
help="Learning rate.")
parser.add_argument('--threshold', type=float, default=0.005,
help="Minimum predicted move to trigger trade (used in loss; model may override manual trades).")
parser.add_argument('--lambda_trade', type=float, default=1.0,
help="Weight for the trade surrogate loss.")
parser.add_argument('--penalty_noaction', type=float, default=10.0,
help="Penalty if no action is taken (used in loss).")
parser.add_argument('--start_fresh', action='store_true',
help="Start training from scratch.")
parser.add_argument('--main_tf', type=str, default='1m',
help="Desired main timeframe to focus on (e.g., '1s' or '1m').")
# Instead of --fetch, we now provide a --no-fetch flag that will override the default behavior.
parser.add_argument('--no-fetch', dest='fetch', action='store_false',
help="Do NOT fetch fresh data from exchange on start.")
parser.set_defaults(fetch=True)
parser.add_argument('--symbol', type=str, default='BTC/USDT', help="Trading pair symbol.")
return parser.parse_args()
def random_action():
return random.randint(0, 2) # 0: SELL, 1: HOLD, 2: BUY
# -------------------------------
# Main Function
# -------------------------------
async def main():
args = parse_args()
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# Fetch data (if not --no-fetch)
# With fetch defaulting to True, live mode will always try to top-up the cache.
if args.fetch:
import ccxt.async_support as ccxt
exchange = ccxt.binance({'enableRateLimit': True}) # Use Binance
now_ms = int(time.time() * 1000)
# Check if we have cached data. If so, only fetch what we need to update the cache.
cached = load_candles_cache(CACHE_FILE)
if cached and args.main_tf in cached and len(cached[args.main_tf]) > 0:
last_ts = cached[args.main_tf][-1]['timestamp']
since = last_ts + 1
else:
# Fetch a reasonable amount of historical data initially (e.g., last 2 days)
since = now_ms - 2 * 24 * 60 * 60 * 1000
print(f"Fetching fresh data for {args.symbol} on timeframe {args.main_tf} from {since} to {now_ms}...")
fresh_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf, since, now_ms)
candles_dict = {args.main_tf: fresh_candles} # Initially, only main timeframe
# Save (or update) the cache
save_candles_cache(CACHE_FILE, candles_dict)
await exchange.close()
else:
# Load from cache
candles_dict = load_candles_cache(CACHE_FILE)
if not candles_dict:
print("No cached data available. Run without --no-fetch (default) to load fresh data from the exchange.")
return
# --- Timeframe Setup ---
default_timeframes = ["1s", "1m", "5m", "15m", "1h", "1d"] # All supported timeframes
timeframes = [tf for tf in default_timeframes if tf in candles_dict] # Use available timeframes
if args.main_tf not in timeframes:
print(f"Desired main timeframe {args.main_tf} is not available. Available: {timeframes}")
return
base_tf = args.main_tf
# --- Model Initialization ---
hidden_dim = 128 # Hidden dimension for the Transformer
total_channels = len(timeframes) + ORDER_CHANNELS + NUM_INDICATORS # Input channels
model = TradingModel(total_channels, len(timeframes)).to(device)
if args.mode == 'train':
# --- Training Setup ---
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
start_epoch = 0
checkpoint = None
# Load checkpoint (if not starting fresh)
if not args.start_fresh:
checkpoint = load_best_checkpoint(model)
if checkpoint is not None:
start_epoch = checkpoint.get("epoch", 0) + 1 # Start from next epoch
print(f"Resuming training from epoch {start_epoch}.")
else:
print("No checkpoint found. Starting training from scratch.")
else:
print("Starting training from scratch as requested.")
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) # AdamW optimizer
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch) # Cosine annealing
# Load optimizer state (if checkpoint exists)
if checkpoint is not None:
optim_state = checkpoint.get("optimizer_state_dict", None)
if optim_state is not None and "param_groups" in optim_state:
try:
optimizer.load_state_dict(optim_state)
print("Loaded optimizer state from checkpoint.")
except Exception as e:
print("Failed to load optimizer state due to:", e)
print("Deleting all checkpoints and starting fresh.")
for chk_dir in [LAST_DIR, BEST_DIR]:
for f in os.listdir(chk_dir):
os.remove(os.path.join(chk_dir, f))
else:
print("No valid optimizer state found; using fresh optimizer state.")
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
elif args.mode == 'live':
import ccxt.async_support as ccxt
exchange = ccxt.binance({'enableRateLimit': True}) # Use Binance
POLL_INTERVAL = 60 # seconds
async def update_live_candles():
nonlocal exchange, args, candles_dict
while True:
now_ms = int(time.time() * 1000)
# Fetch just the most recent candles
new_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf,
since=now_ms - 2 * 60 * 1000, end_time=now_ms) # Fetch last 2 mins
if args.main_tf in candles_dict:
candles_dict[args.main_tf] = new_candles # Update
else:
candles_dict[args.main_tf] = new_candles # Or add if not present
print("Live candles updated.")
await asyncio.sleep(POLL_INTERVAL)
# Start the candle update task
asyncio.create_task(update_live_candles())
load_best_checkpoint(model) # Load the best model for live trading
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
# Start live preview (optional)
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
preview_thread.start()
print("Starting live trading loop. (Using live updated data now.)")
while True:
if args.main_tf == "1s":
simulate_trades_1s(env) # Run 1s trading loop
else:
# Get the current state
state = env.get_state(env.current_index)
current_open = env.candle_window[env.current_index]["open"] # Get current open
# Make predictions
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device) # Create timeframe IDs
pred_high, pred_low = model(state_tensor, timeframe_ids)
pred_high = pred_high.item() # Convert to Python number
pred_low = pred_low.item()
# Decide on action
if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
if (pred_high - current_open) >= (current_open - pred_low):
action = 2 # BUY
else:
action = 0 # SELL
_, _, _, done, _, _ = env.step(action) # Take the step
else:
manual_trade(env)
if env.current_index >= len(env.candle_window) - 1:
print("Reached end of simulation window; resetting environment.")
env.reset() # Reset environment when reaching end
await asyncio.sleep(1) # Short delay for live mode
elif args.mode == 'inference':
load_best_checkpoint(model) # Load the best model for inference
print("Running inference...")
# Add inference logic here (e.g., load data, make predictions, print results)
else:
print("Invalid mode specified.")
if __name__ == "__main__":
asyncio.run(main())