checkpoints

This commit is contained in:
Dobromir Popov 2025-02-02 00:55:14 +02:00
parent 46aee31942
commit f7f10bc17c

View File

@ -14,12 +14,19 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import numpy as np import numpy as np
from collections import deque from collections import deque
from datetime import datetime
# --- Directories for saving models ---
LAST_DIR = os.path.join("models", "last")
BEST_DIR = os.path.join("models", "best")
os.makedirs(LAST_DIR, exist_ok=True)
os.makedirs(BEST_DIR, exist_ok=True)
CACHE_FILE = "candles_cache.json"
# ------------------------------------- # -------------------------------------
# Utility functions for caching candles to file # Utility functions for caching candles to file
# ------------------------------------- # -------------------------------------
CACHE_FILE = "candles_cache.json"
def load_candles_cache(filename): def load_candles_cache(filename):
if os.path.exists(filename): if os.path.exists(filename):
try: try:
@ -38,6 +45,81 @@ def save_candles_cache(filename, candles):
except Exception as e: except Exception as e:
print("Error saving cache file:", e) print("Error saving cache file:", e)
# -------------------------------------
# Functions for handling checkpoints
# -------------------------------------
def maintain_checkpoint_directory(directory, max_files=10):
"""Keep only the most recent max_files in a given directory based on modification time."""
files = os.listdir(directory)
if len(files) > max_files:
full_paths = [os.path.join(directory, f) for f in files]
full_paths.sort(key=lambda x: os.path.getmtime(x))
# Remove the oldest files
for f in full_paths[: len(files) - max_files]:
os.remove(f)
def get_best_models(directory):
"""Return a list of (reward, filename) for files in the best folder.
Expecting filenames like: best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"""
best_files = []
for file in os.listdir(directory):
parts = file.split("_")
try:
# parts[1] should be reward
r = float(parts[1])
best_files.append((r, file))
except Exception:
continue
return best_files
def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
"""Save the model state always to the last_dir and conditionally to best_dir if reward is high enough."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"model_state_dict": model.state_dict()
}, last_path)
# Keep only last 10 models in last_dir.
maintain_checkpoint_directory(last_dir, max_files=10)
# Check the best folder if fewer than 10, simply add;
# Otherwise, add only if reward is higher than the lowest reward in best.
best_models = get_best_models(best_dir)
add_to_best = False
if len(best_models) < 10:
add_to_best = True
else:
min_reward, min_file = min(best_models, key=lambda x: x[0])
if reward > min_reward:
add_to_best = True
# Remove the worst checkpoint.
os.remove(os.path.join(best_dir, min_file))
if add_to_best:
best_filename = f"best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"
best_path = os.path.join(best_dir, best_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"model_state_dict": model.state_dict()
}, best_path)
maintain_checkpoint_directory(best_dir, max_files=10)
print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}")
def load_best_checkpoint(model, best_dir=BEST_DIR):
"""Load the best checkpoint (with highest reward) from the best directory if available."""
best_models = get_best_models(best_dir)
if not best_models:
return None
best_reward, best_file = max(best_models, key=lambda x: x[0])
path = os.path.join(best_dir, best_file)
print(f"Loading best model from checkpoint: {best_file} with reward {best_reward:.4f}")
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
return checkpoint
# ------------------------------------- # -------------------------------------
# Neural Network Architecture Definition # Neural Network Architecture Definition
# ------------------------------------- # -------------------------------------
@ -79,7 +161,7 @@ def compute_indicators(candle, additional_data):
""" """
Combine OHLCV candle data with extra indicator information. Combine OHLCV candle data with extra indicator information.
Base features: open, high, low, close, volume. Base features: open, high, low, close, volume.
Additional channels (e.g. simulated sentiment) are appended. Additional channels (e.g., simulated sentiment) are appended.
""" """
features = [ features = [
candle.get('open', 0.0), candle.get('open', 0.0),
@ -105,7 +187,7 @@ class ContinuousRLAgent:
self.gamma = gamma self.gamma = gamma
def act(self, state, epsilon=0.1): def act(self, state, epsilon=0.1):
# ε-greedy: with probability epsilon take a random action # ε-greedy: choose random action with probability epsilon.
if np.random.rand() < epsilon: if np.random.rand() < epsilon:
return np.random.randint(0, 3) return np.random.randint(0, 3)
state_tensor = torch.from_numpy(np.array(state, dtype=np.float32)).unsqueeze(0) state_tensor = torch.from_numpy(np.array(state, dtype=np.float32)).unsqueeze(0)
@ -115,12 +197,12 @@ class ContinuousRLAgent:
return action return action
def train_step(self): def train_step(self):
# Only train if we have enough samples # Only train if we have enough samples.
if len(self.replay_buffer) < self.batch_size: if len(self.replay_buffer) < self.batch_size:
return return
# Convert lists to numpy arrays in one shot for performance
batch = self.replay_buffer.sample(self.batch_size) batch = self.replay_buffer.sample(self.batch_size)
# Unpack the batch; each experience is (state, action, reward, next_state, done)
states, actions, rewards, next_states, dones = zip(*batch) states, actions, rewards, next_states, dones = zip(*batch)
states_tensor = torch.from_numpy(np.array(states, dtype=np.float32)) states_tensor = torch.from_numpy(np.array(states, dtype=np.float32))
actions_tensor = torch.tensor(actions, dtype=torch.int64) actions_tensor = torch.tensor(actions, dtype=torch.int64)
@ -128,15 +210,12 @@ class ContinuousRLAgent:
next_states_tensor = torch.from_numpy(np.array(next_states, dtype=np.float32)) next_states_tensor = torch.from_numpy(np.array(next_states, dtype=np.float32))
dones_tensor = torch.tensor(dones, dtype=torch.float32).unsqueeze(1) dones_tensor = torch.tensor(dones, dtype=torch.float32).unsqueeze(1)
# Current Q-value for the chosen actions
Q_values = self.model(states_tensor) Q_values = self.model(states_tensor)
current_Q = Q_values.gather(1, actions_tensor.unsqueeze(1)) current_Q = Q_values.gather(1, actions_tensor.unsqueeze(1))
with torch.no_grad(): with torch.no_grad():
next_Q_values = self.model(next_states_tensor) next_Q_values = self.model(next_states_tensor)
max_next_Q = next_Q_values.max(1)[0].unsqueeze(1) max_next_Q = next_Q_values.max(1)[0].unsqueeze(1)
target = rewards_tensor + self.gamma * max_next_Q * (1.0 - dones_tensor) target = rewards_tensor + self.gamma * max_next_Q * (1.0 - dones_tensor)
loss = self.loss_fn(current_Q, target) loss = self.loss_fn(current_Q, target)
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss.backward() loss.backward()
@ -148,7 +227,7 @@ class ContinuousRLAgent:
async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500): async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500):
""" """
Fetch historical OHLCV data for the given symbol and timeframe. Fetch historical OHLCV data for the given symbol and timeframe.
The 'since' and 'end_time' parameters are in milliseconds. "since" and "end_time" are in milliseconds.
""" """
candles = [] candles = []
since_ms = since since_ms = since
@ -181,7 +260,6 @@ async def get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time,
cached_candles = load_candles_cache(cache_file) cached_candles = load_candles_cache(cache_file)
if cached_candles: if cached_candles:
last_ts = cached_candles[-1]['timestamp'] last_ts = cached_candles[-1]['timestamp']
# If the cached candles do not extend to 'end_time', fetch new ones.
if last_ts < end_time: if last_ts < end_time:
print("Fetching new candles to update cache...") print("Fetching new candles to update cache...")
new_candles = await fetch_historical_data(exchange, symbol, timeframe, last_ts + 1, end_time, batch_size) new_candles = await fetch_historical_data(exchange, symbol, timeframe, last_ts + 1, end_time, batch_size)
@ -209,7 +287,6 @@ class BacktestEnvironment:
def get_state(self, index): def get_state(self, index):
candle = self.candles[index] candle = self.candles[index]
# Simulate additional sentiment features.
sentiment = { sentiment = {
'sentiment_score': np.random.rand(), 'sentiment_score': np.random.rand(),
'news_volume': np.random.rand(), 'news_volume': np.random.rand(),
@ -220,10 +297,9 @@ class BacktestEnvironment:
def step(self, action): def step(self, action):
""" """
Simulate a trading step. Simulate a trading step.
- If not in a position and action is BUY (2), buy at the next candle's open. - If not in a position and action is BUY (2), enter a long position at the next candle's open.
- If in a position and action is SELL (0), sell at the next candle's open and compute reward. - If in a position and action is SELL (0), close the position at the next candle's open.
- Otherwise, no trade is executed. Returns: (current_state, reward, next_state, done)
Returns: (state, reward, next_state, done)
""" """
if self.current_index >= len(self.candles) - 1: if self.current_index >= len(self.candles) - 1:
return self.get_state(self.current_index), 0.0, None, True return self.get_state(self.current_index), 0.0, None, True
@ -237,14 +313,15 @@ class BacktestEnvironment:
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY. # Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
if self.position is None: if self.position is None:
if action == 2: # BUY: enter long position at next candle's open. if action == 2: # BUY signal:
entry_price = next_candle['open'] entry_price = next_candle['open']
self.position = {'entry_price': entry_price, 'entry_index': self.current_index} self.position = {'entry_price': entry_price, 'entry_index': self.current_index}
else: else:
if action == 0: # SELL: close long position. if action == 0: # SELL signal:
sell_price = next_candle['open'] sell_price = next_candle['open']
reward = sell_price - self.position['entry_price'] reward = sell_price - self.position['entry_price']
self.position = None self.position = None
self.current_index = next_index self.current_index = next_index
done = (self.current_index >= len(self.candles) - 1) done = (self.current_index >= len(self.candles) - 1)
return current_state, reward, next_state, done return current_state, reward, next_state, done
@ -254,11 +331,11 @@ class BacktestEnvironment:
# ------------------------------------- # -------------------------------------
def train_on_historical_data(env, rl_agent, num_epochs=10, epsilon=0.1): def train_on_historical_data(env, rl_agent, num_epochs=10, epsilon=0.1):
""" """
For each epoch, run through the entire historical data. For each epoch, run through the entire historical episode.
At each step, choose an action using εgreedy policy, simulate a trade, At each step, pick an action (using ε-greedy), simulate a trade, store the experience,
store the experience (state, action, reward, next_state, done), and update the model. and update the model. Then log the cumulative reward and save checkpoints.
""" """
for epoch in range(num_epochs): for epoch in range(1, num_epochs + 1):
state = env.reset() state = env.reset()
done = False done = False
total_reward = 0.0 total_reward = 0.0
@ -269,12 +346,14 @@ def train_on_historical_data(env, rl_agent, num_epochs=10, epsilon=0.1):
state, reward, next_state, done = env.step(action) state, reward, next_state, done = env.step(action)
if next_state is None: if next_state is None:
next_state = np.zeros_like(prev_state) next_state = np.zeros_like(prev_state)
# Store the experience including the action taken. # Save the experience (state, action, reward, next_state, done)
rl_agent.replay_buffer.add((prev_state, action, reward, next_state, done)) rl_agent.replay_buffer.add((prev_state, action, reward, next_state, done))
rl_agent.train_step() rl_agent.train_step()
total_reward += reward total_reward += reward
steps += 1 steps += 1
print(f"Epoch {epoch+1}/{num_epochs} completed, total reward: {total_reward:.4f} over {steps} steps.") print(f"Epoch {epoch}/{num_epochs} completed, total reward: {total_reward:.4f} over {steps} steps.")
# Save a checkpoint after the epoch.
save_checkpoint(rl_agent.model, epoch, total_reward, LAST_DIR, BEST_DIR)
# ------------------------------------- # -------------------------------------
# Main Asynchronous Function for Backtest Training # Main Asynchronous Function for Backtest Training
@ -285,7 +364,7 @@ async def main_backtest():
timeframe = '1m' timeframe = '1m'
now = int(time.time() * 1000) now = int(time.time() * 1000)
one_day_ms = 24 * 60 * 60 * 1000 one_day_ms = 24 * 60 * 60 * 1000
# Fetch a 1-day period from 2 days ago until 1 day ago. # For example, fetch a 1-day period from 2 days ago until 1 day ago.
since = now - one_day_ms * 2 since = now - one_day_ms * 2
end_time = now - one_day_ms end_time = now - one_day_ms
@ -305,27 +384,30 @@ async def main_backtest():
await exchange.close() await exchange.close()
return return
# Save/Update cache file. # Save updated cache.
save_candles_cache(CACHE_FILE, candles) save_candles_cache(CACHE_FILE, candles)
# Initialize the backtest environment with the candles. # Initialize backtest environment.
env = BacktestEnvironment(candles) env = BacktestEnvironment(candles)
# Model dimensions: 5 base OHLCV features + 3 simulated sentiment features = 8. # Model dimensions: 5 (OHLCV) + 3 (sentiment) = 8.
input_dim = 8 input_dim = 8
hidden_dim = 128 hidden_dim = 128
output_dim = 3 # SELL, HOLD, BUY output_dim = 3 # SELL, HOLD, BUY.
model = TradingModel(input_dim, hidden_dim, output_dim) model = TradingModel(input_dim, hidden_dim, output_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-4) optimizer = optim.Adam(model.parameters(), lr=1e-4)
replay_buffer = ReplayBuffer(capacity=10000) replay_buffer = ReplayBuffer(capacity=10000)
rl_agent = ContinuousRLAgent(model, optimizer, replay_buffer, batch_size=32, gamma=0.99) rl_agent = ContinuousRLAgent(model, optimizer, replay_buffer, batch_size=32, gamma=0.99)
# At training start, try loading a best checkpoint (if available).
load_best_checkpoint(model, BEST_DIR)
# Run training over historical data. # Run training over historical data.
num_epochs = 10 # Adjust as needed. num_epochs = 10 # Change as needed.
train_on_historical_data(env, rl_agent, num_epochs=num_epochs, epsilon=0.1) train_on_historical_data(env, rl_agent, num_epochs=num_epochs, epsilon=0.1)
# Optionally, perform a final test run (without exploration) to check cumulative profit. # Final simulation (without exploration) to check cumulative profit.
state = env.reset() state = env.reset()
done = False done = False
cumulative_reward = 0.0 cumulative_reward = 0.0