fix training resume

This commit is contained in:
Dobromir Popov 2025-02-04 21:22:26 +02:00
parent 07047369c9
commit f220e5fc4d

View File

@ -14,7 +14,6 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from collections import deque
from datetime import datetime from datetime import datetime
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import ccxt.async_support as ccxt import ccxt.async_support as ccxt
@ -54,7 +53,7 @@ class PositionalEncoding(nn.Module):
class TradingModel(nn.Module): class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128): def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__() super().__init__()
# Create a branch for each channel (each channel input has FEATURES_PER_CHANNEL features) # Create branch for each channel
self.channel_branches = nn.ModuleList([ self.channel_branches = nn.ModuleList([
nn.Sequential( nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim), nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
@ -63,7 +62,7 @@ class TradingModel(nn.Module):
nn.Dropout(0.1) nn.Dropout(0.1)
) for _ in range(num_channels) ) for _ in range(num_channels)
]) ])
# IMPORTANT FIX: Use num_channels (total channels) instead of num_timeframes. # Embedding for channels 0..num_channels-1.
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim) self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
self.pos_encoder = PositionalEncoding(hidden_dim) self.pos_encoder = PositionalEncoding(hidden_dim)
encoder_layers = TransformerEncoderLayer( encoder_layers = TransformerEncoderLayer(
@ -83,15 +82,15 @@ class TradingModel(nn.Module):
nn.Linear(hidden_dim // 2, 1) nn.Linear(hidden_dim // 2, 1)
) )
def forward(self, x, timeframe_ids): def forward(self, x, timeframe_ids):
# x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL] # x: [batch_size, num_channels, FEATURES_PER_CHANNEL]
batch_size, num_channels, _ = x.shape batch_size, num_channels, _ = x.shape
channel_outs = [] channel_outs = []
for i in range(num_channels): for i in range(num_channels):
channel_out = self.channel_branches[i](x[:, i, :]) channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out) channel_outs.append(channel_out)
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden] stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden] stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
# Use embedding for each channel (indices 0 to num_channels-1) # Add an embedding for each channel.
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1) tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
stacked = stacked + tf_embeds stacked = stacked + tf_embeds
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device) src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
@ -169,14 +168,15 @@ def get_best_models(directory):
continue continue
return best_files return best_files
def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR): def save_checkpoint(model, optimizer, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename) last_path = os.path.join(last_dir, last_filename)
torch.save({ torch.save({
"epoch": epoch, "epoch": epoch,
"reward": reward, "reward": reward,
"model_state_dict": model.state_dict() "model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}, last_path) }, last_path)
maintain_checkpoint_directory(last_dir, max_files=10) maintain_checkpoint_directory(last_dir, max_files=10)
best_models = get_best_models(best_dir) best_models = get_best_models(best_dir)
@ -194,7 +194,8 @@ def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
torch.save({ torch.save({
"epoch": epoch, "epoch": epoch,
"reward": reward, "reward": reward,
"model_state_dict": model.state_dict() "model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}, best_path) }, best_path)
maintain_checkpoint_directory(best_dir, max_files=10) maintain_checkpoint_directory(best_dir, max_files=10)
print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}") print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}")
@ -252,11 +253,11 @@ class BacktestEnvironment:
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY. # Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
if self.position is None: if self.position is None:
if action == 2: # BUY signal: enter at next candle's open. if action == 2: # BUY: enter at next candle's open.
entry_price = next_candle["open"] entry_price = next_candle["open"]
self.position = {"entry_price": entry_price, "entry_index": self.current_index} self.position = {"entry_price": entry_price, "entry_index": self.current_index}
else: else:
if action == 0: # SELL signal: exit at next candle's open. if action == 0: # SELL: exit at next candle's open.
exit_price = next_candle["open"] exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"] reward = exit_price - self.position["entry_price"]
trade = { trade = {
@ -279,18 +280,16 @@ class BacktestEnvironment:
return len(self.candles_dict[self.base_tf]) return len(self.candles_dict[self.base_tf])
# --- Enhanced Training Loop --- # --- Enhanced Training Loop ---
def train_on_historical_data(env, model, device, args, start_epoch=0): def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
state = env.reset() state = env.reset()
total_loss = 0 total_loss = 0
model.train() model.train()
while True: while True:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device) # shape[0] == num_channels timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high, pred_low = model(state_tensor, timeframe_ids)
# Dummy targets from next candle's high/low # Get targets from environment (dummy high/low from next candle)
_, _, next_state, done, actual_high, actual_low = env.step(None) _, _, next_state, done, actual_high, actual_low = env.step(None)
target_high = torch.FloatTensor([actual_high]).to(device) target_high = torch.FloatTensor([actual_high]).to(device)
target_low = torch.FloatTensor([actual_low]).to(device) target_low = torch.FloatTensor([actual_low]).to(device)
@ -307,7 +306,7 @@ def train_on_historical_data(env, model, device, args, start_epoch=0):
state = next_state state = next_state
scheduler.step() scheduler.step()
print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}") print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}")
save_checkpoint(model, epoch, total_loss) save_checkpoint(model, optimizer, epoch, total_loss)
# --- Live Plotting Functions --- # --- Live Plotting Functions ---
def update_live_chart(ax, candles, trade_history): def update_live_chart(ax, candles, trade_history):
@ -354,8 +353,8 @@ def parse_args():
parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005) parser.add_argument('--threshold', type=float, default=0.005)
# New flag: if set, we start training from scratch (ignoring any saved checkpoints) # When set, training starts from scratch (ignoring saved checkpoints)
parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch ignoring saved checkpoints.') parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch.')
return parser.parse_args() return parser.parse_args()
def random_action(): def random_action():
@ -364,7 +363,7 @@ def random_action():
# --- Main Function --- # --- Main Function ---
async def main(): async def main():
args = parse_args() args = parse_args()
# Use GPU if available, else fallback to CPU # Use GPU if available; else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device) print("Using device:", device)
timeframes = ["1m", "5m", "15m", "1h", "1d"] timeframes = ["1m", "5m", "15m", "1h", "1d"]
@ -381,6 +380,7 @@ async def main():
env = BacktestEnvironment(candles_dict, base_tf, timeframes) env = BacktestEnvironment(candles_dict, base_tf, timeframes)
start_epoch = 0 start_epoch = 0
checkpoint = None
if not args.start_fresh: if not args.start_fresh:
checkpoint = load_best_checkpoint(model) checkpoint = load_best_checkpoint(model)
if checkpoint is not None: if checkpoint is not None:
@ -391,7 +391,14 @@ async def main():
else: else:
print("Starting training from scratch as requested.") print("Starting training from scratch as requested.")
train_on_historical_data(env, model, device, args, start_epoch=start_epoch) # Create optimizer and scheduler in main
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
if checkpoint is not None:
# Restore optimizer state for a true resume
optimizer.load_state_dict(checkpoint.get("optimizer_state_dict", {}))
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
elif args.mode == 'live': elif args.mode == 'live':
load_best_checkpoint(model) load_best_checkpoint(model)
candles_dict = load_candles_cache(CACHE_FILE) candles_dict = load_candles_cache(CACHE_FILE)
@ -411,7 +418,7 @@ async def main():
elif args.mode == 'inference': elif args.mode == 'inference':
load_best_checkpoint(model) load_best_checkpoint(model)
print("Running inference...") print("Running inference...")
# Place your inference logic here. # Your inference logic goes here.
else: else:
print("Invalid mode specified.") print("Invalid mode specified.")