wip
This commit is contained in:
parent
0e7997d50a
commit
79707c997c
File diff suppressed because one or more lines are too long
@ -22,21 +22,14 @@ import matplotlib.dates as mdates
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
# Define global constants FIRST.
|
# Define global constants FIRST.
|
||||||
CACHE_FILE = "candles_cache.json"
|
CACHE_FILE = "candles_cache.json"
|
||||||
TRAINING_CACHE_FILE = "training_cache.json"
|
TRAINING_CACHE_FILE = "training_cache.json"
|
||||||
|
|
||||||
|
|
||||||
# --- Helper Function for Timestamp Conversion ---
|
# --- Helper Function for Timestamp Conversion ---
|
||||||
def convert_timestamp(ts):
|
def convert_timestamp(ts):
|
||||||
"""
|
|
||||||
Safely converts a timestamp to a datetime object.
|
|
||||||
If the timestamp is abnormally high (e.g. in milliseconds),
|
|
||||||
it is divided by 1000.
|
|
||||||
"""
|
|
||||||
ts = float(ts)
|
ts = float(ts)
|
||||||
if ts > 1e10: # Likely in milliseconds
|
if ts > 1e10:
|
||||||
ts /= 1000.0
|
ts /= 1000.0
|
||||||
return datetime.fromtimestamp(ts)
|
return datetime.fromtimestamp(ts)
|
||||||
|
|
||||||
@ -44,10 +37,6 @@ def convert_timestamp(ts):
|
|||||||
# Historical Data Fetching Functions (Using CCXT)
|
# Historical Data Fetching Functions (Using CCXT)
|
||||||
# -------------------------------
|
# -------------------------------
|
||||||
async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500):
|
async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500):
|
||||||
"""
|
|
||||||
Fetch historical OHLCV data for the given symbol and timeframe.
|
|
||||||
"since" and "end_time" are in milliseconds.
|
|
||||||
"""
|
|
||||||
candles = []
|
candles = []
|
||||||
since_ms = since
|
since_ms = since
|
||||||
while True:
|
while True:
|
||||||
@ -90,7 +79,6 @@ async def get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time,
|
|||||||
candles = await fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size)
|
candles = await fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size)
|
||||||
return candles
|
return candles
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------
|
# -------------------------------
|
||||||
# Cache and Training Cache Helpers
|
# Cache and Training Cache Helpers
|
||||||
# -------------------------------
|
# -------------------------------
|
||||||
@ -130,8 +118,6 @@ def save_training_cache(filename, cache):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error saving training cache:", e)
|
print("Error saving training cache:", e)
|
||||||
|
|
||||||
TRAINING_CACHE_FILE = "training_cache.json"
|
|
||||||
|
|
||||||
# -------------------------------
|
# -------------------------------
|
||||||
# Checkpoint Functions
|
# Checkpoint Functions
|
||||||
# -------------------------------
|
# -------------------------------
|
||||||
@ -218,7 +204,7 @@ class PositionalEncoding(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dropout = nn.Dropout(p=dropout)
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
position = torch.arange(max_len).unsqueeze(1)
|
position = torch.arange(max_len).unsqueeze(1)
|
||||||
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0)/d_model))
|
||||||
pe = torch.zeros(max_len, 1, d_model)
|
pe = torch.zeros(max_len, 1, d_model)
|
||||||
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
||||||
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
||||||
@ -230,7 +216,6 @@ class PositionalEncoding(nn.Module):
|
|||||||
class TradingModel(nn.Module):
|
class TradingModel(nn.Module):
|
||||||
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
|
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# One branch per channel.
|
|
||||||
self.channel_branches = nn.ModuleList([
|
self.channel_branches = nn.ModuleList([
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
|
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
|
||||||
@ -248,14 +233,14 @@ class TradingModel(nn.Module):
|
|||||||
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
|
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
|
||||||
self.attn_pool = nn.Linear(hidden_dim, 1)
|
self.attn_pool = nn.Linear(hidden_dim, 1)
|
||||||
self.high_pred = nn.Sequential(
|
self.high_pred = nn.Sequential(
|
||||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
nn.Linear(hidden_dim, hidden_dim//2),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(hidden_dim // 2, 1)
|
nn.Linear(hidden_dim//2, 1)
|
||||||
)
|
)
|
||||||
self.low_pred = nn.Sequential(
|
self.low_pred = nn.Sequential(
|
||||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
nn.Linear(hidden_dim, hidden_dim//2),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(hidden_dim // 2, 1)
|
nn.Linear(hidden_dim//2, 1)
|
||||||
)
|
)
|
||||||
def forward(self, x, timeframe_ids):
|
def forward(self, x, timeframe_ids):
|
||||||
batch_size, num_channels, _ = x.shape
|
batch_size, num_channels, _ = x.shape
|
||||||
@ -309,7 +294,7 @@ def get_features_for_tf(candles_list, index, period=10):
|
|||||||
# -------------------------------
|
# -------------------------------
|
||||||
class BacktestEnvironment:
|
class BacktestEnvironment:
|
||||||
def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
|
def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
|
||||||
self.candles_dict = candles_dict # full candles dict across timeframes
|
self.candles_dict = candles_dict
|
||||||
self.base_tf = base_tf
|
self.base_tf = base_tf
|
||||||
self.timeframes = timeframes
|
self.timeframes = timeframes
|
||||||
self.full_candles = candles_dict[base_tf]
|
self.full_candles = candles_dict[base_tf]
|
||||||
@ -361,10 +346,10 @@ class BacktestEnvironment:
|
|||||||
next_candle = base[next_index]
|
next_candle = base[next_index]
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
if self.position is None:
|
if self.position is None:
|
||||||
if action == 2: # BUY (open long)
|
if action == 2:
|
||||||
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
|
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
|
||||||
else:
|
else:
|
||||||
if action == 0: # SELL (close trade)
|
if action == 0:
|
||||||
exit_price = next_candle["close"]
|
exit_price = next_candle["close"]
|
||||||
reward = exit_price - self.position["entry_price"]
|
reward = exit_price - self.position["entry_price"]
|
||||||
trade = {
|
trade = {
|
||||||
@ -755,11 +740,6 @@ NUM_INDICATORS = 20
|
|||||||
FEATURES_PER_CHANNEL = 7
|
FEATURES_PER_CHANNEL = 7
|
||||||
ORDER_CHANNELS = 1
|
ORDER_CHANNELS = 1
|
||||||
|
|
||||||
# -------------------------------
|
|
||||||
# Backtest Environment with Sliding Window and Order Info (Already Defined Above)
|
|
||||||
# [See BacktestEnvironment class above]
|
|
||||||
# -------------------------------
|
|
||||||
|
|
||||||
# -------------------------------
|
# -------------------------------
|
||||||
# General Simulation of Trades Function
|
# General Simulation of Trades Function
|
||||||
# -------------------------------
|
# -------------------------------
|
||||||
@ -807,7 +787,10 @@ def parse_args():
|
|||||||
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
|
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
|
||||||
parser.add_argument('--main_tf', type=str, default='1m',
|
parser.add_argument('--main_tf', type=str, default='1m',
|
||||||
help="Desired main timeframe to focus on (e.g., '1s' or '1m').")
|
help="Desired main timeframe to focus on (e.g., '1s' or '1m').")
|
||||||
parser.add_argument('--fetch', action='store_true', help="Fetch fresh data from exchange on start.")
|
# Instead of --fetch, we now provide a --no-fetch flag that will override the default behavior.
|
||||||
|
parser.add_argument('--no-fetch', dest='fetch', action='store_false',
|
||||||
|
help="Do NOT fetch fresh data from exchange on start.")
|
||||||
|
parser.set_defaults(fetch=True)
|
||||||
parser.add_argument('--symbol', type=str, default='BTC/USDT', help="Trading pair symbol.")
|
parser.add_argument('--symbol', type=str, default='BTC/USDT', help="Trading pair symbol.")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -822,33 +805,28 @@ async def main():
|
|||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
print("Using device:", device)
|
print("Using device:", device)
|
||||||
|
|
||||||
# If --fetch flag is provided, top-up cached OHLCV data with fresh data from exchange.
|
# With fetch defaulting to True, live mode will always try to top-up the cache.
|
||||||
if args.fetch:
|
if args.fetch:
|
||||||
import ccxt.async_support as ccxt
|
import ccxt.async_support as ccxt
|
||||||
exchange = ccxt.binance({'enableRateLimit': True})
|
exchange = ccxt.binance({'enableRateLimit': True})
|
||||||
now_ms = int(time.time()*1000)
|
now_ms = int(time.time()*1000)
|
||||||
# Determine default "since" time based on cache.
|
|
||||||
cached = load_candles_cache(CACHE_FILE)
|
cached = load_candles_cache(CACHE_FILE)
|
||||||
if cached and args.main_tf in cached and len(cached[args.main_tf]) > 0:
|
if cached and args.main_tf in cached and len(cached[args.main_tf]) > 0:
|
||||||
last_ts = cached[args.main_tf][-1]['timestamp']
|
last_ts = cached[args.main_tf][-1]['timestamp']
|
||||||
since = last_ts + 1
|
since = last_ts + 1
|
||||||
else:
|
else:
|
||||||
# Default: fetch candles from the last 2 days.
|
|
||||||
since = now_ms - 2*24*60*60*1000
|
since = now_ms - 2*24*60*60*1000
|
||||||
# Top-up data for the main timeframe.
|
|
||||||
print(f"Fetching fresh data for {args.symbol} on timeframe {args.main_tf} from {since} to {now_ms}...")
|
print(f"Fetching fresh data for {args.symbol} on timeframe {args.main_tf} from {since} to {now_ms}...")
|
||||||
fresh_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf, since, now_ms)
|
fresh_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf, since, now_ms)
|
||||||
# Update cache (for simplicity, we store only the main timeframe here).
|
|
||||||
candles_dict = {args.main_tf: fresh_candles}
|
candles_dict = {args.main_tf: fresh_candles}
|
||||||
save_candles_cache(CACHE_FILE, candles_dict)
|
save_candles_cache(CACHE_FILE, candles_dict)
|
||||||
await exchange.close()
|
await exchange.close()
|
||||||
else:
|
else:
|
||||||
candles_dict = load_candles_cache(CACHE_FILE)
|
candles_dict = load_candles_cache(CACHE_FILE)
|
||||||
if not candles_dict:
|
if not candles_dict:
|
||||||
print("No cached data available. Run with --fetch to load fresh data from the exchange.")
|
print("No cached data available. Run without --no-fetch (default) to load fresh data from the exchange.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Define desired timeframes list.
|
|
||||||
default_timeframes = ["1s", "1m", "5m", "15m", "1h", "1d"]
|
default_timeframes = ["1s", "1m", "5m", "15m", "1h", "1d"]
|
||||||
timeframes = [tf for tf in default_timeframes if tf in candles_dict]
|
timeframes = [tf for tf in default_timeframes if tf in candles_dict]
|
||||||
if args.main_tf not in timeframes:
|
if args.main_tf not in timeframes:
|
||||||
@ -890,11 +868,27 @@ async def main():
|
|||||||
print("No valid optimizer state found; using fresh optimizer state.")
|
print("No valid optimizer state found; using fresh optimizer state.")
|
||||||
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
|
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
|
||||||
elif args.mode == 'live':
|
elif args.mode == 'live':
|
||||||
|
import ccxt.async_support as ccxt
|
||||||
|
exchange = ccxt.binance({'enableRateLimit': True})
|
||||||
|
POLL_INTERVAL = 60 # seconds
|
||||||
|
async def update_live_candles():
|
||||||
|
nonlocal exchange, args, candles_dict
|
||||||
|
while True:
|
||||||
|
now_ms = int(time.time()*1000)
|
||||||
|
new_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf,
|
||||||
|
since=now_ms - 2*60*1000, end_time=now_ms)
|
||||||
|
if args.main_tf in candles_dict:
|
||||||
|
candles_dict[args.main_tf] = new_candles
|
||||||
|
else:
|
||||||
|
candles_dict[args.main_tf] = new_candles
|
||||||
|
print("Live candles updated.")
|
||||||
|
await asyncio.sleep(POLL_INTERVAL)
|
||||||
|
asyncio.create_task(update_live_candles())
|
||||||
load_best_checkpoint(model)
|
load_best_checkpoint(model)
|
||||||
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
|
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
|
||||||
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
|
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
|
||||||
preview_thread.start()
|
preview_thread.start()
|
||||||
print("Starting live trading loop. (For main_tf={} using manual override if model signal is weak.)".format(args.main_tf))
|
print("Starting live trading loop. (Using live updated data now.)")
|
||||||
while True:
|
while True:
|
||||||
if args.main_tf == "1s":
|
if args.main_tf == "1s":
|
||||||
simulate_trades_1s(env)
|
simulate_trades_1s(env)
|
||||||
@ -914,14 +908,13 @@ async def main():
|
|||||||
_, _, _, done, _, _ = env.step(action)
|
_, _, _, done, _, _ = env.step(action)
|
||||||
else:
|
else:
|
||||||
manual_trade(env)
|
manual_trade(env)
|
||||||
if env.current_index >= len(env.candle_window)-1:
|
if env.current_index >= len(env.candle_window) - 1:
|
||||||
print("Reached end of simulation window; resetting environment.")
|
print("Reached end of simulation window; resetting environment.")
|
||||||
env.reset()
|
env.reset()
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
elif args.mode == 'inference':
|
elif args.mode == 'inference':
|
||||||
load_best_checkpoint(model)
|
load_best_checkpoint(model)
|
||||||
print("Running inference...")
|
print("Running inference...")
|
||||||
# Inference logic goes here.
|
|
||||||
else:
|
else:
|
||||||
print("Invalid mode specified.")
|
print("Invalid mode specified.")
|
||||||
|
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user