wip
This commit is contained in:
parent
0e7997d50a
commit
79707c997c
File diff suppressed because one or more lines are too long
@ -22,21 +22,14 @@ import matplotlib.dates as mdates
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# Define global constants FIRST.
|
||||
CACHE_FILE = "candles_cache.json"
|
||||
TRAINING_CACHE_FILE = "training_cache.json"
|
||||
|
||||
|
||||
# --- Helper Function for Timestamp Conversion ---
|
||||
def convert_timestamp(ts):
|
||||
"""
|
||||
Safely converts a timestamp to a datetime object.
|
||||
If the timestamp is abnormally high (e.g. in milliseconds),
|
||||
it is divided by 1000.
|
||||
"""
|
||||
ts = float(ts)
|
||||
if ts > 1e10: # Likely in milliseconds
|
||||
if ts > 1e10:
|
||||
ts /= 1000.0
|
||||
return datetime.fromtimestamp(ts)
|
||||
|
||||
@ -44,10 +37,6 @@ def convert_timestamp(ts):
|
||||
# Historical Data Fetching Functions (Using CCXT)
|
||||
# -------------------------------
|
||||
async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500):
|
||||
"""
|
||||
Fetch historical OHLCV data for the given symbol and timeframe.
|
||||
"since" and "end_time" are in milliseconds.
|
||||
"""
|
||||
candles = []
|
||||
since_ms = since
|
||||
while True:
|
||||
@ -90,7 +79,6 @@ async def get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time,
|
||||
candles = await fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size)
|
||||
return candles
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Cache and Training Cache Helpers
|
||||
# -------------------------------
|
||||
@ -130,8 +118,6 @@ def save_training_cache(filename, cache):
|
||||
except Exception as e:
|
||||
print("Error saving training cache:", e)
|
||||
|
||||
TRAINING_CACHE_FILE = "training_cache.json"
|
||||
|
||||
# -------------------------------
|
||||
# Checkpoint Functions
|
||||
# -------------------------------
|
||||
@ -218,7 +204,7 @@ class PositionalEncoding(nn.Module):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
position = torch.arange(max_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0)/d_model))
|
||||
pe = torch.zeros(max_len, 1, d_model)
|
||||
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
||||
@ -230,7 +216,6 @@ class PositionalEncoding(nn.Module):
|
||||
class TradingModel(nn.Module):
|
||||
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
|
||||
super().__init__()
|
||||
# One branch per channel.
|
||||
self.channel_branches = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
|
||||
@ -248,14 +233,14 @@ class TradingModel(nn.Module):
|
||||
self.transformer = TransformerEncoder(encoder_layers, num_layers=2)
|
||||
self.attn_pool = nn.Linear(hidden_dim, 1)
|
||||
self.high_pred = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.Linear(hidden_dim, hidden_dim//2),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim // 2, 1)
|
||||
nn.Linear(hidden_dim//2, 1)
|
||||
)
|
||||
self.low_pred = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.Linear(hidden_dim, hidden_dim//2),
|
||||
nn.GELU(),
|
||||
nn.Linear(hidden_dim // 2, 1)
|
||||
nn.Linear(hidden_dim//2, 1)
|
||||
)
|
||||
def forward(self, x, timeframe_ids):
|
||||
batch_size, num_channels, _ = x.shape
|
||||
@ -309,7 +294,7 @@ def get_features_for_tf(candles_list, index, period=10):
|
||||
# -------------------------------
|
||||
class BacktestEnvironment:
|
||||
def __init__(self, candles_dict, base_tf, timeframes, window_size=None):
|
||||
self.candles_dict = candles_dict # full candles dict across timeframes
|
||||
self.candles_dict = candles_dict
|
||||
self.base_tf = base_tf
|
||||
self.timeframes = timeframes
|
||||
self.full_candles = candles_dict[base_tf]
|
||||
@ -361,10 +346,10 @@ class BacktestEnvironment:
|
||||
next_candle = base[next_index]
|
||||
reward = 0.0
|
||||
if self.position is None:
|
||||
if action == 2: # BUY (open long)
|
||||
if action == 2:
|
||||
self.position = {"entry_price": next_candle["open"], "entry_index": self.current_index}
|
||||
else:
|
||||
if action == 0: # SELL (close trade)
|
||||
if action == 0:
|
||||
exit_price = next_candle["close"]
|
||||
reward = exit_price - self.position["entry_price"]
|
||||
trade = {
|
||||
@ -755,11 +740,6 @@ NUM_INDICATORS = 20
|
||||
FEATURES_PER_CHANNEL = 7
|
||||
ORDER_CHANNELS = 1
|
||||
|
||||
# -------------------------------
|
||||
# Backtest Environment with Sliding Window and Order Info (Already Defined Above)
|
||||
# [See BacktestEnvironment class above]
|
||||
# -------------------------------
|
||||
|
||||
# -------------------------------
|
||||
# General Simulation of Trades Function
|
||||
# -------------------------------
|
||||
@ -807,7 +787,10 @@ def parse_args():
|
||||
parser.add_argument('--start_fresh', action='store_true', help="Start training from scratch.")
|
||||
parser.add_argument('--main_tf', type=str, default='1m',
|
||||
help="Desired main timeframe to focus on (e.g., '1s' or '1m').")
|
||||
parser.add_argument('--fetch', action='store_true', help="Fetch fresh data from exchange on start.")
|
||||
# Instead of --fetch, we now provide a --no-fetch flag that will override the default behavior.
|
||||
parser.add_argument('--no-fetch', dest='fetch', action='store_false',
|
||||
help="Do NOT fetch fresh data from exchange on start.")
|
||||
parser.set_defaults(fetch=True)
|
||||
parser.add_argument('--symbol', type=str, default='BTC/USDT', help="Trading pair symbol.")
|
||||
return parser.parse_args()
|
||||
|
||||
@ -822,33 +805,28 @@ async def main():
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print("Using device:", device)
|
||||
|
||||
# If --fetch flag is provided, top-up cached OHLCV data with fresh data from exchange.
|
||||
# With fetch defaulting to True, live mode will always try to top-up the cache.
|
||||
if args.fetch:
|
||||
import ccxt.async_support as ccxt
|
||||
exchange = ccxt.binance({'enableRateLimit': True})
|
||||
now_ms = int(time.time()*1000)
|
||||
# Determine default "since" time based on cache.
|
||||
cached = load_candles_cache(CACHE_FILE)
|
||||
if cached and args.main_tf in cached and len(cached[args.main_tf]) > 0:
|
||||
last_ts = cached[args.main_tf][-1]['timestamp']
|
||||
since = last_ts + 1
|
||||
else:
|
||||
# Default: fetch candles from the last 2 days.
|
||||
since = now_ms - 2*24*60*60*1000
|
||||
# Top-up data for the main timeframe.
|
||||
print(f"Fetching fresh data for {args.symbol} on timeframe {args.main_tf} from {since} to {now_ms}...")
|
||||
fresh_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf, since, now_ms)
|
||||
# Update cache (for simplicity, we store only the main timeframe here).
|
||||
candles_dict = {args.main_tf: fresh_candles}
|
||||
save_candles_cache(CACHE_FILE, candles_dict)
|
||||
await exchange.close()
|
||||
else:
|
||||
candles_dict = load_candles_cache(CACHE_FILE)
|
||||
if not candles_dict:
|
||||
print("No cached data available. Run with --fetch to load fresh data from the exchange.")
|
||||
print("No cached data available. Run without --no-fetch (default) to load fresh data from the exchange.")
|
||||
return
|
||||
|
||||
# Define desired timeframes list.
|
||||
default_timeframes = ["1s", "1m", "5m", "15m", "1h", "1d"]
|
||||
timeframes = [tf for tf in default_timeframes if tf in candles_dict]
|
||||
if args.main_tf not in timeframes:
|
||||
@ -890,11 +868,27 @@ async def main():
|
||||
print("No valid optimizer state found; using fresh optimizer state.")
|
||||
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
|
||||
elif args.mode == 'live':
|
||||
import ccxt.async_support as ccxt
|
||||
exchange = ccxt.binance({'enableRateLimit': True})
|
||||
POLL_INTERVAL = 60 # seconds
|
||||
async def update_live_candles():
|
||||
nonlocal exchange, args, candles_dict
|
||||
while True:
|
||||
now_ms = int(time.time()*1000)
|
||||
new_candles = await get_cached_or_fetch_data(exchange, args.symbol, args.main_tf,
|
||||
since=now_ms - 2*60*1000, end_time=now_ms)
|
||||
if args.main_tf in candles_dict:
|
||||
candles_dict[args.main_tf] = new_candles
|
||||
else:
|
||||
candles_dict[args.main_tf] = new_candles
|
||||
print("Live candles updated.")
|
||||
await asyncio.sleep(POLL_INTERVAL)
|
||||
asyncio.create_task(update_live_candles())
|
||||
load_best_checkpoint(model)
|
||||
env = BacktestEnvironment(candles_dict, base_tf, timeframes, window_size=100)
|
||||
preview_thread = threading.Thread(target=live_preview_loop, args=(env.candle_window, env), daemon=True)
|
||||
preview_thread.start()
|
||||
print("Starting live trading loop. (For main_tf={} using manual override if model signal is weak.)".format(args.main_tf))
|
||||
print("Starting live trading loop. (Using live updated data now.)")
|
||||
while True:
|
||||
if args.main_tf == "1s":
|
||||
simulate_trades_1s(env)
|
||||
@ -914,14 +908,13 @@ async def main():
|
||||
_, _, _, done, _, _ = env.step(action)
|
||||
else:
|
||||
manual_trade(env)
|
||||
if env.current_index >= len(env.candle_window)-1:
|
||||
if env.current_index >= len(env.candle_window) - 1:
|
||||
print("Reached end of simulation window; resetting environment.")
|
||||
env.reset()
|
||||
await asyncio.sleep(1)
|
||||
elif args.mode == 'inference':
|
||||
load_best_checkpoint(model)
|
||||
print("Running inference...")
|
||||
# Inference logic goes here.
|
||||
else:
|
||||
print("Invalid mode specified.")
|
||||
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user