implementations
This commit is contained in:
194
crypto/gogo/training/train_historical.py
Normal file
194
crypto/gogo/training/train_historical.py
Normal file
@ -0,0 +1,194 @@
|
||||
# training/train_historical.py
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from model.trading_model import TradingModel
|
||||
from data.data_utils import get_aligned_candle_with_index, get_features_for_tf
|
||||
|
||||
# --- Directories for saving models ---
|
||||
LAST_DIR = os.path.join("models", "last")
|
||||
BEST_DIR = os.path.join("models", "best")
|
||||
os.makedirs(LAST_DIR, exist_ok=True)
|
||||
os.makedirs(BEST_DIR, exist_ok=True)
|
||||
|
||||
# --- File for saving candles cache ---
|
||||
CACHE_FILE = "candles_cache.json"
|
||||
|
||||
# -------------------------------------
|
||||
# Checkpoint Functions (same as before)
|
||||
# -------------------------------------
|
||||
def maintain_checkpoint_directory(directory, max_files=10):
|
||||
files = os.listdir(directory)
|
||||
if len(files) > max_files:
|
||||
full_paths = [os.path.join(directory, f) for f in files]
|
||||
full_paths.sort(key=lambda x: os.path.getmtime(x))
|
||||
for f in full_paths[: len(files) - max_files]:
|
||||
os.remove(f)
|
||||
|
||||
def get_best_models(directory):
|
||||
best_files = []
|
||||
for file in os.listdir(directory):
|
||||
parts = file.split("_")
|
||||
try:
|
||||
r = float(parts[1])
|
||||
best_files.append((r, file))
|
||||
except Exception:
|
||||
continue
|
||||
return best_files
|
||||
|
||||
def save_checkpoint(model, epoch, total_loss, last_dir=LAST_DIR, best_dir=BEST_DIR):
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
|
||||
last_path = os.path.join(last_dir, last_filename)
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"total_loss": total_loss,
|
||||
"model_state_dict": model.state_dict()
|
||||
}, last_path)
|
||||
maintain_checkpoint_directory(last_dir, max_files=10)
|
||||
|
||||
best_models = get_best_models(best_dir)
|
||||
add_to_best = False
|
||||
if len(best_models) < 10:
|
||||
add_to_best = True
|
||||
else:
|
||||
min_loss, min_file = min(best_models, key=lambda x: x[0])
|
||||
if total_loss < min_loss:
|
||||
add_to_best = True
|
||||
os.remove(os.path.join(best_dir, min_file))
|
||||
if add_to_best:
|
||||
best_filename = f"best_{total_loss:.4f}_epoch_{epoch}_{timestamp}.pt"
|
||||
best_path = os.path.join(best_dir, best_filename)
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"total_loss": total_loss,
|
||||
"model_state_dict": model.state_dict()
|
||||
}, best_path)
|
||||
maintain_checkpoint_directory(best_dir, max_files=10)
|
||||
print(f"Saved checkpoint for epoch {epoch} with loss {total_loss:.4f}")
|
||||
|
||||
def load_best_checkpoint(model, best_dir=BEST_DIR):
|
||||
best_models = get_best_models(best_dir)
|
||||
if not best_models:
|
||||
return None
|
||||
best_loss, best_file = min(best_models, key=lambda x: x[0]) #changed to min to represent the loss
|
||||
path = os.path.join(best_dir, best_file)
|
||||
print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}")
|
||||
checkpoint = torch.load(path)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
return checkpoint
|
||||
|
||||
# -------------------------------------
|
||||
# Training Loop on Historical Data
|
||||
# -------------------------------------
|
||||
def train_on_historical_data(env, rl_agent, num_epochs=10, epsilon=0.1):
|
||||
"""
|
||||
Train the RL agent on historical data using the backtest environment.
|
||||
"""
|
||||
model = rl_agent.model
|
||||
optimizer = rl_agent.optimizer
|
||||
replay_buffer = rl_agent.replay_buffer
|
||||
batch_size = rl_agent.batch_size
|
||||
gamma = rl_agent.gamma
|
||||
|
||||
model.train()
|
||||
criterion = nn.MSELoss() # or another suitable loss
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
state = env.reset()
|
||||
done = False
|
||||
total_reward = 0
|
||||
total_loss = 0
|
||||
|
||||
while not done:
|
||||
# Agent takes action (with exploration).
|
||||
action = rl_agent.act(state, epsilon=epsilon)
|
||||
current_state, reward, next_state, done = env.step(action)
|
||||
total_reward += reward
|
||||
|
||||
# Store experience in replay buffer.
|
||||
replay_buffer.push(state, action, reward, next_state, done)
|
||||
|
||||
# Train on a batch from the replay buffer.
|
||||
if len(replay_buffer) > batch_size:
|
||||
# Sample a batch from the replay buffer.
|
||||
states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
|
||||
|
||||
# Convert data to PyTorch tensors.
|
||||
states = torch.tensor(states, dtype=torch.float32)
|
||||
actions = torch.tensor(actions, dtype=torch.float32)
|
||||
rewards = torch.tensor(rewards, dtype=torch.float32)
|
||||
next_states = torch.tensor(next_states, dtype=torch.float32)
|
||||
dones = torch.tensor(dones, dtype=torch.float32)
|
||||
|
||||
# Compute Q-values for current states.
|
||||
q_values = model(states)
|
||||
|
||||
# Compute Q-values for next states.
|
||||
next_q_values = model(next_states)
|
||||
|
||||
# Compute the TD target.
|
||||
td_target = rewards + gamma * torch.max(next_q_values, dim=1)[0] * (1 - dones)
|
||||
|
||||
# Compute the loss.
|
||||
loss = criterion(q_values.gather(1, actions.long().unsqueeze(1)).squeeze(), td_target)
|
||||
|
||||
# Optimize the model.
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
# Move to the next state.
|
||||
state = next_state
|
||||
|
||||
print(f"Epoch {epoch + 1}/{num_epochs}, Total Reward: {total_reward:.4f}, Loss: {total_loss:.4f}")
|
||||
save_checkpoint(model, epoch, total_loss, LAST_DIR, BEST_DIR)
|
||||
|
||||
# -------------------------------------
|
||||
# Caching Functions (for candles data)
|
||||
# -------------------------------------
|
||||
def save_candles_cache(filename, candles_dict):
|
||||
"""
|
||||
Save the candles data to a JSON file.
|
||||
"""
|
||||
# Convert numpy arrays to lists for JSON serialization.
|
||||
serializable_candles_dict = {}
|
||||
for timeframe, candles in candles_dict.items():
|
||||
serializable_candles = []
|
||||
for candle in candles:
|
||||
serializable_candle = {
|
||||
'timestamp': candle['timestamp'],
|
||||
'open': candle['open'],
|
||||
'high': candle['high'],
|
||||
'low': candle['low'],
|
||||
'close': candle['close'],
|
||||
'volume': candle['volume']
|
||||
}
|
||||
serializable_candles.append(serializable_candle)
|
||||
serializable_candles_dict[timeframe] = serializable_candles
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(serializable_candles_dict, f)
|
||||
|
||||
def load_candles_cache(filename):
|
||||
"""
|
||||
Load the candles data from a JSON file.
|
||||
"""
|
||||
with open(filename, 'r') as f:
|
||||
candles_dict = json.load(f)
|
||||
|
||||
# Convert lists back to numpy arrays.
|
||||
for timeframe, candles in candles_dict.items():
|
||||
for candle in candles:
|
||||
candle['open'] = float(candle['open'])
|
||||
candle['high'] = float(candle['high'])
|
||||
candle['low'] = float(candle['low'])
|
||||
candle['close'] = float(candle['close'])
|
||||
candle['volume'] = float(candle['volume'])
|
||||
return candles_dict
|
Reference in New Issue
Block a user