wip
This commit is contained in:
parent
070d58f2bf
commit
07047369c9
@ -239,7 +239,7 @@ class BacktestEnvironment:
|
||||
|
||||
def step(self, action):
|
||||
base_candles = self.candles_dict[self.base_tf]
|
||||
# Handle end-of-data scenario: return dummy high/low values if needed.
|
||||
# End-of-data: return dummy high/low targets
|
||||
if self.current_index >= len(base_candles) - 1:
|
||||
current_state = self.get_state(self.current_index)
|
||||
return current_state, 0.0, None, True, 0.0, 0.0
|
||||
@ -247,7 +247,6 @@ class BacktestEnvironment:
|
||||
current_state = self.get_state(self.current_index)
|
||||
next_index = self.current_index + 1
|
||||
next_state = self.get_state(next_index)
|
||||
current_candle = base_candles[self.current_index]
|
||||
next_candle = base_candles[next_index]
|
||||
reward = 0.0
|
||||
|
||||
@ -272,24 +271,24 @@ class BacktestEnvironment:
|
||||
|
||||
self.current_index = next_index
|
||||
done = (self.current_index >= len(base_candles) - 1)
|
||||
# Return the high and low of the next candle as the targets.
|
||||
actual_high = next_candle["high"]
|
||||
actual_low = next_candle["low"]
|
||||
return current_state, reward, next_state, done, actual_high, actual_low
|
||||
|
||||
def __len__(self):
|
||||
return len(self.candles_dict[self.base_tf])
|
||||
|
||||
# --- Enhanced Training Loop ---
|
||||
def train_on_historical_data(env, model, device, args):
|
||||
def train_on_historical_data(env, model, device, args, start_epoch=0):
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
||||
for epoch in range(args.epochs):
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
state = env.reset()
|
||||
total_loss = 0
|
||||
model.train()
|
||||
while True:
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||
timeframe_ids = torch.arange(state.shape[0]).to(device) # Expect shape[0]==num_channels
|
||||
timeframe_ids = torch.arange(state.shape[0]).to(device) # shape[0] == num_channels
|
||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||
# Dummy targets from next candle's high/low
|
||||
_, _, next_state, done, actual_high, actual_low = env.step(None)
|
||||
@ -355,6 +354,7 @@ def parse_args():
|
||||
parser.add_argument('--epochs', type=int, default=100)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--threshold', type=float, default=0.005)
|
||||
# New flag: if set, we start training from scratch (ignoring any saved checkpoints)
|
||||
parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch ignoring saved checkpoints.')
|
||||
return parser.parse_args()
|
||||
|
||||
@ -364,12 +364,11 @@ def random_action():
|
||||
# --- Main Function ---
|
||||
async def main():
|
||||
args = parse_args()
|
||||
# Use GPU if available, else fallback to CPU
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print("Using device:", device)
|
||||
timeframes = ["1m", "5m", "15m", "1h", "1d"]
|
||||
input_dim = len(timeframes) * 7 # 7 features per timeframe
|
||||
hidden_dim = 128
|
||||
output_dim = 3 # SELL, HOLD, BUY
|
||||
# Set total number of channels = NUM_TIMEFRAMES + NUM_INDICATORS.
|
||||
total_channels = NUM_TIMEFRAMES + NUM_INDICATORS
|
||||
model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device)
|
||||
|
||||
@ -380,6 +379,7 @@ async def main():
|
||||
return
|
||||
base_tf = "1m"
|
||||
env = BacktestEnvironment(candles_dict, base_tf, timeframes)
|
||||
|
||||
start_epoch = 0
|
||||
if not args.start_fresh:
|
||||
checkpoint = load_best_checkpoint(model)
|
||||
@ -403,10 +403,10 @@ async def main():
|
||||
preview_thread.start()
|
||||
print("Starting live trading loop. (Using random actions for simulation.)")
|
||||
while True:
|
||||
state, reward, next_state, done = env.step(random_action())
|
||||
state, reward, next_state, done, _, _ = env.step(random_action())
|
||||
if done:
|
||||
print("Reached end of simulated data, resetting environment.")
|
||||
state = env.reset(clear_trade_history=False)
|
||||
state = env.reset()
|
||||
await asyncio.sleep(1)
|
||||
elif args.mode == 'inference':
|
||||
load_best_checkpoint(model)
|
||||
|
Loading…
x
Reference in New Issue
Block a user