gogo2/crypto/brian/index.py
2025-02-04 17:01:08 +02:00

410 lines
14 KiB
Python

#!/usr/bin/env python3
import sys
import asyncio
import os
import time
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
from datetime import datetime
import matplotlib.pyplot as plt
import ccxt.async_support as ccxt
# Load environment variables
if sys.platform == 'win32':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
from dotenv import load_dotenv
load_dotenv()
# Directory setup
LAST_DIR = os.path.join("models", "last")
BEST_DIR = os.path.join("models", "best")
os.makedirs(LAST_DIR, exist_ok=True)
os.makedirs(BEST_DIR, exist_ok=True)
CACHE_FILE = "candles_cache.json"
# -----------------
# Helper Functions
# -----------------
def load_candles_cache(filename):
try:
with open(filename, "r") as f:
data = json.load(f)
print(f"Loaded cached data from {filename}.")
return data
except Exception as e:
print(f"Error reading cache file: {e}")
return {}
def save_candles_cache(filename, candles_dict):
try:
with open(filename, "w") as f:
json.dump(candles_dict, f)
except Exception as e:
print(f"Error saving cache file: {e}")
def maintain_checkpoint_directory(directory, max_files=10):
files = os.listdir(directory)
if len(files) > max_files:
full_paths = [os.path.join(directory, f) for f in files]
full_paths.sort(key=lambda x: os.path.getmtime(x))
for f in full_paths[:len(files) - max_files]:
os.remove(f)
def get_best_models(directory):
best_files = []
for file in os.listdir(directory):
parts = file.split("_")
try:
r = float(parts[1])
best_files.append((r, file))
except Exception:
continue
return best_files
def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
last_path = os.path.join(last_dir, last_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"model_state_dict": model.state_dict()
}, last_path)
maintain_checkpoint_directory(last_dir, max_files=10)
best_models = get_best_models(best_dir)
add_to_best = False
if len(best_models) < 10:
add_to_best = True
else:
min_reward, min_file = min(best_models, key=lambda x: x[0])
if reward > min_reward:
add_to_best = True
os.remove(os.path.join(best_dir, min_file))
if add_to_best:
best_filename = f"best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"
best_path = os.path.join(best_dir, best_filename)
torch.save({
"epoch": epoch,
"reward": reward,
"model_state_dict": model.state_dict()
}, best_path)
maintain_checkpoint_directory(best_dir, max_files=10)
print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}")
def load_best_checkpoint(model, best_dir=BEST_DIR):
best_models = get_best_models(best_dir)
if not best_models:
return None
best_reward, best_file = max(best_models, key=lambda x: x[0])
path = os.path.join(best_dir, best_file)
print(f"Loading best model from checkpoint: {best_file} with reward {best_reward:.4f}")
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
return checkpoint
# --------------------------
# Technical Indicators
# --------------------------
def compute_sma(candles_list, index, period=10):
start = max(0, index - period + 1)
values = [candle["close"] for candle in candles_list[start:index+1]]
return sum(values) / len(values) if values else 0.0
def compute_sma_volume(candles_list, index, period=10):
start = max(0, index - period + 1)
values = [candle["volume"] for candle in candles_list[start:index+1]]
return sum(values) / len(values) if values else 0.0
def compute_rsi(candles_list, index, period=14):
start = max(0, index - period + 1)
values = [candle["close"] for candle in candles_list[start:index+1]]
delta = [values[i+1] - values[i] for i in range(len(values)-1)]
gain, loss = [], []
for d in delta:
if d > 0:
gain.append(d)
loss.append(0)
else:
gain.append(0)
loss.append(abs(d))
avg_gain = sum(gain) / len(gain) if gain else 0
avg_loss = sum(loss) / len(loss) if loss else 0
rs = avg_gain / avg_loss if avg_loss != 0 else 0
rsi = 100 - (100 / (1 + rs)) if avg_loss != 0 else 0
return rsi
def get_aligned_candle_with_index(candles_list, target_ts):
best_idx = 0
for i, candle in enumerate(candles_list):
if candle["timestamp"] <= target_ts:
best_idx = i
else:
break
return best_idx, candles_list[best_idx]
def get_features_for_tf(candles_list, index, period=10):
features = []
candle = candles_list[index]
features.extend([
candle["open"], candle["high"], candle["low"], candle["close"], candle["volume"],
compute_sma(candles_list, index, period), compute_sma_volume(candles_list, index, period),
compute_rsi(candles_list, index, period)
])
return features
# -------------------
# Neural Network
# -------------------
class TransformerModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, n_heads=2, dropout=0.1):
super().__init__()
self.input_linear = nn.Linear(input_dim, hidden_dim)
self.transformer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, dropout=dropout)
self.output_linear = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
x = self.input_linear(x)
x = x.unsqueeze(1)
x = self.transformer(x)
x = x.squeeze(1)
return self.output_linear(x)
class TradingModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.net(x)
# -----------------
# Replay Buffer
# -----------------
class ReplayBuffer:
def __init__(self, capacity=10000):
self.buffer = deque(maxlen=capacity)
def add(self, experience):
self.buffer.append(experience)
def sample(self, batch_size):
indices = np.random.choice(len(self.buffer), size=batch_size, replace=False)
return [self.buffer[i] for i in indices]
def __len__(self):
return len(self.buffer)
# -----------------
# RL Agent
# -----------------
class ContinuousRLAgent:
def __init__(self, model, optimizer, replay_buffer, batch_size=32, gamma=0.99):
self.model = model
self.optimizer = optimizer
self.replay_buffer = replay_buffer
self.batch_size = batch_size
self.loss_fn = nn.MSELoss()
self.gamma = gamma
def act(self, state, epsilon=0.1):
if np.random.rand() < epsilon:
return np.random.randint(0, 3)
state_tensor = torch.from_numpy(np.array(state, dtype=np.float32)).unsqueeze(0)
with torch.no_grad():
output = self.model(state_tensor)
return torch.argmax(output, dim=1).item()
def train_step(self):
if len(self.replay_buffer) < self.batch_size:
return
batch = self.replay_buffer.sample(self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
states_tensor = torch.from_numpy(np.array(states, dtype=np.float32))
actions_tensor = torch.tensor(actions, dtype=torch.int64).unsqueeze(1)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1)
next_states_tensor = torch.from_numpy(np.array(next_states, dtype=np.float32))
dones_tensor = torch.tensor(dones, dtype=torch.float32).unsqueeze(1)
Q_values = self.model(states_tensor)
current_Q = Q_values.gather(1, actions_tensor)
with torch.no_grad():
next_Q_values = self.model(next_states_tensor)
max_next_Q = next_Q_values.max(1)[0].unsqueeze(1)
target = rewards_tensor + self.gamma * max_next_Q * (1.0 - dones_tensor)
loss = self.loss_fn(current_Q, target)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# -----------------
# Trading Environment
# -----------------
class TradingEnvironment:
def __init__(self, candles_dict, base_tf="1m", timeframes=None):
self.candles_dict = candles_dict
self.base_tf = base_tf
self.timeframes = timeframes if timeframes else [base_tf]
self.current_index = 0
self.position = None
self.trade_history = []
def reset(self):
self.current_index = 0
self.position = None
return self.get_state()
def get_state(self):
state_features = []
for tf in self.timeframes:
candles = self.candles_dict[tf]
aligned_idx, candle = get_aligned_candle_with_index(candles, self.candles_dict[self.base_tf][self.current_index]["timestamp"])
features = get_features_for_tf(candles, aligned_idx)
state_features.extend(features)
return np.array(state_features, dtype=np.float32)
def step(self, action):
done = self.current_index >= len(self.candles_dict[self.base_tf]) - 1
if done:
return self.get_state(), 0.0, None, True
current_candle = self.candles_dict[self.base_tf][self.current_index]
next_candle = self.candles_dict[self.base_tf][self.current_index + 1]
if self.position is None:
if action == 2: # Buy
self.position = {"type": "long", "entry_price": next_candle["open"]}
else:
if action == 0: # Sell
exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"]
self.trade_history.append({
"entry_index": self.current_index,
"exit_index": self.current_index + 1,
"entry_price": self.position["entry_price"],
"exit_price": exit_price,
"pnl": reward
})
self.position = None
elif action == 1: # Hold
reward = 0.0
self.current_index += 1
next_state = self.get_state()
done = self.current_index >= len(self.candles_dict[self.base_tf]) - 1
return current_candle, reward, next_state, done
# -----------------
# Fetching Data
# -----------------
async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500):
candles = []
since_ms = since
while True:
try:
batch = await exchange.fetch_ohlcv(symbol, timeframe=timeframe, since=since_ms, limit=batch_size)
except Exception as e:
print(f"Error fetching historical data for {timeframe}: {e}")
break
if not batch:
break
for c in batch:
candle_dict = {
'timestamp': c[0],
'open': c[1],
'high': c[2],
'low': c[3],
'close': c[4],
'volume': c[5]
}
candles.append(candle_dict)
last_timestamp = batch[-1][0]
if last_timestamp >= end_time:
break
since_ms = last_timestamp + 1
return candles
# -----------------
# Training Loop
# -----------------
async def train_model(symbol, timeframes, model, optimizer, replay_buffer, num_epochs=10, epsilon=0.1):
exchange = ccxt.mexc({
'apiKey': os.environ.get('MEXC_API_KEY'),
'secret': os.environ.get('MEXC_API_SECRET'),
'enableRateLimit': True,
})
now = int(time.time() * 1000)
period_ms = 1500 * 60 * 1000 # 1500 minutes
since = now - period_ms
end_time = now
candles_dict = {}
for tf in timeframes:
print(f"Fetching historical data for timeframe {tf}...")
candles = await fetch_historical_data(exchange, symbol, tf, since, end_time)
candles_dict[tf] = candles
env = TradingEnvironment(candles_dict, base_tf=timeframes[0], timeframes=timeframes)
for epoch in range(1, num_epochs + 1):
state = env.reset()
done = False
total_reward = 0.0
steps = 0
while not done:
action = model.act(state, epsilon)
next_state, reward, done_flag, _ = env.step(action)
env_step_result = env.step(action)
current_state = state
action_taken = action
reward_received = env_step_result[1]
next_state = env_step_result[2]
done = env_step_result[3]
replay_buffer.add((current_state, action_taken, reward_received, next_state, done))
if len(replay_buffer) >= replay_buffer.maxlen:
model.train_step()
total_reward += reward_received
steps += 1
state = next_state
print(f"Epoch {epoch}/{num_epochs} completed, total reward: {total_reward:.4f} over {steps} steps.")
save_checkpoint(model, epoch, total_reward)
# -----------------
# Main Function
# -----------------
async def main():
symbol = 'BTC/USDT'
timeframes = ["1m", "5m", "15m", "1h", "1d"]
input_dim = len(timeframes) * 7 # 7 features per timeframe
hidden_dim = 128
output_dim = 3 # Buy, Hold, Sell
model = TradingModel(input_dim, hidden_dim, output_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
replay_buffer = ReplayBuffer(capacity=10000)
await train_model(symbol, timeframes, model, optimizer, replay_buffer)
if __name__ == "__main__":
asyncio.run(main())