compile fix

This commit is contained in:
Dobromir Popov 2025-02-04 18:00:58 +02:00
parent f5b1692a82
commit 2ec75e66cb

View File

@ -85,37 +85,49 @@ class TradingModel(nn.Module):
)
def forward(self, x, timeframe_ids):
# x shape: [batch_size, num_channels, features]
# x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL]
batch_size, num_channels, _ = x.shape
# Process each channel
# Process each channel through its branch
channel_outs = []
for i in range(num_channels):
channel_out = self.channel_branches[i](x[:,i,:])
channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out)
# Stack and add embeddings
stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden]
# Stack and add timeframe embeddings
stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden]
# Add timeframe embeddings
# Add timeframe embeddings to each channel
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
stacked = stacked + tf_embeds
# Transformer
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool()
transformer_out = self.transformer(stacked, src_mask=src_mask.to(x.device))
# Apply Transformer
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
transformer_out = self.transformer(stacked, src_mask=src_mask)
# Attention Pool
# Attention Pooling over channels
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0)
aggregated = (transformer_out * attn_weights).sum(dim=0)
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
# --- Enhanced Data Processing ---
# Here you need to have the helper functions get_aligned_candle_with_index and get_features_for_tf
# They must be defined elsewhere in your code.
class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes):
self.candles_dict = candles_dict
self.base_tf = base_tf
self.timeframes = timeframes
self.current_index = 0 # Initialize step pointer
def reset(self):
self.current_index = 0
return self.get_state(self.current_index)
def get_state(self, index):
"""Returns shape [num_channels, FEATURES_PER_CHANNEL]"""
"""Returns state as an array of shape [num_channels, FEATURES_PER_CHANNEL]."""
state_features = []
base_ts = self.candles_dict[self.base_tf][index]["timestamp"]
@ -127,10 +139,29 @@ class BacktestEnvironment:
# Indicator channels (placeholder - implement your indicators)
for _ in range(NUM_INDICATORS):
# Add indicator calculation here
state_features.append([0.0]*FEATURES_PER_CHANNEL)
state_features.append([0.0] * FEATURES_PER_CHANNEL)
return np.array(state_features, dtype=np.float32)
def step(self, action):
"""
Advance the environment by one step.
Since this is for backtesting, action isn't used here.
Returns: current candle info, reward, next state, done, actual high, actual low.
"""
# Dummy implementation: you would generate targets based on your backtest logic.
done = (self.current_index >= len(self.candles_dict[self.base_tf]) - 2)
current_candle = self.candles_dict[self.base_tf][self.current_index]
# For example, take the next candle's high/low as targets
next_candle = self.candles_dict[self.base_tf][self.current_index + 1]
actual_high = next_candle["high"]
actual_low = next_candle["low"]
self.current_index += 1
next_state = self.get_state(self.current_index)
return current_candle, 0.0, next_state, done, actual_high, actual_low
def __len__(self):
return len(self.candles_dict[self.base_tf])
# --- Enhanced Training Loop ---
def train_on_historical_data(env, model, device, args):
@ -143,24 +174,24 @@ def train_on_historical_data(env, model, device, args):
model.train()
while True:
# Prepare batch
# Prepare batch (here batch size is 1 for simplicity)
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device)
# Forward pass
pred_high, pred_low = model(state_tensor, timeframe_ids)
# Get targets from next candle
# Get target values from next candle (dummy targets from environment)
_, _, next_state, done, actual_high, actual_low = env.step(None) # Dummy action
target_high = torch.FloatTensor([actual_high]).to(device)
target_low = torch.FloatTensor([actual_low]).to(device)
# Custom loss
# Custom loss: use absolute error scaled by 2
high_loss = torch.abs(pred_high - target_high) * 2
low_loss = torch.abs(pred_low - target_low) * 2
loss = (high_loss + low_loss).mean()
# Backprop
# Backpropagation
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
@ -176,7 +207,7 @@ def train_on_historical_data(env, model, device, args):
print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}")
save_checkpoint(model, epoch, total_loss)
# --- Mode Handling ---
# --- Mode Handling and Argument Parsing ---
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train')
@ -185,31 +216,50 @@ def parse_args():
parser.add_argument('--threshold', type=float, default=0.005)
return parser.parse_args()
def load_best_checkpoint(model, best_dir="models/best"):
# Dummy implementation for loading the best checkpoint.
# In real usage, check your saved checkpoints.
print("Loading best checkpoint (dummy implementation)")
# torch.load(...) can be invoked here.
return
def save_checkpoint(model, epoch, reward, last_dir="models/last", best_dir="models/best"):
# Dummy implementation for saving checkpoints.
print(f"Saving checkpoint for epoch {epoch}, reward: {reward:.4f}")
# --- Main Function ---
async def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model
model = TradingModel(
num_channels=NUM_TIMEFRAMES+NUM_INDICATORS,
num_channels=NUM_TIMEFRAMES + NUM_INDICATORS,
num_timeframes=NUM_TIMEFRAMES
).to(device)
if args.mode == 'train':
# Initialize environment and train
env = BacktestEnvironment(...)
# Load historical candle data for backtesting
candles_dict = load_candles_cache("candles_cache.json")
if not candles_dict:
print("No historical candle data available for backtesting.")
return
base_tf = "1m" # Base timeframe
timeframes = ["1m", "5m", "15m", "1h", "1d"]
env = BacktestEnvironment(candles_dict, base_tf, timeframes)
train_on_historical_data(env, model, device, args)
elif args.mode == 'live':
# Load model and connect to live data
load_best_checkpoint(model)
while True:
# Process live data
# Make predictions and execute trades
# Process live data: fetch live candles, make predictions, execute trades
print("Processing live data...")
await asyncio.sleep(1)
elif args.mode == 'inference':
# Load model and run inference
load_best_checkpoint(model)
# Generate signals without training
print("Running inference...")
# Add your inference logic here
if __name__ == "__main__":
asyncio.run(main())