This commit is contained in:
Dobromir Popov 2025-02-05 11:39:41 +02:00
parent 0e7997d50a
commit 79707c997c
3 changed files with 37 additions and 44 deletions

File diff suppressed because one or more lines are too long

View File

@ -22,21 +22,14 @@ import matplotlib.dates as mdates
from dotenv import load_dotenv
load_dotenv()
# Define global constants FIRST.
CACHE_FILE = "candles_cache.json"
TRAINING_CACHE_FILE = "training_cache.json"
# --- Helper Function for Timestamp Conversion ---
def convert_timestamp(ts):
"""
Safely converts a timestamp to a datetime object.
If the timestamp is abnormally high (e.g. in milliseconds),
it is divided by 1000.
"""
ts = float(ts)
if ts > 1e10: # Likely in milliseconds
if ts > 1e10:
ts /= 1000.0
return datetime.fromtimestamp(ts)
@ -44,10 +37,6 @@ def convert_timestamp(ts):
# Historical Data Fetching Functions (Using CCXT)
# -------------------------------
async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500):
"""
Fetch historical OHLCV data for the given symbol and timeframe.
"since" and "end_time" are in milliseconds.
"""
candles = []
since_ms = since
while True:
@ -90,7 +79,6 @@ async def get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time,
candles = await fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size)
return candles
# -------------------------------
# Cache and Training Cache Helpers
# -------------------------------
@ -130,8 +118,6 @@ def save_training_cache(filename, cache):
except Exception as e:
print("Error saving training cache:", e)
TRAINING_CACHE_FILE = "training_cache.json"
# -------------------------------
# Checkpoint Functions
# -------------------------------
@ -218,7 +204,7 @@ class PositionalEncoding(nn.Module):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0)/d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
@ -230,7 +216,6 @@ class PositionalEncoding(nn.Module):
class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__()
# One branch per channel.
self.channel_branches = nn.ModuleList([
nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
@ -248,14 +233,14 @@ class TradingModel(nn.Module):
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
self.attn_pool = nn.Linear(hidden_dim, 1)
self.high_pred = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.Linear(hidden_dim, hidden_dim//2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 1)
nn.Linear(hidden_dim//2, 1)
)
self.low_pred = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.Linear(hidden_dim, hidden_dim//2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 1)
nn.Linear(hidden_dim//2, 1)
)
def forward(self, x, timeframe_ids):
batch_size, num_channels, _ = x.shape
@ -309,7 +294,7 @@ def get_features_for_tf(candles_list, index, period=10):
# -------------------------------
class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
self.candles_dict = candles_dict # full candles dict across timeframes
self.candles_dict = candles_dict
self.base_tf = base_tf
self.timeframes = timeframes
self.full_candles = candles_dict[base_tf]
@ -361,10 +346,10 @@ class BacktestEnvironment:
next_candle = base[next_index]
reward = 0.0
if self.position is None:
if action == 2: # BUY (open long)
if action == 2:
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
else:
if action == 0: # SELL (close trade)
if action == 0:
exit_price = next_candle["close"]
reward = exit_price - self.position["entry_price"]
trade = {
@ -755,11 +740,6 @@ NUM_INDICATORS = 20
FEATURES_PER_CHANNEL = 7
ORDER_CHANNELS = 1
# -------------------------------
# Backtest Environment with Sliding Window and Order Info (Already Defined Above)
# [See BacktestEnvironment class above]
# -------------------------------
# -------------------------------
# General Simulation of Trades Function
# -------------------------------
@ -807,7 +787,10 @@ def parse_args():
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
parser.add_argument('--main_tf', type=str, default='1m',
help="Desired main timeframe to focus on (e.g., '1s' or '1m').")
parser.add_argument('--fetch', action='store_true', help="Fetch fresh data from exchange on start.")
# Instead of --fetch, we now provide a --no-fetch flag that will override the default behavior.
parser.add_argument('--no-fetch', dest='fetch', action='store_false',
help="Do NOT fetch fresh data from exchange on start.")
parser.set_defaults(fetch=True)
parser.add_argument('--symbol', type=str, default='BTC/USDT', help="Trading pair symbol.")
return parser.parse_args()
@ -822,33 +805,28 @@ async def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# If --fetch flag is provided, top-up cached OHLCV data with fresh data from exchange.
# With fetch defaulting to True, live mode will always try to top-up the cache.
if args.fetch:
import ccxt.async_support as ccxt
exchange = ccxt.binance({'enableRateLimit': True})
now_ms = int(time.time()*1000)
# Determine default "since" time based on cache.
cached = load_candles_cache(CACHE_FILE)
if cached and args.main_tf in cached and len(cached[args.main_tf]) > 0:
last_ts = cached[args.main_tf][-1]['timestamp']
since = last_ts + 1
else:
# Default: fetch candles from the last 2 days.
since = now_ms - 2*24*60*60*1000
# Top-up data for the main timeframe.
print(f"Fetching fresh data for {args.symbol} on timeframe {args.main_tf} from {since} to {now_ms}...")
fresh_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf, since, now_ms)
# Update cache (for simplicity, we store only the main timeframe here).
candles_dict = {args.main_tf: fresh_candles}
save_candles_cache(CACHE_FILE, candles_dict)
await exchange.close()
else:
candles_dict = load_candles_cache(CACHE_FILE)
if not candles_dict:
print("No cached data available. Run with --fetch to load fresh data from the exchange.")
print("No cached data available. Run without --no-fetch (default) to load fresh data from the exchange.")
return
# Define desired timeframes list.
default_timeframes = ["1s", "1m", "5m", "15m", "1h", "1d"]
timeframes = [tf for tf in default_timeframes if tf in candles_dict]
if args.main_tf not in timeframes:
@ -890,11 +868,27 @@ async def main():
print("No valid optimizer state found; using fresh optimizer state.")
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
elif args.mode == 'live':
import ccxt.async_support as ccxt
exchange = ccxt.binance({'enableRateLimit': True})
POLL_INTERVAL = 60 # seconds
async def update_live_candles():
nonlocal exchange, args, candles_dict
while True:
now_ms = int(time.time()*1000)
new_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf,
since=now_ms - 2*60*1000, end_time=now_ms)
if args.main_tf in candles_dict:
candles_dict[args.main_tf] = new_candles
else:
candles_dict[args.main_tf] = new_candles
print("Live candles updated.")
await asyncio.sleep(POLL_INTERVAL)
asyncio.create_task(update_live_candles())
load_best_checkpoint(model)
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
preview_thread.start()
print("Starting live trading loop. (For main_tf={} using manual override if model signal is weak.)".format(args.main_tf))
print("Starting live trading loop. (Using live updated data now.)")
while True:
if args.main_tf == "1s":
simulate_trades_1s(env)
@ -914,14 +908,13 @@ async def main():
_, _, _, done, _, _ = env.step(action)
else:
manual_trade(env)
if env.current_index >= len(env.candle_window)-1:
if env.current_index >= len(env.candle_window) - 1:
print("Reached end of simulation window; resetting environment.")
env.reset()
await asyncio.sleep(1)
elif args.mode == 'inference':
load_best_checkpoint(model)
print("Running inference...")
# Inference logic goes here.
else:
print("Invalid mode specified.")

File diff suppressed because one or more lines are too long