compile fix
This commit is contained in:
parent
f5b1692a82
commit
2ec75e66cb
@ -85,37 +85,49 @@ class TradingModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, timeframe_ids):
|
def forward(self, x, timeframe_ids):
|
||||||
# x shape: [batch_size, num_channels, features]
|
# x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL]
|
||||||
batch_size, num_channels, _ = x.shape
|
batch_size, num_channels, _ = x.shape
|
||||||
|
|
||||||
# Process each channel
|
# Process each channel through its branch
|
||||||
channel_outs = []
|
channel_outs = []
|
||||||
for i in range(num_channels):
|
for i in range(num_channels):
|
||||||
channel_out = self.channel_branches[i](x[:,i,:])
|
channel_out = self.channel_branches[i](x[:, i, :])
|
||||||
channel_outs.append(channel_out)
|
channel_outs.append(channel_out)
|
||||||
|
|
||||||
# Stack and add embeddings
|
# Stack and add timeframe embeddings
|
||||||
stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
|
stacked = torch.stack(channel_outs, dim=1) # [batch, channels, hidden]
|
||||||
stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden]
|
stacked = stacked.permute(1, 0, 2) # [channels, batch, hidden]
|
||||||
|
|
||||||
# Add timeframe embeddings
|
# Add timeframe embeddings to each channel
|
||||||
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
|
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
|
||||||
stacked = stacked + tf_embeds
|
stacked = stacked + tf_embeds
|
||||||
|
|
||||||
# Transformer
|
# Apply Transformer
|
||||||
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool()
|
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
|
||||||
transformer_out = self.transformer(stacked, src_mask=src_mask.to(x.device))
|
transformer_out = self.transformer(stacked, src_mask=src_mask)
|
||||||
|
|
||||||
# Attention Pool
|
# Attention Pooling over channels
|
||||||
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0)
|
attn_weights = torch.softmax(self.attn_pool(transformer_out), dim=0)
|
||||||
aggregated = (transformer_out * attn_weights).sum(dim=0)
|
aggregated = (transformer_out * attn_weights).sum(dim=0)
|
||||||
|
|
||||||
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
|
return self.high_pred(aggregated).squeeze(), self.low_pred(aggregated).squeeze()
|
||||||
|
|
||||||
# --- Enhanced Data Processing ---
|
# --- Enhanced Data Processing ---
|
||||||
|
# Here you need to have the helper functions get_aligned_candle_with_index and get_features_for_tf
|
||||||
|
# They must be defined elsewhere in your code.
|
||||||
class BacktestEnvironment:
|
class BacktestEnvironment:
|
||||||
|
def __init__(self, candles_dict, base_tf, timeframes):
|
||||||
|
self.candles_dict = candles_dict
|
||||||
|
self.base_tf = base_tf
|
||||||
|
self.timeframes = timeframes
|
||||||
|
self.current_index = 0 # Initialize step pointer
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.current_index = 0
|
||||||
|
return self.get_state(self.current_index)
|
||||||
|
|
||||||
def get_state(self, index):
|
def get_state(self, index):
|
||||||
"""Returns shape [num_channels, FEATURES_PER_CHANNEL]"""
|
"""Returns state as an array of shape [num_channels, FEATURES_PER_CHANNEL]."""
|
||||||
state_features = []
|
state_features = []
|
||||||
base_ts = self.candles_dict[self.base_tf][index]["timestamp"]
|
base_ts = self.candles_dict[self.base_tf][index]["timestamp"]
|
||||||
|
|
||||||
@ -127,10 +139,29 @@ class BacktestEnvironment:
|
|||||||
|
|
||||||
# Indicator channels (placeholder - implement your indicators)
|
# Indicator channels (placeholder - implement your indicators)
|
||||||
for _ in range(NUM_INDICATORS):
|
for _ in range(NUM_INDICATORS):
|
||||||
# Add indicator calculation here
|
state_features.append([0.0] * FEATURES_PER_CHANNEL)
|
||||||
state_features.append([0.0]*FEATURES_PER_CHANNEL)
|
|
||||||
|
|
||||||
return np.array(state_features, dtype=np.float32)
|
return np.array(state_features, dtype=np.float32)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
"""
|
||||||
|
Advance the environment by one step.
|
||||||
|
Since this is for backtesting, action isn't used here.
|
||||||
|
Returns: current candle info, reward, next state, done, actual high, actual low.
|
||||||
|
"""
|
||||||
|
# Dummy implementation: you would generate targets based on your backtest logic.
|
||||||
|
done = (self.current_index >= len(self.candles_dict[self.base_tf]) - 2)
|
||||||
|
current_candle = self.candles_dict[self.base_tf][self.current_index]
|
||||||
|
# For example, take the next candle's high/low as targets
|
||||||
|
next_candle = self.candles_dict[self.base_tf][self.current_index + 1]
|
||||||
|
actual_high = next_candle["high"]
|
||||||
|
actual_low = next_candle["low"]
|
||||||
|
self.current_index += 1
|
||||||
|
next_state = self.get_state(self.current_index)
|
||||||
|
return current_candle, 0.0, next_state, done, actual_high, actual_low
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.candles_dict[self.base_tf])
|
||||||
|
|
||||||
# --- Enhanced Training Loop ---
|
# --- Enhanced Training Loop ---
|
||||||
def train_on_historical_data(env, model, device, args):
|
def train_on_historical_data(env, model, device, args):
|
||||||
@ -143,24 +174,24 @@ def train_on_historical_data(env, model, device, args):
|
|||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# Prepare batch
|
# Prepare batch (here batch size is 1 for simplicity)
|
||||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||||
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||||
|
|
||||||
# Get targets from next candle
|
# Get target values from next candle (dummy targets from environment)
|
||||||
_, _, next_state, done, actual_high, actual_low = env.step(None) # Dummy action
|
_, _, next_state, done, actual_high, actual_low = env.step(None) # Dummy action
|
||||||
target_high = torch.FloatTensor([actual_high]).to(device)
|
target_high = torch.FloatTensor([actual_high]).to(device)
|
||||||
target_low = torch.FloatTensor([actual_low]).to(device)
|
target_low = torch.FloatTensor([actual_low]).to(device)
|
||||||
|
|
||||||
# Custom loss
|
# Custom loss: use absolute error scaled by 2
|
||||||
high_loss = torch.abs(pred_high - target_high) * 2
|
high_loss = torch.abs(pred_high - target_high) * 2
|
||||||
low_loss = torch.abs(pred_low - target_low) * 2
|
low_loss = torch.abs(pred_low - target_low) * 2
|
||||||
loss = (high_loss + low_loss).mean()
|
loss = (high_loss + low_loss).mean()
|
||||||
|
|
||||||
# Backprop
|
# Backpropagation
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
@ -176,7 +207,7 @@ def train_on_historical_data(env, model, device, args):
|
|||||||
print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}")
|
print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}")
|
||||||
save_checkpoint(model, epoch, total_loss)
|
save_checkpoint(model, epoch, total_loss)
|
||||||
|
|
||||||
# --- Mode Handling ---
|
# --- Mode Handling and Argument Parsing ---
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train')
|
parser.add_argument('--mode', choices=['train', 'live', 'inference'], default='train')
|
||||||
@ -185,31 +216,50 @@ def parse_args():
|
|||||||
parser.add_argument('--threshold', type=float, default=0.005)
|
parser.add_argument('--threshold', type=float, default=0.005)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def load_best_checkpoint(model, best_dir="models/best"):
|
||||||
|
# Dummy implementation for loading the best checkpoint.
|
||||||
|
# In real usage, check your saved checkpoints.
|
||||||
|
print("Loading best checkpoint (dummy implementation)")
|
||||||
|
# torch.load(...) can be invoked here.
|
||||||
|
return
|
||||||
|
|
||||||
|
def save_checkpoint(model, epoch, reward, last_dir="models/last", best_dir="models/best"):
|
||||||
|
# Dummy implementation for saving checkpoints.
|
||||||
|
print(f"Saving checkpoint for epoch {epoch}, reward: {reward:.4f}")
|
||||||
|
|
||||||
|
# --- Main Function ---
|
||||||
async def main():
|
async def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
model = TradingModel(
|
model = TradingModel(
|
||||||
num_channels=NUM_TIMEFRAMES+NUM_INDICATORS,
|
num_channels=NUM_TIMEFRAMES + NUM_INDICATORS,
|
||||||
num_timeframes=NUM_TIMEFRAMES
|
num_timeframes=NUM_TIMEFRAMES
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
if args.mode == 'train':
|
if args.mode == 'train':
|
||||||
# Initialize environment and train
|
# Load historical candle data for backtesting
|
||||||
env = BacktestEnvironment(...)
|
candles_dict = load_candles_cache("candles_cache.json")
|
||||||
|
if not candles_dict:
|
||||||
|
print("No historical candle data available for backtesting.")
|
||||||
|
return
|
||||||
|
base_tf = "1m" # Base timeframe
|
||||||
|
timeframes = ["1m", "5m", "15m", "1h", "1d"]
|
||||||
|
env = BacktestEnvironment(candles_dict, base_tf, timeframes)
|
||||||
train_on_historical_data(env, model, device, args)
|
train_on_historical_data(env, model, device, args)
|
||||||
elif args.mode == 'live':
|
elif args.mode == 'live':
|
||||||
# Load model and connect to live data
|
# Load model and connect to live data
|
||||||
load_best_checkpoint(model)
|
load_best_checkpoint(model)
|
||||||
while True:
|
while True:
|
||||||
# Process live data
|
# Process live data: fetch live candles, make predictions, execute trades
|
||||||
# Make predictions and execute trades
|
print("Processing live data...")
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
elif args.mode == 'inference':
|
elif args.mode == 'inference':
|
||||||
# Load model and run inference
|
# Load model and run inference
|
||||||
load_best_checkpoint(model)
|
load_best_checkpoint(model)
|
||||||
# Generate signals without training
|
print("Running inference...")
|
||||||
|
# Add your inference logic here
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
Loading…
x
Reference in New Issue
Block a user