This commit is contained in:
Dobromir Popov 2025-02-05 11:39:41 +02:00
parent 0e7997d50a
commit 79707c997c
3 changed files with 37 additions and 44 deletions

File diff suppressed because one or more lines are too long

View File

@ -22,21 +22,14 @@ import matplotlib.dates as mdates
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
# Define global constants FIRST. # Define global constants FIRST.
CACHE_FILE = "candles_cache.json" CACHE_FILE = "candles_cache.json"
TRAINING_CACHE_FILE = "training_cache.json" TRAINING_CACHE_FILE = "training_cache.json"
# --- Helper Function for Timestamp Conversion --- # --- Helper Function for Timestamp Conversion ---
def convert_timestamp(ts): def convert_timestamp(ts):
"""
Safely converts a timestamp to a datetime object.
If the timestamp is abnormally high (e.g. in milliseconds),
it is divided by 1000.
"""
ts = float(ts) ts = float(ts)
if ts > 1e10: # Likely in milliseconds if ts > 1e10:
ts /= 1000.0 ts /= 1000.0
return datetime.fromtimestamp(ts) return datetime.fromtimestamp(ts)
@ -44,10 +37,6 @@ def convert_timestamp(ts):
# Historical Data Fetching Functions (Using CCXT) # Historical Data Fetching Functions (Using CCXT)
# ------------------------------- # -------------------------------
async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500): async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500):
"""
Fetch historical OHLCV data for the given symbol and timeframe.
"since" and "end_time" are in milliseconds.
"""
candles = [] candles = []
since_ms = since since_ms = since
while True: while True:
@ -90,7 +79,6 @@ async def get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time,
candles = await fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size) candles = await fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size)
return candles return candles
# ------------------------------- # -------------------------------
# Cache and Training Cache Helpers # Cache and Training Cache Helpers
# ------------------------------- # -------------------------------
@ -130,8 +118,6 @@ def save_training_cache(filename, cache):
except Exception as e: except Exception as e:
print("Error saving training cache:", e) print("Error saving training cache:", e)
TRAINING_CACHE_FILE = "training_cache.json"
# ------------------------------- # -------------------------------
# Checkpoint Functions # Checkpoint Functions
# ------------------------------- # -------------------------------
@ -230,7 +216,6 @@ class PositionalEncoding(nn.Module):
class TradingModel(nn.Module): class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128): def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__() super().__init__()
# One branch per channel.
self.channel_branches = nn.ModuleList([ self.channel_branches = nn.ModuleList([
nn.Sequential( nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
@ -309,7 +294,7 @@ def get_features_for_tf(candles_list, index, period=10):
# ------------------------------- # -------------------------------
class BacktestEnvironment: class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes, window_size=None): def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
self.candles_dict = candles_dict # full candles dict across timeframes self.candles_dict = candles_dict
self.base_tf = base_tf self.base_tf = base_tf
self.timeframes = timeframes self.timeframes = timeframes
self.full_candles = candles_dict[base_tf] self.full_candles = candles_dict[base_tf]
@ -361,10 +346,10 @@ class BacktestEnvironment:
next_candle = base[next_index] next_candle = base[next_index]
reward = 0.0 reward = 0.0
if self.position is None: if self.position is None:
if action == 2: # BUY (open long) if action == 2:
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index} self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
else: else:
if action == 0: # SELL (close trade) if action == 0:
exit_price = next_candle["close"] exit_price = next_candle["close"]
reward = exit_price - self.position["entry_price"] reward = exit_price - self.position["entry_price"]
trade = { trade = {
@ -755,11 +740,6 @@ NUM_INDICATORS = 20
FEATURES_PER_CHANNEL = 7 FEATURES_PER_CHANNEL = 7
ORDER_CHANNELS = 1 ORDER_CHANNELS = 1
# -------------------------------
# Backtest Environment with Sliding Window and Order Info (Already Defined Above)
# [See BacktestEnvironment class above]
# -------------------------------
# ------------------------------- # -------------------------------
# General Simulation of Trades Function # General Simulation of Trades Function
# ------------------------------- # -------------------------------
@ -807,7 +787,10 @@ def parse_args():
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.") parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
parser.add_argument('--main_tf', type=str, default='1m', parser.add_argument('--main_tf', type=str, default='1m',
help="Desired main timeframe to focus on (e.g., '1s' or '1m').") help="Desired main timeframe to focus on (e.g., '1s' or '1m').")
parser.add_argument('--fetch', action='store_true', help="Fetch fresh data from exchange on start.") # Instead of --fetch, we now provide a --no-fetch flag that will override the default behavior.
parser.add_argument('--no-fetch', dest='fetch', action='store_false',
help="Do NOT fetch fresh data from exchange on start.")
parser.set_defaults(fetch=True)
parser.add_argument('--symbol', type=str, default='BTC/USDT', help="Trading pair symbol.") parser.add_argument('--symbol', type=str, default='BTC/USDT', help="Trading pair symbol.")
return parser.parse_args() return parser.parse_args()
@ -822,33 +805,28 @@ async def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device) print("Using device:", device)
# If --fetch flag is provided, top-up cached OHLCV data with fresh data from exchange. # With fetch defaulting to True, live mode will always try to top-up the cache.
if args.fetch: if args.fetch:
import ccxt.async_support as ccxt import ccxt.async_support as ccxt
exchange = ccxt.binance({'enableRateLimit': True}) exchange = ccxt.binance({'enableRateLimit': True})
now_ms = int(time.time()*1000) now_ms = int(time.time()*1000)
# Determine default "since" time based on cache.
cached = load_candles_cache(CACHE_FILE) cached = load_candles_cache(CACHE_FILE)
if cached and args.main_tf in cached and len(cached[args.main_tf]) > 0: if cached and args.main_tf in cached and len(cached[args.main_tf]) > 0:
last_ts = cached[args.main_tf][-1]['timestamp'] last_ts = cached[args.main_tf][-1]['timestamp']
since = last_ts + 1 since = last_ts + 1
else: else:
# Default: fetch candles from the last 2 days.
since = now_ms - 2*24*60*60*1000 since = now_ms - 2*24*60*60*1000
# Top-up data for the main timeframe.
print(f"Fetching fresh data for {args.symbol} on timeframe {args.main_tf} from {since} to {now_ms}...") print(f"Fetching fresh data for {args.symbol} on timeframe {args.main_tf} from {since} to {now_ms}...")
fresh_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf, since, now_ms) fresh_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf, since, now_ms)
# Update cache (for simplicity, we store only the main timeframe here).
candles_dict = {args.main_tf: fresh_candles} candles_dict = {args.main_tf: fresh_candles}
save_candles_cache(CACHE_FILE, candles_dict) save_candles_cache(CACHE_FILE, candles_dict)
await exchange.close() await exchange.close()
else: else:
candles_dict = load_candles_cache(CACHE_FILE) candles_dict = load_candles_cache(CACHE_FILE)
if not candles_dict: if not candles_dict:
print("No cached data available. Run with --fetch to load fresh data from the exchange.") print("No cached data available. Run without --no-fetch (default) to load fresh data from the exchange.")
return return
# Define desired timeframes list.
default_timeframes = ["1s", "1m", "5m", "15m", "1h", "1d"] default_timeframes = ["1s", "1m", "5m", "15m", "1h", "1d"]
timeframes = [tf for tf in default_timeframes if tf in candles_dict] timeframes = [tf for tf in default_timeframes if tf in candles_dict]
if args.main_tf not in timeframes: if args.main_tf not in timeframes:
@ -890,11 +868,27 @@ async def main():
print("No valid optimizer state found; using fresh optimizer state.") print("No valid optimizer state found; using fresh optimizer state.")
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler) train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
elif args.mode == 'live': elif args.mode == 'live':
import ccxt.async_support as ccxt
exchange = ccxt.binance({'enableRateLimit': True})
POLL_INTERVAL = 60 # seconds
async def update_live_candles():
nonlocal exchange, args, candles_dict
while True:
now_ms = int(time.time()*1000)
new_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf,
since=now_ms - 2*60*1000, end_time=now_ms)
if args.main_tf in candles_dict:
candles_dict[args.main_tf] = new_candles
else:
candles_dict[args.main_tf] = new_candles
print("Live candles updated.")
await asyncio.sleep(POLL_INTERVAL)
asyncio.create_task(update_live_candles())
load_best_checkpoint(model) load_best_checkpoint(model)
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100) env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True) preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
preview_thread.start() preview_thread.start()
print("Starting live trading loop. (For main_tf={} using manual override if model signal is weak.)".format(args.main_tf)) print("Starting live trading loop. (Using live updated data now.)")
while True: while True:
if args.main_tf == "1s": if args.main_tf == "1s":
simulate_trades_1s(env) simulate_trades_1s(env)
@ -921,7 +915,6 @@ async def main():
elif args.mode == 'inference': elif args.mode == 'inference':
load_best_checkpoint(model) load_best_checkpoint(model)
print("Running inference...") print("Running inference...")
# Inference logic goes here.
else: else:
print("Invalid mode specified.") print("Invalid mode specified.")

File diff suppressed because one or more lines are too long