graph, ignores
This commit is contained in:
parent
08f26785ea
commit
c7c7acdb26
2
.gitignore
vendored
2
.gitignore
vendored
@ -29,3 +29,5 @@ crypto/sol/logs/transation_details.json
|
||||
.env
|
||||
app_data.db
|
||||
crypto/sol/.vs/*
|
||||
crypto/brian/models/best/*
|
||||
crypto/brian/models/last/*
|
||||
|
@ -15,6 +15,7 @@ import torch.optim as optim
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# --- Directories for saving models ---
|
||||
LAST_DIR = os.path.join("models", "last")
|
||||
@ -60,12 +61,12 @@ def maintain_checkpoint_directory(directory, max_files=10):
|
||||
|
||||
def get_best_models(directory):
|
||||
"""Return a list of (reward, filename) for files in the best folder.
|
||||
Expecting filenames like: best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"""
|
||||
Expected filename format: best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"""
|
||||
best_files = []
|
||||
for file in os.listdir(directory):
|
||||
parts = file.split("_")
|
||||
try:
|
||||
# parts[1] should be reward
|
||||
# parts[1] should be the reward
|
||||
r = float(parts[1])
|
||||
best_files.append((r, file))
|
||||
except Exception:
|
||||
@ -73,21 +74,18 @@ def get_best_models(directory):
|
||||
return best_files
|
||||
|
||||
def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
|
||||
"""Save the model state always to the last_dir and conditionally to best_dir if reward is high enough."""
|
||||
"""Save the model state at each epoch to last_dir and, conditionally, to best_dir."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
# last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
|
||||
last_filename = f"model_last_epoch_{epoch}.pt"
|
||||
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
|
||||
last_path = os.path.join(last_dir, last_filename)
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"reward": reward,
|
||||
"model_state_dict": model.state_dict()
|
||||
}, last_path)
|
||||
# Keep only last 10 models in last_dir.
|
||||
# Maintain only last 10 checkpoints
|
||||
maintain_checkpoint_directory(last_dir, max_files=10)
|
||||
|
||||
# Check the best folder – if fewer than 10, simply add;
|
||||
# Otherwise, add only if reward is higher than the lowest reward in best.
|
||||
best_models = get_best_models(best_dir)
|
||||
add_to_best = False
|
||||
if len(best_models) < 10:
|
||||
@ -96,11 +94,9 @@ def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
|
||||
min_reward, min_file = min(best_models, key=lambda x: x[0])
|
||||
if reward > min_reward:
|
||||
add_to_best = True
|
||||
# Remove the worst checkpoint.
|
||||
os.remove(os.path.join(best_dir, min_file))
|
||||
if add_to_best:
|
||||
# best_filename = f"best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"
|
||||
best_filename = f"best_epoch_{epoch}.pt"
|
||||
best_filename = f"best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt"
|
||||
best_path = os.path.join(best_dir, best_filename)
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
@ -111,7 +107,7 @@ def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR):
|
||||
print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}")
|
||||
|
||||
def load_best_checkpoint(model, best_dir=BEST_DIR):
|
||||
"""Load the best checkpoint (with highest reward) from the best directory if available."""
|
||||
"""Load the best checkpoint (with highest reward) if available."""
|
||||
best_models = get_best_models(best_dir)
|
||||
if not best_models:
|
||||
return None
|
||||
@ -157,7 +153,7 @@ class ReplayBuffer:
|
||||
return len(self.buffer)
|
||||
|
||||
# -------------------------------------
|
||||
# A Simple Indicator and Feature Preparation Function
|
||||
# Indicator and Feature Preparation Function
|
||||
# -------------------------------------
|
||||
def compute_indicators(candle, additional_data):
|
||||
"""
|
||||
@ -177,7 +173,7 @@ def compute_indicators(candle, additional_data):
|
||||
return np.array(features, dtype=np.float32)
|
||||
|
||||
# -------------------------------------
|
||||
# RL Agent with Q-Learning Update and Epsilon-Greedy Exploration
|
||||
# RL Agent with Q-Learning and Epsilon-Greedy Exploration
|
||||
# -------------------------------------
|
||||
class ContinuousRLAgent:
|
||||
def __init__(self, model, optimizer, replay_buffer, batch_size=32, gamma=0.99):
|
||||
@ -189,7 +185,6 @@ class ContinuousRLAgent:
|
||||
self.gamma = gamma
|
||||
|
||||
def act(self, state, epsilon=0.1):
|
||||
# ε-greedy: choose random action with probability epsilon.
|
||||
if np.random.rand() < epsilon:
|
||||
return np.random.randint(0, 3)
|
||||
state_tensor = torch.from_numpy(np.array(state, dtype=np.float32)).unsqueeze(0)
|
||||
@ -199,12 +194,10 @@ class ContinuousRLAgent:
|
||||
return action
|
||||
|
||||
def train_step(self):
|
||||
# Only train if we have enough samples.
|
||||
if len(self.replay_buffer) < self.batch_size:
|
||||
return
|
||||
|
||||
batch = self.replay_buffer.sample(self.batch_size)
|
||||
# Unpack the batch; each experience is (state, action, reward, next_state, done)
|
||||
states, actions, rewards, next_states, dones = zip(*batch)
|
||||
states_tensor = torch.from_numpy(np.array(states, dtype=np.float32))
|
||||
actions_tensor = torch.tensor(actions, dtype=torch.int64)
|
||||
@ -224,12 +217,12 @@ class ContinuousRLAgent:
|
||||
self.optimizer.step()
|
||||
|
||||
# -------------------------------------
|
||||
# Historical Data Fetching Function
|
||||
# Historical Data Fetching Functions
|
||||
# -------------------------------------
|
||||
async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500):
|
||||
"""
|
||||
Fetch historical OHLCV data for the given symbol and timeframe.
|
||||
"since" and "end_time" are in milliseconds.
|
||||
Fetch historical OHLCV data for a given symbol and timeframe.
|
||||
"since" and "end_time" are given in milliseconds.
|
||||
"""
|
||||
candles = []
|
||||
since_ms = since
|
||||
@ -274,17 +267,20 @@ async def get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time,
|
||||
return candles
|
||||
|
||||
# -------------------------------------
|
||||
# Backtest Environment Class Definition
|
||||
# Backtest Environment with Trade History Recording
|
||||
# -------------------------------------
|
||||
class BacktestEnvironment:
|
||||
def __init__(self, candles):
|
||||
self.candles = candles
|
||||
self.current_index = 0
|
||||
self.position = None # Holds an open position, if any
|
||||
self.position = None # Active position: dict with 'entry_price' and 'entry_index'
|
||||
self.trade_history = [] # List of closed trades
|
||||
|
||||
def reset(self):
|
||||
def reset(self, clear_trade_history=True):
|
||||
self.current_index = 0
|
||||
self.position = None
|
||||
if clear_trade_history:
|
||||
self.trade_history = []
|
||||
return self.get_state(self.current_index)
|
||||
|
||||
def get_state(self, index):
|
||||
@ -298,9 +294,9 @@ class BacktestEnvironment:
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Simulate a trading step.
|
||||
- If not in a position and action is BUY (2), enter a long position at the next candle's open.
|
||||
- If in a position and action is SELL (0), close the position at the next candle's open.
|
||||
Simulate a trading step:
|
||||
- If not in a position and action is BUY (2), record an entry at next candle's open.
|
||||
- If in a position and action is SELL (0), record an exit at next candle's open and compute PnL.
|
||||
Returns: (current_state, reward, next_state, done)
|
||||
"""
|
||||
if self.current_index >= len(self.candles) - 1:
|
||||
@ -314,31 +310,79 @@ class BacktestEnvironment:
|
||||
reward = 0.0
|
||||
|
||||
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
|
||||
# If not in a position:
|
||||
if self.position is None:
|
||||
if action == 2: # BUY signal:
|
||||
if action == 2: # BUY signal: enter position at next candle's open.
|
||||
entry_price = next_candle['open']
|
||||
self.position = {'entry_price': entry_price, 'entry_index': self.current_index}
|
||||
else:
|
||||
if action == 0: # SELL signal:
|
||||
sell_price = next_candle['open']
|
||||
reward = sell_price - self.position['entry_price']
|
||||
if action == 0: # SELL signal: exit position at next candle's open.
|
||||
exit_price = next_candle['open']
|
||||
reward = exit_price - self.position['entry_price']
|
||||
trade = {
|
||||
'entry_index': self.position['entry_index'],
|
||||
'entry_price': self.position['entry_price'],
|
||||
'exit_index': next_index,
|
||||
'exit_price': exit_price,
|
||||
'pnl': reward
|
||||
}
|
||||
self.trade_history.append(trade)
|
||||
self.position = None
|
||||
|
||||
self.current_index = next_index
|
||||
done = (self.current_index >= len(self.candles) - 1)
|
||||
return current_state, reward, next_state, done
|
||||
|
||||
# -------------------------------------
|
||||
# Plot Trading Chart with Buy/Sell Markers and PnL Annotations
|
||||
# -------------------------------------
|
||||
def plot_trade_history(candles, trade_history):
|
||||
# Extract close price series from candles.
|
||||
close_prices = [candle['close'] for candle in candles]
|
||||
x = list(range(len(close_prices)))
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(x, close_prices, label="Close Price", color="black", linewidth=1)
|
||||
|
||||
# Plot markers only once (avoid duplicate labels)
|
||||
buy_plotted = False
|
||||
sell_plotted = False
|
||||
for trade in trade_history:
|
||||
entry_idx = trade["entry_index"]
|
||||
exit_idx = trade["exit_index"]
|
||||
entry_price = trade["entry_price"]
|
||||
exit_price = trade["exit_price"]
|
||||
pnl = trade["pnl"]
|
||||
if not buy_plotted:
|
||||
plt.plot(entry_idx, entry_price, marker="^", color="green", markersize=10, label="BUY")
|
||||
buy_plotted = True
|
||||
else:
|
||||
plt.plot(entry_idx, entry_price, marker="^", color="green", markersize=10)
|
||||
if not sell_plotted:
|
||||
plt.plot(exit_idx, exit_price, marker="v", color="red", markersize=10, label="SELL")
|
||||
sell_plotted = True
|
||||
else:
|
||||
plt.plot(exit_idx, exit_price, marker="v", color="red", markersize=10)
|
||||
plt.text(exit_idx, exit_price, f"{pnl:+.2f}", color="blue", fontsize=8)
|
||||
|
||||
plt.title("Trade History with PnL After Order Close")
|
||||
plt.xlabel("Candle Index")
|
||||
plt.ylabel("Price")
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.show()
|
||||
|
||||
# -------------------------------------
|
||||
# Training Loop Over Historical Data (Backtest)
|
||||
# -------------------------------------
|
||||
def train_on_historical_data(env, rl_agent, num_epochs=10, epsilon=0.1):
|
||||
"""
|
||||
For each epoch, run through the entire historical episode.
|
||||
At each step, pick an action (using ε-greedy), simulate a trade, store the experience,
|
||||
and update the model. Then log the cumulative reward and save checkpoints.
|
||||
For each epoch, run through the historical episode.
|
||||
At each step, select an action (using ε‑greedy), simulate a trade,
|
||||
store the experience, and update the network.
|
||||
After the epoch, log the total reward and save checkpoints.
|
||||
"""
|
||||
for epoch in range(1, num_epochs + 1):
|
||||
state = env.reset()
|
||||
state = env.reset() # clear trade history each epoch
|
||||
done = False
|
||||
total_reward = 0.0
|
||||
steps = 0
|
||||
@ -348,20 +392,18 @@ def train_on_historical_data(env, rl_agent, num_epochs=10, epsilon=0.1):
|
||||
state, reward, next_state, done = env.step(action)
|
||||
if next_state is None:
|
||||
next_state = np.zeros_like(prev_state)
|
||||
# Save the experience (state, action, reward, next_state, done)
|
||||
rl_agent.replay_buffer.add((prev_state, action, reward, next_state, done))
|
||||
rl_agent.train_step()
|
||||
total_reward += reward
|
||||
steps += 1
|
||||
print(f"Epoch {epoch}/{num_epochs} completed, total reward: {total_reward:.4f} over {steps} steps.")
|
||||
# Save a checkpoint after the epoch.
|
||||
save_checkpoint(rl_agent.model, epoch, total_reward, LAST_DIR, BEST_DIR)
|
||||
|
||||
# -------------------------------------
|
||||
# Main Asynchronous Function for Backtest Training
|
||||
# Main Asynchronous Function for Backtest Training and Charting
|
||||
# -------------------------------------
|
||||
async def main_backtest():
|
||||
# Define symbol, timeframe, and historical period.
|
||||
# Define symbol, timeframe, and period.
|
||||
symbol = 'BTC/USDT'
|
||||
timeframe = '1m'
|
||||
now = int(time.time() * 1000)
|
||||
@ -386,10 +428,7 @@ async def main_backtest():
|
||||
await exchange.close()
|
||||
return
|
||||
|
||||
# Save updated cache.
|
||||
save_candles_cache(CACHE_FILE, candles)
|
||||
|
||||
# Initialize backtest environment.
|
||||
env = BacktestEnvironment(candles)
|
||||
|
||||
# Model dimensions: 5 (OHLCV) + 3 (sentiment) = 8.
|
||||
@ -402,15 +441,15 @@ async def main_backtest():
|
||||
replay_buffer = ReplayBuffer(capacity=10000)
|
||||
rl_agent = ContinuousRLAgent(model, optimizer, replay_buffer, batch_size=32, gamma=0.99)
|
||||
|
||||
# At training start, try loading a best checkpoint (if available).
|
||||
# At training start, try loading the best checkpoint if available.
|
||||
load_best_checkpoint(model, BEST_DIR)
|
||||
|
||||
# Run training over historical data.
|
||||
num_epochs = 10 # Change as needed.
|
||||
# Run training (backtesting) over historical data.
|
||||
num_epochs = 10 # adjust as needed.
|
||||
train_on_historical_data(env, rl_agent, num_epochs=num_epochs, epsilon=0.1)
|
||||
|
||||
# Final simulation (without exploration) to check cumulative profit.
|
||||
state = env.reset()
|
||||
# Final simulation (without exploration) to log trade history.
|
||||
state = env.reset(clear_trade_history=True)
|
||||
done = False
|
||||
cumulative_reward = 0.0
|
||||
while not done:
|
||||
@ -420,6 +459,9 @@ async def main_backtest():
|
||||
state = next_state
|
||||
print("Final backtest simulation cumulative profit:", cumulative_reward)
|
||||
|
||||
# Draw the chart: plot close price with BUY/SELL markers and PnL annotations.
|
||||
plot_trade_history(candles, env.trade_history)
|
||||
|
||||
await exchange.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user