From f220e5fc4d22122e48157ae05c24920c4aecde67 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Tue, 4 Feb 2025 21:22:26 +0200 Subject: [PATCH] fix training resume --- crypto/brian/index-deep-new.py | 51 +++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/crypto/brian/index-deep-new.py b/crypto/brian/index-deep-new.py index e6423eb..855b6b0 100644 --- a/crypto/brian/index-deep-new.py +++ b/crypto/brian/index-deep-new.py @@ -14,7 +14,6 @@ import numpy as np import torch import torch.nn as nn import torch.optim as optim -from collections import deque from datetime import datetime import matplotlib.pyplot as plt import ccxt.async_support as ccxt @@ -54,7 +53,7 @@ class PositionalEncoding(nn.Module): class TradingModel(nn.Module): def __init__(self, num_channels, num_timeframes, hidden_dim=128): super().__init__() - # Create a branch for each channel (each channel input has FEATURES_PER_CHANNEL features) + # Create branch for each channel self.channel_branches = nn.ModuleList([ nn.Sequential( nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), @@ -63,7 +62,7 @@ class TradingModel(nn.Module): nn.Dropout(0.1) ) for _ in range(num_channels) ]) - # IMPORTANT FIX: Use num_channels (total channels) instead of num_timeframes. + # Embedding for channels 0..num_channels-1. self.timeframe_embed = nn.Embedding(num_channels, hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim) encoder_layers = TransformerEncoderLayer( @@ -83,15 +82,15 @@ class TradingModel(nn.Module): nn.Linear(hidden_dim // 2, 1) ) def forward(self, x, timeframe_ids): - # x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL] + # x: [batch_size, num_channels, FEATURES_PER_CHANNEL] batch_size, num_channels, _ = x.shape channel_outs = [] for i in range(num_channels): channel_out = self.channel_branches[i](x[:, i, :]) channel_outs.append(channel_out) stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden] - stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden] - # Use embedding for each channel (indices 0 to num_channels-1) + stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden] + # Add an embedding for each channel. tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) stacked = stacked + tf_embeds src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device) @@ -169,14 +168,15 @@ def get_best_models(directory): continue return best_files -def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR): +def save_checkpoint(model, optimizer, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_path = os.path.join(last_dir, last_filename) torch.save({ "epoch": epoch, "reward": reward, - "model_state_dict": model.state_dict() + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict() }, last_path) maintain_checkpoint_directory(last_dir, max_files=10) best_models = get_best_models(best_dir) @@ -194,7 +194,8 @@ def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR): torch.save({ "epoch": epoch, "reward": reward, - "model_state_dict": model.state_dict() + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict() }, best_path) maintain_checkpoint_directory(best_dir, max_files=10) print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}") @@ -252,11 +253,11 @@ class BacktestEnvironment: # Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY. if self.position is None: - if action == 2: # BUY signal: enter at next candle's open. + if action == 2: # BUY: enter at next candle's open. entry_price = next_candle["open"] self.position = {"entry_price": entry_price, "entry_index": self.current_index} else: - if action == 0: # SELL signal: exit at next candle's open. + if action == 0: # SELL: exit at next candle's open. exit_price = next_candle["open"] reward = exit_price - self.position["entry_price"] trade = { @@ -279,18 +280,16 @@ class BacktestEnvironment: return len(self.candles_dict[self.base_tf]) # --- Enhanced Training Loop --- -def train_on_historical_data(env, model, device, args, start_epoch=0): - optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch) +def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler): for epoch in range(start_epoch, args.epochs): state = env.reset() total_loss = 0 model.train() while True: state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) - timeframe_ids = torch.arange(state.shape[0]).to(device) # shape[0] == num_channels + timeframe_ids = torch.arange(state.shape[0]).to(device) pred_high, pred_low = model(state_tensor, timeframe_ids) - # Dummy targets from next candle's high/low + # Get targets from environment (dummy high/low from next candle) _, _, next_state, done, actual_high, actual_low = env.step(None) target_high = torch.FloatTensor([actual_high]).to(device) target_low = torch.FloatTensor([actual_low]).to(device) @@ -307,7 +306,7 @@ def train_on_historical_data(env, model, device, args, start_epoch=0): state = next_state scheduler.step() print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}") - save_checkpoint(model, epoch, total_loss) + save_checkpoint(model, optimizer, epoch, total_loss) # --- Live Plotting Functions --- def update_live_chart(ax, candles, trade_history): @@ -354,8 +353,8 @@ def parse_args(): parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--threshold', type=float, default=0.005) - # New flag: if set, we start training from scratch (ignoring any saved checkpoints) - parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch ignoring saved checkpoints.') + # When set, training starts from scratch (ignoring saved checkpoints) + parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch.') return parser.parse_args() def random_action(): @@ -364,7 +363,7 @@ def random_action(): # --- Main Function --- async def main(): args = parse_args() - # Use GPU if available, else fallback to CPU + # Use GPU if available; else CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) timeframes = ["1m", "5m", "15m", "1h", "1d"] @@ -381,6 +380,7 @@ async def main(): env = BacktestEnvironment(candles_dict, base_tf, timeframes) start_epoch = 0 + checkpoint = None if not args.start_fresh: checkpoint = load_best_checkpoint(model) if checkpoint is not None: @@ -391,7 +391,14 @@ async def main(): else: print("Starting training from scratch as requested.") - train_on_historical_data(env, model, device, args, start_epoch=start_epoch) + # Create optimizer and scheduler in main + optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch) + if checkpoint is not None: + # Restore optimizer state for a true resume + optimizer.load_state_dict(checkpoint.get("optimizer_state_dict", {})) + train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler) + elif args.mode == 'live': load_best_checkpoint(model) candles_dict = load_candles_cache(CACHE_FILE) @@ -411,7 +418,7 @@ async def main(): elif args.mode == 'inference': load_best_checkpoint(model) print("Running inference...") - # Place your inference logic here. + # Your inference logic goes here. else: print("Invalid mode specified.")