From 79c51c0d5df7a39459d3cfb776ee275295372163 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Sun, 2 Feb 2025 01:06:56 +0200 Subject: [PATCH] wip --- crypto/brian/index.py | 256 +++++++++++++++++++++--------------------- 1 file changed, 127 insertions(+), 129 deletions(-) diff --git a/crypto/brian/index.py b/crypto/brian/index.py index 840e6bb..7691501 100644 --- a/crypto/brian/index.py +++ b/crypto/brian/index.py @@ -33,40 +33,35 @@ def load_candles_cache(filename): try: with open(filename, "r") as f: data = json.load(f) - print(f"Loaded {len(data)} candles from cache.") + print(f"Loaded cached data from {filename}.") return data except Exception as e: print("Error reading cache file:", e) - return [] + return {} -def save_candles_cache(filename, candles): +def save_candles_cache(filename, candles_dict): try: with open(filename, "w") as f: - json.dump(candles, f) + json.dump(candles_dict, f) except Exception as e: print("Error saving cache file:", e) # ------------------------------------- -# Functions for handling checkpoints +# Checkpoint Functions (same as before) # ------------------------------------- def maintain_checkpoint_directory(directory, max_files=10): - """Keep only the most recent max_files in a given directory based on modification time.""" files = os.listdir(directory) if len(files) > max_files: full_paths = [os.path.join(directory, f) for f in files] full_paths.sort(key=lambda x: os.path.getmtime(x)) - # Remove the oldest files for f in full_paths[: len(files) - max_files]: os.remove(f) def get_best_models(directory): - """Return a list of (reward, filename) for files in the best folder. - Expected filename format: best_{reward:.4f}_epoch_{epoch}_{timestamp}.pt""" best_files = [] for file in os.listdir(directory): parts = file.split("_") try: - # parts[1] should be the reward r = float(parts[1]) best_files.append((r, file)) except Exception: @@ -74,7 +69,6 @@ def get_best_models(directory): return best_files def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR): - """Save the model state at each epoch to last_dir and, conditionally, to best_dir.""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_path = os.path.join(last_dir, last_filename) @@ -83,7 +77,6 @@ def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR): "reward": reward, "model_state_dict": model.state_dict() }, last_path) - # Maintain only last 10 checkpoints maintain_checkpoint_directory(last_dir, max_files=10) best_models = get_best_models(best_dir) @@ -107,7 +100,6 @@ def save_checkpoint(model, epoch, reward, last_dir=LAST_DIR, best_dir=BEST_DIR): print(f"Saved checkpoint for epoch {epoch} with reward {reward:.4f}") def load_best_checkpoint(model, best_dir=BEST_DIR): - """Load the best checkpoint (with highest reward) if available.""" best_models = get_best_models(best_dir) if not best_models: return None @@ -118,6 +110,41 @@ def load_best_checkpoint(model, best_dir=BEST_DIR): model.load_state_dict(checkpoint["model_state_dict"]) return checkpoint +# ------------------------------------- +# Technical Indicator Helper Functions +# ------------------------------------- +def compute_sma(candles_list, index, period=10): + start = max(0, index - period + 1) + values = [candle["close"] for candle in candles_list[start:index+1]] + return sum(values) / len(values) if values else 0.0 + +def compute_sma_volume(candles_list, index, period=10): + start = max(0, index - period + 1) + values = [candle["volume"] for candle in candles_list[start:index+1]] + return sum(values) / len(values) if values else 0.0 + +def get_aligned_candle_with_index(candles_list, target_ts): + """Find the candle in the list whose timestamp is the largest that is <= target_ts.""" + best_idx = 0 + for i, candle in enumerate(candles_list): + if candle["timestamp"] <= target_ts: + best_idx = i + else: + break + return best_idx, candles_list[best_idx] + +def get_features_for_tf(candles_list, index, period=10): + """Return a vector of 7 features: open, high, low, close, volume, sma_close, sma_volume.""" + candle = candles_list[index] + f_open = candle["open"] + f_high = candle["high"] + f_low = candle["low"] + f_close = candle["close"] + f_volume = candle["volume"] + sma_close = compute_sma(candles_list, index, period) + sma_volume = compute_sma_volume(candles_list, index, period) + return [f_open, f_high, f_low, f_close, f_volume, sma_close, sma_volume] + # ------------------------------------- # Neural Network Architecture Definition # ------------------------------------- @@ -152,26 +179,6 @@ class ReplayBuffer: def __len__(self): return len(self.buffer) -# ------------------------------------- -# Indicator and Feature Preparation Function -# ------------------------------------- -def compute_indicators(candle, additional_data): - """ - Combine OHLCV candle data with extra indicator information. - Base features: open, high, low, close, volume. - Additional channels (e.g., simulated sentiment) are appended. - """ - features = [ - candle.get('open', 0.0), - candle.get('high', 0.0), - candle.get('low', 0.0), - candle.get('close', 0.0), - candle.get('volume', 0.0), - ] - for key, value in additional_data.items(): - features.append(value) - return np.array(features, dtype=np.float32) - # ------------------------------------- # RL Agent with Q-Learning and Epsilon-Greedy Exploration # ------------------------------------- @@ -190,8 +197,7 @@ class ContinuousRLAgent: state_tensor = torch.from_numpy(np.array(state, dtype=np.float32)).unsqueeze(0) with torch.no_grad(): output = self.model(state_tensor) - action = torch.argmax(output, dim=1).item() - return action + return torch.argmax(output, dim=1).item() def train_step(self): if len(self.replay_buffer) < self.batch_size: @@ -217,20 +223,16 @@ class ContinuousRLAgent: self.optimizer.step() # ------------------------------------- -# Historical Data Fetching Functions +# Historical Data Fetching Function (for a given timeframe) # ------------------------------------- async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size=500): - """ - Fetch historical OHLCV data for a given symbol and timeframe. - "since" and "end_time" are given in milliseconds. - """ candles = [] since_ms = since while True: try: batch = await exchange.fetch_ohlcv(symbol, timeframe=timeframe, since=since_ms, limit=batch_size) except Exception as e: - print("Error fetching historical data:", e) + print(f"Error fetching historical data for {timeframe}:", e) break if not batch: break @@ -248,33 +250,23 @@ async def fetch_historical_data(exchange, symbol, timeframe, since, end_time, ba if last_timestamp >= end_time: break since_ms = last_timestamp + 1 - print(f"Fetched {len(candles)} candles.") + print(f"Fetched {len(candles)} candles for timeframe {timeframe}.") return candles -async def get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time, cache_file=CACHE_FILE, batch_size=500): - cached_candles = load_candles_cache(cache_file) - if cached_candles: - last_ts = cached_candles[-1]['timestamp'] - if last_ts < end_time: - print("Fetching new candles to update cache...") - new_candles = await fetch_historical_data(exchange, symbol, timeframe, last_ts + 1, end_time, batch_size) - cached_candles.extend(new_candles) - else: - print("Cache covers the requested period.") - return cached_candles - else: - candles = await fetch_historical_data(exchange, symbol, timeframe, since, end_time, batch_size) - return candles - # ------------------------------------- -# Backtest Environment with Trade History Recording +# Backtest Environment with Multi-Timeframe State # ------------------------------------- class BacktestEnvironment: - def __init__(self, candles): - self.candles = candles - self.current_index = 0 - self.position = None # Active position: dict with 'entry_price' and 'entry_index' - self.trade_history = [] # List of closed trades + def __init__(self, candles_dict, base_tf="1m", timeframes=None): + self.candles_dict = candles_dict # dict of timeframe: candles_list + self.base_tf = base_tf + if timeframes is None: + self.timeframes = [base_tf] # fallback to single timeframe + else: + self.timeframes = timeframes + self.trade_history = [] # record of closed trades + self.current_index = 0 # index on base_tf candles + self.position = None # active position record def reset(self, clear_trade_history=True): self.current_index = 0 @@ -284,66 +276,69 @@ class BacktestEnvironment: return self.get_state(self.current_index) def get_state(self, index): - candle = self.candles[index] - sentiment = { - 'sentiment_score': np.random.rand(), - 'news_volume': np.random.rand(), - 'social_engagement': np.random.rand() - } - return compute_indicators(candle, sentiment) + """Construct the state as the concatenated features of all timeframes. + For each timeframe, find the aligned candle for the base timeframe’s timestamp.""" + state_features = [] + base_candle = self.candles_dict[self.base_tf][index] + base_ts = base_candle["timestamp"] + for tf in self.timeframes: + candles_list = self.candles_dict[tf] + # Get the candle from this timeframe that is closest to (and <=) base_ts. + aligned_index, _ = get_aligned_candle_with_index(candles_list, base_ts) + features = get_features_for_tf(candles_list, aligned_index, period=10) + state_features.extend(features) + return np.array(state_features, dtype=np.float32) def step(self, action): """ - Simulate a trading step: - - If not in a position and action is BUY (2), record an entry at next candle's open. - - If in a position and action is SELL (0), record an exit at next candle's open and compute PnL. + Simulate a trading step based on the base timeframe. + - If not in a position and action is BUY (2), record entry at next candle's open. + - If in a position and action is SELL (0), record exit at next candle's open, computing PnL. Returns: (current_state, reward, next_state, done) """ - if self.current_index >= len(self.candles) - 1: + base_candles = self.candles_dict[self.base_tf] + if self.current_index >= len(base_candles) - 1: return self.get_state(self.current_index), 0.0, None, True current_state = self.get_state(self.current_index) next_index = self.current_index + 1 next_state = self.get_state(next_index) - current_candle = self.candles[self.current_index] - next_candle = self.candles[next_index] + current_candle = base_candles[self.current_index] + next_candle = base_candles[next_index] reward = 0.0 # Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY. - # If not in a position: if self.position is None: if action == 2: # BUY signal: enter position at next candle's open. - entry_price = next_candle['open'] - self.position = {'entry_price': entry_price, 'entry_index': self.current_index} + entry_price = next_candle["open"] + self.position = {"entry_price": entry_price, "entry_index": self.current_index} else: - if action == 0: # SELL signal: exit position at next candle's open. - exit_price = next_candle['open'] - reward = exit_price - self.position['entry_price'] + if action == 0: # SELL signal: close position at next candle's open. + exit_price = next_candle["open"] + reward = exit_price - self.position["entry_price"] trade = { - 'entry_index': self.position['entry_index'], - 'entry_price': self.position['entry_price'], - 'exit_index': next_index, - 'exit_price': exit_price, - 'pnl': reward + "entry_index": self.position["entry_index"], + "entry_price": self.position["entry_price"], + "exit_index": next_index, + "exit_price": exit_price, + "pnl": reward } self.trade_history.append(trade) self.position = None self.current_index = next_index - done = (self.current_index >= len(self.candles) - 1) + done = (self.current_index >= len(base_candles) - 1) return current_state, reward, next_state, done # ------------------------------------- -# Plot Trading Chart with Buy/Sell Markers and PnL Annotations +# Chart Plotting: Trade History & PnL # ------------------------------------- def plot_trade_history(candles, trade_history): - # Extract close price series from candles. - close_prices = [candle['close'] for candle in candles] + close_prices = [candle["close"] for candle in candles] x = list(range(len(close_prices))) plt.figure(figsize=(12, 6)) plt.plot(x, close_prices, label="Close Price", color="black", linewidth=1) - # Plot markers only once (avoid duplicate labels) buy_plotted = False sell_plotted = False for trade in trade_history: @@ -363,24 +358,17 @@ def plot_trade_history(candles, trade_history): else: plt.plot(exit_idx, exit_price, marker="v", color="red", markersize=10) plt.text(exit_idx, exit_price, f"{pnl:+.2f}", color="blue", fontsize=8) - - plt.title("Trade History with PnL After Order Close") - plt.xlabel("Candle Index") + plt.title("Trade History with PnL") + plt.xlabel("Base Candle Index (1m)") plt.ylabel("Price") plt.legend() plt.grid(True) plt.show() # ------------------------------------- -# Training Loop Over Historical Data (Backtest) +# Training Loop: Backtesting Trading Episodes # ------------------------------------- def train_on_historical_data(env, rl_agent, num_epochs=10, epsilon=0.1): - """ - For each epoch, run through the historical episode. - At each step, select an action (using ε‑greedy), simulate a trade, - store the experience, and update the network. - After the epoch, log the total reward and save checkpoints. - """ for epoch in range(1, num_epochs + 1): state = env.reset() # clear trade history each epoch done = False @@ -400,19 +388,19 @@ def train_on_historical_data(env, rl_agent, num_epochs=10, epsilon=0.1): save_checkpoint(rl_agent.model, epoch, total_reward, LAST_DIR, BEST_DIR) # ------------------------------------- -# Main Asynchronous Function for Backtest Training and Charting +# Main Asynchronous Function for Training & Charting # ------------------------------------- async def main_backtest(): - # Define symbol, timeframe, and period. symbol = 'BTC/USDT' - timeframe = '1m' + # Define timeframes: we'll use 5 different ones. + timeframes = ["1m", "5m", "15m", "1h", "1d"] now = int(time.time() * 1000) - one_day_ms = 24 * 60 * 60 * 1000 - # For example, fetch a 1-day period from 2 days ago until 1 day ago. - since = now - one_day_ms * 2 - end_time = now - one_day_ms + # Use the base timeframe period of 1500 candles. For 1m, that is 1500 minutes. + period_ms = 1500 * 60 * 1000 + since = now - period_ms + end_time = now - # Initialize exchange (using MEXC for example). + # Initialize exchange using MEXC (or your preferred exchange). mexc_api_key = os.environ.get('MEXC_API_KEY', 'YOUR_API_KEY') mexc_api_secret = os.environ.get('MEXC_API_SECRET', 'YOUR_SECRET_KEY') exchange = ccxt.mexc({ @@ -421,34 +409,36 @@ async def main_backtest(): 'enableRateLimit': True, }) - print("Fetching historical data...") - candles = await get_cached_or_fetch_data(exchange, symbol, timeframe, since, end_time) - if not candles: - print("No historical data fetched.") - await exchange.close() - return + candles_dict = {} + for tf in timeframes: + print(f"Fetching historical data for timeframe {tf}...") + candles = await fetch_historical_data(exchange, symbol, tf, since, end_time, batch_size=500) + candles_dict[tf] = candles - save_candles_cache(CACHE_FILE, candles) - env = BacktestEnvironment(candles) + # Optionally, save the multi-timeframe cache. + save_candles_cache(CACHE_FILE, candles_dict) - # Model dimensions: 5 (OHLCV) + 3 (sentiment) = 8. - input_dim = 8 + # Create the backtest environment using multi-timeframe data. + env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes) + + # Neural Network dimensions: each timeframe produces 7 features. + input_dim = len(timeframes) * 7 # 7 features * 5 timeframes = 35. hidden_dim = 128 - output_dim = 3 # SELL, HOLD, BUY. + output_dim = 3 # Actions: SELL, HOLD, BUY. model = TradingModel(input_dim, hidden_dim, output_dim) optimizer = optim.Adam(model.parameters(), lr=1e-4) replay_buffer = ReplayBuffer(capacity=10000) rl_agent = ContinuousRLAgent(model, optimizer, replay_buffer, batch_size=32, gamma=0.99) - # At training start, try loading the best checkpoint if available. + # Load best checkpoint if available. load_best_checkpoint(model, BEST_DIR) - # Run training (backtesting) over historical data. - num_epochs = 10 # adjust as needed. + # Train the agent over the historical period. + num_epochs = 10 # Adjust as needed. train_on_historical_data(env, rl_agent, num_epochs=num_epochs, epsilon=0.1) - # Final simulation (without exploration) to log trade history. + # Run a final simulation (without exploration) to record trade history. state = env.reset(clear_trade_history=True) done = False cumulative_reward = 0.0 @@ -457,10 +447,18 @@ async def main_backtest(): state, reward, next_state, done = env.step(action) cumulative_reward += reward state = next_state - print("Final backtest simulation cumulative profit:", cumulative_reward) - - # Draw the chart: plot close price with BUY/SELL markers and PnL annotations. - plot_trade_history(candles, env.trade_history) + print("Final simulation cumulative profit:", cumulative_reward) + + # Evaluate trade performance. + trades = env.trade_history + num_trades = len(trades) + num_wins = sum(1 for trade in trades if trade["pnl"] > 0) + win_rate = (num_wins / num_trades * 100) if num_trades > 0 else 0.0 + total_profit = sum(trade["pnl"] for trade in trades) + print(f"Total trades: {num_trades}, Wins: {num_wins}, Win rate: {win_rate:.2f}%, Total Profit: {total_profit:.4f}") + + # Plot chart with buy/sell markers on the base timeframe ("1m"). + plot_trade_history(candles_dict["1m"], trades) await exchange.close()