wip
This commit is contained in:
parent
070d58f2bf
commit
07047369c9
@ -239,7 +239,7 @@ class BacktestEnvironment:
|
|||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
base_candles = self.candles_dict[self.base_tf]
|
base_candles = self.candles_dict[self.base_tf]
|
||||||
# Handle end-of-data scenario: return dummy high/low values if needed.
|
# End-of-data: return dummy high/low targets
|
||||||
if self.current_index >= len(base_candles) - 1:
|
if self.current_index >= len(base_candles) - 1:
|
||||||
current_state = self.get_state(self.current_index)
|
current_state = self.get_state(self.current_index)
|
||||||
return current_state, 0.0, None, True, 0.0, 0.0
|
return current_state, 0.0, None, True, 0.0, 0.0
|
||||||
@ -247,7 +247,6 @@ class BacktestEnvironment:
|
|||||||
current_state = self.get_state(self.current_index)
|
current_state = self.get_state(self.current_index)
|
||||||
next_index = self.current_index + 1
|
next_index = self.current_index + 1
|
||||||
next_state = self.get_state(next_index)
|
next_state = self.get_state(next_index)
|
||||||
current_candle = base_candles[self.current_index]
|
|
||||||
next_candle = base_candles[next_index]
|
next_candle = base_candles[next_index]
|
||||||
reward = 0.0
|
reward = 0.0
|
||||||
|
|
||||||
@ -272,31 +271,31 @@ class BacktestEnvironment:
|
|||||||
|
|
||||||
self.current_index = next_index
|
self.current_index = next_index
|
||||||
done = (self.current_index >= len(base_candles) - 1)
|
done = (self.current_index >= len(base_candles) - 1)
|
||||||
# Return the high and low of the next candle as the targets.
|
|
||||||
actual_high = next_candle["high"]
|
actual_high = next_candle["high"]
|
||||||
actual_low = next_candle["low"]
|
actual_low = next_candle["low"]
|
||||||
return current_state, reward, next_state, done, actual_high, actual_low
|
return current_state, reward, next_state, done, actual_high, actual_low
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.candles_dict[self.base_tf])
|
return len(self.candles_dict[self.base_tf])
|
||||||
|
|
||||||
# --- Enhanced Training Loop ---
|
# --- Enhanced Training Loop ---
|
||||||
def train_on_historical_data(env, model, device, args):
|
def train_on_historical_data(env, model, device, args, start_epoch=0):
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
|
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
state = env.reset()
|
state = env.reset()
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
model.train()
|
model.train()
|
||||||
while True:
|
while True:
|
||||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||||
timeframe_ids = torch.arange(state.shape[0]).to(device) # Expect shape[0]==num_channels
|
timeframe_ids = torch.arange(state.shape[0]).to(device) # shape[0] == num_channels
|
||||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||||
# Dummy targets from next candle's high/low
|
# Dummy targets from next candle's high/low
|
||||||
_, _, next_state, done, actual_high, actual_low = env.step(None)
|
_, _, next_state, done, actual_high, actual_low = env.step(None)
|
||||||
target_high = torch.FloatTensor([actual_high]).to(device)
|
target_high = torch.FloatTensor([actual_high]).to(device)
|
||||||
target_low = torch.FloatTensor([actual_low]).to(device)
|
target_low = torch.FloatTensor([actual_low]).to(device)
|
||||||
high_loss = torch.abs(pred_high - target_high) * 2
|
high_loss = torch.abs(pred_high - target_high) * 2
|
||||||
low_loss = torch.abs(pred_low - target_low) * 2
|
low_loss = torch.abs(pred_low - target_low) * 2
|
||||||
loss = (high_loss + low_loss).mean()
|
loss = (high_loss + low_loss).mean()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -355,6 +354,7 @@ def parse_args():
|
|||||||
parser.add_argument('--epochs', type=int, default=100)
|
parser.add_argument('--epochs', type=int, default=100)
|
||||||
parser.add_argument('--lr', type=float, default=3e-4)
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--threshold', type=float, default=0.005)
|
parser.add_argument('--threshold', type=float, default=0.005)
|
||||||
|
# New flag: if set, we start training from scratch (ignoring any saved checkpoints)
|
||||||
parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch ignoring saved checkpoints.')
|
parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch ignoring saved checkpoints.')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -364,12 +364,11 @@ def random_action():
|
|||||||
# --- Main Function ---
|
# --- Main Function ---
|
||||||
async def main():
|
async def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
# Use GPU if available, else fallback to CPU
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print("Using device:", device)
|
||||||
timeframes = ["1m", "5m", "15m", "1h", "1d"]
|
timeframes = ["1m", "5m", "15m", "1h", "1d"]
|
||||||
input_dim = len(timeframes) * 7 # 7 features per timeframe
|
|
||||||
hidden_dim = 128
|
hidden_dim = 128
|
||||||
output_dim = 3 # SELL, HOLD, BUY
|
|
||||||
# Set total number of channels = NUM_TIMEFRAMES + NUM_INDICATORS.
|
|
||||||
total_channels = NUM_TIMEFRAMES + NUM_INDICATORS
|
total_channels = NUM_TIMEFRAMES + NUM_INDICATORS
|
||||||
model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device)
|
model = TradingModel(total_channels, NUM_TIMEFRAMES).to(device)
|
||||||
|
|
||||||
@ -380,6 +379,7 @@ async def main():
|
|||||||
return
|
return
|
||||||
base_tf = "1m"
|
base_tf = "1m"
|
||||||
env = BacktestEnvironment(candles_dict, base_tf, timeframes)
|
env = BacktestEnvironment(candles_dict, base_tf, timeframes)
|
||||||
|
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
if not args.start_fresh:
|
if not args.start_fresh:
|
||||||
checkpoint = load_best_checkpoint(model)
|
checkpoint = load_best_checkpoint(model)
|
||||||
@ -403,10 +403,10 @@ async def main():
|
|||||||
preview_thread.start()
|
preview_thread.start()
|
||||||
print("Starting live trading loop. (Using random actions for simulation.)")
|
print("Starting live trading loop. (Using random actions for simulation.)")
|
||||||
while True:
|
while True:
|
||||||
state, reward, next_state, done = env.step(random_action())
|
state, reward, next_state, done, _, _ = env.step(random_action())
|
||||||
if done:
|
if done:
|
||||||
print("Reached end of simulated data, resetting environment.")
|
print("Reached end of simulated data, resetting environment.")
|
||||||
state = env.reset(clear_trade_history=False)
|
state = env.reset()
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
elif args.mode == 'inference':
|
elif args.mode == 'inference':
|
||||||
load_best_checkpoint(model)
|
load_best_checkpoint(model)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user