gogo2/crypto/brian/index.py
Dobromir Popov f7f10bc17c checkpoints
2025-02-02 00:55:14 +02:00

425 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import sys
import asyncio
if sys.platform == 'win32':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
from dotenv import load_dotenv
import os
import time
import json
import ccxt.async_support as ccxt
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
from datetime import datetime
# --- Directories for saving models ---
LAST_DIR = os.path.join("models", "last")
BEST_DIR = os.path.join("models", "best")
os.makedirs(LAST_DIR, exist_ok=True)
os.makedirs(BEST_DIR, exist_ok=True)
CACHE_FILE = "candles_cache.json"
# -------------------------------------
# Utility functions for caching candles to file
# -------------------------------------
def load_candles_cache(filename):
if os.path.exists(filename):
try:
with open(filename, "r") as f:
data = json.load(f)
print(f"Loaded {len(data)} candles from cache.")
return data
except Exception as e:
print("Error reading cache file:", e)
return []
def save_candles_cache(filename, candles):
try:
with open(filename, "w") as f:
json.dump(candles, f)
except Exception as e:
print("Error saving cache file:", e)
# -------------------------------------
# Functions for handling checkpoints
# -------------------------------------
def maintain_checkpoint_directory(directory, max_files=10):
"""Keep only the most recent max_files in a given directory based on modification time."""
files = os.listdir(directory)
if len(files) > max_files:
full_paths = [os.path.join(directory, f) for f in files]
full_paths.sort(key=lambda x: os.path.getmtime(x))
# Remove the oldest files
for f in full_paths[: len(files) - max_files]:
os.remove(f)
def get_best_models(directory):
"""Return a list of (reward, filename) for files in the best folder.
Expecting filenames like: best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"""
best_files = []
for file in os.listdir(directory):
parts = file.split("_")
try:
# parts[1] should be reward
r = float(parts[1])
best_files.append((r, file))
except Exception:
continue
return best_files
def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
"""Save the model state always to the last_dir and conditionally to best_dir if reward is high enough."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"model_state_dict": model.state_dict()
}, last_path)
# Keep only last 10 models in last_dir.
maintain_checkpoint_directory(last_dir, max_files=10)
# Check the best folder if fewer than 10, simply add;
# Otherwise, add only if reward is higher than the lowest reward in best.
best_models = get_best_models(best_dir)
add_to_best = False
if len(best_models) < 10:
add_to_best = True
else:
min_reward, min_file = min(best_models, key=lambda x: x[0])
if reward > min_reward:
add_to_best = True
# Remove the worst checkpoint.
os.remove(os.path.join(best_dir, min_file))
if add_to_best:
best_filename = f"best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"
best_path = os.path.join(best_dir, best_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"model_state_dict": model.state_dict()
}, best_path)
maintain_checkpoint_directory(best_dir, max_files=10)
print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}")
def load_best_checkpoint(model, best_dir=BEST_DIR):
"""Load the best checkpoint (with highest reward) from the best directory if available."""
best_models = get_best_models(best_dir)
if not best_models:
return None
best_reward, best_file = max(best_models, key=lambda x: x[0])
path = os.path.join(best_dir, best_file)
print(f"Loading best model from checkpoint: {best_file} with reward {best_reward:.4f}")
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
return checkpoint
# -------------------------------------
# Neural Network Architecture Definition
# -------------------------------------
class TradingModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(TradingModel, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.net(x)
# -------------------------------------
# Replay Buffer for Experience Storage
# -------------------------------------
class ReplayBuffer:
def __init__(self, capacity=10000):
self.buffer = deque(maxlen=capacity)
def add(self, experience):
self.buffer.append(experience)
def sample(self, batch_size):
indices = np.random.choice(len(self.buffer), size=batch_size, replace=False)
return [self.buffer[i] for i in indices]
def __len__(self):
return len(self.buffer)
# -------------------------------------
# A Simple Indicator and Feature Preparation Function
# -------------------------------------
def compute_indicators(candle, additional_data):
"""
Combine OHLCV candle data with extra indicator information.
Base features: open, high, low, close, volume.
Additional channels (e.g., simulated sentiment) are appended.
"""
features = [
candle.get('open', 0.0),
candle.get('high', 0.0),
candle.get('low', 0.0),
candle.get('close', 0.0),
candle.get('volume', 0.0),
]
for key, value in additional_data.items():
features.append(value)
return np.array(features, dtype=np.float32)
# -------------------------------------
# RL Agent with Q-Learning Update and Epsilon-Greedy Exploration
# -------------------------------------
class ContinuousRLAgent:
def __init__(self, model, optimizer, replay_buffer, batch_size=32, gamma=0.99):
self.model = model
self.optimizer = optimizer
self.replay_buffer = replay_buffer
self.batch_size = batch_size
self.loss_fn = nn.MSELoss()
self.gamma = gamma
def act(self, state, epsilon=0.1):
# ε-greedy: choose random action with probability epsilon.
if np.random.rand() < epsilon:
return np.random.randint(0, 3)
state_tensor = torch.from_numpy(np.array(state, dtype=np.float32)).unsqueeze(0)
with torch.no_grad():
output = self.model(state_tensor)
action = torch.argmax(output, dim=1).item()
return action
def train_step(self):
# Only train if we have enough samples.
if len(self.replay_buffer) < self.batch_size:
return
batch = self.replay_buffer.sample(self.batch_size)
# Unpack the batch; each experience is (state, action, reward, next_state, done)
states, actions, rewards, next_states, dones = zip(*batch)
states_tensor = torch.from_numpy(np.array(states, dtype=np.float32))
actions_tensor = torch.tensor(actions, dtype=torch.int64)
rewards_tensor = torch.from_numpy(np.array(rewards, dtype=np.float32)).unsqueeze(1)
next_states_tensor = torch.from_numpy(np.array(next_states, dtype=np.float32))
dones_tensor = torch.tensor(dones, dtype=torch.float32).unsqueeze(1)
Q_values = self.model(states_tensor)
current_Q = Q_values.gather(1, actions_tensor.unsqueeze(1))
with torch.no_grad():
next_Q_values = self.model(next_states_tensor)
max_next_Q = next_Q_values.max(1)[0].unsqueeze(1)
target = rewards_tensor + self.gamma * max_next_Q * (1.0 - dones_tensor)
loss = self.loss_fn(current_Q, target)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# -------------------------------------
# Historical Data Fetching Function
# -------------------------------------
async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500):
"""
Fetch historical OHLCV data for the given symbol and timeframe.
"since" and "end_time" are in milliseconds.
"""
candles = []
since_ms = since
while True:
try:
batch = await exchange.fetch_ohlcv(symbol, timeframe=timeframe, since=since_ms, limit=batch_size)
except Exception as e:
print("Error fetching historical data:", e)
break
if not batch:
break
for c in batch:
candle_dict = {
'timestamp': c[0],
'open': c[1],
'high': c[2],
'low': c[3],
'close': c[4],
'volume': c[5]
}
candles.append(candle_dict)
last_timestamp = batch[-1][0]
if last_timestamp >= end_time:
break
since_ms = last_timestamp + 1
print(f"Fetched {len(candles)} candles.")
return candles
async def get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time, cache_file=CACHE_FILE, batch_size=500):
cached_candles = load_candles_cache(cache_file)
if cached_candles:
last_ts = cached_candles[-1]['timestamp']
if last_ts < end_time:
print("Fetching new candles to update cache...")
new_candles = await fetch_historical_data(exchange, symbol, timeframe, last_ts + 1, end_time, batch_size)
cached_candles.extend(new_candles)
else:
print("Cache covers the requested period.")
return cached_candles
else:
candles = await fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size)
return candles
# -------------------------------------
# Backtest Environment Class Definition
# -------------------------------------
class BacktestEnvironment:
def __init__(self, candles):
self.candles = candles
self.current_index = 0
self.position = None # Holds an open position, if any
def reset(self):
self.current_index = 0
self.position = None
return self.get_state(self.current_index)
def get_state(self, index):
candle = self.candles[index]
sentiment = {
'sentiment_score': np.random.rand(),
'news_volume': np.random.rand(),
'social_engagement': np.random.rand()
}
return compute_indicators(candle, sentiment)
def step(self, action):
"""
Simulate a trading step.
- If not in a position and action is BUY (2), enter a long position at the next candle's open.
- If in a position and action is SELL (0), close the position at the next candle's open.
Returns: (current_state, reward, next_state, done)
"""
if self.current_index >= len(self.candles) - 1:
return self.get_state(self.current_index), 0.0, None, True
current_state = self.get_state(self.current_index)
next_index = self.current_index + 1
next_state = self.get_state(next_index)
current_candle = self.candles[self.current_index]
next_candle = self.candles[next_index]
reward = 0.0
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
if self.position is None:
if action == 2: # BUY signal:
entry_price = next_candle['open']
self.position = {'entry_price': entry_price, 'entry_index': self.current_index}
else:
if action == 0: # SELL signal:
sell_price = next_candle['open']
reward = sell_price - self.position['entry_price']
self.position = None
self.current_index = next_index
done = (self.current_index >= len(self.candles) - 1)
return current_state, reward, next_state, done
# -------------------------------------
# Training Loop Over Historical Data (Backtest)
# -------------------------------------
def train_on_historical_data(env, rl_agent, num_epochs=10, epsilon=0.1):
"""
For each epoch, run through the entire historical episode.
At each step, pick an action (using ε-greedy), simulate a trade, store the experience,
and update the model. Then log the cumulative reward and save checkpoints.
"""
for epoch in range(1, num_epochs + 1):
state = env.reset()
done = False
total_reward = 0.0
steps = 0
while not done:
action = rl_agent.act(state, epsilon=epsilon)
prev_state = state
state, reward, next_state, done = env.step(action)
if next_state is None:
next_state = np.zeros_like(prev_state)
# Save the experience (state, action, reward, next_state, done)
rl_agent.replay_buffer.add((prev_state, action, reward, next_state, done))
rl_agent.train_step()
total_reward += reward
steps += 1
print(f"Epoch {epoch}/{num_epochs} completed, total reward: {total_reward:.4f} over {steps} steps.")
# Save a checkpoint after the epoch.
save_checkpoint(rl_agent.model, epoch, total_reward, LAST_DIR, BEST_DIR)
# -------------------------------------
# Main Asynchronous Function for Backtest Training
# -------------------------------------
async def main_backtest():
# Define symbol, timeframe, and historical period.
symbol = 'BTC/USDT'
timeframe = '1m'
now = int(time.time() * 1000)
one_day_ms = 24 * 60 * 60 * 1000
# For example, fetch a 1-day period from 2 days ago until 1 day ago.
since = now - one_day_ms * 2
end_time = now - one_day_ms
# Initialize exchange (using MEXC for example).
mexc_api_key = os.environ.get('MEXC_API_KEY', 'YOUR_API_KEY')
mexc_api_secret = os.environ.get('MEXC_API_SECRET', 'YOUR_SECRET_KEY')
exchange = ccxt.mexc({
'apiKey': mexc_api_key,
'secret': mexc_api_secret,
'enableRateLimit': True,
})
print("Fetching historical data...")
candles = await get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time)
if not candles:
print("No historical data fetched.")
await exchange.close()
return
# Save updated cache.
save_candles_cache(CACHE_FILE, candles)
# Initialize backtest environment.
env = BacktestEnvironment(candles)
# Model dimensions: 5 (OHLCV) + 3 (sentiment) = 8.
input_dim = 8
hidden_dim = 128
output_dim = 3 # SELL, HOLD, BUY.
model = TradingModel(input_dim, hidden_dim, output_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
replay_buffer = ReplayBuffer(capacity=10000)
rl_agent = ContinuousRLAgent(model, optimizer, replay_buffer, batch_size=32, gamma=0.99)
# At training start, try loading a best checkpoint (if available).
load_best_checkpoint(model, BEST_DIR)
# Run training over historical data.
num_epochs = 10 # Change as needed.
train_on_historical_data(env, rl_agent, num_epochs=num_epochs, epsilon=0.1)
# Final simulation (without exploration) to check cumulative profit.
state = env.reset()
done = False
cumulative_reward = 0.0
while not done:
action = rl_agent.act(state, epsilon=0.0)
state, reward, next_state, done = env.step(action)
cumulative_reward += reward
state = next_state
print("Final backtest simulation cumulative profit:", cumulative_reward)
await exchange.close()
if __name__ == "__main__":
load_dotenv()
asyncio.run(main_backtest())