fix training resume
This commit is contained in:
parent
07047369c9
commit
f220e5fc4d
@ -14,7 +14,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from collections import deque
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import ccxt.async_support as ccxt
|
import ccxt.async_support as ccxt
|
||||||
@ -54,7 +53,7 @@ class PositionalEncoding(nn.Module):
|
|||||||
class TradingModel(nn.Module):
|
class TradingModel(nn.Module):
|
||||||
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
|
def __init__(self, num_channels, num_timeframes, hidden_dim=128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Create a branch for each channel (each channel input has FEATURES_PER_CHANNEL features)
|
# Create branch for each channel
|
||||||
self.channel_branches = nn.ModuleList([
|
self.channel_branches = nn.ModuleList([
|
||||||
nn.Sequential(
|
nn.Sequential(
|
||||||
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
|
nn.Linear(FEATURES_PER_CHANNEL, hidden_dim),
|
||||||
@ -63,7 +62,7 @@ class TradingModel(nn.Module):
|
|||||||
nn.Dropout(0.1)
|
nn.Dropout(0.1)
|
||||||
) for _ in range(num_channels)
|
) for _ in range(num_channels)
|
||||||
])
|
])
|
||||||
# IMPORTANT FIX: Use num_channels (total channels) instead of num_timeframes.
|
# Embedding for channels 0..num_channels-1.
|
||||||
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
|
self.timeframe_embed = nn.Embedding(num_channels, hidden_dim)
|
||||||
self.pos_encoder = PositionalEncoding(hidden_dim)
|
self.pos_encoder = PositionalEncoding(hidden_dim)
|
||||||
encoder_layers = TransformerEncoderLayer(
|
encoder_layers = TransformerEncoderLayer(
|
||||||
@ -83,15 +82,15 @@ class TradingModel(nn.Module):
|
|||||||
nn.Linear(hidden_dim // 2, 1)
|
nn.Linear(hidden_dim // 2, 1)
|
||||||
)
|
)
|
||||||
def forward(self, x, timeframe_ids):
|
def forward(self, x, timeframe_ids):
|
||||||
# x shape: [batch_size, num_channels, FEATURES_PER_CHANNEL]
|
# x: [batch_size, num_channels, FEATURES_PER_CHANNEL]
|
||||||
batch_size, num_channels, _ = x.shape
|
batch_size, num_channels, _ = x.shape
|
||||||
channel_outs = []
|
channel_outs = []
|
||||||
for i in range(num_channels):
|
for i in range(num_channels):
|
||||||
channel_out = self.channel_branches[i](x[:, i, :])
|
channel_out = self.channel_branches[i](x[:, i, :])
|
||||||
channel_outs.append(channel_out)
|
channel_outs.append(channel_out)
|
||||||
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
|
stacked = torch.stack(channel_outs, dim=1) # shape: [batch, channels, hidden]
|
||||||
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
|
stacked = stacked.permute(1, 0, 2) # shape: [channels, batch, hidden]
|
||||||
# Use embedding for each channel (indices 0 to num_channels-1)
|
# Add an embedding for each channel.
|
||||||
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
|
tf_embeds = self.timeframe_embed(timeframe_ids).unsqueeze(1)
|
||||||
stacked = stacked + tf_embeds
|
stacked = stacked + tf_embeds
|
||||||
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
|
src_mask = torch.triu(torch.ones(stacked.size(0), stacked.size(0)), diagonal=1).bool().to(x.device)
|
||||||
@ -169,14 +168,15 @@ def get_best_models(directory):
|
|||||||
continue
|
continue
|
||||||
return best_files
|
return best_files
|
||||||
|
|
||||||
def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
|
def save_checkpoint(model, optimizer, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
|
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
|
||||||
last_path = os.path.join(last_dir, last_filename)
|
last_path = os.path.join(last_dir, last_filename)
|
||||||
torch.save({
|
torch.save({
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
"reward": reward,
|
"reward": reward,
|
||||||
"model_state_dict": model.state_dict()
|
"model_state_dict": model.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict()
|
||||||
}, last_path)
|
}, last_path)
|
||||||
maintain_checkpoint_directory(last_dir, max_files=10)
|
maintain_checkpoint_directory(last_dir, max_files=10)
|
||||||
best_models = get_best_models(best_dir)
|
best_models = get_best_models(best_dir)
|
||||||
@ -194,7 +194,8 @@ def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
|
|||||||
torch.save({
|
torch.save({
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
"reward": reward,
|
"reward": reward,
|
||||||
"model_state_dict": model.state_dict()
|
"model_state_dict": model.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict()
|
||||||
}, best_path)
|
}, best_path)
|
||||||
maintain_checkpoint_directory(best_dir, max_files=10)
|
maintain_checkpoint_directory(best_dir, max_files=10)
|
||||||
print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}")
|
print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}")
|
||||||
@ -252,11 +253,11 @@ class BacktestEnvironment:
|
|||||||
|
|
||||||
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
|
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
|
||||||
if self.position is None:
|
if self.position is None:
|
||||||
if action == 2: # BUY signal: enter at next candle's open.
|
if action == 2: # BUY: enter at next candle's open.
|
||||||
entry_price = next_candle["open"]
|
entry_price = next_candle["open"]
|
||||||
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
|
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
|
||||||
else:
|
else:
|
||||||
if action == 0: # SELL signal: exit at next candle's open.
|
if action == 0: # SELL: exit at next candle's open.
|
||||||
exit_price = next_candle["open"]
|
exit_price = next_candle["open"]
|
||||||
reward = exit_price - self.position["entry_price"]
|
reward = exit_price - self.position["entry_price"]
|
||||||
trade = {
|
trade = {
|
||||||
@ -279,18 +280,16 @@ class BacktestEnvironment:
|
|||||||
return len(self.candles_dict[self.base_tf])
|
return len(self.candles_dict[self.base_tf])
|
||||||
|
|
||||||
# --- Enhanced Training Loop ---
|
# --- Enhanced Training Loop ---
|
||||||
def train_on_historical_data(env, model, device, args, start_epoch=0):
|
def train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler):
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
|
|
||||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
|
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
state = env.reset()
|
state = env.reset()
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
model.train()
|
model.train()
|
||||||
while True:
|
while True:
|
||||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||||
timeframe_ids = torch.arange(state.shape[0]).to(device) # shape[0] == num_channels
|
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
||||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||||
# Dummy targets from next candle's high/low
|
# Get targets from environment (dummy high/low from next candle)
|
||||||
_, _, next_state, done, actual_high, actual_low = env.step(None)
|
_, _, next_state, done, actual_high, actual_low = env.step(None)
|
||||||
target_high = torch.FloatTensor([actual_high]).to(device)
|
target_high = torch.FloatTensor([actual_high]).to(device)
|
||||||
target_low = torch.FloatTensor([actual_low]).to(device)
|
target_low = torch.FloatTensor([actual_low]).to(device)
|
||||||
@ -307,7 +306,7 @@ def train_on_historical_data(env, model, device, args, start_epoch=0):
|
|||||||
state = next_state
|
state = next_state
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}")
|
print(f"Epoch {epoch+1} Loss: {total_loss/len(env):.4f}")
|
||||||
save_checkpoint(model, epoch, total_loss)
|
save_checkpoint(model, optimizer, epoch, total_loss)
|
||||||
|
|
||||||
# --- Live Plotting Functions ---
|
# --- Live Plotting Functions ---
|
||||||
def update_live_chart(ax, candles, trade_history):
|
def update_live_chart(ax, candles, trade_history):
|
||||||
@ -354,8 +353,8 @@ def parse_args():
|
|||||||
parser.add_argument('--epochs', type=int, default=100)
|
parser.add_argument('--epochs', type=int, default=100)
|
||||||
parser.add_argument('--lr', type=float, default=3e-4)
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--threshold', type=float, default=0.005)
|
parser.add_argument('--threshold', type=float, default=0.005)
|
||||||
# New flag: if set, we start training from scratch (ignoring any saved checkpoints)
|
# When set, training starts from scratch (ignoring saved checkpoints)
|
||||||
parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch ignoring saved checkpoints.')
|
parser.add_argument('--start_fresh', action='store_true', help='Start training from scratch.')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
def random_action():
|
def random_action():
|
||||||
@ -364,7 +363,7 @@ def random_action():
|
|||||||
# --- Main Function ---
|
# --- Main Function ---
|
||||||
async def main():
|
async def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
# Use GPU if available, else fallback to CPU
|
# Use GPU if available; else CPU
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
print("Using device:", device)
|
print("Using device:", device)
|
||||||
timeframes = ["1m", "5m", "15m", "1h", "1d"]
|
timeframes = ["1m", "5m", "15m", "1h", "1d"]
|
||||||
@ -381,6 +380,7 @@ async def main():
|
|||||||
env = BacktestEnvironment(candles_dict, base_tf, timeframes)
|
env = BacktestEnvironment(candles_dict, base_tf, timeframes)
|
||||||
|
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
|
checkpoint = None
|
||||||
if not args.start_fresh:
|
if not args.start_fresh:
|
||||||
checkpoint = load_best_checkpoint(model)
|
checkpoint = load_best_checkpoint(model)
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
@ -391,7 +391,14 @@ async def main():
|
|||||||
else:
|
else:
|
||||||
print("Starting training from scratch as requested.")
|
print("Starting training from scratch as requested.")
|
||||||
|
|
||||||
train_on_historical_data(env, model, device, args, start_epoch=start_epoch)
|
# Create optimizer and scheduler in main
|
||||||
|
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-5)
|
||||||
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs - start_epoch)
|
||||||
|
if checkpoint is not None:
|
||||||
|
# Restore optimizer state for a true resume
|
||||||
|
optimizer.load_state_dict(checkpoint.get("optimizer_state_dict", {}))
|
||||||
|
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler)
|
||||||
|
|
||||||
elif args.mode == 'live':
|
elif args.mode == 'live':
|
||||||
load_best_checkpoint(model)
|
load_best_checkpoint(model)
|
||||||
candles_dict = load_candles_cache(CACHE_FILE)
|
candles_dict = load_candles_cache(CACHE_FILE)
|
||||||
@ -411,7 +418,7 @@ async def main():
|
|||||||
elif args.mode == 'inference':
|
elif args.mode == 'inference':
|
||||||
load_best_checkpoint(model)
|
load_best_checkpoint(model)
|
||||||
print("Running inference...")
|
print("Running inference...")
|
||||||
# Place your inference logic here.
|
# Your inference logic goes here.
|
||||||
else:
|
else:
|
||||||
print("Invalid mode specified.")
|
print("Invalid mode specified.")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user