gogo2/crypto/brian/index.py
2025-02-02 01:06:39 +02:00

469 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import sys
import asyncio
if sys.platform == 'win32':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
from dotenv import load_dotenv
import os
import time
import json
import ccxt.async_support as ccxt
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
from datetime import datetime
import matplotlib.pyplot as plt
# --- Directories for saving models ---
LAST_DIR = os.path.join("models", "last")
BEST_DIR = os.path.join("models", "best")
os.makedirs(LAST_DIR, exist_ok=True)
os.makedirs(BEST_DIR, exist_ok=True)
CACHE_FILE = "candles_cache.json"
# -------------------------------------
# Utility functions for caching candles to file
# -------------------------------------
def load_candles_cache(filename):
if os.path.exists(filename):
try:
with open(filename, "r") as f:
data = json.load(f)
print(f"Loaded {len(data)} candles from cache.")
return data
except Exception as e:
print("Error reading cache file:", e)
return []
def save_candles_cache(filename, candles):
try:
with open(filename, "w") as f:
json.dump(candles, f)
except Exception as e:
print("Error saving cache file:", e)
# -------------------------------------
# Functions for handling checkpoints
# -------------------------------------
def maintain_checkpoint_directory(directory, max_files=10):
"""Keep only the most recent max_files in a given directory based on modification time."""
files = os.listdir(directory)
if len(files) > max_files:
full_paths = [os.path.join(directory, f) for f in files]
full_paths.sort(key=lambda x: os.path.getmtime(x))
# Remove the oldest files
for f in full_paths[: len(files) - max_files]:
os.remove(f)
def get_best_models(directory):
"""Return a list of (reward, filename) for files in the best folder.
Expected filename format: best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"""
best_files = []
for file in os.listdir(directory):
parts = file.split("_")
try:
# parts[1] should be the reward
r = float(parts[1])
best_files.append((r, file))
except Exception:
continue
return best_files
def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
"""Save the model state at each epoch to last_dir and, conditionally, to best_dir."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"model_state_dict": model.state_dict()
}, last_path)
# Maintain only last 10 checkpoints
maintain_checkpoint_directory(last_dir, max_files=10)
best_models = get_best_models(best_dir)
add_to_best = False
if len(best_models) < 10:
add_to_best = True
else:
min_reward, min_file = min(best_models, key=lambda x: x[0])
if reward > min_reward:
add_to_best = True
os.remove(os.path.join(best_dir, min_file))
if add_to_best:
best_filename = f"best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"
best_path = os.path.join(best_dir, best_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"model_state_dict": model.state_dict()
}, best_path)
maintain_checkpoint_directory(best_dir, max_files=10)
print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}")
def load_best_checkpoint(model, best_dir=BEST_DIR):
"""Load the best checkpoint (with highest reward) if available."""
best_models = get_best_models(best_dir)
if not best_models:
return None
best_reward, best_file = max(best_models, key=lambda x: x[0])
path = os.path.join(best_dir, best_file)
print(f"Loading best model from checkpoint: {best_file} with reward {best_reward:.4f}")
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
return checkpoint
# -------------------------------------
# Neural Network Architecture Definition
# -------------------------------------
class TradingModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(TradingModel, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.net(x)
# -------------------------------------
# Replay Buffer for Experience Storage
# -------------------------------------
class ReplayBuffer:
def __init__(self, capacity=10000):
self.buffer = deque(maxlen=capacity)
def add(self, experience):
self.buffer.append(experience)
def sample(self, batch_size):
indices = np.random.choice(len(self.buffer), size=batch_size, replace=False)
return [self.buffer[i] for i in indices]
def __len__(self):
return len(self.buffer)
# -------------------------------------
# Indicator and Feature Preparation Function
# -------------------------------------
def compute_indicators(candle, additional_data):
"""
Combine OHLCV candle data with extra indicator information.
Base features: open, high, low, close, volume.
Additional channels (e.g., simulated sentiment) are appended.
"""
features = [
candle.get('open', 0.0),
candle.get('high', 0.0),
candle.get('low', 0.0),
candle.get('close', 0.0),
candle.get('volume', 0.0),
]
for key, value in additional_data.items():
features.append(value)
return np.array(features, dtype=np.float32)
# -------------------------------------
# RL Agent with Q-Learning and Epsilon-Greedy Exploration
# -------------------------------------
class ContinuousRLAgent:
def __init__(self, model, optimizer, replay_buffer, batch_size=32, gamma=0.99):
self.model = model
self.optimizer = optimizer
self.replay_buffer = replay_buffer
self.batch_size = batch_size
self.loss_fn = nn.MSELoss()
self.gamma = gamma
def act(self, state, epsilon=0.1):
if np.random.rand() < epsilon:
return np.random.randint(0, 3)
state_tensor = torch.from_numpy(np.array(state, dtype=np.float32)).unsqueeze(0)
with torch.no_grad():
output = self.model(state_tensor)
action = torch.argmax(output, dim=1).item()
return action
def train_step(self):
if len(self.replay_buffer) < self.batch_size:
return
batch = self.replay_buffer.sample(self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
states_tensor = torch.from_numpy(np.array(states, dtype=np.float32))
actions_tensor = torch.tensor(actions, dtype=torch.int64)
rewards_tensor = torch.from_numpy(np.array(rewards, dtype=np.float32)).unsqueeze(1)
next_states_tensor = torch.from_numpy(np.array(next_states, dtype=np.float32))
dones_tensor = torch.tensor(dones, dtype=torch.float32).unsqueeze(1)
Q_values = self.model(states_tensor)
current_Q = Q_values.gather(1, actions_tensor.unsqueeze(1))
with torch.no_grad():
next_Q_values = self.model(next_states_tensor)
max_next_Q = next_Q_values.max(1)[0].unsqueeze(1)
target = rewards_tensor + self.gamma * max_next_Q * (1.0 - dones_tensor)
loss = self.loss_fn(current_Q, target)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# -------------------------------------
# Historical Data Fetching Functions
# -------------------------------------
async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500):
"""
Fetch historical OHLCV data for a given symbol and timeframe.
"since" and "end_time" are given in milliseconds.
"""
candles = []
since_ms = since
while True:
try:
batch = await exchange.fetch_ohlcv(symbol, timeframe=timeframe, since=since_ms, limit=batch_size)
except Exception as e:
print("Error fetching historical data:", e)
break
if not batch:
break
for c in batch:
candle_dict = {
'timestamp': c[0],
'open': c[1],
'high': c[2],
'low': c[3],
'close': c[4],
'volume': c[5]
}
candles.append(candle_dict)
last_timestamp = batch[-1][0]
if last_timestamp >= end_time:
break
since_ms = last_timestamp + 1
print(f"Fetched {len(candles)} candles.")
return candles
async def get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time, cache_file=CACHE_FILE, batch_size=500):
cached_candles = load_candles_cache(cache_file)
if cached_candles:
last_ts = cached_candles[-1]['timestamp']
if last_ts < end_time:
print("Fetching new candles to update cache...")
new_candles = await fetch_historical_data(exchange, symbol, timeframe, last_ts + 1, end_time, batch_size)
cached_candles.extend(new_candles)
else:
print("Cache covers the requested period.")
return cached_candles
else:
candles = await fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size)
return candles
# -------------------------------------
# Backtest Environment with Trade History Recording
# -------------------------------------
class BacktestEnvironment:
def __init__(self, candles):
self.candles = candles
self.current_index = 0
self.position = None # Active position: dict with 'entry_price' and 'entry_index'
self.trade_history = [] # List of closed trades
def reset(self, clear_trade_history=True):
self.current_index = 0
self.position = None
if clear_trade_history:
self.trade_history = []
return self.get_state(self.current_index)
def get_state(self, index):
candle = self.candles[index]
sentiment = {
'sentiment_score': np.random.rand(),
'news_volume': np.random.rand(),
'social_engagement': np.random.rand()
}
return compute_indicators(candle, sentiment)
def step(self, action):
"""
Simulate a trading step:
- If not in a position and action is BUY (2), record an entry at next candle's open.
- If in a position and action is SELL (0), record an exit at next candle's open and compute PnL.
Returns: (current_state, reward, next_state, done)
"""
if self.current_index >= len(self.candles) - 1:
return self.get_state(self.current_index), 0.0, None, True
current_state = self.get_state(self.current_index)
next_index = self.current_index + 1
next_state = self.get_state(next_index)
current_candle = self.candles[self.current_index]
next_candle = self.candles[next_index]
reward = 0.0
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
# If not in a position:
if self.position is None:
if action == 2: # BUY signal: enter position at next candle's open.
entry_price = next_candle['open']
self.position = {'entry_price': entry_price, 'entry_index': self.current_index}
else:
if action == 0: # SELL signal: exit position at next candle's open.
exit_price = next_candle['open']
reward = exit_price - self.position['entry_price']
trade = {
'entry_index': self.position['entry_index'],
'entry_price': self.position['entry_price'],
'exit_index': next_index,
'exit_price': exit_price,
'pnl': reward
}
self.trade_history.append(trade)
self.position = None
self.current_index = next_index
done = (self.current_index >= len(self.candles) - 1)
return current_state, reward, next_state, done
# -------------------------------------
# Plot Trading Chart with Buy/Sell Markers and PnL Annotations
# -------------------------------------
def plot_trade_history(candles, trade_history):
# Extract close price series from candles.
close_prices = [candle['close'] for candle in candles]
x = list(range(len(close_prices)))
plt.figure(figsize=(12, 6))
plt.plot(x, close_prices, label="Close Price", color="black", linewidth=1)
# Plot markers only once (avoid duplicate labels)
buy_plotted = False
sell_plotted = False
for trade in trade_history:
entry_idx = trade["entry_index"]
exit_idx = trade["exit_index"]
entry_price = trade["entry_price"]
exit_price = trade["exit_price"]
pnl = trade["pnl"]
if not buy_plotted:
plt.plot(entry_idx, entry_price, marker="^", color="green", markersize=10, label="BUY")
buy_plotted = True
else:
plt.plot(entry_idx, entry_price, marker="^", color="green", markersize=10)
if not sell_plotted:
plt.plot(exit_idx, exit_price, marker="v", color="red", markersize=10, label="SELL")
sell_plotted = True
else:
plt.plot(exit_idx, exit_price, marker="v", color="red", markersize=10)
plt.text(exit_idx, exit_price, f"{pnl:+.2f}", color="blue", fontsize=8)
plt.title("Trade History with PnL After Order Close")
plt.xlabel("Candle Index")
plt.ylabel("Price")
plt.legend()
plt.grid(True)
plt.show()
# -------------------------------------
# Training Loop Over Historical Data (Backtest)
# -------------------------------------
def train_on_historical_data(env, rl_agent, num_epochs=10, epsilon=0.1):
"""
For each epoch, run through the historical episode.
At each step, select an action (using εgreedy), simulate a trade,
store the experience, and update the network.
After the epoch, log the total reward and save checkpoints.
"""
for epoch in range(1, num_epochs + 1):
state = env.reset() # clear trade history each epoch
done = False
total_reward = 0.0
steps = 0
while not done:
action = rl_agent.act(state, epsilon=epsilon)
prev_state = state
state, reward, next_state, done = env.step(action)
if next_state is None:
next_state = np.zeros_like(prev_state)
rl_agent.replay_buffer.add((prev_state, action, reward, next_state, done))
rl_agent.train_step()
total_reward += reward
steps += 1
print(f"Epoch {epoch}/{num_epochs} completed, total reward: {total_reward:.4f} over {steps} steps.")
save_checkpoint(rl_agent.model, epoch, total_reward, LAST_DIR, BEST_DIR)
# -------------------------------------
# Main Asynchronous Function for Backtest Training and Charting
# -------------------------------------
async def main_backtest():
# Define symbol, timeframe, and period.
symbol = 'BTC/USDT'
timeframe = '1m'
now = int(time.time() * 1000)
one_day_ms = 24 * 60 * 60 * 1000
# For example, fetch a 1-day period from 2 days ago until 1 day ago.
since = now - one_day_ms * 2
end_time = now - one_day_ms
# Initialize exchange (using MEXC for example).
mexc_api_key = os.environ.get('MEXC_API_KEY', 'YOUR_API_KEY')
mexc_api_secret = os.environ.get('MEXC_API_SECRET', 'YOUR_SECRET_KEY')
exchange = ccxt.mexc({
'apiKey': mexc_api_key,
'secret': mexc_api_secret,
'enableRateLimit': True,
})
print("Fetching historical data...")
candles = await get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time)
if not candles:
print("No historical data fetched.")
await exchange.close()
return
save_candles_cache(CACHE_FILE, candles)
env = BacktestEnvironment(candles)
# Model dimensions: 5 (OHLCV) + 3 (sentiment) = 8.
input_dim = 8
hidden_dim = 128
output_dim = 3 # SELL, HOLD, BUY.
model = TradingModel(input_dim, hidden_dim, output_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
replay_buffer = ReplayBuffer(capacity=10000)
rl_agent = ContinuousRLAgent(model, optimizer, replay_buffer, batch_size=32, gamma=0.99)
# At training start, try loading the best checkpoint if available.
load_best_checkpoint(model, BEST_DIR)
# Run training (backtesting) over historical data.
num_epochs = 10 # adjust as needed.
train_on_historical_data(env, rl_agent, num_epochs=num_epochs, epsilon=0.1)
# Final simulation (without exploration) to log trade history.
state = env.reset(clear_trade_history=True)
done = False
cumulative_reward = 0.0
while not done:
action = rl_agent.act(state, epsilon=0.0)
state, reward, next_state, done = env.step(action)
cumulative_reward += reward
state = next_state
print("Final backtest simulation cumulative profit:", cumulative_reward)
# Draw the chart: plot close price with BUY/SELL markers and PnL annotations.
plot_trade_history(candles, env.trade_history)
await exchange.close()
if __name__ == "__main__":
load_dotenv()
asyncio.run(main_backtest())