This commit is contained in:
Dobromir Popov 2025-02-04 20:53:00 +02:00
parent c8043a9dcd
commit 20d6542d2c
2 changed files with 38 additions and 26 deletions

View File

@ -31,9 +31,9 @@ os.makedirs(BEST_DIR, exist_ok=True)
CACHE_FILE = "candles_cache.json"
# --- Constants ---
NUM_TIMEFRAMES = 5 # Example: ["1m", "5m", "15m", "1h", "1d"]
NUM_INDICATORS = 20 # Example: 20 technical indicators
FEATURES_PER_CHANNEL = 7 # e.g. HLOC, SMA_close, SMA_volume
NUM_TIMEFRAMES = 5 # e.g., ["1m", "5m", "15m", "1h", "1d"]
NUM_INDICATORS = 20 # e.g., 20 technical indicators
FEATURES_PER_CHANNEL = 7 # H, L, O, C, Volume, SMA_close, SMA_volume
# --- Positional Encoding Module ---
class PositionalEncoding(nn.Module):
@ -46,7 +46,6 @@ class PositionalEncoding(nn.Module):
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0)]
return self.dropout(x)
@ -55,6 +54,7 @@ class PositionalEncoding(nn.Module):
class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__()
# Create a branch for each channel (each channel input has FEATURES_PER_CHANNEL features)
self.channel_branches = nn.ModuleList([
nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
@ -63,7 +63,8 @@ class TradingModel(nn.Module):
nn.Dropout(0.1)
) for _ in range(num_channels)
])
self.timeframe_embed = nn.Embedding(num_timeframes, hidden_dim)
# IMPORTANT FIX: Use num_channels (total channels) instead of num_timeframes.
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
self.pos_encoder = PositionalEncoding(hidden_dim)
encoder_layers = TransformerEncoderLayer(
d_model=hidden_dim, nhead=4, dim_feedforward=512,
@ -81,22 +82,25 @@ class TradingModel(nn.Module):
nn.GELU(),
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, x, timeframe_ids):
# x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL]
batch_size, num_channels, _ = x.shape
channel_outs = [self.channel_branches[i](x[:, i, :]) for i in range(num_channels)]
stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden]
channel_outs = []
for i in range(num_channels):
channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out)
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
# Use embedding for each channel (indices 0 to num_channels-1)
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
stacked = stacked + tf_embeds
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
transformer_out = self.transformer(stacked, src_mask=src_mask)
transformer_out = self.transformer(stacked, mask=src_mask)
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0)
aggregated = (transformer_out * attn_weights).sum(dim=0)
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
# --- Technical Indicator Helper Functions ---
# --- Technical Indicator Helpers ---
def compute_sma(candles_list, index, period=10):
start = max(0, index - period + 1)
values = [candle["close"] for candle in candles_list[start:index+1]]
@ -127,7 +131,7 @@ def get_features_for_tf(candles_list, index, period=10):
sma_volume = compute_sma_volume(candles_list, index, period)
return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume]
# --- Caching and Checkpoint Functions ---
# --- Caching & Checkpoint Functions ---
def load_candles_cache(filename):
if os.path.exists(filename):
try:
@ -209,7 +213,7 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
# --- Backtest Environment ---
class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes):
self.candles_dict = candles_dict # dict of timeframe: candles_list
self.candles_dict = candles_dict # dict: timeframe -> list of candles
self.base_tf = base_tf
self.timeframes = timeframes
self.current_index = 0
@ -245,11 +249,11 @@ class BacktestEnvironment:
reward = 0.0
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
if self.position is None:
if action == 2: # BUY signal
if action == 2: # BUY signal: enter at next candle's open.
entry_price = next_candle["open"]
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
else:
if action == 0: # SELL signal
if action == 0: # SELL signal: exit at next candle's open.
exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"]
trade = {
@ -278,9 +282,9 @@ def train_on_historical_data(env, model, device, args):
model.train()
while True:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device) # Expect shape[0]==num_channels
pred_high, pred_low = model(state_tensor, timeframe_ids)
# Here we use dummy targets extracted from the next candle's high/low
# Dummy targets from next candle's high/low
_, _, next_state, done, actual_high, actual_low = env.step(None)
target_high = torch.FloatTensor([actual_high]).to(device)
target_low = torch.FloatTensor([actual_low]).to(device)
@ -335,7 +339,7 @@ def live_preview_loop(candles, env):
while True:
update_live_chart(ax, candles, env.trade_history)
plt.draw()
plt.pause(1) # Update every second
plt.pause(1)
# --- Argument Parsing ---
def parse_args():
@ -353,13 +357,13 @@ def random_action():
async def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define timeframes; these must match your data and expected state dimensions.
timeframes = ["1m", "5m", "15m", "1h", "1d"]
input_dim = len(timeframes) * 7 # 7 features per timeframe.
input_dim = len(timeframes) * 7 # 7 features per timeframe
hidden_dim = 128
output_dim = 3 # Actions: SELL, HOLD, BUY.
# For the Transformer model, we set number of channels = NUM_TIMEFRAMES + NUM_INDICATORS.
model = TradingModel(NUM_TIMEFRAMES + NUM_INDICATORS, NUM_TIMEFRAMES).to(device)
output_dim = 3 # SELL, HOLD, BUY
# Set total number of channels = NUM_TIMEFRAMES + NUM_INDICATORS.
total_channels = NUM_TIMEFRAMES + NUM_INDICATORS
model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device)
if args.mode == 'train':
candles_dict = load_candles_cache(CACHE_FILE)
@ -376,7 +380,6 @@ async def main():
print("No cached candles available for live preview.")
return
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes)
# Start the live preview in a separate daemon thread.
preview_thread = threading.Thread(target=live_preview_loop, args=(candles_dict["1m"], env), daemon=True)
preview_thread.start()
print("Starting live trading loop. (Using random actions for simulation.)")
@ -385,11 +388,11 @@ async def main():
if done:
print("Reached end of simulated data, resetting environment.")
state = env.reset(clear_trade_history=False)
await asyncio.sleep(1) # Simulate one candle per second.
await asyncio.sleep(1)
elif args.mode == 'inference':
load_best_checkpoint(model)
print("Running inference...")
# Implement your inference loop here.
# Place your inference logic here.
else:
print("Invalid mode specified.")

View File

@ -4,6 +4,15 @@ pip install ccxt torch numpy
run: >conda activate gpt-gpu
python .\index.py
Usage:
Run the script with a command-line argument — for example:
• python index-deep-new.py --mode train
• python index-deep-new.py --mode live
• python index-deep-new.py --mode inference
prompts:
1.
create a 8b neural network (ai) that will consume live and historical HLOCv (candle sticks) data with a specific time window and in different time periods (1s, 1m 15m, 1h, 1d) and perform buy/sell operations. It will be based on the latest RL unsupervised training techniques, will continiously and retrospectively improve itself (without entering separate modes for training/inference) and the info it can digest will be able to be extendable and dynamic. for example, we should be able to feed sentiment analysis on current X feeds or news. We will also prepare/ calculte various indicators on top of the incomming HLOCV data (stocastic, rsi, etc - all most popular). we should be able to support up to 100 indicators (additional data) channels. The signals of the NN will be used by a bot first to trade on Solana using jupiter api.