wip
This commit is contained in:
parent
c8043a9dcd
commit
20d6542d2c
@ -31,9 +31,9 @@ os.makedirs(BEST_DIR, exist_ok=True)
|
|||||||
CACHE_FILE = "candles_cache.json"
|
CACHE_FILE = "candles_cache.json"
|
||||||
|
|
||||||
# --- Constants ---
|
# --- Constants ---
|
||||||
NUM_TIMEFRAMES = 5 # Example: ["1m", "5m", "15m", "1h", "1d"]
|
NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"]
|
||||||
NUM_INDICATORS = 20 # Example: 20 technical indicators
|
NUM_INDICATORS = 20 # e.g., 20 technical indicators
|
||||||
FEATURES_PER_CHANNEL = 7 # e.g. HLOC, SMA_close, SMA_volume
|
FEATURES_PER_CHANNEL = 7 # H, L, O, C, Volume, SMA_close, SMA_volume
|
||||||
|
|
||||||
# --- Positional Encoding Module ---
|
# --- Positional Encoding Module ---
|
||||||
class PositionalEncoding(nn.Module):
|
class PositionalEncoding(nn.Module):
|
||||||
@ -46,7 +46,6 @@ class PositionalEncoding(nn.Module):
|
|||||||
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
||||||
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
||||||
self.register_buffer('pe', pe)
|
self.register_buffer('pe', pe)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x + self.pe[:x.size(0)]
|
x = x + self.pe[:x.size(0)]
|
||||||
return self.dropout(x)
|
return self.dropout(x)
|
||||||
@ -55,6 +54,7 @@ class PositionalEncoding(nn.Module):
|
|||||||
class TradingModel(nn.Module):
|
class TradingModel(nn.Module):
|
||||||
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
|
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# Create a branch for each channel (each channel input has FEATURES_PER_CHANNEL features)
|
||||||
self.channel_branches = nn.ModuleList([
|
self.channel_branches = nn.ModuleList([
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
|
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
|
||||||
@ -63,7 +63,8 @@ class TradingModel(nn.Module):
|
|||||||
nn.Dropout(0.1)
|
nn.Dropout(0.1)
|
||||||
) for _ in range(num_channels)
|
) for _ in range(num_channels)
|
||||||
])
|
])
|
||||||
self.timeframe_embed = nn.Embedding(num_timeframes, hidden_dim)
|
# IMPORTANT FIX: Use num_channels (total channels) instead of num_timeframes.
|
||||||
|
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
|
||||||
self.pos_encoder = PositionalEncoding(hidden_dim)
|
self.pos_encoder = PositionalEncoding(hidden_dim)
|
||||||
encoder_layers = TransformerEncoderLayer(
|
encoder_layers = TransformerEncoderLayer(
|
||||||
d_model=hidden_dim, nhead=4, dim_feedforward=512,
|
d_model=hidden_dim, nhead=4, dim_feedforward=512,
|
||||||
@ -81,22 +82,25 @@ class TradingModel(nn.Module):
|
|||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Linear(hidden_dim // 2, 1)
|
nn.Linear(hidden_dim // 2, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, timeframe_ids):
|
def forward(self, x, timeframe_ids):
|
||||||
# x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL]
|
# x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL]
|
||||||
batch_size, num_channels, _ = x.shape
|
batch_size, num_channels, _ = x.shape
|
||||||
channel_outs = [self.channel_branches[i](x[:, i, :]) for i in range(num_channels)]
|
channel_outs = []
|
||||||
stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
|
for i in range(num_channels):
|
||||||
stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden]
|
channel_out = self.channel_branches[i](x[:, i, :])
|
||||||
|
channel_outs.append(channel_out)
|
||||||
|
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
|
||||||
|
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
|
||||||
|
# Use embedding for each channel (indices 0 to num_channels-1)
|
||||||
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
|
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
|
||||||
stacked = stacked + tf_embeds
|
stacked = stacked + tf_embeds
|
||||||
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
|
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
|
||||||
transformer_out = self.transformer(stacked, src_mask=src_mask)
|
transformer_out = self.transformer(stacked, mask=src_mask)
|
||||||
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0)
|
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0)
|
||||||
aggregated = (transformer_out * attn_weights).sum(dim=0)
|
aggregated = (transformer_out * attn_weights).sum(dim=0)
|
||||||
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
|
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
|
||||||
|
|
||||||
# --- Technical Indicator Helper Functions ---
|
# --- Technical Indicator Helpers ---
|
||||||
def compute_sma(candles_list, index, period=10):
|
def compute_sma(candles_list, index, period=10):
|
||||||
start = max(0, index - period + 1)
|
start = max(0, index - period + 1)
|
||||||
values = [candle["close"] for candle in candles_list[start:index+1]]
|
values = [candle["close"] for candle in candles_list[start:index+1]]
|
||||||
@ -127,7 +131,7 @@ def get_features_for_tf(candles_list, index, period=10):
|
|||||||
sma_volume = compute_sma_volume(candles_list, index, period)
|
sma_volume = compute_sma_volume(candles_list, index, period)
|
||||||
return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume]
|
return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume]
|
||||||
|
|
||||||
# --- Caching and Checkpoint Functions ---
|
# --- Caching & Checkpoint Functions ---
|
||||||
def load_candles_cache(filename):
|
def load_candles_cache(filename):
|
||||||
if os.path.exists(filename):
|
if os.path.exists(filename):
|
||||||
try:
|
try:
|
||||||
@ -209,7 +213,7 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
|
|||||||
# --- Backtest Environment ---
|
# --- Backtest Environment ---
|
||||||
class BacktestEnvironment:
|
class BacktestEnvironment:
|
||||||
def __init__(self, candles_dict, base_tf, timeframes):
|
def __init__(self, candles_dict, base_tf, timeframes):
|
||||||
self.candles_dict = candles_dict # dict of timeframe: candles_list
|
self.candles_dict = candles_dict # dict: timeframe -> list of candles
|
||||||
self.base_tf = base_tf
|
self.base_tf = base_tf
|
||||||
self.timeframes = timeframes
|
self.timeframes = timeframes
|
||||||
self.current_index = 0
|
self.current_index = 0
|
||||||
@ -245,11 +249,11 @@ class BacktestEnvironment:
|
|||||||
reward = 0.0
|
reward = 0.0
|
||||||
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
|
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
|
||||||
if self.position is None:
|
if self.position is None:
|
||||||
if action == 2: # BUY signal
|
if action == 2: # BUY signal: enter at next candle's open.
|
||||||
entry_price = next_candle["open"]
|
entry_price = next_candle["open"]
|
||||||
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
|
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
|
||||||
else:
|
else:
|
||||||
if action == 0: # SELL signal
|
if action == 0: # SELL signal: exit at next candle's open.
|
||||||
exit_price = next_candle["open"]
|
exit_price = next_candle["open"]
|
||||||
reward = exit_price - self.position["entry_price"]
|
reward = exit_price - self.position["entry_price"]
|
||||||
trade = {
|
trade = {
|
||||||
@ -278,9 +282,9 @@ def train_on_historical_data(env, model, device, args):
|
|||||||
model.train()
|
model.train()
|
||||||
while True:
|
while True:
|
||||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||||
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
timeframe_ids = torch.arange(state.shape[0]).to(device) # Expect shape[0]==num_channels
|
||||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||||
# Here we use dummy targets extracted from the next candle's high/low
|
# Dummy targets from next candle's high/low
|
||||||
_, _, next_state, done, actual_high, actual_low = env.step(None)
|
_, _, next_state, done, actual_high, actual_low = env.step(None)
|
||||||
target_high = torch.FloatTensor([actual_high]).to(device)
|
target_high = torch.FloatTensor([actual_high]).to(device)
|
||||||
target_low = torch.FloatTensor([actual_low]).to(device)
|
target_low = torch.FloatTensor([actual_low]).to(device)
|
||||||
@ -335,7 +339,7 @@ def live_preview_loop(candles, env):
|
|||||||
while True:
|
while True:
|
||||||
update_live_chart(ax, candles, env.trade_history)
|
update_live_chart(ax, candles, env.trade_history)
|
||||||
plt.draw()
|
plt.draw()
|
||||||
plt.pause(1) # Update every second
|
plt.pause(1)
|
||||||
|
|
||||||
# --- Argument Parsing ---
|
# --- Argument Parsing ---
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -353,13 +357,13 @@ def random_action():
|
|||||||
async def main():
|
async def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
# Define timeframes; these must match your data and expected state dimensions.
|
|
||||||
timeframes = ["1m", "5m", "15m", "1h", "1d"]
|
timeframes = ["1m", "5m", "15m", "1h", "1d"]
|
||||||
input_dim = len(timeframes) * 7 # 7 features per timeframe.
|
input_dim = len(timeframes) * 7 # 7 features per timeframe
|
||||||
hidden_dim = 128
|
hidden_dim = 128
|
||||||
output_dim = 3 # Actions: SELL, HOLD, BUY.
|
output_dim = 3 # SELL, HOLD, BUY
|
||||||
# For the Transformer model, we set number of channels = NUM_TIMEFRAMES + NUM_INDICATORS.
|
# Set total number of channels = NUM_TIMEFRAMES + NUM_INDICATORS.
|
||||||
model = TradingModel(NUM_TIMEFRAMES + NUM_INDICATORS, NUM_TIMEFRAMES).to(device)
|
total_channels = NUM_TIMEFRAMES + NUM_INDICATORS
|
||||||
|
model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device)
|
||||||
|
|
||||||
if args.mode == 'train':
|
if args.mode == 'train':
|
||||||
candles_dict = load_candles_cache(CACHE_FILE)
|
candles_dict = load_candles_cache(CACHE_FILE)
|
||||||
@ -376,7 +380,6 @@ async def main():
|
|||||||
print("No cached candles available for live preview.")
|
print("No cached candles available for live preview.")
|
||||||
return
|
return
|
||||||
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes)
|
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes)
|
||||||
# Start the live preview in a separate daemon thread.
|
|
||||||
preview_thread = threading.Thread(target=live_preview_loop, args=(candles_dict["1m"], env), daemon=True)
|
preview_thread = threading.Thread(target=live_preview_loop, args=(candles_dict["1m"], env), daemon=True)
|
||||||
preview_thread.start()
|
preview_thread.start()
|
||||||
print("Starting live trading loop. (Using random actions for simulation.)")
|
print("Starting live trading loop. (Using random actions for simulation.)")
|
||||||
@ -385,11 +388,11 @@ async def main():
|
|||||||
if done:
|
if done:
|
||||||
print("Reached end of simulated data, resetting environment.")
|
print("Reached end of simulated data, resetting environment.")
|
||||||
state = env.reset(clear_trade_history=False)
|
state = env.reset(clear_trade_history=False)
|
||||||
await asyncio.sleep(1) # Simulate one candle per second.
|
await asyncio.sleep(1)
|
||||||
elif args.mode == 'inference':
|
elif args.mode == 'inference':
|
||||||
load_best_checkpoint(model)
|
load_best_checkpoint(model)
|
||||||
print("Running inference...")
|
print("Running inference...")
|
||||||
# Implement your inference loop here.
|
# Place your inference logic here.
|
||||||
else:
|
else:
|
||||||
print("Invalid mode specified.")
|
print("Invalid mode specified.")
|
||||||
|
|
||||||
|
@ -4,6 +4,15 @@ pip install ccxt torch numpy
|
|||||||
run: >conda activate gpt-gpu
|
run: >conda activate gpt-gpu
|
||||||
python .\index.py
|
python .\index.py
|
||||||
|
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
Run the script with a command-line argument — for example:
|
||||||
|
• python index-deep-new.py --mode train
|
||||||
|
• python index-deep-new.py --mode live
|
||||||
|
• python index-deep-new.py --mode inference
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
prompts:
|
prompts:
|
||||||
1.
|
1.
|
||||||
create a 8b neural network (ai) that will consume live and historical HLOCv (candle sticks) data with a specific time window and in different time periods (1s, 1m 15m, 1h, 1d) and perform buy/sell operations. It will be based on the latest RL unsupervised training techniques, will continiously and retrospectively improve itself (without entering separate modes for training/inference) and the info it can digest will be able to be extendable and dynamic. for example, we should be able to feed sentiment analysis on current X feeds or news. We will also prepare/ calculte various indicators on top of the incomming HLOCV data (stocastic, rsi, etc - all most popular). we should be able to support up to 100 indicators (additional data) channels. The signals of the NN will be used by a bot first to trade on Solana using jupiter api.
|
create a 8b neural network (ai) that will consume live and historical HLOCv (candle sticks) data with a specific time window and in different time periods (1s, 1m 15m, 1h, 1d) and perform buy/sell operations. It will be based on the latest RL unsupervised training techniques, will continiously and retrospectively improve itself (without entering separate modes for training/inference) and the info it can digest will be able to be extendable and dynamic. for example, we should be able to feed sentiment analysis on current X feeds or news. We will also prepare/ calculte various indicators on top of the incomming HLOCV data (stocastic, rsi, etc - all most popular). we should be able to support up to 100 indicators (additional data) channels. The signals of the NN will be used by a bot first to trade on Solana using jupiter api.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user