requirements
This commit is contained in:
File diff suppressed because one or more lines are too long
@ -22,6 +22,13 @@ import matplotlib.dates as mdates
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
|
||||
|
||||
import torch
|
||||
print(torch.cuda.is_available())
|
||||
|
||||
|
||||
|
||||
# Define global constants FIRST.
|
||||
CACHE_FILE = "candles_cache.json"
|
||||
TRAINING_CACHE_FILE = "training_cache.json"
|
||||
|
910
crypto/brian/index-gem.py
Normal file
910
crypto/brian/index-gem.py
Normal file
@ -0,0 +1,910 @@
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
if sys.platform == 'win32':
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import argparse
|
||||
import threading
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from datetime import datetime
|
||||
import matplotlib.pyplot as plt
|
||||
import math
|
||||
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
||||
import matplotlib.dates as mdates
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Define global constants FIRST.
|
||||
CACHE_FILE = "candles_cache.json"
|
||||
TRAINING_CACHE_FILE = "training_cache.json"
|
||||
|
||||
# --- Helper Function for Timestamp Conversion ---
|
||||
def convert_timestamp(ts):
|
||||
ts = float(ts)
|
||||
if ts > 1e10: # Handle milliseconds
|
||||
ts /= 1000.0
|
||||
return datetime.fromtimestamp(ts)
|
||||
|
||||
# -------------------------------
|
||||
# Historical Data Fetching Functions (Using CCXT)
|
||||
# -------------------------------
|
||||
async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500):
|
||||
candles = []
|
||||
since_ms = since
|
||||
while True:
|
||||
try:
|
||||
batch = await exchange.fetch_ohlcv(symbol, timeframe=timeframe, since=since_ms, limit=batch_size)
|
||||
except Exception as e:
|
||||
print("Error fetching historical data:", e)
|
||||
break
|
||||
if not batch:
|
||||
break
|
||||
for c in batch:
|
||||
candle_dict = {
|
||||
'timestamp': c[0],
|
||||
'open': c[1],
|
||||
'high': c[2],
|
||||
'low': c[3],
|
||||
'close': c[4],
|
||||
'volume': c[5]
|
||||
}
|
||||
candles.append(candle_dict)
|
||||
last_timestamp = batch[-1][0]
|
||||
if last_timestamp >= end_time:
|
||||
break
|
||||
since_ms = last_timestamp + 1
|
||||
print(f"Fetched {len(candles)} candles for timeframe {timeframe}.")
|
||||
|
||||
return candles
|
||||
|
||||
async def get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time, cache_file=CACHE_FILE, batch_size=500):
|
||||
cached_candles = load_candles_cache(cache_file)
|
||||
|
||||
if cached_candles and timeframe in cached_candles:
|
||||
last_ts = cached_candles[timeframe][-1]['timestamp']
|
||||
if last_ts < end_time:
|
||||
print("Fetching new candles to update cache...")
|
||||
new_candles = await fetch_historical_data(exchange, symbol, timeframe, last_ts + 1, end_time, batch_size)
|
||||
cached_candles[timeframe].extend(new_candles)
|
||||
else:
|
||||
print("Cache covers the requested period.")
|
||||
return cached_candles[timeframe]
|
||||
else:
|
||||
candles = await fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size)
|
||||
return candles
|
||||
|
||||
# -------------------------------
|
||||
# Cache and Training Cache Helpers
|
||||
# -------------------------------
|
||||
def load_candles_cache(filename):
|
||||
if os.path.exists(filename):
|
||||
try:
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
print(f"Loaded cached data from {filename}.")
|
||||
return data
|
||||
except Exception as e:
|
||||
print("Error reading cache file:", e)
|
||||
return {} # Return empty dict if no cache
|
||||
|
||||
|
||||
def save_candles_cache(filename, candles_dict):
|
||||
try:
|
||||
with open(filename, "w") as f:
|
||||
json.dump(candles_dict, f)
|
||||
except Exception as e:
|
||||
print("Error saving cache file:", e)
|
||||
|
||||
def load_training_cache(filename):
|
||||
if os.path.exists(filename):
|
||||
try:
|
||||
with open(filename, "r") as f:
|
||||
cache = json.load(f)
|
||||
print(f"Loaded training cache from {filename}.")
|
||||
return cache
|
||||
except Exception as e:
|
||||
print("Error loading training cache:", e)
|
||||
return {"total_pnl": 0.0} # Initialize if not found
|
||||
|
||||
|
||||
def save_training_cache(filename, cache):
|
||||
try:
|
||||
with open(filename, "w") as f:
|
||||
json.dump(cache, f)
|
||||
except Exception as e:
|
||||
print("Error saving training cache:", e)
|
||||
|
||||
# -------------------------------
|
||||
# Checkpoint Functions
|
||||
# -------------------------------
|
||||
LAST_DIR = os.path.join("models", "last")
|
||||
BEST_DIR = os.path.join("models", "best")
|
||||
os.makedirs(LAST_DIR, exist_ok=True)
|
||||
os.makedirs(BEST_DIR, exist_ok=True)
|
||||
|
||||
def maintain_checkpoint_directory(directory, max_files=10):
|
||||
files = os.listdir(directory)
|
||||
if len(files) > max_files:
|
||||
full_paths = [os.path.join(directory, f) for f in files]
|
||||
full_paths.sort(key=lambda x: os.path.getmtime(x))
|
||||
for f in full_paths[:len(files) - max_files]:
|
||||
os.remove(f)
|
||||
|
||||
def get_best_models(directory):
|
||||
best_files = []
|
||||
for file in os.listdir(directory):
|
||||
parts = file.split("_")
|
||||
try:
|
||||
loss = float(parts[1]) # Get loss from filename
|
||||
best_files.append((loss, file))
|
||||
except Exception:
|
||||
continue
|
||||
return best_files
|
||||
|
||||
def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=BEST_DIR):
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
|
||||
last_path = os.path.join(last_dir, last_filename)
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"loss": loss,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict()
|
||||
}, last_path)
|
||||
maintain_checkpoint_directory(last_dir, max_files=10)
|
||||
|
||||
best_models = get_best_models(best_dir)
|
||||
add_to_best = False
|
||||
if len(best_models) < 10:
|
||||
add_to_best = True
|
||||
else:
|
||||
worst_loss, worst_file = max(best_models, key=lambda x: x[0])
|
||||
if loss < worst_loss: # Save if better than worst
|
||||
add_to_best = True
|
||||
os.remove(os.path.join(best_dir, worst_file)) # Remove worst
|
||||
|
||||
if add_to_best:
|
||||
best_filename = f"best_{loss:.4f}_epoch_{epoch}_{timestamp}.pt" # Include loss in name
|
||||
best_path = os.path.join(best_dir, best_filename)
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"loss": loss,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict()
|
||||
}, best_path)
|
||||
maintain_checkpoint_directory(best_dir, max_files=10)
|
||||
|
||||
print(f"Saved checkpoint for epoch {epoch} with loss {loss:.4f}")
|
||||
|
||||
def load_best_checkpoint(model, best_dir=BEST_DIR):
|
||||
best_models = get_best_models(best_dir)
|
||||
if not best_models:
|
||||
return None
|
||||
best_loss, best_file = min(best_models, key=lambda x: x[0]) # Load best (lowest loss)
|
||||
path = os.path.join(best_dir, best_file)
|
||||
print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}")
|
||||
checkpoint = torch.load(path)
|
||||
|
||||
# Handle potential embedding size mismatch
|
||||
old_state = checkpoint["model_state_dict"]
|
||||
new_state = model.state_dict()
|
||||
|
||||
if "timeframe_embed.weight" in old_state:
|
||||
old_embed = old_state["timeframe_embed.weight"]
|
||||
new_embed = new_state["timeframe_embed.weight"]
|
||||
|
||||
if old_embed.shape[0] < new_embed.shape[0]:
|
||||
# Copy old embeddings to the new embedding, handling size increase
|
||||
new_embed[:old_embed.shape[0]] = old_embed
|
||||
old_state["timeframe_embed.weight"] = new_embed
|
||||
|
||||
model.load_state_dict(old_state, strict=False) # Allow for size differences
|
||||
return checkpoint
|
||||
|
||||
# -------------------------------
|
||||
# Positional Encoding and Transformer-Based Model
|
||||
# -------------------------------
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
position = torch.arange(max_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
||||
pe = torch.zeros(max_len, 1, d_model)
|
||||
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: Tensor, shape [seq_len, batch_size, embedding_dim]
|
||||
"""
|
||||
x = x + self.pe[:x.size(0)]
|
||||
return self.dropout(x)
|
||||
|
||||
class TradingModel(nn.Module):
|
||||
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
|
||||
super().__init__()
|
||||
self.channel_branches = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.1)
|
||||
) for _ in range(num_channels)
|
||||
])
|
||||
# Embedding for each timeframe
|
||||
self.timeframe_embed = nn.Embedding(num_timeframes, hidden_dim)
|
||||
self.pos_encoder = PositionalEncoding(hidden_dim)
|
||||
# Increased number of layers and heads for larger model
|
||||
encoder_layers = TransformerEncoderLayer(
|
||||
d_model=hidden_dim, nhead=8, dim_feedforward=2048, # Increased nhead and dim_feedforward
|
||||
dropout=0.1, activation='gelu', batch_first=True
|
||||
)
|
||||
self.transformer = TransformerEncoder(encoder_layers, num_layers=6) # More layers
|
||||
|
||||
# Attention pooling to aggregate channel outputs
|
||||
self.attn_pool = nn.Linear(hidden_dim, 1)
|
||||
|
||||
# Separate prediction heads for high and low
|
||||
self.high_pred = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim // 2, 1)
|
||||
)
|
||||
self.low_pred = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim // 2, 1)
|
||||
)
|
||||
|
||||
def forward(self, x, timeframe_ids):
|
||||
# x shape: (batch, num_channels, features_per_channel)
|
||||
batch_size, num_channels, _ = x.shape
|
||||
channel_outs = []
|
||||
|
||||
# Process each channel through its branch
|
||||
for i in range(num_channels):
|
||||
channel_out = self.channel_branches[i](x[:, i, :])
|
||||
channel_outs.append(channel_out)
|
||||
|
||||
# Stack channel outputs
|
||||
stacked = torch.stack(channel_outs, dim=1) # (batch, num_channels, hidden_dim)
|
||||
|
||||
# Add timeframe embeddings
|
||||
tf_embeds = self.timeframe_embed(timeframe_ids) # (num_timeframes, hidden_dim)
|
||||
stacked = stacked + tf_embeds.unsqueeze(0) # Add to each item in batch
|
||||
|
||||
# Transformer
|
||||
transformer_out = self.transformer(stacked) # (batch, num_channels, hidden_dim)
|
||||
|
||||
# Attention pooling
|
||||
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=1) # (batch, num_channels, 1)
|
||||
aggregated = (transformer_out * attn_weights).sum(dim=1) # (batch, hidden_dim)
|
||||
|
||||
# Predict high and low
|
||||
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
|
||||
|
||||
# -------------------------------
|
||||
# Technical Indicator Helpers
|
||||
# -------------------------------
|
||||
def compute_sma(candles_list, index, period=10):
|
||||
start = max(0, index - period + 1)
|
||||
values = [candle["close"] for candle in candles_list[start:index + 1]]
|
||||
return sum(values) / len(values) if values else 0.0
|
||||
|
||||
def compute_sma_volume(candles_list, index, period=10):
|
||||
start = max(0, index - period + 1)
|
||||
values = [candle["volume"] for candle in candles_list[start:index + 1]]
|
||||
return sum(values) / len(values) if values else 0.0
|
||||
|
||||
def get_aligned_candle_with_index(candles_list, target_ts):
|
||||
"""Find the candle in the list whose timestamp is the largest that is <= target_ts."""
|
||||
best_idx = 0
|
||||
for i, candle in enumerate(candles_list):
|
||||
if candle["timestamp"] <= target_ts:
|
||||
best_idx = i
|
||||
else:
|
||||
break # Stop once we go past the target
|
||||
return best_idx, candles_list[best_idx]
|
||||
|
||||
def get_features_for_tf(candles_list, index, period=10):
|
||||
"""Return a vector of 7 features: open, high, low, close, volume, sma_close, sma_volume."""
|
||||
candle = candles_list[index]
|
||||
f_open = candle["open"]
|
||||
f_high = candle["high"]
|
||||
f_low = candle["low"]
|
||||
f_close = candle["close"]
|
||||
f_volume = candle["volume"]
|
||||
sma_close = compute_sma(candles_list, index, period)
|
||||
sma_volume = compute_sma_volume(candles_list, index, period)
|
||||
return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume]
|
||||
|
||||
# -------------------------------
|
||||
# Backtest Environment Class
|
||||
# -------------------------------
|
||||
class BacktestEnvironment:
|
||||
def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
|
||||
self.candles_dict = candles_dict
|
||||
self.base_tf = base_tf
|
||||
self.timeframes = timeframes
|
||||
self.full_candles = candles_dict[base_tf] # All candles for base timeframe
|
||||
|
||||
# Define window size (or use a reasonable default if not provided)
|
||||
if window_size is None:
|
||||
window_size = 100 if len(self.full_candles) >= 100 else len(self.full_candles) # Use 100 or total length
|
||||
self.window_size = window_size
|
||||
|
||||
self.reset() # Initialize
|
||||
|
||||
def reset(self):
|
||||
# Randomly select a starting point for the window
|
||||
self.start_index = random.randint(0, len(self.full_candles) - self.window_size)
|
||||
self.candle_window = self.full_candles[self.start_index:self.start_index + self.window_size]
|
||||
self.current_index = 0
|
||||
self.trade_history = []
|
||||
self.position = None # Track if we're in a trade: None, or {"entry_price": ..., "entry_index": ...}
|
||||
return self.get_state(self.current_index) # Return initial state
|
||||
|
||||
def __len__(self):
|
||||
return self.window_size # Length of the environment is the window size
|
||||
|
||||
def get_order_features(self, index):
|
||||
"""Get features related to the current order (if any)."""
|
||||
candle = self.candle_window[index]
|
||||
if self.position is None:
|
||||
# No position: all zeros
|
||||
return [0.0] * FEATURES_PER_CHANNEL # 7 zeros
|
||||
else:
|
||||
# In a position: [1.0, price_diff, 0, 0, 0, 0, 0]
|
||||
flag = 1.0
|
||||
diff = (candle["open"] - self.position["entry_price"]) / candle["open"] # Relative difference
|
||||
return [flag, diff] + [0.0] * (FEATURES_PER_CHANNEL - 2)
|
||||
|
||||
def get_state(self, index):
|
||||
"""Construct state for the given index."""
|
||||
state_features = []
|
||||
base_ts = self.candle_window[index]["timestamp"]
|
||||
|
||||
# Get features for each timeframe
|
||||
for tf in self.timeframes:
|
||||
if tf == self.base_tf:
|
||||
# For the base timeframe, use the candle directly from the window
|
||||
candle = self.candle_window[index]
|
||||
features = get_features_for_tf([candle], 0) # Pass as a list with single candle
|
||||
else:
|
||||
# For other timeframes, align with the base timestamp
|
||||
aligned_idx, _ = get_aligned_candle_with_index(self.candles_dict[tf], base_ts)
|
||||
features = get_features_for_tf(self.candles_dict[tf], aligned_idx)
|
||||
state_features.append(features)
|
||||
|
||||
# Add order features
|
||||
order_features = self.get_order_features(index)
|
||||
state_features.append(order_features)
|
||||
|
||||
# Add placeholder channels for additional indicators (if needed)
|
||||
for _ in range(NUM_INDICATORS):
|
||||
state_features.append([0.0] * FEATURES_PER_CHANNEL)
|
||||
|
||||
|
||||
return np.array(state_features, dtype=np.float32)
|
||||
|
||||
def step(self, action):
|
||||
"""Take a step, given an action."""
|
||||
base = self.candle_window # Shorter name
|
||||
if self.current_index >= len(base) - 1:
|
||||
current_state = self.get_state(self.current_index)
|
||||
return current_state, 0.0, None, True, 0.0, 0.0 # No reward at very end, and done
|
||||
|
||||
current_state = self.get_state(self.current_index)
|
||||
next_index = self.current_index + 1
|
||||
next_state = self.get_state(next_index)
|
||||
next_candle = base[next_index] # Next candle for open, high, low
|
||||
|
||||
reward = 0.0
|
||||
|
||||
# Handle actions (simplified for clarity)
|
||||
if self.position is None:
|
||||
if action == 2: # BUY
|
||||
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
|
||||
else:
|
||||
if action == 0: # SELL
|
||||
exit_price = next_candle["close"]
|
||||
reward = exit_price - self.position["entry_price"] # PnL is reward
|
||||
trade = {
|
||||
"entry_index": self.position["entry_index"],
|
||||
"entry_price": self.position["entry_price"],
|
||||
"exit_index": next_index,
|
||||
"exit_price": exit_price,
|
||||
"pnl": reward
|
||||
}
|
||||
self.trade_history.append(trade)
|
||||
self.position = None # Exit position
|
||||
|
||||
self.current_index = next_index
|
||||
done = (self.current_index >= len(base) - 1) # Done if at end of window
|
||||
actual_high = next_candle["high"]
|
||||
actual_low = next_candle["low"]
|
||||
return current_state, reward, next_state, done, actual_high, actual_low # Return next high/low
|
||||
|
||||
# -------------------------------
|
||||
# Enhanced Training Loop
|
||||
# -------------------------------
|
||||
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
|
||||
|
||||
lambda_trade = args.lambda_trade # Weight for trade loss
|
||||
|
||||
# Load training cache (for total PnL tracking)
|
||||
training_cache = load_training_cache(TRAINING_CACHE_FILE)
|
||||
total_pnl = training_cache.get("total_pnl", 0.0)
|
||||
|
||||
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
env.reset() # Reset environment for each epoch
|
||||
loss_accum = 0.0
|
||||
steps = len(env) - 1 # Number of steps in the episode
|
||||
|
||||
for i in range(steps): # Iterate through the episode
|
||||
state = env.get_state(i)
|
||||
current_open = env.candle_window[i]["open"] # Current candle's open
|
||||
actual_high = env.candle_window[i + 1]["high"] # Next candle's high
|
||||
actual_low = env.candle_window[i + 1]["low"] # Next candle's low
|
||||
|
||||
# Forward pass
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) # Add batch dimension
|
||||
timeframe_ids = torch.arange(state.shape[0]).to(device) # Create timeframe IDs
|
||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||
|
||||
# Calculate prediction loss (L_pred)
|
||||
L_pred = torch.abs(pred_high - torch.tensor(actual_high, device=device)) + \
|
||||
torch.abs(pred_low - torch.tensor(actual_low, device=device))
|
||||
|
||||
# Calculate trade surrogate loss (L_trade)
|
||||
profit_buy = pred_high - current_open # Potential profit if buying
|
||||
profit_sell = current_open - pred_low # Potential profit if selling
|
||||
L_trade = - torch.max(profit_buy, profit_sell) # Minimize negative profit
|
||||
|
||||
# Calculate no-action penalty (encourage taking action when profitable)
|
||||
current_open_tensor = torch.tensor(current_open, device=device)
|
||||
signal_strength = torch.max(pred_high - current_open_tensor, current_open_tensor - pred_low)
|
||||
penalty_term = args.penalty_noaction * torch.clamp(args.threshold - signal_strength, min=0)
|
||||
|
||||
# Total loss
|
||||
loss = L_pred + lambda_trade * L_trade + penalty_term
|
||||
|
||||
# Backpropagation
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Gradient clipping
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
loss_accum += loss.item()
|
||||
|
||||
epoch_loss = loss_accum / steps # Average loss per step
|
||||
if len(env.trade_history) == 0:
|
||||
epoch_loss *= 3
|
||||
epoch_pnl = sum(trade["pnl"] for trade in env.trade_history) # PnL for the epoch
|
||||
total_pnl += epoch_pnl # Update total PnL
|
||||
|
||||
print(f"Epoch {epoch + 1} Loss: {epoch_loss:.4f} | Epoch PnL: {epoch_pnl:.2f} | Total PnL: {total_pnl:.2f}")
|
||||
|
||||
save_checkpoint(model, optimizer, epoch, loss_accum) # Save with accumulated loss
|
||||
simulate_trades(model, env, device, args) # Simulate trades after each epoch
|
||||
update_live_html(env.candle_window, env.trade_history, epoch + 1, epoch_loss, total_pnl) # Update HTML visualization
|
||||
|
||||
# Save training cache (for total PnL tracking)
|
||||
training_cache["total_pnl"] = total_pnl
|
||||
save_training_cache(TRAINING_CACHE_FILE, training_cache)
|
||||
|
||||
# -------------------------------
|
||||
# Live Plotting (for Live Mode)
|
||||
# -------------------------------
|
||||
def live_preview_loop(candles, env):
|
||||
plt.ion() # Interactive mode
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
|
||||
while True:
|
||||
update_live_chart(ax, candles, env.trade_history)
|
||||
plt.draw()
|
||||
plt.pause(1) # Update every second
|
||||
|
||||
# -------------------------------
|
||||
# Live HTML Chart Update (with Volume and Loss)
|
||||
# -------------------------------
|
||||
def update_live_html(candles, trade_history, epoch, loss, total_pnl):
|
||||
from io import BytesIO
|
||||
import base64
|
||||
# Create a new figure and axes for each update
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
|
||||
# Draw the chart
|
||||
update_live_chart(ax, candles, trade_history)
|
||||
|
||||
epoch_pnl = sum(trade["pnl"] for trade in trade_history) # PnL for this window
|
||||
ax.set_title(f"Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}")
|
||||
|
||||
# Save the plot to a BytesIO buffer
|
||||
buf = BytesIO()
|
||||
fig.savefig(buf, format='png')
|
||||
plt.close(fig) # Close the figure to free memory
|
||||
buf.seek(0)
|
||||
image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
|
||||
|
||||
# Generate HTML content
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta http-equiv="refresh" content="1"> <!-- Refresh every second -->
|
||||
<title>Live Trading Chart - Epoch {epoch}</title>
|
||||
<style>
|
||||
body {{
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
background-color: #f4f4f4;
|
||||
}}
|
||||
.chart-container {{
|
||||
text-align: center;
|
||||
}}
|
||||
img {{
|
||||
max-width: 100%;
|
||||
height: auto;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="chart-container">
|
||||
<h2>Epoch {epoch} | Loss: {loss:.4f} | PnL: {epoch_pnl:.2f}| Total PnL: {total_pnl:.2f}</h2>
|
||||
<img src="data:image/png;base64,{image_base64}" alt="Live Chart"/>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
with open("live_chart.html", "w") as f:
|
||||
f.write(html_content)
|
||||
print("Updated live_chart.html.")
|
||||
|
||||
# -------------------------------
|
||||
# Chart Drawing Helpers (with Volume and Date+Time)
|
||||
# -------------------------------
|
||||
def update_live_chart(ax, candles, trade_history):
|
||||
ax.clear() # Clear previous data
|
||||
|
||||
# Extract data for plotting
|
||||
times = [convert_timestamp(candle["timestamp"]) for candle in candles]
|
||||
close_prices = [candle["close"] for candle in candles]
|
||||
|
||||
# Plot close prices
|
||||
ax.plot(times, close_prices, label="Close Price", color="black", linewidth=1)
|
||||
ax.set_xlabel("Time")
|
||||
ax.set_ylabel("Price")
|
||||
|
||||
# Format x-axis to show dates and times
|
||||
ax.xaxis.set_major_formatter(mdates.DateFormatter('%m-%d %H:%M'))
|
||||
|
||||
# Create a second y-axis for volume
|
||||
ax2 = ax.twinx()
|
||||
volumes = [candle["volume"] for candle in candles]
|
||||
if len(times) > 1:
|
||||
times_num = mdates.date2num(times)
|
||||
bar_width = (times_num[-1] - times_num[0]) / len(times) * 0.8 # Relative width
|
||||
else:
|
||||
bar_width = 0.01
|
||||
|
||||
ax2.bar(times, volumes, width=bar_width, alpha=0.3, color="grey", label="Volume")
|
||||
ax2.set_ylabel("Volume")
|
||||
|
||||
# Plot trade markers (buy/sell)
|
||||
for trade in trade_history:
|
||||
entry_time = convert_timestamp(candles[trade["entry_index"]]["timestamp"])
|
||||
exit_time = convert_timestamp(candles[trade["exit_index"]]["timestamp"])
|
||||
in_price = trade["entry_price"]
|
||||
out_price = trade["exit_price"]
|
||||
ax.plot(entry_time, in_price, marker="^", color="green", markersize=10, label="BUY") # Buy marker
|
||||
ax.plot(exit_time, out_price, marker="v", color="red", markersize=10, label="SELL") # Sell marker
|
||||
ax.plot([entry_time, exit_time], [in_price, out_price], linestyle="dotted", color="blue") # Dotted line
|
||||
|
||||
# Combine legends from both axes
|
||||
lines, labels = ax.get_legend_handles_labels()
|
||||
lines2, labels2 = ax2.get_legend_handles_labels()
|
||||
ax.legend(lines + lines2, labels + labels2)
|
||||
ax.grid(True)
|
||||
|
||||
# Auto-format x-axis labels for better readability
|
||||
fig = ax.get_figure()
|
||||
fig.autofmt_xdate()
|
||||
|
||||
# -------------------------------
|
||||
# Global Constants for Features
|
||||
# -------------------------------
|
||||
NUM_INDICATORS = 20 # Number of additional indicator channels
|
||||
FEATURES_PER_CHANNEL = 7
|
||||
ORDER_CHANNELS = 1
|
||||
|
||||
# -------------------------------
|
||||
# General Simulation of Trades Function
|
||||
# -------------------------------
|
||||
def simulate_trades(model, env, device, args):
|
||||
if args.main_tf == "1s":
|
||||
simulate_trades_1s(env)
|
||||
return
|
||||
env.reset() # Reset to a random starting point
|
||||
while True:
|
||||
i = env.current_index
|
||||
if i >= len(env.candle_window) - 1:
|
||||
break # Exit if at the end of the window
|
||||
|
||||
state = env.get_state(i)
|
||||
current_open = env.candle_window[i]["open"] # Get current open
|
||||
|
||||
# Make predictions
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) # Add batch dimension
|
||||
timeframe_ids = torch.arange(state.shape[0]).to(device) # IDs for timeframes
|
||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||
pred_high = pred_high.item() # Convert to Python number
|
||||
pred_low = pred_low.item()
|
||||
|
||||
# Decide on action (simplified for clarity)
|
||||
if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
|
||||
if (pred_high - current_open) >= (current_open - pred_low):
|
||||
action = 2 # BUY
|
||||
else:
|
||||
action = 0 # SELL
|
||||
_, _, _, done, _, _ = env.step(action) # Take the step
|
||||
else:
|
||||
manual_trade(env)
|
||||
if env.current_index >= len(env.candle_window) - 1:
|
||||
break
|
||||
|
||||
def simulate_trades_1s(env):
|
||||
# Ensure main_tf is 1s
|
||||
if env.base_tf != "1s":
|
||||
raise ValueError("simulate_trades_1s can only be used with base_tf='1s'")
|
||||
|
||||
env.reset()
|
||||
|
||||
current_second = env.candle_window[env.current_index]['timestamp'] // 1000
|
||||
|
||||
# Simulate trading for the entire window
|
||||
while True:
|
||||
|
||||
if env.current_index >= len(env.candle_window) - 1:
|
||||
break # Break if end of the window is reached.
|
||||
|
||||
# Check if a new second has started
|
||||
next_second = env.candle_window[env.current_index]['timestamp'] // 1000
|
||||
|
||||
if next_second != current_second: # A new second has started. Take random action (buy/sell/hold).
|
||||
action = random_action()
|
||||
_, _, _, done, _, _ = env.step(action) # Take the step
|
||||
current_second = next_second # Update the current second
|
||||
|
||||
else:
|
||||
# Still the same second, hold position
|
||||
manual_trade(env)
|
||||
|
||||
if env.current_index >= len(env.candle_window) - 1:
|
||||
break
|
||||
|
||||
def manual_trade(env):
|
||||
#If we are in a position, hold it. Otherwise, do nothing
|
||||
if env.position is not None:
|
||||
#In position, take 'HOLD' action implicitly by doing nothing
|
||||
pass
|
||||
|
||||
# -------------------------------
|
||||
# Argument Parsing
|
||||
# -------------------------------
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train',
|
||||
help="Operating mode: 'train', 'live', or 'inference'.")
|
||||
parser.add_argument('--epochs', type=int, default=1000,
|
||||
help="Number of training epochs.")
|
||||
parser.add_argument('--lr', type=float, default=3e-4,
|
||||
help="Learning rate.")
|
||||
parser.add_argument('--threshold', type=float, default=0.005,
|
||||
help="Minimum predicted move to trigger trade (used in loss; model may override manual trades).")
|
||||
parser.add_argument('--lambda_trade', type=float, default=1.0,
|
||||
help="Weight for the trade surrogate loss.")
|
||||
parser.add_argument('--penalty_noaction', type=float, default=10.0,
|
||||
help="Penalty if no action is taken (used in loss).")
|
||||
parser.add_argument('--start_fresh', action='store_true',
|
||||
help="Start training from scratch.")
|
||||
parser.add_argument('--main_tf', type=str, default='1m',
|
||||
help="Desired main timeframe to focus on (e.g., '1s' or '1m').")
|
||||
# Instead of --fetch, we now provide a --no-fetch flag that will override the default behavior.
|
||||
parser.add_argument('--no-fetch', dest='fetch', action='store_false',
|
||||
help="Do NOT fetch fresh data from exchange on start.")
|
||||
parser.set_defaults(fetch=True)
|
||||
parser.add_argument('--symbol', type=str, default='BTC/USDT', help="Trading pair symbol.")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def random_action():
|
||||
return random.randint(0, 2) # 0: SELL, 1: HOLD, 2: BUY
|
||||
|
||||
# -------------------------------
|
||||
# Main Function
|
||||
# -------------------------------
|
||||
async def main():
|
||||
args = parse_args()
|
||||
|
||||
# Use GPU if available
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print("Using device:", device)
|
||||
|
||||
# Fetch data (if not --no-fetch)
|
||||
# With fetch defaulting to True, live mode will always try to top-up the cache.
|
||||
if args.fetch:
|
||||
import ccxt.async_support as ccxt
|
||||
exchange = ccxt.binance({'enableRateLimit': True}) # Use Binance
|
||||
now_ms = int(time.time() * 1000)
|
||||
|
||||
# Check if we have cached data. If so, only fetch what we need to update the cache.
|
||||
cached = load_candles_cache(CACHE_FILE)
|
||||
if cached and args.main_tf in cached and len(cached[args.main_tf]) > 0:
|
||||
last_ts = cached[args.main_tf][-1]['timestamp']
|
||||
since = last_ts + 1
|
||||
else:
|
||||
# Fetch a reasonable amount of historical data initially (e.g., last 2 days)
|
||||
since = now_ms - 2 * 24 * 60 * 60 * 1000
|
||||
|
||||
print(f"Fetching fresh data for {args.symbol} on timeframe {args.main_tf} from {since} to {now_ms}...")
|
||||
|
||||
fresh_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf, since, now_ms)
|
||||
candles_dict = {args.main_tf: fresh_candles} # Initially, only main timeframe
|
||||
|
||||
# Save (or update) the cache
|
||||
save_candles_cache(CACHE_FILE, candles_dict)
|
||||
await exchange.close()
|
||||
else:
|
||||
# Load from cache
|
||||
candles_dict = load_candles_cache(CACHE_FILE)
|
||||
if not candles_dict:
|
||||
print("No cached data available. Run without --no-fetch (default) to load fresh data from the exchange.")
|
||||
return
|
||||
|
||||
# --- Timeframe Setup ---
|
||||
default_timeframes = ["1s", "1m", "5m", "15m", "1h", "1d"] # All supported timeframes
|
||||
timeframes = [tf for tf in default_timeframes if tf in candles_dict] # Use available timeframes
|
||||
if args.main_tf not in timeframes:
|
||||
print(f"Desired main timeframe {args.main_tf} is not available. Available: {timeframes}")
|
||||
return
|
||||
base_tf = args.main_tf
|
||||
|
||||
# --- Model Initialization ---
|
||||
hidden_dim = 128 # Hidden dimension for the Transformer
|
||||
total_channels = len(timeframes) + ORDER_CHANNELS + NUM_INDICATORS # Input channels
|
||||
model = TradingModel(total_channels, len(timeframes)).to(device)
|
||||
|
||||
if args.mode == 'train':
|
||||
# --- Training Setup ---
|
||||
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
|
||||
start_epoch = 0
|
||||
checkpoint = None
|
||||
|
||||
# Load checkpoint (if not starting fresh)
|
||||
if not args.start_fresh:
|
||||
checkpoint = load_best_checkpoint(model)
|
||||
if checkpoint is not None:
|
||||
start_epoch = checkpoint.get("epoch", 0) + 1 # Start from next epoch
|
||||
print(f"Resuming training from epoch {start_epoch}.")
|
||||
else:
|
||||
print("No checkpoint found. Starting training from scratch.")
|
||||
else:
|
||||
print("Starting training from scratch as requested.")
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) # AdamW optimizer
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch) # Cosine annealing
|
||||
|
||||
# Load optimizer state (if checkpoint exists)
|
||||
if checkpoint is not None:
|
||||
optim_state = checkpoint.get("optimizer_state_dict", None)
|
||||
if optim_state is not None and "param_groups" in optim_state:
|
||||
try:
|
||||
optimizer.load_state_dict(optim_state)
|
||||
print("Loaded optimizer state from checkpoint.")
|
||||
except Exception as e:
|
||||
print("Failed to load optimizer state due to:", e)
|
||||
print("Deleting all checkpoints and starting fresh.")
|
||||
for chk_dir in [LAST_DIR, BEST_DIR]:
|
||||
for f in os.listdir(chk_dir):
|
||||
os.remove(os.path.join(chk_dir, f))
|
||||
else:
|
||||
print("No valid optimizer state found; using fresh optimizer state.")
|
||||
|
||||
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
|
||||
|
||||
elif args.mode == 'live':
|
||||
import ccxt.async_support as ccxt
|
||||
exchange = ccxt.binance({'enableRateLimit': True}) # Use Binance
|
||||
POLL_INTERVAL = 60 # seconds
|
||||
|
||||
async def update_live_candles():
|
||||
nonlocal exchange, args, candles_dict
|
||||
|
||||
while True:
|
||||
now_ms = int(time.time() * 1000)
|
||||
# Fetch just the most recent candles
|
||||
new_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf,
|
||||
since=now_ms - 2 * 60 * 1000, end_time=now_ms) # Fetch last 2 mins
|
||||
if args.main_tf in candles_dict:
|
||||
candles_dict[args.main_tf] = new_candles # Update
|
||||
else:
|
||||
candles_dict[args.main_tf] = new_candles # Or add if not present
|
||||
|
||||
print("Live candles updated.")
|
||||
await asyncio.sleep(POLL_INTERVAL)
|
||||
|
||||
# Start the candle update task
|
||||
asyncio.create_task(update_live_candles())
|
||||
|
||||
load_best_checkpoint(model) # Load the best model for live trading
|
||||
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
|
||||
|
||||
# Start live preview (optional)
|
||||
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
|
||||
preview_thread.start()
|
||||
|
||||
print("Starting live trading loop. (Using live updated data now.)")
|
||||
while True:
|
||||
if args.main_tf == "1s":
|
||||
simulate_trades_1s(env) # Run 1s trading loop
|
||||
else:
|
||||
# Get the current state
|
||||
state = env.get_state(env.current_index)
|
||||
current_open = env.candle_window[env.current_index]["open"] # Get current open
|
||||
|
||||
# Make predictions
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||
timeframe_ids = torch.arange(state.shape[0]).to(device) # Create timeframe IDs
|
||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||
pred_high = pred_high.item() # Convert to Python number
|
||||
pred_low = pred_low.item()
|
||||
|
||||
|
||||
# Decide on action
|
||||
if (pred_high - current_open) > args.threshold or (current_open - pred_low) > args.threshold:
|
||||
if (pred_high - current_open) >= (current_open - pred_low):
|
||||
action = 2 # BUY
|
||||
else:
|
||||
action = 0 # SELL
|
||||
_, _, _, done, _, _ = env.step(action) # Take the step
|
||||
else:
|
||||
manual_trade(env)
|
||||
|
||||
if env.current_index >= len(env.candle_window) - 1:
|
||||
print("Reached end of simulation window; resetting environment.")
|
||||
env.reset() # Reset environment when reaching end
|
||||
|
||||
await asyncio.sleep(1) # Short delay for live mode
|
||||
|
||||
|
||||
elif args.mode == 'inference':
|
||||
load_best_checkpoint(model) # Load the best model for inference
|
||||
print("Running inference...")
|
||||
# Add inference logic here (e.g., load data, make predictions, print results)
|
||||
else:
|
||||
print("Invalid mode specified.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@ -5,6 +5,15 @@ run: >conda activate gpt-gpu
|
||||
python .\index.py
|
||||
|
||||
|
||||
conda create --name my_trading_env python=3.9
|
||||
conda activate my_trading_env
|
||||
pip install -r requirements.txt
|
||||
|
||||
|
||||
# --------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
Usage:
|
||||
Run the script with a command-line argument — for example:
|
||||
• python index-deep-new.py --mode train
|
||||
@ -52,4 +61,10 @@ existing (running but unfinished ) code:
|
||||
implement these suggestions into our code and add arguments for easy switching of modes:
|
||||
- train (only): pool latest data and use it for backtesting with RL to learn to detect peaks/valleys
|
||||
- live: load best checkpoint and latest HLOCv data to actively generate trade signals, but calculate and back propagate errors when closing positions. optimize for profit in the reward function
|
||||
- inference: optimize model loading for inference only - load historical data and periodically append new live data and generate siganls but without active RL
|
||||
- inference: optimize model loading for inference only - load historical data and periodically append new live data and generate siganls but without active RL
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
does this code implement retrospective RL (when we detect with live data the valley/top sequence we should do a parameter optimisation (back propagation or other modern technique) with the best move we should have taken in retrospective). this will allow us to learn and become optimized over time. we should also incorporate MoE architecture so we will have one expert specialized at fastly tuning to the current data stream/market conditions.
|
7
crypto/brian/requirements.txt
Normal file
7
crypto/brian/requirements.txt
Normal file
@ -0,0 +1,7 @@
|
||||
ccxt==4.1.97
|
||||
numpy==1.26.3
|
||||
torch==2.1.2
|
||||
torchaudio==2.1.2
|
||||
torchvision==0.16.2
|
||||
matplotlib==3.8.2
|
||||
python-dotenv==1.0.0
|
Reference in New Issue
Block a user