compile fix

This commit is contained in:
Dobromir Popov 2025-02-04 18:00:58 +02:00
parent f5b1692a82
commit 2ec75e66cb

View File

@ -85,37 +85,49 @@ class TradingModel(nn.Module):
) )
def forward(self, x, timeframe_ids): def forward(self, x, timeframe_ids):
# x shape: [batch_size, num_channels, features] # x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL]
batch_size, num_channels, _ = x.shape batch_size, num_channels, _ = x.shape
# Process each channel # Process each channel through its branch
channel_outs = [] channel_outs = []
for i in range(num_channels): for i in range(num_channels):
channel_out = self.channel_branches[i](x[:,i,:]) channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out) channel_outs.append(channel_out)
# Stack and add embeddings # Stack and add timeframe embeddings
stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden] stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden] stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden]
# Add timeframe embeddings # Add timeframe embeddings to each channel
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
stacked = stacked + tf_embeds stacked = stacked + tf_embeds
# Transformer # Apply Transformer
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool() src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
transformer_out = self.transformer(stacked, src_mask=src_mask.to(x.device)) transformer_out = self.transformer(stacked, src_mask=src_mask)
# Attention Pool # Attention Pooling over channels
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0) attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0)
aggregated = (transformer_out * attn_weights).sum(dim=0) aggregated = (transformer_out * attn_weights).sum(dim=0)
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze() return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
# --- Enhanced Data Processing --- # --- Enhanced Data Processing ---
# Here you need to have the helper functions get_aligned_candle_with_index and get_features_for_tf
# They must be defined elsewhere in your code.
class BacktestEnvironment: class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes):
self.candles_dict = candles_dict
self.base_tf = base_tf
self.timeframes = timeframes
self.current_index = 0 # Initialize step pointer
def reset(self):
self.current_index = 0
return self.get_state(self.current_index)
def get_state(self, index): def get_state(self, index):
"""Returns shape [num_channels, FEATURES_PER_CHANNEL]""" """Returns state as an array of shape [num_channels, FEATURES_PER_CHANNEL]."""
state_features = [] state_features = []
base_ts = self.candles_dict[self.base_tf][index]["timestamp"] base_ts = self.candles_dict[self.base_tf][index]["timestamp"]
@ -127,10 +139,29 @@ class BacktestEnvironment:
# Indicator channels (placeholder - implement your indicators) # Indicator channels (placeholder - implement your indicators)
for _ in range(NUM_INDICATORS): for _ in range(NUM_INDICATORS):
# Add indicator calculation here state_features.append([0.0] * FEATURES_PER_CHANNEL)
state_features.append([0.0]*FEATURES_PER_CHANNEL)
return np.array(state_features, dtype=np.float32) return np.array(state_features, dtype=np.float32)
def step(self, action):
"""
Advance the environment by one step.
Since this is for backtesting, action isn't used here.
Returns: current candle info, reward, next state, done, actual high, actual low.
"""
# Dummy implementation: you would generate targets based on your backtest logic.
done = (self.current_index >= len(self.candles_dict[self.base_tf]) - 2)
current_candle = self.candles_dict[self.base_tf][self.current_index]
# For example, take the next candle's high/low as targets
next_candle = self.candles_dict[self.base_tf][self.current_index + 1]
actual_high = next_candle["high"]
actual_low = next_candle["low"]
self.current_index += 1
next_state = self.get_state(self.current_index)
return current_candle, 0.0, next_state, done, actual_high, actual_low
def __len__(self):
return len(self.candles_dict[self.base_tf])
# --- Enhanced Training Loop --- # --- Enhanced Training Loop ---
def train_on_historical_data(env, model, device, args): def train_on_historical_data(env, model, device, args):
@ -143,24 +174,24 @@ def train_on_historical_data(env, model, device, args):
model.train() model.train()
while True: while True:
# Prepare batch # Prepare batch (here batch size is 1 for simplicity)
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device)
# Forward pass # Forward pass
pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high, pred_low = model(state_tensor, timeframe_ids)
# Get targets from next candle # Get target values from next candle (dummy targets from environment)
_, _, next_state, done, actual_high, actual_low = env.step(None) # Dummy action _, _, next_state, done, actual_high, actual_low = env.step(None) # Dummy action
target_high = torch.FloatTensor([actual_high]).to(device) target_high = torch.FloatTensor([actual_high]).to(device)
target_low = torch.FloatTensor([actual_low]).to(device) target_low = torch.FloatTensor([actual_low]).to(device)
# Custom loss # Custom loss: use absolute error scaled by 2
high_loss = torch.abs(pred_high - target_high) * 2 high_loss = torch.abs(pred_high - target_high) * 2
low_loss = torch.abs(pred_low - target_low) * 2 low_loss = torch.abs(pred_low - target_low) * 2
loss = (high_loss + low_loss).mean() loss = (high_loss + low_loss).mean()
# Backprop # Backpropagation
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
@ -176,7 +207,7 @@ def train_on_historical_data(env, model, device, args):
print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}") print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}")
save_checkpoint(model, epoch, total_loss) save_checkpoint(model, epoch, total_loss)
# --- Mode Handling --- # --- Mode Handling and Argument Parsing ---
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train') parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train')
@ -185,31 +216,50 @@ def parse_args():
parser.add_argument('--threshold', type=float, default=0.005) parser.add_argument('--threshold', type=float, default=0.005)
return parser.parse_args() return parser.parse_args()
def load_best_checkpoint(model, best_dir="models/best"):
# Dummy implementation for loading the best checkpoint.
# In real usage, check your saved checkpoints.
print("Loading best checkpoint (dummy implementation)")
# torch.load(...) can be invoked here.
return
def save_checkpoint(model, epoch, reward, last_dir="models/last", best_dir="models/best"):
# Dummy implementation for saving checkpoints.
print(f"Saving checkpoint for epoch {epoch}, reward: {reward:.4f}")
# --- Main Function ---
async def main(): async def main():
args = parse_args() args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model # Initialize model
model = TradingModel( model = TradingModel(
num_channels=NUM_TIMEFRAMES+NUM_INDICATORS, num_channels=NUM_TIMEFRAMES + NUM_INDICATORS,
num_timeframes=NUM_TIMEFRAMES num_timeframes=NUM_TIMEFRAMES
).to(device) ).to(device)
if args.mode == 'train': if args.mode == 'train':
# Initialize environment and train # Load historical candle data for backtesting
env = BacktestEnvironment(...) candles_dict = load_candles_cache("candles_cache.json")
if not candles_dict:
print("No historical candle data available for backtesting.")
return
base_tf = "1m" # Base timeframe
timeframes = ["1m", "5m", "15m", "1h", "1d"]
env = BacktestEnvironment(candles_dict, base_tf, timeframes)
train_on_historical_data(env, model, device, args) train_on_historical_data(env, model, device, args)
elif args.mode == 'live': elif args.mode == 'live':
# Load model and connect to live data # Load model and connect to live data
load_best_checkpoint(model) load_best_checkpoint(model)
while True: while True:
# Process live data # Process live data: fetch live candles, make predictions, execute trades
# Make predictions and execute trades print("Processing live data...")
await asyncio.sleep(1) await asyncio.sleep(1)
elif args.mode == 'inference': elif args.mode == 'inference':
# Load model and run inference # Load model and run inference
load_best_checkpoint(model) load_best_checkpoint(model)
# Generate signals without training print("Running inference...")
# Add your inference logic here
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())