# training/train_historical.py import os import json import time import numpy as np import torch import torch.nn as nn import torch.optim as optim from model.trading_model import TradingModel from data.data_utils import get_aligned_candle_with_index, get_features_for_tf # --- Directories for saving models --- LAST_DIR = os.path.join("models", "last") BEST_DIR = os.path.join("models", "best") os.makedirs(LAST_DIR, exist_ok=True) os.makedirs(BEST_DIR, exist_ok=True) # --- File for saving candles cache --- CACHE_FILE = "candles_cache.json" # ------------------------------------- # Checkpoint Functions (same as before) # ------------------------------------- def maintain_checkpoint_directory(directory, max_files=10): files = os.listdir(directory) if len(files) > max_files: full_paths = [os.path.join(directory, f) for f in files] full_paths.sort(key=lambda x: os.path.getmtime(x)) for f in full_paths[: len(files) - max_files]: os.remove(f) def get_best_models(directory): best_files = [] for file in os.listdir(directory): parts = file.split("_") try: r = float(parts[1]) best_files.append((r, file)) except Exception: continue return best_files def save_checkpoint(model, epoch, total_loss, last_dir=LAST_DIR, best_dir=BEST_DIR): timestamp = time.strftime("%Y%m%d_%H%M%S") last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_path = os.path.join(last_dir, last_filename) torch.save({ "epoch": epoch, "total_loss": total_loss, "model_state_dict": model.state_dict() }, last_path) maintain_checkpoint_directory(last_dir, max_files=10) best_models = get_best_models(best_dir) add_to_best = False if len(best_models) < 10: add_to_best = True else: min_loss, min_file = min(best_models, key=lambda x: x[0]) if total_loss < min_loss: add_to_best = True os.remove(os.path.join(best_dir, min_file)) if add_to_best: best_filename = f"best_{total_loss:.4f}_epoch_{epoch}_{timestamp}.pt" best_path = os.path.join(best_dir, best_filename) torch.save({ "epoch": epoch, "total_loss": total_loss, "model_state_dict": model.state_dict() }, best_path) maintain_checkpoint_directory(best_dir, max_files=10) print(f"Saved checkpoint for epoch {epoch} with loss {total_loss:.4f}") def load_best_checkpoint(model, best_dir=BEST_DIR): best_models = get_best_models(best_dir) if not best_models: return None best_loss, best_file = min(best_models, key=lambda x: x[0]) #changed to min to represent the loss path = os.path.join(best_dir, best_file) print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}") checkpoint = torch.load(path) model.load_state_dict(checkpoint["model_state_dict"]) return checkpoint # ------------------------------------- # Training Loop on Historical Data # ------------------------------------- def train_on_historical_data(env, rl_agent, num_epochs=10, epsilon=0.1): """ Train the RL agent on historical data using the backtest environment. """ model = rl_agent.model optimizer = rl_agent.optimizer replay_buffer = rl_agent.replay_buffer batch_size = rl_agent.batch_size gamma = rl_agent.gamma model.train() criterion = nn.MSELoss() # or another suitable loss for epoch in range(num_epochs): state = env.reset() done = False total_reward = 0 total_loss = 0 while not done: # Agent takes action (with exploration). action = rl_agent.act(state, epsilon=epsilon) current_state, reward, next_state, done = env.step(action) total_reward += reward # Store experience in replay buffer. replay_buffer.push(state, action, reward, next_state, done) # Train on a batch from the replay buffer. if len(replay_buffer) > batch_size: # Sample a batch from the replay buffer. states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size) # Convert data to PyTorch tensors. states = torch.tensor(states, dtype=torch.float32) actions = torch.tensor(actions, dtype=torch.float32) rewards = torch.tensor(rewards, dtype=torch.float32) next_states = torch.tensor(next_states, dtype=torch.float32) dones = torch.tensor(dones, dtype=torch.float32) # Compute Q-values for current states. q_values = model(states) # Compute Q-values for next states. next_q_values = model(next_states) # Compute the TD target. td_target = rewards + gamma * torch.max(next_q_values, dim=1)[0] * (1 - dones) # Compute the loss. loss = criterion(q_values.gather(1, actions.long().unsqueeze(1)).squeeze(), td_target) # Optimize the model. optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() # Move to the next state. state = next_state print(f"Epoch {epoch + 1}/{num_epochs}, Total Reward: {total_reward:.4f}, Loss: {total_loss:.4f}") save_checkpoint(model, epoch, total_loss, LAST_DIR, BEST_DIR) # ------------------------------------- # Caching Functions (for candles data) # ------------------------------------- def save_candles_cache(filename, candles_dict): """ Save the candles data to a JSON file. """ # Convert numpy arrays to lists for JSON serialization. serializable_candles_dict = {} for timeframe, candles in candles_dict.items(): serializable_candles = [] for candle in candles: serializable_candle = { 'timestamp': candle['timestamp'], 'open': candle['open'], 'high': candle['high'], 'low': candle['low'], 'close': candle['close'], 'volume': candle['volume'] } serializable_candles.append(serializable_candle) serializable_candles_dict[timeframe] = serializable_candles with open(filename, 'w') as f: json.dump(serializable_candles_dict, f) def load_candles_cache(filename): """ Load the candles data from a JSON file. """ with open(filename, 'r') as f: candles_dict = json.load(f) # Convert lists back to numpy arrays. for timeframe, candles in candles_dict.items(): for candle in candles: candle['open'] = float(candle['open']) candle['high'] = float(candle['high']) candle['low'] = float(candle['low']) candle['close'] = float(candle['close']) candle['volume'] = float(candle['volume']) return candles_dict