fix training resume

This commit is contained in:
Dobromir Popov 2025-02-04 21:22:26 +02:00
parent 07047369c9
commit f220e5fc4d

View File

@ -14,7 +14,6 @@ import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
from datetime import datetime
import matplotlib.pyplot as plt
import ccxt.async_support as ccxt
@ -54,7 +53,7 @@ class PositionalEncoding(nn.Module):
class TradingModel(nn.Module):
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
super().__init__()
# Create a branch for each channel (each channel input has FEATURES_PER_CHANNEL features)
# Create branch for each channel
self.channel_branches = nn.ModuleList([
nn.Sequential(
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
@ -63,7 +62,7 @@ class TradingModel(nn.Module):
nn.Dropout(0.1)
) for _ in range(num_channels)
])
# IMPORTANT FIX: Use num_channels (total channels) instead of num_timeframes.
# Embedding for channels 0..num_channels-1.
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
self.pos_encoder = PositionalEncoding(hidden_dim)
encoder_layers = TransformerEncoderLayer(
@ -83,15 +82,15 @@ class TradingModel(nn.Module):
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, x, timeframe_ids):
# x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL]
# x: [batch_size, num_channels, FEATURES_PER_CHANNEL]
batch_size, num_channels, _ = x.shape
channel_outs = []
for i in range(num_channels):
channel_out = self.channel_branches[i](x[:, i, :])
channel_outs.append(channel_out)
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
# Use embedding for each channel (indices 0 to num_channels-1)
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
# Add an embedding for each channel.
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
stacked = stacked + tf_embeds
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
@ -169,14 +168,15 @@ def get_best_models(directory):
continue
return best_files
def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
def save_checkpoint(model, optimizer, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"model_state_dict": model.state_dict()
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}, last_path)
maintain_checkpoint_directory(last_dir, max_files=10)
best_models = get_best_models(best_dir)
@ -194,7 +194,8 @@ def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
torch.save({
"epoch": epoch,
"reward": reward,
"model_state_dict": model.state_dict()
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
}, best_path)
maintain_checkpoint_directory(best_dir, max_files=10)
print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}")
@ -252,11 +253,11 @@ class BacktestEnvironment:
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
if self.position is None:
if action == 2: # BUY signal: enter at next candle's open.
if action == 2: # BUY: enter at next candle's open.
entry_price = next_candle["open"]
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
else:
if action == 0: # SELL signal: exit at next candle's open.
if action == 0: # SELL: exit at next candle's open.
exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"]
trade = {
@ -279,18 +280,16 @@ class BacktestEnvironment:
return len(self.candles_dict[self.base_tf])
# --- Enhanced Training Loop ---
def train_on_historical_data(env, model, device, args, start_epoch=0):
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
for epoch in range(start_epoch, args.epochs):
state = env.reset()
total_loss = 0
model.train()
while True:
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device) # shape[0] == num_channels
timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids)
# Dummy targets from next candle's high/low
# Get targets from environment (dummy high/low from next candle)
_, _, next_state, done, actual_high, actual_low = env.step(None)
target_high = torch.FloatTensor([actual_high]).to(device)
target_low = torch.FloatTensor([actual_low]).to(device)
@ -307,7 +306,7 @@ def train_on_historical_data(env, model, device, args, start_epoch=0):
state = next_state
scheduler.step()
print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}")
save_checkpoint(model, epoch, total_loss)
save_checkpoint(model, optimizer, epoch, total_loss)
# --- Live Plotting Functions ---
def update_live_chart(ax, candles, trade_history):
@ -354,8 +353,8 @@ def parse_args():
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--threshold', type=float, default=0.005)
# New flag: if set, we start training from scratch (ignoring any saved checkpoints)
parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch ignoring saved checkpoints.')
# When set, training starts from scratch (ignoring saved checkpoints)
parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch.')
return parser.parse_args()
def random_action():
@ -364,7 +363,7 @@ def random_action():
# --- Main Function ---
async def main():
args = parse_args()
# Use GPU if available, else fallback to CPU
# Use GPU if available; else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
timeframes = ["1m", "5m", "15m", "1h", "1d"]
@ -381,6 +380,7 @@ async def main():
env = BacktestEnvironment(candles_dict, base_tf, timeframes)
start_epoch = 0
checkpoint = None
if not args.start_fresh:
checkpoint = load_best_checkpoint(model)
if checkpoint is not None:
@ -391,7 +391,14 @@ async def main():
else:
print("Starting training from scratch as requested.")
train_on_historical_data(env, model, device, args, start_epoch=start_epoch)
# Create optimizer and scheduler in main
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
if checkpoint is not None:
# Restore optimizer state for a true resume
optimizer.load_state_dict(checkpoint.get("optimizer_state_dict", {}))
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
elif args.mode == 'live':
load_best_checkpoint(model)
candles_dict = load_candles_cache(CACHE_FILE)
@ -411,7 +418,7 @@ async def main():
elif args.mode == 'inference':
load_best_checkpoint(model)
print("Running inference...")
# Place your inference logic here.
# Your inference logic goes here.
else:
print("Invalid mode specified.")