From 4de63524686ce1481442b5bca818e3038cb889b7 Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Mon, 17 Mar 2025 16:18:11 +0200 Subject: [PATCH] added CNN module --- crypto/gogo2/.vscode/launch.json | 2 +- crypto/gogo2/main.py | 2745 ++++++++++++++++-------------- 2 files changed, 1477 insertions(+), 1270 deletions(-) diff --git a/crypto/gogo2/.vscode/launch.json b/crypto/gogo2/.vscode/launch.json index fd260bd..165d15a 100644 --- a/crypto/gogo2/.vscode/launch.json +++ b/crypto/gogo2/.vscode/launch.json @@ -6,7 +6,7 @@ "type": "python", "request": "launch", "program": "main.py", - "args": ["--mode", "train", "--episodes", "1000"], + "args": ["--mode", "train", "--episodes", "10"], "console": "integratedTerminal", "justMyCode": true }, diff --git a/crypto/gogo2/main.py b/crypto/gogo2/main.py index b5c40fc..354e5dd 100644 --- a/crypto/gogo2/main.py +++ b/crypto/gogo2/main.py @@ -30,6 +30,7 @@ import matplotlib.pyplot as mpf import matplotlib.gridspec as gridspec import datetime from datetime import datetime as dt +from collections import defaultdict # Configure logging logging.basicConfig( @@ -64,25 +65,73 @@ Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state' # Add this function near the top of the file, after the imports but before any classes def find_local_extrema(prices, window=5): - """Find local minima (bottoms) and maxima (tops) in price data""" - bottoms = [] + """ + Find local extrema (tops and bottoms) in price series. + + Args: + prices: Array of price values + window: Window size for finding extrema + + Returns: + Tuple of (tops, bottoms) indices + """ tops = [] + bottoms = [] if len(prices) < window * 2 + 1: - return bottoms, tops + return tops, bottoms - for i in range(window, len(prices) - window): - # Check if this is a local minimum (bottom) - if all(prices[i] <= prices[i-j] for j in range(1, window+1)) and \ - all(prices[i] <= prices[i+j] for j in range(1, window+1)): - bottoms.append(i) + try: + # Use peak detection algorithms from scipy if available + from scipy.signal import find_peaks - # Check if this is a local maximum (top) - if all(prices[i] >= prices[i-j] for j in range(1, window+1)) and \ - all(prices[i] >= prices[i+j] for j in range(1, window+1)): - tops.append(i) + # Find peaks (tops) + peaks, _ = find_peaks(prices, distance=window) + tops = list(peaks) + + # Find valleys (bottoms) by inverting the prices + valleys, _ = find_peaks(-prices, distance=window) + bottoms = list(valleys) + + # Optional: Filter extrema for significance + if len(tops) > 0 and len(bottoms) > 0: + # Calculate average price move + avg_move = np.mean(np.abs(np.diff(prices))) + + # Filter tops and bottoms for significant moves + filtered_tops = [] + for top in tops: + # Check if this top is significantly higher than surrounding points + if top > window and top < len(prices) - window: + surrounding_min = min(prices[top-window:top+window]) + if prices[top] - surrounding_min > avg_move * 1.5: # 1.5x average move + filtered_tops.append(top) + + filtered_bottoms = [] + for bottom in bottoms: + # Check if this bottom is significantly lower than surrounding points + if bottom > window and bottom < len(prices) - window: + surrounding_max = max(prices[bottom-window:bottom+window]) + if surrounding_max - prices[bottom] > avg_move * 1.5: # 1.5x average move + filtered_bottoms.append(bottom) + + tops = filtered_tops + bottoms = filtered_bottoms - return bottoms, tops + except ImportError: + # Fallback to manual detection if scipy is not available + for i in range(window, len(prices) - window): + # Check if this point is a local maximum + if all(prices[i] >= prices[i - j] for j in range(1, window + 1)) and \ + all(prices[i] >= prices[i + j] for j in range(1, window + 1)): + tops.append(i) + + # Check if this point is a local minimum + if all(prices[i] <= prices[i - j] for j in range(1, window + 1)) and \ + all(prices[i] <= prices[i + j] for j in range(1, window + 1)): + bottoms.append(i) + + return tops, bottoms class ReplayMemory: def __init__(self, capacity): @@ -98,88 +147,28 @@ class ReplayMemory: return len(self.memory) class DQN(nn.Module): + """ + Wrapper class that uses LSTMAttentionDQN as the network architecture. + This maintains backward compatibility with any code expecting the DQN class. + """ def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4): super(DQN, self).__init__() + # Directly use LSTMAttentionDQN as the internal network + self.network = LSTMAttentionDQN(state_size, action_size, hidden_size, lstm_layers, attention_heads) + # Store network parameters for access self.state_size = state_size + self.action_size = action_size self.hidden_size = hidden_size self.lstm_layers = lstm_layers - - # Initial feature extraction - self.fc1 = nn.Linear(state_size, hidden_size) - # Use LayerNorm instead of BatchNorm for more stability with varying batch sizes - self.ln1 = nn.LayerNorm(hidden_size) - self.dropout1 = nn.Dropout(0.2) - - # LSTM layer for sequential data - self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=lstm_layers, batch_first=True, dropout=0.2) - - # Attention mechanism - self.attention = nn.MultiheadAttention(hidden_size, attention_heads) - - # Output layers with increased capacity - self.fc2 = nn.Linear(hidden_size, hidden_size) - self.ln2 = nn.LayerNorm(hidden_size) # LayerNorm instead of BatchNorm - self.dropout2 = nn.Dropout(0.2) - self.fc3 = nn.Linear(hidden_size, hidden_size // 2) - - # Dueling DQN architecture - self.value_stream = nn.Linear(hidden_size // 2, 1) - self.advantage_stream = nn.Linear(hidden_size // 2, action_size) - - # Transformer encoder for more complex pattern recognition - encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=attention_heads, dropout=0.1) - self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2) - - def forward(self, x): - batch_size = x.size(0) if x.dim() > 1 else 1 - - # Ensure input has correct shape - if x.dim() == 1: - x = x.unsqueeze(0) # Add batch dimension - - # Check if state size matches expected input size - if x.size(1) != self.state_size: - # Handle mismatched input by either truncating or padding - if x.size(1) > self.state_size: - x = x[:, :self.state_size] # Truncate - else: - # Pad with zeros - padding = torch.zeros(batch_size, self.state_size - x.size(1), device=x.device) - x = torch.cat([x, padding], dim=1) - - # Initial feature extraction - x = self.fc1(x) - x = F.relu(self.ln1(x)) # LayerNorm works with any batch size - x = self.dropout1(x) - - # Reshape for LSTM - x_lstm = x.unsqueeze(1) if x.dim() == 2 else x - - # Process through LSTM - lstm_out, _ = self.lstm(x_lstm) - lstm_out = lstm_out.squeeze(1) if lstm_out.size(1) == 1 else lstm_out[:, -1] - - # Process through transformer for more complex patterns - transformer_input = x.unsqueeze(1) if x.dim() == 2 else x - transformer_out = self.transformer_encoder(transformer_input.transpose(0, 1)) - transformer_out = transformer_out.transpose(0, 1).mean(dim=1) - - # Combine LSTM and transformer outputs - x = lstm_out + transformer_out - - # Final layers - x = self.fc2(x) - x = F.relu(self.ln2(x)) # LayerNorm works with any batch size - x = self.dropout2(x) - x = F.relu(self.fc3(x)) - - # Dueling architecture - value = self.value_stream(x) - advantages = self.advantage_stream(x) - qvals = value + (advantages - advantages.mean(dim=1, keepdim=True)) - - return qvals + self.attention_heads = attention_heads + + def forward(self, state, x_1s=None, x_1m=None, x_1h=None, x_1d=None): + # Pass through to LSTMAttentionDQN + if x_1s is not None and x_1m is not None and x_1h is not None and x_1d is not None: + return self.network(state, x_1s, x_1m, x_1h, x_1d) + else: + return self.network(state) class PricePredictionModel(nn.Module): def __init__(self, input_size=30, hidden_size=128, output_size=5, num_layers=2): @@ -337,6 +326,9 @@ class TradingEnvironment: self.position_mode = "hedge" # For simultaneous long/short positions self.margin_mode = "cross" # Cross margin mode + # Initialize data format indicator (list or dict) + self.data_format_is_list = True + def reset(self): """Reset the environment to initial state""" self.balance = self.initial_balance @@ -357,7 +349,10 @@ class TradingEnvironment: # Keep data but reset current position if len(self.data) > self.window_size: self.current_step = self.window_size - self.current_price = self.data[self.current_step]['close'] + if self.data_format_is_list: + self.current_price = self.data[self.current_step][4] # Close price is at index 4 + else: + self.current_price = self.data[self.current_step]['close'] # Reset trade signals self.trade_signals = [] @@ -366,9 +361,17 @@ class TradingEnvironment: def add_data(self, candle): """Add a new candle to the data""" - self.data.append(candle) + # Check if candle is a list or dictionary + if isinstance(candle, list): + self.data_format_is_list = True + self.data.append(candle) + self.current_price = candle[4] # Close price is at index 4 + else: + self.data_format_is_list = False + self.data.append(candle) + self.current_price = candle['close'] + self._update_features() - self.current_price = candle['close'] def _initialize_features(self): """Initialize technical indicators and features""" @@ -376,7 +379,12 @@ class TradingEnvironment: return # Convert data to pandas DataFrame for easier calculation - df = pd.DataFrame(self.data) + if self.data_format_is_list: + # Convert list format to DataFrame + df = pd.DataFrame(self.data, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']) + else: + # Dictionary format + df = pd.DataFrame(self.data) # Basic price and volume self.features['price'] = df['close'].values @@ -464,8 +472,22 @@ class TradingEnvironment: } return next_state, 0, done, info + # Circuit breaker after consecutive losses + if self.count_consecutive_losses() >= 5: + logger.warning("Circuit breaker triggered after 5 consecutive losses") + return self.get_state(), -1, True, {'action': 'circuit_breaker_triggered'} + + # Reduce leverage in volatile markets + if self.is_volatile_market(): + self.leverage = MAX_LEVERAGE * 0.5 # Half leverage in volatile markets + else: + self.leverage = MAX_LEVERAGE + # Store current price before taking action - self.current_price = self.data[self.current_step]['close'] + if self.data_format_is_list: + self.current_price = self.data[self.current_step][4] # Close price + else: + self.current_price = self.data[self.current_step]['close'] # Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE) reward = self.calculate_reward(action) @@ -484,8 +506,13 @@ class TradingEnvironment: signal_type = 'close_short' if signal_type: + if self.data_format_is_list: + timestamp = self.data[self.current_step][0] # Timestamp + else: + timestamp = self.data[self.current_step]['timestamp'] + self.trade_signals.append({ - 'timestamp': self.data[self.current_step]['timestamp'], + 'timestamp': timestamp, 'price': self.current_price, 'type': signal_type, 'balance': self.balance, @@ -514,11 +541,15 @@ class TradingEnvironment: return next_state, reward, done, info def check_sl_tp(self): - """Check if stop loss or take profit has been hit""" + """Check if stop loss or take profit has been hit with improved trailing stop""" if self.position == 'flat': return if self.position == 'long': + # Implement trailing stop loss if in profit + if self.current_price > self.entry_price * 1.01: + self.stop_loss = max(self.stop_loss, self.current_price * 0.995) # Trail at 0.5% below current price + # Check stop loss if self.current_price <= self.stop_loss: # Stop loss hit @@ -531,45 +562,32 @@ class TradingEnvironment: # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar - self.episode_pnl += pnl_dollar - - # Update max drawdown - if self.balance > self.peak_balance: - self.peak_balance = self.balance - drawdown = (self.peak_balance - self.balance) / self.peak_balance - self.max_drawdown = max(self.max_drawdown, drawdown) # Record trade + trade_duration = self.current_step - self.entry_index self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.stop_loss, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, - 'duration': self.current_step - self.entry_index, - 'market_direction': self.get_market_direction(), + 'duration': trade_duration, + 'timestamp': self.data[self.current_step]['timestamp'], 'reason': 'stop_loss' }) - # Update win/loss count - self.loss_count += 1 + if pnl_dollar > 0: + self.win_count += 1 + else: + self.loss_count += 1 logger.info(f"STOP LOSS hit for long at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") - # Record signal for visualization - self.trade_signals.append({ - 'timestamp': self.data[self.current_step]['timestamp'], - 'price': self.stop_loss, - 'type': 'stop_loss_long', - 'balance': self.balance, - 'pnl': self.total_pnl - }) - # Reset position self.position = 'flat' + self.position_size = 0 self.entry_price = 0 self.entry_index = 0 - self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 @@ -585,47 +603,37 @@ class TradingEnvironment: # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar - self.episode_pnl += pnl_dollar - - # Update max drawdown - if self.balance > self.peak_balance: - self.peak_balance = self.balance # Record trade + trade_duration = self.current_step - self.entry_index self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.take_profit, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, - 'duration': self.current_step - self.entry_index, - 'market_direction': self.get_market_direction(), + 'duration': trade_duration, + 'timestamp': self.data[self.current_step]['timestamp'], 'reason': 'take_profit' }) - # Update win/loss count self.win_count += 1 logger.info(f"TAKE PROFIT hit for long at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") - # Record signal for visualization - self.trade_signals.append({ - 'timestamp': self.data[self.current_step]['timestamp'], - 'price': self.take_profit, - 'type': 'take_profit_long', - 'balance': self.balance, - 'pnl': self.total_pnl - }) - # Reset position self.position = 'flat' + self.position_size = 0 self.entry_price = 0 self.entry_index = 0 - self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 elif self.position == 'short': + # Implement trailing stop loss if in profit + if self.current_price < self.entry_price * 0.99: + self.stop_loss = min(self.stop_loss, self.current_price * 1.005) # Trail at 0.5% above current price + # Check stop loss if self.current_price >= self.stop_loss: # Stop loss hit @@ -638,45 +646,32 @@ class TradingEnvironment: # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar - self.episode_pnl += pnl_dollar - - # Update max drawdown - if self.balance > self.peak_balance: - self.peak_balance = self.balance - drawdown = (self.peak_balance - self.balance) / self.peak_balance - self.max_drawdown = max(self.max_drawdown, drawdown) # Record trade + trade_duration = self.current_step - self.entry_index self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.stop_loss, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, - 'duration': self.current_step - self.entry_index, - 'market_direction': self.get_market_direction(), + 'duration': trade_duration, + 'timestamp': self.data[self.current_step]['timestamp'], 'reason': 'stop_loss' }) - # Update win/loss count - self.loss_count += 1 + if pnl_dollar > 0: + self.win_count += 1 + else: + self.loss_count += 1 logger.info(f"STOP LOSS hit for short at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") - # Record signal for visualization - self.trade_signals.append({ - 'timestamp': self.data[self.current_step]['timestamp'], - 'price': self.stop_loss, - 'type': 'stop_loss_short', - 'balance': self.balance, - 'pnl': self.total_pnl - }) - # Reset position self.position = 'flat' + self.position_size = 0 self.entry_price = 0 self.entry_index = 0 - self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 @@ -692,43 +687,29 @@ class TradingEnvironment: # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar - self.episode_pnl += pnl_dollar - - # Update max drawdown - if self.balance > self.peak_balance: - self.peak_balance = self.balance # Record trade + trade_duration = self.current_step - self.entry_index self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.take_profit, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, - 'duration': self.current_step - self.entry_index, - 'market_direction': self.get_market_direction(), + 'duration': trade_duration, + 'timestamp': self.data[self.current_step]['timestamp'], 'reason': 'take_profit' }) - # Update win/loss count self.win_count += 1 logger.info(f"TAKE PROFIT hit for short at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") - # Record signal for visualization - self.trade_signals.append({ - 'timestamp': self.data[self.current_step]['timestamp'], - 'price': self.take_profit, - 'type': 'take_profit_short', - 'balance': self.balance, - 'pnl': self.total_pnl - }) - # Reset position self.position = 'flat' + self.position_size = 0 self.entry_price = 0 self.entry_index = 0 - self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 @@ -840,7 +821,7 @@ class TradingEnvironment: # NEW FEATURES START HERE - # 1. Price momentum features (rate of change over different periods) + # 1. Price momentum features (rate of change) if len(self.features['price']) >= 20: roc_5 = (latest_price / self.features['price'][-5] - 1.0) if self.features['price'][-5] != 0 else 0 roc_10 = (latest_price / self.features['price'][-10] - 1.0) if self.features['price'][-10] != 0 else 0 @@ -856,11 +837,21 @@ class TradingEnvironment: returns = np.diff(self.features['price'][-21:]) / self.features['price'][-21:-1] # Calculate volatility (standard deviation of returns) volatility = np.std(returns) - # Calculate normalized high-low range - high_low_range = np.mean([ - (self.data[i]['high'] - self.data[i]['low']) / self.data[i]['close'] - for i in range(max(0, len(self.data)-5), len(self.data)) - ]) if len(self.data) > 0 else 0 + + # Calculate normalized high-low range based on data format + if self.data_format_is_list: + # List format: high at index 2, low at index 3, close at index 4 + high_low_range = np.mean([ + (self.data[i][2] - self.data[i][3]) / self.data[i][4] + for i in range(max(0, self.current_step-5), min(len(self.data), self.current_step+1)) + ]) if len(self.data) > 0 else 0 + else: + # Dictionary format + high_low_range = np.mean([ + (self.data[i]['high'] - self.data[i]['low']) / self.data[i]['close'] + for i in range(max(0, self.current_step-5), min(len(self.data), self.current_step+1)) + ]) if len(self.data) > 0 else 0 + # ATR normalized by price atr_norm = self.features['atr'][-1] / latest_price if len(self.features['atr']) > 0 else 0 @@ -963,587 +954,56 @@ class TradingEnvironment: def calculate_reward(self, action): - """Calculate reward for the given action with improved penalties for losing trades""" - reward = 0 - - # Base reward for actions - if action == 0: # HOLD - reward = -0.01 # Small penalty for doing nothing - - elif action == 1: # BUY/LONG - if self.position == 'flat': - # Opening a long position - self.position = 'long' - self.entry_price = self.current_price - self.position_size = self.calculate_position_size() - self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT/100) - self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT/100) - - # Check if this is an optimal buy point (bottom) - current_idx = len(self.features['price']) - 1 - if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms: - reward += 2.0 # Bonus for buying at a bottom - else: - # Check if we're buying in a downtrend (bad) - if self.is_downtrend(): - reward -= 0.5 # Penalty for buying in downtrend - else: - reward += 0.1 # Small reward for opening a position - - logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") - - elif self.position == 'short': - # Close short and open long - pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100 - pnl_dollar = pnl_percent / 100 * self.position_size - - # Apply fees - pnl_dollar -= self.calculate_fees(self.position_size) - - # Update balance - self.balance += pnl_dollar - self.total_pnl += pnl_dollar - - # Record trade - trade_duration = len(self.features['price']) - self.entry_index - self.trades.append({ - 'type': 'short', - 'entry': self.entry_price, - 'exit': self.current_price, - 'pnl_percent': pnl_percent, - 'pnl_dollar': pnl_dollar, - 'duration': trade_duration, - 'market_direction': self.get_market_direction() - }) - - # Reward based on PnL with stronger penalties for losses - if pnl_dollar > 0: - reward += 1.0 + pnl_dollar / 10 # Positive reward for profit - self.win_count += 1 - else: - # Stronger penalty for losses, scaled by the size of the loss - loss_penalty = 1.0 + abs(pnl_dollar) / 5 - reward -= loss_penalty - self.loss_count += 1 - - # Extra penalty for closing a losing trade too quickly - if trade_duration < 5: - reward -= 0.5 # Penalty for very short losing trades - - logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") - - # Now open long - self.position = 'long' - self.entry_price = self.current_price - self.entry_index = len(self.features['price']) - 1 - self.position_size = self.calculate_position_size() - self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT/100) - self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT/100) - - # Check if this is an optimal buy point - if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms: - reward += 2.0 # Bonus for buying at a bottom - - logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") - - elif action == 2: # SELL/SHORT - if self.position == 'flat': - # Opening a short position - self.position = 'short' - self.entry_price = self.current_price - self.position_size = self.calculate_position_size() - self.stop_loss = self.entry_price * (1 + STOP_LOSS_PERCENT/100) - self.take_profit = self.entry_price * (1 - TAKE_PROFIT_PERCENT/100) - - # Check if this is an optimal sell point (top) - current_idx = len(self.features['price']) - 1 - if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops: - reward += 2.0 # Bonus for selling at a top - else: - reward += 0.1 # Small reward for opening a position - - logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") - - elif self.position == 'long': - # Close long and open short - pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100 - pnl_dollar = pnl_percent / 100 * self.position_size - - # Apply fees - pnl_dollar -= self.calculate_fees(self.position_size) - - # Update balance - self.balance += pnl_dollar - self.total_pnl += pnl_dollar - - # Record trade - self.trades.append({ - 'type': 'long', - 'entry': self.entry_price, - 'exit': self.current_price, - 'pnl_percent': pnl_percent, - 'pnl_dollar': pnl_dollar - }) - - # Reward based on PnL - if pnl_dollar > 0: - reward += 1.0 + pnl_dollar / 10 # Positive reward for profit - self.win_count += 1 - else: - reward -= 1.0 # Negative reward for loss - self.loss_count += 1 - - logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") - - # Now open short - self.position = 'short' - self.entry_price = self.current_price - self.position_size = self.calculate_position_size() - self.stop_loss = self.entry_price * (1 + STOP_LOSS_PERCENT/100) - self.take_profit = self.entry_price * (1 - TAKE_PROFIT_PERCENT/100) - - # Check if this is an optimal sell point - current_idx = len(self.features['price']) - 1 - if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops: - reward += 2.0 # Bonus for selling at a top - - logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") - - elif action == 3: # CLOSE - if self.position == 'long': - # Close long position - pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100 - pnl_dollar = pnl_percent / 100 * self.position_size - - # Apply fees - pnl_dollar -= self.calculate_fees(self.position_size) - - # Update balance - self.balance += pnl_dollar - self.total_pnl += pnl_dollar - self.episode_pnl += pnl_dollar - - # Update max drawdown - if self.balance > self.peak_balance: - self.peak_balance = self.balance - drawdown = (self.peak_balance - self.balance) / self.peak_balance - self.max_drawdown = max(self.max_drawdown, drawdown) - - # Record trade - self.trades.append({ - 'type': 'long', - 'entry': self.entry_price, - 'exit': self.current_price, - 'pnl_percent': pnl_percent, - 'pnl_dollar': pnl_dollar - }) - - # Reward based on PnL - if pnl_dollar > 0: - reward += 1.0 + pnl_dollar / 10 # Positive reward for profit - self.win_count += 1 - else: - reward -= 1.0 # Negative reward for loss - self.loss_count += 1 - - logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") - - # Reset position - self.position = 'flat' - self.entry_price = 0 - self.position_size = 0 - self.stop_loss = 0 - self.take_profit = 0 - - elif self.position == 'short': - # Close short position - pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100 - pnl_dollar = pnl_percent / 100 * self.position_size - - # Apply fees - pnl_dollar -= self.calculate_fees(self.position_size) - - # Update balance - self.balance += pnl_dollar - self.total_pnl += pnl_dollar - self.episode_pnl += pnl_dollar - - # Update max drawdown - if self.balance > self.peak_balance: - self.peak_balance = self.balance - drawdown = (self.peak_balance - self.balance) / self.peak_balance - self.max_drawdown = max(self.max_drawdown, drawdown) - - # Record trade - self.trades.append({ - 'type': 'short', - 'entry': self.entry_price, - 'exit': self.current_price, - 'pnl_percent': pnl_percent, - 'pnl_dollar': pnl_dollar - }) - - # Reward based on PnL - if pnl_dollar > 0: - reward += 1.0 + pnl_dollar / 10 # Positive reward for profit - self.win_count += 1 - else: - reward -= 1.0 # Negative reward for loss - self.loss_count += 1 - - logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") - - # Reset position - self.position = 'flat' - self.entry_price = 0 - self.position_size = 0 - self.stop_loss = 0 - self.take_profit = 0 - - # Add prediction accuracy component to reward - if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0: - # Compare the first prediction with actual price - if len(self.data) > 1: - actual_price = self.data[-1]['close'] - predicted_price = self.predicted_prices[0] - prediction_error = abs(predicted_price - actual_price) / actual_price - - # Reward accurate predictions, penalize bad ones - if prediction_error < 0.005: # Less than 0.5% error - reward += 0.5 - elif prediction_error > 0.02: # More than 2% error - reward -= 0.5 - - return reward - - def is_downtrend(self): - """Check if the market is in a downtrend""" - if len(self.features['price']) < 20: - return False - - # Use EMA to determine trend - short_ema = self.features['ema_9'][-1] - long_ema = self.features['ema_21'][-1] - - # Downtrend if short EMA is below long EMA - return short_ema < long_ema - - def is_uptrend(self): - """Check if the market is in an uptrend""" - if len(self.features['price']) < 20: - return False - - # Use EMA to determine trend - short_ema = self.features['ema_9'][-1] - long_ema = self.features['ema_21'][-1] - - # Uptrend if short EMA is above long EMA - return short_ema > long_ema - - def get_market_direction(self): - """Get the current market direction""" - if self.is_uptrend(): - return "uptrend" - elif self.is_downtrend(): - return "downtrend" - else: - return "sideways" - - def analyze_trades(self): - """Analyze completed trades to identify patterns""" - if not self.trades: - return {} - - analysis = { - 'total_trades': len(self.trades), - 'winning_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) > 0), - 'losing_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) <= 0), - 'avg_win': 0, - 'avg_loss': 0, - 'avg_duration': 0, - 'uptrend_win_rate': 0, - 'downtrend_win_rate': 0, - 'sideways_win_rate': 0 - } - - # Calculate averages - wins = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) > 0] - losses = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) <= 0] - durations = [t.get('duration', 0) for t in self.trades] - - analysis['avg_win'] = sum(wins) / len(wins) if wins else 0 - analysis['avg_loss'] = sum(losses) / len(losses) if losses else 0 - analysis['avg_duration'] = sum(durations) / len(durations) if durations else 0 - - # Calculate win rates by market direction - for direction in ['uptrend', 'downtrend', 'sideways']: - direction_trades = [t for t in self.trades if t.get('market_direction') == direction] - if direction_trades: - wins_in_direction = sum(1 for t in direction_trades if t.get('pnl_dollar', 0) > 0) - analysis[f'{direction}_win_rate'] = wins_in_direction / len(direction_trades) * 100 - - return analysis - - def initialize_price_predictor(self, device="cpu"): - """Initialize the price prediction model""" - self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5) - self.price_predictor.to(device) - self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3) - self.predicted_prices = np.array([]) - - def train_price_predictor(self): - """Train the price prediction model on recent data""" - if len(self.features['price']) < 35: - return 0.0 - - # Get price history - price_history = self.features['price'] - - # Train the model - loss = self.price_predictor.train_on_new_data( - price_history, - self.price_predictor_optimizer, - epochs=5 - ) - - return loss - - def update_price_predictions(self): - """Update price predictions""" - if len(self.features['price']) < 30 or not hasattr(self, 'price_predictor') or self.price_predictor is None: - self.predicted_prices = np.array([]) - return - - # Get price history - price_history = self.features['price'] - - try: - # Get predictions - self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5) - except Exception as e: - logger.warning(f"Error updating predictions: {e}") - self.predicted_prices = np.array([]) - - def identify_optimal_trades(self): - """Identify optimal entry and exit points based on local extrema""" - if len(self.features['price']) < 20: - return - - # Find local bottoms and tops - bottoms, tops = find_local_extrema(self.features['price'], window=5) - - # Store optimal trade points - self.optimal_bottoms = bottoms # Buy points - self.optimal_tops = tops # Sell points - - # Create optimal trade signals - self.optimal_signals = np.zeros(len(self.features['price'])) - for i in bottoms: - if 0 <= i < len(self.optimal_signals): # Ensure index is valid - self.optimal_signals[i] = 1 # Buy signal - for i in tops: - if 0 <= i < len(self.optimal_signals): # Ensure index is valid - self.optimal_signals[i] = -1 # Sell signal - - logger.info(f"Identified {len(bottoms)} optimal buy points and {len(tops)} optimal sell points") - - def calculate_position_size(self): - """Calculate position size based on current balance and risk parameters""" - # Use a fixed percentage of balance for each trade - risk_percent = 5.0 # Risk 5% of balance per trade - - # Calculate position size with leverage - position_size = self.balance * (risk_percent / 100) * MAX_LEVERAGE - - # Apply a safety factor to avoid liquidation - safety_factor = 0.8 - position_size *= safety_factor - - # Ensure minimum position size - min_position = 10.0 # Minimum position size in USD - position_size = max(position_size, min(min_position, self.balance * 0.5)) - - # Ensure position size doesn't exceed balance * leverage - max_position = self.balance * MAX_LEVERAGE - position_size = min(position_size, max_position) - - return position_size - - def calculate_fees(self, position_size): - """Calculate trading fees for a given position size""" - # Typical fee rate for crypto exchanges (0.1%) - fee_rate = 0.001 - - # Calculate fee - fee = position_size * fee_rate - - return fee - - def is_uncertain_market(self): - """Check if the market is in an uncertain/sideways state""" - if len(self.features['price']) < 20: - return True - - # Check if price is within a narrow range - recent_prices = self.features['price'][-20:] - price_range = (max(recent_prices) - min(recent_prices)) / np.mean(recent_prices) - - # Check if EMAs are close to each other - if len(self.features['ema_9']) > 0 and len(self.features['ema_21']) > 0: - short_ema = self.features['ema_9'][-1] - long_ema = self.features['ema_21'][-1] - ema_diff = abs(short_ema - long_ema) / long_ema - - # Return True if price range is small and EMAs are close - return price_range < 0.02 and ema_diff < 0.005 - - return price_range < 0.015 # Very narrow range - - def is_near_support(self): - """Check if current price is near a support level""" - if not hasattr(self, 'features') or len(self.features['price']) < 30: - return False - - # Find recent lows - prices = self.features['price'][-30:] - lows = [] - - for i in range(1, len(prices)-1): - if prices[i] < prices[i-1] and prices[i] < prices[i+1]: - lows.append(prices[i]) - - if not lows: - return False - - # Check if current price is near any of these lows - current_price = self.current_price - for low in lows: - if abs(current_price - low) / low < 0.01: # Within 1% of a recent low - return True - - return False - - def is_near_resistance(self): - """Check if current price is near a resistance level""" - if not hasattr(self, 'features') or len(self.features['price']) < 30: - return False - - # Find recent highs - prices = self.features['price'][-30:] - highs = [] - - for i in range(1, len(prices)-1): - if prices[i] > prices[i-1] and prices[i] > prices[i+1]: - highs.append(prices[i]) - - if not highs: - return False - - # Check if current price is near any of these highs - current_price = self.current_price - for high in highs: - if abs(current_price - high) / high < 0.01: # Within 1% of a recent high - return True - - return False - - def is_market_turning(self): - """Check if the market is potentially changing direction""" - if len(self.features['price']) < 20: - return False - - # Check for divergence between price and momentum indicators - if len(self.features['rsi']) > 5: - # Price making higher highs but RSI making lower highs (bearish divergence) - price_trend = self.features['price'][-1] > self.features['price'][-5] - rsi_trend = self.features['rsi'][-1] < self.features['rsi'][-5] - - if price_trend != rsi_trend: - return True - - # Check for EMA crossover - if len(self.features['ema_9']) > 1 and len(self.features['ema_21']) > 1: - short_ema_prev = self.features['ema_9'][-2] - long_ema_prev = self.features['ema_21'][-2] - short_ema_curr = self.features['ema_9'][-1] - long_ema_curr = self.features['ema_21'][-1] - - # Check if EMAs just crossed - if (short_ema_prev < long_ema_prev and short_ema_curr > long_ema_curr) or \ - (short_ema_prev > long_ema_prev and short_ema_curr < long_ema_curr): - return True - - return False - - def is_market_against_position(self, position_type): - """Check if market conditions have turned against the current position""" - if position_type == 'long': - # For long positions, check if market has turned bearish - return self.is_downtrend() and not self.is_near_support() - elif position_type == 'short': - # For short positions, check if market has turned bullish - return self.is_uptrend() and not self.is_near_resistance() - - return False - - def is_near_optimal_exit(self, position_type): - """Check if current price is near an optimal exit point for the position""" - current_idx = len(self.features['price']) - 1 - - if position_type == 'long' and hasattr(self, 'optimal_tops'): - # For long positions, optimal exit is near tops - for top_idx in self.optimal_tops: - if abs(current_idx - top_idx) < 3: # Within 3 candles of a top - return True - elif position_type == 'short' and hasattr(self, 'optimal_bottoms'): - # For short positions, optimal exit is near bottoms - for bottom_idx in self.optimal_bottoms: - if abs(current_idx - bottom_idx) < 3: # Within 3 candles of a bottom - return True - - return False - - def calculate_future_profit_potential(self, position_type, lookahead=20): """ - Calculate potential profit if position is held for a certain period - This is used for retrospective backtesting rewards + Calculate reward for taking the given action. Args: - position_type: 'long' or 'short' - lookahead: Number of candles to look ahead - + action: The action taken (0=hold, 1=buy, 2=sell, 3=close) + Returns: - Potential profit percentage + The calculated reward """ - if len(self.data) <= 1 or self.current_step >= len(self.data): - return 0 + reward = 0 # Get current price - current_price = self.current_price + if self.data_format_is_list: + current_price = self.data[self.current_step][4] # Close price + else: + current_price = self.data[self.current_step]['close'] - # Get future prices (if available in historical data) - future_prices = [] - current_idx = self.current_step + # Base reward component based on price movement + price_change_pct = 0 + if self.current_step > 0: + if self.data_format_is_list: + prev_price = self.data[self.current_step-1][4] # Previous close + else: + prev_price = self.data[self.current_step-1]['close'] + + price_change_pct = (current_price - prev_price) / prev_price - # Safely get future prices - for i in range(1, min(lookahead + 1, len(self.data) - current_idx)): - if current_idx + i < len(self.data): - future_prices.append(self.data[current_idx + i]['close']) + # Check if we have CNN patterns available + pattern_confidence = 0 + if hasattr(self, 'cnn_patterns'): + if action == 1 and 'long_confidence' in self.cnn_patterns: # Buy action + pattern_confidence = self.cnn_patterns['long_confidence'] + elif action == 2 and 'short_confidence' in self.cnn_patterns: # Sell action + pattern_confidence = self.cnn_patterns['short_confidence'] - if not future_prices: - return 0 + # Action-specific rewards + if action == 0: # HOLD + # Small positive reward for holding in the right direction of market movement + if self.position == 'long' and price_change_pct > 0: + reward += 0.1 + price_change_pct * 10 + elif self.position == 'short' and price_change_pct < 0: + reward += 0.1 + abs(price_change_pct) * 10 + else: + # Small negative reward for holding in the wrong direction + reward -= 0.1 - # Calculate potential profit - if position_type == 'long': - # For long positions, find the maximum price in the future - max_future_price = max(future_prices) - potential_profit = (max_future_price - current_price) / current_price * 100 - else: # short - # For short positions, find the minimum price in the future - min_future_price = min(future_prices) - potential_profit = (current_price - min_future_price) / current_price * 100 + # Add CNN pattern confidence to reward + reward += pattern_confidence * 10 - return potential_profit + return reward async def initialize_futures(self, exchange): """Initialize futures trading parameters""" @@ -1600,6 +1060,306 @@ class TradingEnvironment: logger.error(f"Trade execution failed: {e}") return None + def trades_in_last_n_candles(self, n=20): + """Count the number of trades in the last n candles""" + if len(self.trades) == 0: + return 0 + + if self.data_format_is_list: + # List format: timestamp at index 0 + current_time = self.data[self.current_step][0] + n_candles_ago = self.data[max(0, self.current_step - n)][0] + else: + # Dictionary format + current_time = self.data[self.current_step]['timestamp'] + n_candles_ago = self.data[max(0, self.current_step - n)]['timestamp'] + + count = 0 + for trade in reversed(self.trades): + if 'timestamp' in trade and trade['timestamp'] >= n_candles_ago and trade['timestamp'] <= current_time: + count += 1 + else: + # Older trades, we can stop counting + break + + return count + + def count_consecutive_losses(self): + """Count the number of consecutive losing trades""" + count = 0 + for trade in reversed(self.trades): + if trade.get('pnl_dollar', 0) < 0: + count += 1 + else: + break + return count + + def is_volatile_market(self): + """Determine if the current market is volatile""" + if len(self.features['price']) < 20: + return False + + recent_prices = self.features['price'][-20:] + avg_price = sum(recent_prices) / len(recent_prices) + volatility = sum([abs(p - avg_price) / avg_price for p in recent_prices]) / len(recent_prices) + + return volatility > 0.01 # 1% average deviation is considered volatile + + def is_uptrend(self): + """Determine if the market is in an uptrend""" + if len(self.features['ema_9']) < 2 or len(self.features['ema_21']) < 2: + return False + + # Short-term trend + short_trend = self.features['ema_9'][-1] > self.features['ema_9'][-2] + + # Medium-term trend + medium_trend = self.features['ema_9'][-1] > self.features['ema_21'][-1] + + return short_trend and medium_trend + + def is_downtrend(self): + """Determine if the market is in a downtrend""" + if len(self.features['ema_9']) < 2 or len(self.features['ema_21']) < 2: + return False + + # Short-term trend + short_trend = self.features['ema_9'][-1] < self.features['ema_9'][-2] + + # Medium-term trend + medium_trend = self.features['ema_9'][-1] < self.features['ema_21'][-1] + + return short_trend and medium_trend + + def calculate_position_size(self): + """Calculate position size based on risk management rules""" + # Reduce position size after losses + consecutive_losses = self.count_consecutive_losses() + risk_factor = max(0.3, 1.0 - (consecutive_losses * 0.1)) # Reduce by 10% per loss, min 30% + + # Calculate position size based on available balance and risk + max_risk_amount = self.balance * 0.02 # Risk 2% per trade + position_size = max_risk_amount / (STOP_LOSS_PERCENT / 100 * self.current_price) + + # Apply leverage + position_size = position_size * self.leverage + + # Cap at available balance + position_size = min(position_size, self.balance * self.leverage) + + return position_size * risk_factor + + # ... existing identify_optimal_trades method ... + + def update_cnn_patterns(self, candle_data=None): + """ + Update CNN patterns using multi-timeframe data. + + Args: + candle_data: Dictionary containing candle data for different timeframes + """ + if not candle_data: + return + + try: + # Check if we have the necessary timeframes + required_timeframes = ['1m', '1h', '1d'] + if not all(tf in candle_data for tf in required_timeframes): + logging.warning(f"Missing required timeframes for CNN pattern detection") + return + + # Initialize patterns if not already done + if not hasattr(self, 'cnn_patterns'): + self.cnn_patterns = {} + + # Extract features from candle data + features = {} + + # Process each timeframe + for tf in required_timeframes: + candles = candle_data[tf] + if not candles or len(candles) < 30: + continue + + # Convert to numpy arrays for easier processing + closes = np.array([c[4] for c in candles[-100:]]) + highs = np.array([c[2] for c in candles[-100:]]) + lows = np.array([c[3] for c in candles[-100:]]) + + # Simple feature extraction + # 1. Detect trends + ema20 = self._calculate_ema(closes, 20) + ema50 = self._calculate_ema(closes, 50) + + uptrend = ema20[-1] > ema50[-1] and closes[-1] > ema20[-1] + downtrend = ema20[-1] < ema50[-1] and closes[-1] < ema20[-1] + + # 2. Detect potential reversal patterns + # Find local extrema + tops, bottoms = find_local_extrema(closes, window=10) + + # Check if we're near a bottom (potential buy) + near_bottom = False + bottom_confidence = 0 + if bottoms and len(bottoms) > 0: + last_bottom = bottoms[-1] + if len(closes) - last_bottom < 5: # Recent bottom + bottom_dist = abs(closes[-1] - closes[last_bottom]) / closes[last_bottom] + if bottom_dist < 0.01: # Within 1% of the bottom + near_bottom = True + # Higher confidence if volume is increasing + bottom_confidence = 0.8 - bottom_dist * 50 # 0.8 to 0.3 range + + # Check if we're near a top (potential sell) + near_top = False + top_confidence = 0 + if tops and len(tops) > 0: + last_top = tops[-1] + if len(closes) - last_top < 5: # Recent top + top_dist = abs(closes[-1] - closes[last_top]) / closes[last_top] + if top_dist < 0.01: # Within 1% of the top + near_top = True + # Higher confidence if volume is increasing + top_confidence = 0.8 - top_dist * 50 # 0.8 to 0.3 range + + # Store features for this timeframe + features[tf] = { + 'uptrend': uptrend, + 'downtrend': downtrend, + 'near_bottom': near_bottom, + 'bottom_confidence': bottom_confidence, + 'near_top': near_top, + 'top_confidence': top_confidence + } + + # Combine features across timeframes to get overall pattern confidence + long_confidence = 0 + short_confidence = 0 + + # Weight each timeframe (higher weight for longer timeframes) + weights = {'1m': 0.2, '1h': 0.3, '1d': 0.5} + + for tf, tf_features in features.items(): + weight = weights.get(tf, 0.2) + + # Add to long confidence + if tf_features['uptrend'] or tf_features['near_bottom']: + long_confidence += weight * (0.6 if tf_features['uptrend'] else 0) + \ + weight * (tf_features['bottom_confidence'] if tf_features['near_bottom'] else 0) + + # Add to short confidence + if tf_features['downtrend'] or tf_features['near_top']: + short_confidence += weight * (0.6 if tf_features['downtrend'] else 0) + \ + weight * (tf_features['top_confidence'] if tf_features['near_top'] else 0) + + # Normalize confidence scores to [0, 1] + long_confidence = min(1.0, long_confidence) + short_confidence = min(1.0, short_confidence) + + # Update patterns + self.cnn_patterns = { + 'long_confidence': long_confidence, + 'short_confidence': short_confidence, + 'features': features + } + + logging.debug(f"Updated CNN patterns - Long: {long_confidence:.2f}, Short: {short_confidence:.2f}") + + except Exception as e: + logging.error(f"Error updating CNN patterns: {e}") + + def _calculate_ema(self, data, span): + """Calculate exponential moving average""" + alpha = 2 / (span + 1) + alpha_rev = 1 - alpha + + ema = np.zeros_like(data) + ema[0] = data[0] + + for i in range(1, len(data)): + ema[i] = alpha * data[i] + alpha_rev * ema[i-1] + + return ema + + def is_uncertain_market(self): + """Determine if the market is in an uncertain/range-bound state""" + if len(self.features['price']) < 30: + return False + + # Check if EMAs are close to each other (no clear trend) + if len(self.features['ema_9']) > 0 and len(self.features['ema_21']) > 0: + ema9 = self.features['ema_9'][-1] + ema21 = self.features['ema_21'][-1] + + # If EMAs are within 0.2% of each other, market is uncertain + if abs(ema9 - ema21) / ema21 < 0.002: + return True + + # Check if price is oscillating without clear direction + if len(self.features['price']) >= 10: + recent_prices = self.features['price'][-10:] + ups = downs = 0 + for i in range(1, len(recent_prices)): + if recent_prices[i] > recent_prices[i-1]: + ups += 1 + else: + downs += 1 + + # If there's a mix of ups and downs (neither dominates heavily) + return abs(ups - downs) < 3 + + return False + + def is_near_support(self): + """Determine if the current price is near a support level""" + if len(self.features['price']) < 30: + return False + + # Use Bollinger lower band as support + if len(self.features['bollinger_lower']) > 0 and len(self.features['price']) > 0: + current_price = self.features['price'][-1] + lower_band = self.features['bollinger_lower'][-1] + + # If price is within 0.5% of the lower band + if (current_price - lower_band) / current_price < 0.005: + return True + + # Check if we're near recent lows + if len(self.features['price']) >= 20: + current_price = self.features['price'][-1] + min_price = min(self.features['price'][-20:]) + + # If within 1% of recent lows + if (current_price - min_price) / current_price < 0.01: + return True + + return False + + def is_near_resistance(self): + """Determine if the current price is near a resistance level""" + if len(self.features['price']) < 30: + return False + + # Use Bollinger upper band as resistance + if len(self.features['bollinger_upper']) > 0 and len(self.features['price']) > 0: + current_price = self.features['price'][-1] + upper_band = self.features['bollinger_upper'][-1] + + # If price is within 0.5% of the upper band + if (upper_band - current_price) / current_price < 0.005: + return True + + # Check if we're near recent highs + if len(self.features['price']) >= 20: + current_price = self.features['price'][-1] + max_price = max(self.features['price'][-20:]) + + # If within 1% of recent highs + if (max_price - current_price) / current_price < 0.01: + return True + + return False + # Ensure GPU usage if available def get_device(): """Get the best available device (CUDA GPU or CPU)""" @@ -1626,9 +1386,9 @@ class Agent: # Set device self.device = device if device is not None else get_device() - # Initialize networks - self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device) - self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device) + # Initialize networks - use LSTMAttentionDQN instead of DQN + self.policy_net = LSTMAttentionDQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device) + self.target_net = LSTMAttentionDQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device) self.target_net.load_state_dict(self.policy_net.state_dict()) # Initialize optimizer @@ -1651,7 +1411,14 @@ class Agent: # Initialize GradScaler for mixed precision training self.scaler = torch.cuda.amp.GradScaler() if self.device.type == "cuda" else None - # Rest of the initialization code... + # Initialize candle cache for multi-timeframe data + self.candle_cache = CandleCache() + + # Store model name for logging + self.model_name = f"LSTM_Attention_DQN_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}" + + logger.info(f"Initialized agent with state_size={state_size}, action_size={action_size}, hidden_size={hidden_size}") + logger.info(f"Using device: {self.device}") def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8): """Expand the model to handle more features or increase capacity""" @@ -1662,9 +1429,9 @@ class Agent: old_state_dict = self.policy_net.state_dict() # Create new larger networks - new_policy_net = DQN(new_state_size, self.action_size, + new_policy_net = LSTMAttentionDQN(new_state_size, self.action_size, new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device) - new_target_net = DQN(new_state_size, self.action_size, + new_target_net = LSTMAttentionDQN(new_state_size, self.action_size, new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device) # Transfer weights for common layers @@ -1702,25 +1469,78 @@ class Agent: return True - def select_action(self, state, training=True): - sample = random.random() + def select_action(self, state, training=True, candle_data=None): + """ + Select an action using the policy network. - if training: - # Epsilon decay - self.epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \ - np.exp(-1. * self.steps_done / EPSILON_DECAY) - self.steps_done += 1 + Args: + state: The current state + training: Whether we're in training mode (for epsilon-greedy) + candle_data: Dictionary with '1s', '1m', '1h', '1d' candle data + + Returns: + The selected action + """ + # ... existing code ... - if sample > self.epsilon or not training: - with torch.no_grad(): - state_tensor = torch.FloatTensor(state).to(self.device) - action_values = self.policy_net(state_tensor) - return action_values.max(1)[1].item() - else: + # Add CNN processing if candle data is available + cnn_inputs = None + if candle_data and all(k in candle_data for k in ['1s', '1m', '1h', '1d']): + # Process candle data into tensors + x_1s = self.prepare_candle_tensor(candle_data['1s']) + x_1m = self.prepare_candle_tensor(candle_data['1m']) + x_1h = self.prepare_candle_tensor(candle_data['1h']) + x_1d = self.prepare_candle_tensor(candle_data['1d']) + + cnn_inputs = (x_1s, x_1m, x_1h, x_1d) + + # Use epsilon-greedy strategy during training + if training and random.random() < self.epsilon: return random.randrange(self.action_size) + + with torch.no_grad(): + state_tensor = torch.FloatTensor(state).to(self.device) + + if cnn_inputs: + q_values = self.policy_net(state_tensor, *cnn_inputs) + else: + q_values = self.policy_net(state_tensor) + + return q_values.max(1)[1].item() + + def prepare_candle_tensor(self, candles, max_candles=300): + """Convert candle data to tensors for CNN input""" + if not candles: + # Return zeros if no candles available + return torch.zeros((1, 5, max_candles), device=self.device) + + # Limit to the most recent candles + candles = candles[-max_candles:] + + # Extract OHLCV data + ohlcv = np.array([[c[1], c[2], c[3], c[4], c[5]] for c in candles], dtype=np.float32) + + # Normalize the data + if len(ohlcv) > 0: + # Simple min-max normalization per column + min_vals = ohlcv.min(axis=0, keepdims=True) + max_vals = ohlcv.max(axis=0, keepdims=True) + range_vals = max_vals - min_vals + range_vals[range_vals == 0] = 1 # Avoid division by zero + ohlcv = (ohlcv - min_vals) / range_vals + + # Pad if needed + padded = np.zeros((max_candles, 5), dtype=np.float32) + padded[-len(ohlcv):] = ohlcv + + # Convert to tensor [batch, channels, sequence] + tensor = torch.FloatTensor(padded.transpose(1, 0)).unsqueeze(0).to(self.device) + return tensor + else: + return torch.zeros((1, 5, max_candles), device=self.device) def learn(self): - """Learn from a batch of experiences""" + """Learn from a batch of experiences with GPU acceleration and CNN features""" if len(self.memory) < BATCH_SIZE: return None @@ -1735,31 +1555,32 @@ class Agent: next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device) dones = torch.FloatTensor([e.done for e in experiences]).to(self.device) - # Use mixed precision for forward/backward passes + # Use mixed precision for forward/backward passes if on GPU if self.device.type == "cuda" and self.scaler is not None: - with torch.amp.autocast('cuda'): - # Compute Q values + with torch.cuda.amp.autocast(): + # Compute current Q values current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)) - # Compute next Q values with target network + # Compute next Q values with torch.no_grad(): next_q_values = self.target_net(next_states).max(1)[0] - target_q_values = rewards + (GAMMA * next_q_values * (1 - dones)) - # Reshape target values to match current_q_values + # Compute target Q values + target_q_values = rewards + (GAMMA * next_q_values * (1 - dones)) target_q_values = target_q_values.unsqueeze(1) # Compute loss loss = F.smooth_l1_loss(current_q_values, target_q_values) - - # Backward pass with mixed precision + + # Backward pass with gradient scaling self.optimizer.zero_grad() self.scaler.scale(loss).backward() - # Gradient clipping to prevent exploding gradients + # Clip gradients self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0) + # Update weights self.scaler.step(self.optimizer) self.scaler.update() else: @@ -1800,6 +1621,11 @@ class Agent: logger.error(f"Error during learning: {e}") logger.error(f"Traceback: {traceback.format_exc()}") return None + + def update_epsilon(self, episode): + """Update epsilon value based on episode number""" + self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) + return self.epsilon def update_target_network(self): self.target_net.load_state_dict(self.policy_net.state_dict()) @@ -1925,40 +1751,62 @@ class Agent: logger.error(traceback.format_exc()) raise - def add_chart_to_tensorboard(self, env, global_step): - """Add trading chart to TensorBoard""" + def add_chart_to_tensorboard(self, env, step): + """Add candlestick chart to tensorboard""" try: - if len(env.data) < 10: + # Initialize writer if it doesn't exist + if not hasattr(self, 'writer') or self.writer is None: + self.writer = SummaryWriter(log_dir=f'runs/{self.model_name}') + + # Check if we have enough data + if not hasattr(env, 'data') or len(env.data) < 20: + logger.warning("Not enough data for chart in TensorBoard") return - # Create chart image - chart_img = create_candlestick_figure( - env.data, - env.trade_signals, - window_size=100, - title=f"Trading Chart - Step {global_step}" - ) + # Get position value (convert from string if needed) + position_value = 0 # Default to flat + if hasattr(env, 'position'): + if isinstance(env.position, str): + # Map string positions to numeric values + position_map = {'flat': 0, 'long': 1, 'short': -1} + position_value = position_map.get(env.position.lower(), 0) + else: + position_value = float(env.position) + + # Log metrics to tensorboard + self.writer.add_scalar('Trading/Position', position_value, step) - if chart_img is not None: - # Convert PIL image to numpy array for TensorBoard - chart_array = np.array(chart_img) - # TensorBoard expects [C, H, W] format - chart_array = np.transpose(chart_array, (2, 0, 1)) - self.writer.add_image('Trading Chart', chart_array, global_step) + if hasattr(env, 'balance'): + self.writer.add_scalar('Trading/Balance', env.balance, step) - # Add position information as text - entry_price = env.entry_price if env.entry_price else 0.00 - position_info = f""" - **Current Position**: {env.position.upper()} - **Entry Price**: ${entry_price:.2f} - **Current Price**: ${env.data[-1]['close']:.2f} - **Position Size**: ${env.position_size:.2f} - **Unrealized PnL**: ${env.total_pnl:.2f} - """ - self.writer.add_text('Position', position_info, global_step) + if hasattr(env, 'total_pnl'): + self.writer.add_scalar('Trading/Total_PnL', env.total_pnl, step) + + if hasattr(env, 'max_drawdown'): + self.writer.add_scalar('Trading/Drawdown', env.max_drawdown, step) + + if hasattr(env, 'win_rate'): + self.writer.add_scalar('Trading/Win_Rate', env.win_rate, step) + + if hasattr(env, 'trade_count'): + self.writer.add_scalar('Trading/Trade_Count', env.trade_count, step) + + # Get recent trades if available + recent_trades = [] + if hasattr(env, 'trades') and env.trades: + recent_trades = env.trades[-10:] # Last 10 trades + + # Create candlestick figure with the last 100 candles and recent trades + fig = create_candlestick_figure(env.data[-100:], recent_trades) + + # Add figure to tensorboard + self.writer.add_figure('Trading/Chart', fig, step) + + # Close figure to free resources + plt.close(fig) + except Exception as e: - logger.error(f"Error adding chart to TensorBoard: {str(e)}") - # Continue without visualization rather than crashing + logger.warning(f"Error adding chart to tensorboard: {e}") async def get_live_prices(symbol="ETH/USDT", timeframe="1m"): """Get live price data using websockets""" @@ -1999,7 +1847,37 @@ async def get_live_prices(symbol="ETH/USDT", timeframe="1m"): break async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000): - """Train the agent using historical and live data with GPU acceleration""" + """ + Train the agent using reinforcement learning with multi-timeframe data and CNN pattern recognition. + + Args: + agent: The agent to train + env: The trading environment + num_episodes: Number of episodes to train for + max_steps_per_episode: Maximum steps per episode + + Returns: + Training statistics + """ + # Initialize TensorBoard writer if not already done + try: + if agent.writer is None: + from torch.utils.tensorboard import SummaryWriter + agent.writer = SummaryWriter(log_dir=f'runs/{agent.model_name}') + + writer = agent.writer + except Exception as e: + logging.error(f"Failed to initialize TensorBoard: {e}") + writer = None + + # Initialize exchange for data fetching + try: + exchange = await initialize_exchange() + logging.info("Initialized exchange for data fetching") + except Exception as e: + logging.error(f"Failed to initialize exchange: {e}") + exchange = None + # Initialize statistics tracking stats = { 'episode_rewards': [], @@ -2009,225 +1887,309 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000) 'episode_pnls': [], 'cumulative_pnl': [], 'drawdowns': [], - 'prediction_accuracy': [], - 'trade_analysis': [] + 'trade_counts': [], + 'loss_values': [] } # Track best models best_reward = float('-inf') best_pnl = float('-inf') - # Initialize TensorBoard writer if not already initialized - if not hasattr(agent, 'writer') or agent.writer is None: - agent.writer = SummaryWriter('runs/training') + # Make directory for models if it doesn't exist + os.makedirs('models', exist_ok=True) - # Training loop + # Start training loop for episode in range(num_episodes): try: # Reset environment state = env.reset() episode_reward = 0 - prediction_loss = 0 + episode_losses = [] + + # Fetch multi-timeframe data at the start of the episode + candle_data = None + if exchange: + try: + candle_data = await fetch_multi_timeframe_data( + exchange, "ETH/USDT", agent.candle_cache + ) + # Update CNN patterns + env.update_cnn_patterns(candle_data) + logging.info(f"Fetched multi-timeframe data for episode {episode+1}") + except Exception as e: + logging.error(f"Failed to fetch candle data: {e}") # Episode loop for step in range(max_steps_per_episode): - # Select action - action = agent.select_action(state) - - # Take action try: + # Select action using CNN-enhanced policy + action = agent.select_action(state, training=True, candle_data=candle_data) + + # Take action next_state, reward, done, info = env.step(action) + + # Store transition in replay memory + agent.memory.push(state, action, reward, next_state, done) + + # Move to the next state + state = next_state + + # Update episode reward + episode_reward += reward + + # Learn from experience + if len(agent.memory) > BATCH_SIZE: + try: + loss = agent.learn() + if loss is not None: + episode_losses.append(loss) + # Log loss to TensorBoard + global_step = episode * max_steps_per_episode + step + if writer: + writer.add_scalar('Loss/step', loss, global_step) + except Exception as e: + logging.error(f"Error during learning: {e}") + + # Update target network periodically + if step % TARGET_UPDATE == 0: + agent.update_target_network() + + # Update price predictions and CNN patterns periodically + if step % 50 == 0: + try: + # Update internal environment predictions + if hasattr(env, 'update_price_predictions'): + env.update_price_predictions() + if hasattr(env, 'identify_optimal_trades'): + env.identify_optimal_trades() + + # Fetch fresh candle data periodically + if exchange: + try: + candle_data = await fetch_multi_timeframe_data( + exchange, "ETH/USDT", agent.candle_cache + ) + + # Update CNN patterns with the new candle data + env.update_cnn_patterns(candle_data) + logging.info(f"Updated multi-timeframe data at step {step}") + except Exception as e: + logging.error(f"Failed to fetch candle data: {e}") + except Exception as e: + logging.warning(f"Error updating predictions: {e}") + + # Add chart to TensorBoard periodically + if step % 100 == 0 or (step == max_steps_per_episode - 1) or done: + try: + global_step = episode * max_steps_per_episode + step + if writer: + agent.add_chart_to_tensorboard(env, global_step) + except Exception as e: + logging.warning(f"Error adding chart to TensorBoard: {e}") + + if done: + break + except Exception as e: - logger.error(f"Error in step function: {e}") - break - - # Store transition in replay memory - agent.memory.push(state, action, reward, next_state, done) - - # Move to the next state - state = next_state - - # Update episode reward - episode_reward += reward - - # Learn from experience - if len(agent.memory) > BATCH_SIZE: - agent.learn() - - # Update price predictions periodically - if step % 50 == 0: - try: - env.update_price_predictions() - env.identify_optimal_trades() - except Exception as e: - logger.warning(f"Error updating predictions: {e}") - - # Add chart to TensorBoard periodically - if step % 50 == 0 or (step == max_steps_per_episode - 1) or done: - try: - global_step = episode * max_steps_per_episode + step - agent.add_chart_to_tensorboard(env, global_step) - except Exception as e: - logger.warning(f"Error adding chart to TensorBoard: {e}") - - # End episode if done - if done: + logging.error(f"Error in training step: {e}") break - # Update target network periodically - if episode % TARGET_UPDATE == 0: - agent.update_target_network() + # Calculate statistics from this episode + balance = env.balance + pnl = balance - env.initial_balance if hasattr(env, 'initial_balance') else 0 - # Calculate win rate - total_trades = env.win_count + env.loss_count - win_rate = (env.win_count / total_trades * 100) if total_trades > 0 else 0 - - # Train price predictor - try: - if episode % 5 == 0 and len(env.data) > 50: - prediction_loss = env.train_price_predictor() - except Exception as e: - logger.warning(f"Error training price predictor: {e}") - prediction_loss = 0 - - # Analyze trades - try: + # Get trading statistics + trade_analysis = None + if hasattr(env, 'analyze_trades'): trade_analysis = env.analyze_trades() - stats['trade_analysis'].append(trade_analysis) - except Exception as e: - logger.warning(f"Error analyzing trades: {e}") - trade_analysis = {} - stats['trade_analysis'].append({}) - # Calculate prediction accuracy - prediction_accuracy = 0.0 - try: - if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0: - if len(env.data) > 5: - actual_prices = [candle['close'] for candle in env.data[-5:]] - predicted = env.predicted_prices[:min(5, len(actual_prices))] - errors = [abs(p - a) / a for p, a in zip(predicted, actual_prices[:len(predicted)])] - prediction_accuracy = 100 * (1 - sum(errors) / len(errors)) - except Exception as e: - logger.warning(f"Error calculating prediction accuracy: {e}") + win_rate = trade_analysis['win_rate'] if trade_analysis and 'win_rate' in trade_analysis else 0 + trade_count = trade_analysis['total_trades'] if trade_analysis and 'total_trades' in trade_analysis else 0 + max_drawdown = trade_analysis['max_drawdown'] if trade_analysis and 'max_drawdown' in trade_analysis else 0 - # Log statistics + # Calculate average loss for this episode + avg_loss = sum(episode_losses) / len(episode_losses) if episode_losses else 0 + + # Log episode metrics to TensorBoard + if writer: + writer.add_scalar('Reward/episode', episode_reward, episode) + writer.add_scalar('Balance/episode', balance, episode) + writer.add_scalar('PnL/episode', pnl, episode) + writer.add_scalar('WinRate/episode', win_rate, episode) + writer.add_scalar('TradeCount/episode', trade_count, episode) + writer.add_scalar('Drawdown/episode', max_drawdown, episode) + writer.add_scalar('Loss/episode', avg_loss, episode) + writer.add_scalar('Epsilon/episode', agent.epsilon, episode) + + # Update stats dictionary stats['episode_rewards'].append(episode_reward) stats['episode_lengths'].append(step + 1) - stats['balances'].append(env.balance) + stats['balances'].append(balance) stats['win_rates'].append(win_rate) - stats['episode_pnls'].append(env.episode_pnl) - stats['cumulative_pnl'].append(env.total_pnl) - stats['drawdowns'].append(env.max_drawdown * 100) - stats['prediction_accuracy'].append(prediction_accuracy) + stats['episode_pnls'].append(pnl) + stats['drawdowns'].append(max_drawdown) + stats['trade_counts'].append(trade_count) + stats['loss_values'].append(avg_loss) - # Log detailed trade analysis - if trade_analysis: - logger.info(f"Trade Analysis: Win Rate={trade_analysis.get('uptrend_win_rate', 0):.1f}% in uptrends, " - f"{trade_analysis.get('downtrend_win_rate', 0):.1f}% in downtrends | " - f"Avg Win=${trade_analysis.get('avg_win', 0):.2f}, Avg Loss=${trade_analysis.get('avg_loss', 0):.2f}") + # Calculate and update cumulative PnL + if len(stats['episode_pnls']) > 0: + cumulative_pnl = sum(stats['episode_pnls']) + if 'cumulative_pnl' not in stats: + stats['cumulative_pnl'] = [] + stats['cumulative_pnl'].append(cumulative_pnl) + if writer: + writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode) - # Log to TensorBoard - agent.writer.add_scalar('Reward/train', episode_reward, episode) - agent.writer.add_scalar('Balance/train', env.balance, episode) - agent.writer.add_scalar('WinRate/train', win_rate, episode) - agent.writer.add_scalar('PnL/episode', env.episode_pnl, episode) - agent.writer.add_scalar('PnL/cumulative', env.total_pnl, episode) - agent.writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode) - agent.writer.add_scalar('PredictionLoss', prediction_loss, episode) - agent.writer.add_scalar('PredictionAccuracy', prediction_accuracy, episode) - - # Add final chart for this episode - try: - agent.add_chart_to_tensorboard(env, (episode + 1) * max_steps_per_episode) - except Exception as e: - logger.warning(f"Error adding final chart: {e}") - - logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, " - f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}, " - f"Episode PnL=${env.episode_pnl:.2f}, Total PnL=${env.total_pnl:.2f}, " - f"Max Drawdown={env.max_drawdown*100:.1f}%, Pred Accuracy={prediction_accuracy:.1f}%") - - # Save best model by reward + # Save model if this is the best reward or PnL if episode_reward > best_reward: best_reward = episode_reward - agent.save("models/trading_agent_best_reward.pt") + agent.save('models/trading_agent_best_reward.pt') + logging.info(f"New best reward: {best_reward:.2f}") - # Save best model by PnL - if env.episode_pnl > best_pnl: - best_pnl = env.episode_pnl - agent.save("models/trading_agent_best_pnl.pt") + if pnl > best_pnl: + best_pnl = pnl + agent.save('models/trading_agent_best_pnl.pt') + logging.info(f"New best PnL: ${best_pnl:.2f}") - # Save checkpoint + # Save checkpoint periodically if episode % 10 == 0: - agent.save(f"models/trading_agent_episode_{episode}.pt") - + agent.save(f'models/trading_agent_checkpoint_{episode}.pt') + + # Update epsilon + agent.update_epsilon(episode) + + # Log training progress + logging.info(f"Episode {episode+1}/{num_episodes} | " + + f"Reward: {episode_reward:.2f} | " + + f"Balance: ${balance:.2f} | " + + f"PnL: ${pnl:.2f} | " + + f"Win Rate: {win_rate:.2f} | " + + f"Trades: {trade_count} | " + + f"Loss: {avg_loss:.5f} | " + + f"Epsilon: {agent.epsilon:.4f}") + except Exception as e: - logger.error(f"Error in episode {episode}: {e}") + logging.error(f"Error in episode {episode}: {e}") + import traceback + logging.error(traceback.format_exc()) continue # Save final model - agent.save("models/trading_agent_final.pt") + agent.save('models/trading_agent_final.pt') - # Plot training results - plot_training_results(stats) + # Save training statistics to file + try: + import pandas as pd + stats_df = pd.DataFrame(stats) + stats_df.to_csv('training_stats.csv', index=False) + logging.info(f"Training statistics saved to training_stats.csv") + except Exception as e: + logging.error(f"Failed to save training statistics: {e}") + # Fallback to numpy save + np.save('training_stats.npy', stats) return stats def plot_training_results(stats): - """Plot detailed training results""" - plt.figure(figsize=(20, 15)) - - # Plot rewards - plt.subplot(3, 2, 1) - plt.plot(stats['episode_rewards']) - plt.title('Episode Rewards') - plt.xlabel('Episode') - plt.ylabel('Reward') - - # Plot balance - plt.subplot(3, 2, 2) - plt.plot(stats['balances']) - plt.title('Account Balance') - plt.xlabel('Episode') - plt.ylabel('Balance ($)') - - # Plot win rate - plt.subplot(3, 2, 3) - plt.plot(stats['win_rates']) - plt.title('Win Rate') - plt.xlabel('Episode') - plt.ylabel('Win Rate (%)') - - # Plot episode PnL - plt.subplot(3, 2, 4) - plt.plot(stats['episode_pnls']) - plt.title('Episode PnL') - plt.xlabel('Episode') - plt.ylabel('PnL ($)') - - # Plot cumulative PnL - plt.subplot(3, 2, 5) - plt.plot(stats['cumulative_pnl']) - plt.title('Cumulative PnL') - plt.xlabel('Episode') - plt.ylabel('Cumulative PnL ($)') - - # Plot drawdown - plt.subplot(3, 2, 6) - plt.plot(stats['drawdowns']) - plt.title('Maximum Drawdown') - plt.xlabel('Episode') - plt.ylabel('Drawdown (%)') - - plt.tight_layout() - plt.savefig('training_results.png') - - # Save statistics to CSV - df = pd.DataFrame(stats) - df.to_csv('training_stats.csv', index=False) - - logger.info("Training statistics saved to training_stats.csv and training_results.png") + """Plot training results and save to file""" + try: + # Check if we have data to plot + if not stats or len(stats.get('episode_rewards', [])) == 0: + logger.warning("No training data to plot") + return + + # Create a DataFrame with consistent lengths + max_len = max(len(stats.get(key, [])) for key in stats) + + # Ensure all arrays have the same length by padding with the last value or zeros + processed_stats = {} + for key, values in stats.items(): + if not values: # Skip empty lists + continue + + # Pad arrays to the same length + if len(values) < max_len: + if len(values) > 0: + # Pad with the last value + values = values + [values[-1]] * (max_len - len(values)) + else: + # Pad with zeros + values = [0] * max_len + + processed_stats[key] = values[:max_len] # Trim if longer + + # Create DataFrame + df = pd.DataFrame(processed_stats) + + # Add episode column + df['episode'] = range(1, len(df) + 1) + + # Create figure with subplots + fig, axes = plt.subplots(3, 2, figsize=(15, 15)) + + # Plot episode rewards + if 'episode_rewards' in df.columns: + axes[0, 0].plot(df['episode'], df['episode_rewards']) + axes[0, 0].set_title('Episode Rewards') + axes[0, 0].set_xlabel('Episode') + axes[0, 0].set_ylabel('Reward') + axes[0, 0].grid(True) + + # Plot account balance + if 'balances' in df.columns: + axes[0, 1].plot(df['episode'], df['balances']) + axes[0, 1].set_title('Account Balance') + axes[0, 1].set_xlabel('Episode') + axes[0, 1].set_ylabel('Balance ($)') + axes[0, 1].grid(True) + + # Plot win rate + if 'win_rates' in df.columns: + axes[1, 0].plot(df['episode'], df['win_rates']) + axes[1, 0].set_title('Win Rate') + axes[1, 0].set_xlabel('Episode') + axes[1, 0].set_ylabel('Win Rate') + axes[1, 0].set_ylim([0, 1]) + axes[1, 0].grid(True) + + # Plot episode PnL + if 'episode_pnls' in df.columns: + axes[1, 1].plot(df['episode'], df['episode_pnls']) + axes[1, 1].set_title('Episode PnL') + axes[1, 1].set_xlabel('Episode') + axes[1, 1].set_ylabel('PnL ($)') + axes[1, 1].grid(True) + + # Plot cumulative PnL + if 'cumulative_pnl' in df.columns: + axes[2, 0].plot(df['episode'], df['cumulative_pnl']) + axes[2, 0].set_title('Cumulative PnL') + axes[2, 0].set_xlabel('Episode') + axes[2, 0].set_ylabel('Cumulative PnL ($)') + axes[2, 0].grid(True) + + # Plot maximum drawdown + if 'drawdowns' in df.columns: + axes[2, 1].plot(df['episode'], df['drawdowns']) + axes[2, 1].set_title('Maximum Drawdown') + axes[2, 1].set_xlabel('Episode') + axes[2, 1].set_ylabel('Drawdown') + axes[2, 1].grid(True) + + # Adjust layout + plt.tight_layout() + + # Save figure + plt.savefig('training_results.png') + logger.info("Training results saved to training_results.png") + + # Save statistics to CSV + df.to_csv('training_stats.csv', index=False) + logger.info("Training statistics saved to training_stats.csv") + + except Exception as e: + logger.error(f"Error plotting training results: {e}") + logger.error(traceback.format_exc()) def evaluate_agent(agent, env, num_episodes=10): """Evaluate the agent on test data""" @@ -2390,210 +2352,136 @@ async def get_historical_data(exchange, symbol="ETH/USDT", timeframe="1m", limit return [] async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50): - """Run the trading bot in live mode with enhanced error handling""" - logger.info(f"Starting live trading for {symbol} on {timeframe} timeframe") - logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}") + """ + Run live trading using the trained agent. - # Verify agent is properly initialized + Args: + agent: Trained trading agent + env: Trading environment + exchange: Exchange instance + symbol: Trading symbol + timeframe: Trading timeframe + demo: Whether to run in demo mode (no real trades) + leverage: Leverage to use + """ try: - # Ensure agent has all required attributes - if not hasattr(agent, 'hidden_size'): - agent.hidden_size = 256 # Default value - logger.warning("Agent missing hidden_size attribute, using default: 256") - - if not hasattr(agent, 'lstm_layers'): - agent.lstm_layers = 2 # Default value - logger.warning("Agent missing lstm_layers attribute, using default: 2") - - if not hasattr(agent, 'attention_heads'): - agent.attention_heads = 4 # Default value - logger.warning("Agent missing attention_heads attribute, using default: 4") - - logger.info(f"Agent configuration: state_size={agent.state_size}, action_size={agent.action_size}, hidden_size={agent.hidden_size}") - except Exception as e: - logger.error(f"Error checking agent configuration: {e}") - # Continue anyway, as these are just informational attributes - - if not demo: - # Confirm with user before starting live trading - confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ") - if confirmation != "CONFIRM": - logger.info("Live trading canceled by user") - return - - # Initialize futures trading if not in demo mode - try: + logging.info(f"Starting live trading - Demo: {demo}, Symbol: {symbol}, Timeframe: {timeframe}") + + # Initialize candle cache + if not hasattr(agent, 'candle_cache'): + agent.candle_cache = CandleCache() + + # Get latest candle data for all timeframes + candle_data = await fetch_multi_timeframe_data(exchange, symbol, agent.candle_cache) + + # Set up environment with initial data + env.reset() + # Add historical data to environment + for candle in candle_data['1m'][-200:]: # Use last 200 candles for initial state + env.add_data(candle) + + # Update CNN patterns with multi-timeframe data + env.update_cnn_patterns(candle_data) + + # Initialize futures market if not in demo mode + if not demo: await env.initialize_futures(exchange) - logger.info(f"Futures trading initialized with {leverage}x leverage") - except Exception as e: - logger.error(f"Failed to initialize futures trading: {str(e)}") - logger.info("Falling back to demo mode for safety") - demo = True - - # Initialize TensorBoard for monitoring - if not hasattr(agent, 'writer') or agent.writer is None: - from torch.utils.tensorboard import SummaryWriter - # Fix the datetime usage here - current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - agent.writer = SummaryWriter(f'runs/live_{symbol.replace("/", "_")}_{current_time}') - - # Track performance metrics - trades_count = 0 - winning_trades = 0 - total_profit = 0 - max_drawdown = 0 - peak_balance = env.balance - step_counter = 0 - prev_position = 'flat' - - # Create directory for trade logs - os.makedirs('trade_logs', exist_ok=True) - # Fix the datetime usage here - current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - trade_log_path = f'trade_logs/trades_{current_time}.csv' - with open(trade_log_path, 'w') as f: - f.write("timestamp,action,price,position_size,balance,pnl\n") - - logger.info("Entering live trading loop...") - - try: + # Set leverage + try: + await exchange.futures.set_leverage(leverage, symbol) + logging.info(f"Set leverage to {leverage}x for {symbol}") + except Exception as e: + logging.error(f"Error setting leverage: {e}") + + step = 0 while True: try: - # Fetch latest candle data - candle = await get_latest_candle(exchange, symbol) - if candle is None: - logger.warning("Failed to fetch latest candle, retrying in 5 seconds...") - await asyncio.sleep(5) - continue + # Get latest candle + latest_candle = await get_latest_candle(exchange, symbol) + if latest_candle: + # Only add if we don't already have this candle + env.add_data(latest_candle) - # Add new data to environment - env.add_data(candle) + # Every 5 minutes, update the multi-timeframe data and CNN patterns + if step % 5 == 0: + candle_data = await fetch_multi_timeframe_data(exchange, symbol, agent.candle_cache) + env.update_cnn_patterns(candle_data) + logging.info("Updated multi-timeframe data and CNN patterns") - # Get current state and select action + # Update price predictions and identify optimal trades + env.update_price_predictions() + env.identify_optimal_trades() + + # Get current state state = env.get_state() - # Verify state shape matches agent's expected input - if state.shape[0] != agent.state_size: - logger.warning(f"State size mismatch: got {state.shape[0]}, expected {agent.state_size}") - # Pad or truncate state to match expected size - if state.shape[0] < agent.state_size: - state = np.pad(state, (0, agent.state_size - state.shape[0])) - else: - state = state[:agent.state_size] + # Select action + action = agent.select_action(state, training=False, candle_data=candle_data) - action = agent.select_action(state, training=False) - - # Ensure action is valid - if action >= agent.action_size: - logger.warning(f"Invalid action {action}, clipping to {agent.action_size-1}") - action = agent.action_size - 1 - - # Log action - action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE" - logger.info(f"Step {step_counter}: Action selected: {action_name}, Price: ${env.data[-1]['close']:.2f}") - - # Execute action if not demo: - # Execute real trade on exchange - current_price = env.data[-1]['close'] - trade_result = await env.execute_real_trade(exchange, action, current_price) - if trade_result is None or not isinstance(trade_result, dict) or not trade_result.get('success', False): - error_msg = trade_result.get('error', 'Unknown error') if isinstance(trade_result, dict) else 'Trade execution failed' - logger.error(f"Trade execution failed: {error_msg}") - # Continue with simulated trade for tracking purposes + # Execute real trade + current_price = env.data[-1][4] if len(env.data) > 0 else None + if current_price: + await env.execute_real_trade(exchange, action, current_price) - # Update environment with action (simulated in demo mode) - try: - next_state, reward, done, info = env.step(action) - except ValueError as e: - # Handle case where step returns 3 values instead of 4 - if "not enough values to unpack" in str(e): - logger.warning("Step function returned 3 values instead of 4, creating info dict") - next_state, reward, done = env.step(action) - info = { - 'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close', - 'price': env.current_price, - 'balance': env.balance, - 'position': env.position, - 'pnl': env.total_pnl - } - else: - raise + # Step environment + next_state, reward, done, info = env.step(action) - # Log trade if position changed - if env.position != prev_position: - trades_count += 1 - if env.last_trade_profit > 0: - winning_trades += 1 - total_profit += env.last_trade_profit - - # Log trade details - with open(trade_log_path, 'a') as f: - f.write(f"{datetime.datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n") - - logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}") + # Log + balance = env.balance + position = env.position + position_type = env.position_type if position else "None" + entry_price = env.entry_price if position else 0 + current_price = env.data[-1][4] if len(env.data) > 0 else 0 - # Update performance metrics - if env.balance > peak_balance: - peak_balance = env.balance - current_drawdown = (peak_balance - env.balance) / peak_balance if peak_balance > 0 else 0 - if current_drawdown > max_drawdown: - max_drawdown = current_drawdown + # Calculate PnL if position is open + pnl = 0 + if position and entry_price > 0 and current_price > 0: + if position_type == 'long': + pnl = (current_price - entry_price) / entry_price * 100 * leverage + else: # short + pnl = (entry_price - current_price) / entry_price * 100 * leverage - # Update TensorBoard metrics - step_counter += 1 - agent.writer.add_scalar('Live/Balance', env.balance, step_counter) - agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter) - agent.writer.add_scalar('Live/Drawdown', current_drawdown * 100, step_counter) + # Log status + actions = ["HOLD", "BUY", "SELL"] + logging.info(f"Step {step}: Action={actions[action]}, " + f"Balance=${balance:.2f}, Position={position_type}, " + f"Entry=${entry_price:.2f}, Current=${current_price:.2f}, " + f"PnL={pnl:.2f}%") - # Update chart visualization - if step_counter % 5 == 0 or env.position != prev_position: - agent.add_chart_to_tensorboard(env, step_counter) - - # Log performance summary - if trades_count > 0: - win_rate = (winning_trades / trades_count) * 100 - agent.writer.add_scalar('Live/WinRate', win_rate, step_counter) - - performance_text = f""" - **Live Trading Performance** - Balance: ${env.balance:.2f} - Total PnL: ${env.total_pnl:.2f} - Trades: {trades_count} - Win Rate: {win_rate:.1f}% - Max Drawdown: {max_drawdown*100:.1f}% - """ - agent.writer.add_text('Performance', performance_text, step_counter) + # Update TensorBoard every 30 steps + if step % 30 == 0: + try: + agent.add_chart_to_tensorboard(env, step) + except Exception as e: + logging.warning(f"Error updating TensorBoard: {e}") - prev_position = env.position + # Limit update rate to avoid Binance API limits + await asyncio.sleep(10) # 10 seconds between updates - # Wait for next candle - logger.info(f"Waiting for next candle... (Step {step_counter})") - await asyncio.sleep(10) # Check every 10 seconds + step += 1 except Exception as e: - logger.error(f"Error in live trading loop: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - logger.info("Continuing after error...") - await asyncio.sleep(30) # Wait longer after an error - - except KeyboardInterrupt: - logger.info("Live trading stopped by user") - - # Final performance report - if trades_count > 0: - win_rate = (winning_trades / trades_count) * 100 - logger.info(f"Trading session summary:") - logger.info(f"Total trades: {trades_count}") - logger.info(f"Win rate: {win_rate:.1f}%") - logger.info(f"Final balance: ${env.balance:.2f}") - logger.info(f"Total profit: ${total_profit:.2f}") - logger.info(f"Maximum drawdown: {max_drawdown*100:.1f}%") - logger.info(f"Trade log saved to: {trade_log_path}") + logging.error(f"Error in live trading loop: {e}") + await asyncio.sleep(30) # Wait longer on error + + except Exception as e: + logging.error(f"Error in live trading: {e}") + return False + + return True async def get_latest_candle(exchange, symbol): - """Get the latest candle data""" + """ + Get the latest candle for a symbol. + + Args: + exchange: Exchange instance + symbol: Trading pair symbol + + Returns: + Latest candle data or None on failure + """ try: # Use the refactored fetch method with limit=1 data = await fetch_ohlcv_data(exchange, symbol, "1m", 1) @@ -2607,45 +2495,68 @@ async def get_latest_candle(exchange, symbol): logger.error(f"Failed to fetch latest candle: {e}") return None -async def fetch_ohlcv_data(exchange, symbol, timeframe, limit): - """Fetch OHLCV data with proper handling for both async and standard CCXT""" - try: - # Check if exchange has fetchOHLCV method - if not hasattr(exchange, 'fetchOHLCV'): - logger.error("Exchange does not support OHLCV data fetching") - return [] - - # Handle different CCXT versions - if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False): - # Use async method if available - ohlcv = await exchange.fetchOHLCV(symbol, timeframe, limit=limit) - else: - # Use synchronous method with run_in_executor - loop = asyncio.get_event_loop() - ohlcv = await loop.run_in_executor( - None, - lambda: exchange.fetch_ohlcv(symbol, timeframe, limit=limit) - ) - - # Convert to list of dictionaries - data = [] - for candle in ohlcv: - timestamp, open_price, high, low, close, volume = candle - data.append({ - 'timestamp': timestamp, - 'open': open_price, - 'high': high, - 'low': low, - 'close': close, - 'volume': volume - }) - - logger.info(f"Fetched {len(data)} candles for {symbol} ({timeframe})") - return data +async def fetch_ohlcv_data(exchange, symbol="ETH/USDT", timeframe="1m", limit=1000): + """ + Fetch OHLCV data from exchange with error handling and retry logic. + + Args: + exchange: The exchange instance + symbol: Trading pair symbol + timeframe: Candle timeframe + limit: Number of candles to fetch - except Exception as e: - logger.error(f"Error fetching OHLCV data: {e}") - return [] + Returns: + List of candle data or empty list on failure + """ + max_retries = 3 + retry_delay = 5 + + for attempt in range(max_retries): + try: + logging.info(f"Fetching {limit} {timeframe} candles for {symbol} (attempt {attempt+1}/{max_retries})") + + # Check if exchange has fetch_ohlcv method + if not hasattr(exchange, 'fetch_ohlcv'): + logging.error("Exchange does not support OHLCV data fetching") + return [] + + # Fetch OHLCV data from exchange using asyncio if available, otherwise use run_in_executor + try: + if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False): + ohlcv = await exchange.fetchOHLCVAsync(symbol, timeframe, limit=limit) + else: + # Run in executor to avoid blocking + loop = asyncio.get_event_loop() + ohlcv = await loop.run_in_executor( + None, + lambda: exchange.fetch_ohlcv(symbol, timeframe, limit=limit) + ) + except Exception as e: + logging.error(f"Failed to fetch OHLCV data: {e}") + await asyncio.sleep(retry_delay) + continue + + if not ohlcv or len(ohlcv) == 0: + logging.warning(f"No data returned from exchange (attempt {attempt+1}/{max_retries})") + await asyncio.sleep(retry_delay) + continue + + # Convert to list of lists format + data = [] + for candle in ohlcv: + timestamp, open_price, high, low, close, volume = candle + data.append([timestamp, open_price, high, low, close, volume]) + + logging.info(f"Successfully fetched {len(data)} candles") + return data + + except Exception as e: + logging.error(f"Error fetching OHLCV data (attempt {attempt+1}/{max_retries}): {e}") + if attempt < max_retries - 1: + await asyncio.sleep(retry_delay) + + logging.error(f"Failed to fetch OHLCV data after {max_retries} attempts") + return [] # Add this near the top of the file, after imports def ensure_pytorch_compatibility(): @@ -2787,81 +2698,377 @@ async def main(): logger.warning(f"Could not properly close exchange connection: {e}") # Add this function near the top with other utility functions -def create_candlestick_figure(data, trade_signals, window_size=100, title=""): - """Create a candlestick chart with trade signals for TensorBoard visualization""" - if len(data) < 10: - return None - +def create_candlestick_figure(data, trades=None, title="Trading Chart"): + """Create a candlestick chart with trades marked""" try: - # Create figure - fig = plt.figure(figsize=(12, 8)) - - # Prepare data for plotting - df = pd.DataFrame(data[-window_size:]) - df['date'] = pd.to_datetime(df['timestamp'], unit='ms') - df.set_index('date', inplace=True) - - # Create subplot grid - gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1]) - price_ax = plt.subplot(gs[0]) - volume_ax = plt.subplot(gs[1], sharex=price_ax) - - # Plot candlesticks - use a simpler approach if mplfinance fails - try: - # Use a different style or approach that doesn't use 'type' parameter - mpf.plot(df, type='candle', ax=price_ax, volume=volume_ax, style='yahoo') - except Exception as e: - logger.warning(f"Error plotting with mplfinance: {e}, falling back to simple plot") - # Fallback to simple plot - price_ax.plot(df.index, df['close'], label='Price') - volume_ax.bar(df.index, df['volume'], color='blue', alpha=0.5) - - # Add trade signals - for signal in trade_signals: - try: - timestamp = pd.to_datetime(signal['timestamp'], unit='ms') - price = signal['price'] + if data is None or len(data) < 5: + logger.warning("Not enough data for candlestick chart") + return None + + # Convert data to DataFrame if it's not already + if not isinstance(data, pd.DataFrame): + df = pd.DataFrame(data) + else: + df = data.copy() + + # Ensure required columns exist + required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] + for col in required_columns: + if col not in df.columns: + logger.warning(f"Missing required column {col} for candlestick chart") + return None - if signal['type'] == 'buy': - price_ax.plot(timestamp, price, '^', color='green', markersize=10) - elif signal['type'] == 'sell': - price_ax.plot(timestamp, price, 'v', color='red', markersize=10) - elif signal['type'] == 'close_long': - price_ax.plot(timestamp, price, 'x', color='gold', markersize=10) - elif signal['type'] == 'close_short': - price_ax.plot(timestamp, price, 'x', color='black', markersize=10) - elif 'stop_loss' in signal['type']: - price_ax.plot(timestamp, price, 'X', color='purple', markersize=10) - elif 'take_profit' in signal['type']: - price_ax.plot(timestamp, price, '*', color='cyan', markersize=10) - except Exception as e: - logger.warning(f"Error plotting signal: {e}") - continue + # Format dates + if 'timestamp' in df.columns: + if isinstance(df['timestamp'].iloc[0], (int, float)): + # Convert timestamp to datetime if it's numeric + df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') + + # Set timestamp as index if it's not already + if df.index.name != 'timestamp': + df.set_index('timestamp', inplace=True) - # Add balance and PnL annotation - if trade_signals and 'balance' in trade_signals[-1] and 'pnl' in trade_signals[-1]: - balance = trade_signals[-1]['balance'] - pnl = trade_signals[-1]['pnl'] - price_ax.annotate(f"Balance: ${balance:.2f}\nPnL: ${pnl:.2f}", - xy=(0.02, 0.95), xycoords='axes fraction', - bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8)) + # Rename columns for mplfinance + df_mpf = df.copy() + if 'open' in df_mpf.columns: + df_mpf = df_mpf.rename(columns={ + 'open': 'Open', + 'high': 'High', + 'low': 'Low', + 'close': 'Close', + 'volume': 'Volume' + }) - # Set title and format - price_ax.set_title(title) - fig.tight_layout() + # Create a simple matplotlib figure instead + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), + gridspec_kw={'height_ratios': [3, 1]}) - # Convert to image - buf = io.BytesIO() - fig.savefig(buf, format='png') - buf.seek(0) - plt.close(fig) - img = Image.open(buf) - return img + # Plot candlesticks manually + for i in range(len(df_mpf)): + # Get date and prices + date = df_mpf.index[i] + open_price = df_mpf['Open'].iloc[i] + high_price = df_mpf['High'].iloc[i] + low_price = df_mpf['Low'].iloc[i] + close_price = df_mpf['Close'].iloc[i] + + # Determine color based on price movement + color = 'green' if close_price >= open_price else 'red' + + # Plot candle body + body_height = abs(close_price - open_price) + body_bottom = min(close_price, open_price) + ax1.bar(date, body_height, bottom=body_bottom, width=0.6, + color=color, alpha=0.6) + + # Plot wick + ax1.plot([date, date], [low_price, high_price], color=color, linewidth=1) + + # Plot volume + ax2.bar(df_mpf.index, df_mpf['Volume'], width=0.6, color='blue', alpha=0.5) + + # Mark trades if available + if trades and len(trades) > 0: + for trade in trades: + if 'timestamp' not in trade or 'type' not in trade or 'price' not in trade: + continue + + # Convert timestamp to datetime if needed + if isinstance(trade['timestamp'], (int, float)): + trade_time = pd.to_datetime(trade['timestamp'], unit='ms') + else: + trade_time = trade['timestamp'] + + # Determine marker color based on trade type + marker_color = 'green' if trade['type'].lower() == 'buy' else 'red' + + # Add marker at trade price + ax1.scatter(trade_time, trade['price'], color=marker_color, + marker='^' if trade['type'].lower() == 'buy' else 'v', + s=100, zorder=5) + + # Set title and labels + ax1.set_title(title) + ax1.set_ylabel('Price') + ax2.set_ylabel('Volume') + ax1.grid(True) + ax2.grid(True) + + # Format x-axis + plt.setp(ax1.get_xticklabels(), visible=False) + + # Adjust layout + plt.tight_layout() + + return fig except Exception as e: - logger.error(f"Error creating chart: {str(e)}") + logger.error(f"Error creating candlestick figure: {e}") return None +class CandlePatternCNN(nn.Module): + """ + Multi-timeframe CNN for candle pattern recognition. + Extracts features from 1s, 1m, 1h, and 1d candle data. + """ + def __init__(self, input_channels=5, feature_dimension=512): + super(CandlePatternCNN, self).__init__() + + # Base convolutional network for each timeframe + self.base_conv = nn.Sequential( + nn.Conv2d(input_channels, 64, kernel_size=(1, 3), padding=(0, 1)), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Conv2d(64, 128, kernel_size=(1, 5), padding=(0, 2)), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.MaxPool2d(kernel_size=(1, 2)), + nn.Conv2d(128, 256, kernel_size=(1, 5), padding=(0, 2)), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.MaxPool2d(kernel_size=(1, 2)), + nn.Conv2d(256, 512, kernel_size=(1, 3), padding=(0, 1)), + nn.BatchNorm2d(512), + nn.ReLU() + ) + + # Feature fusion layers + self.fusion = nn.Sequential( + nn.Linear(512 * 4 * 75, 2048), # 4 timeframes, assume 300/4=75 candles after pooling + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(2048, 1024), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(1024, feature_dimension) + ) + + # Store intermediate activations + self.intermediate_features = {} + + def forward(self, x_1s, x_1m, x_1h, x_1d): + """ + Process candle data from multiple timeframes. + + Args: + x_1s: Tensor of shape [batch, channels, history_len] for 1-second candles + x_1m: Tensor of shape [batch, channels, history_len] for 1-minute candles + x_1h: Tensor of shape [batch, channels, history_len] for 1-hour candles + x_1d: Tensor of shape [batch, channels, history_len] for 1-day candles + + Returns: + Tensor of extracted features + """ + # Add a dimension for the conv2d to work properly + x_1s = x_1s.unsqueeze(2) # [batch, channels, 1, history_len] + x_1m = x_1m.unsqueeze(2) + x_1h = x_1h.unsqueeze(2) + x_1d = x_1d.unsqueeze(2) + + # Extract features from each timeframe + feat_1s = self.base_conv(x_1s) + feat_1m = self.base_conv(x_1m) + feat_1h = self.base_conv(x_1h) + feat_1d = self.base_conv(x_1d) + + # Store intermediate features + self.intermediate_features['1s'] = feat_1s + self.intermediate_features['1m'] = feat_1m + self.intermediate_features['1h'] = feat_1h + self.intermediate_features['1d'] = feat_1d + + # Flatten and concatenate features + batch_size = x_1s.size(0) + feat_1s = feat_1s.view(batch_size, -1) + feat_1m = feat_1m.view(batch_size, -1) + feat_1h = feat_1h.view(batch_size, -1) + feat_1d = feat_1d.view(batch_size, -1) + + combined_features = torch.cat([feat_1s, feat_1m, feat_1h, feat_1d], dim=1) + + # Process through fusion layers + output = self.fusion(combined_features) + + # Store final layer features + self.intermediate_features['fusion'] = output + + return output + + def get_features(self): + """Returns dictionary of intermediate features for use by the agent""" + return self.intermediate_features + +# Add candle cache system +class CandleCache: + """ + Cache system for candles of different timeframes. + Reduces API calls by storing and updating candle data. + """ + def __init__(self): + self.candles = { + '1s': [], + '1m': [], + '1h': [], + '1d': [] + } + self.last_updated = { + '1s': None, + '1m': None, + '1h': None, + '1d': None + } + + def add_candles(self, timeframe, new_candles): + """Add new candles to the cache""" + if not self.candles[timeframe]: + self.candles[timeframe] = new_candles + else: + # Find the last timestamp in our current cache + last_timestamp = self.candles[timeframe][-1][0] + + # Add only candles newer than our last cached one + for candle in new_candles: + if candle[0] > last_timestamp: + self.candles[timeframe].append(candle) + + self.last_updated[timeframe] = datetime.datetime.now() + + def get_candles(self, timeframe, limit=300): + """Get the most recent candles for a timeframe""" + if not self.candles[timeframe]: + return [] + + return self.candles[timeframe][-limit:] + + def needs_update(self, timeframe, max_age_seconds): + """Check if the cache needs to be updated""" + if not self.last_updated[timeframe]: + return True + + age = (datetime.datetime.now() - self.last_updated[timeframe]).total_seconds() + return age > max_age_seconds + +async def fetch_multi_timeframe_data(exchange, symbol, candle_cache): + """Fetch candle data for multiple timeframes, using cache when possible""" + update_intervals = { + '1s': 10, # Update every 10 seconds + '1m': 60, # Update every 1 minute + '1h': 3600, # Update every 1 hour + '1d': 86400 # Update every 1 day + } + + limits = { + '1s': 1000, + '1m': 1000, + '1h': 500, + '1d': 300 + } + + for timeframe, interval in update_intervals.items(): + if candle_cache.needs_update(timeframe, interval): + try: + logging.info(f"Fetching {timeframe} candle data for {symbol}") + candles = await fetch_ohlcv_data(exchange, symbol, timeframe, limits[timeframe]) + candle_cache.add_candles(timeframe, candles) + logging.info(f"Fetched {len(candles)} {timeframe} candles") + except Exception as e: + logging.error(f"Error fetching {timeframe} candle data: {e}") + + return { + '1s': candle_cache.get_candles('1s'), + '1m': candle_cache.get_candles('1m'), + '1h': candle_cache.get_candles('1h'), + '1d': candle_cache.get_candles('1d') + } + +# Modify the LSTMAttentionDQN class to incorporate the CNN features +class LSTMAttentionDQN(nn.Module): + def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4): + super(LSTMAttentionDQN, self).__init__() + + # CNN for pattern recognition + self.cnn = CandlePatternCNN(input_channels=5, feature_dimension=512) + + # Calculate expanded state size with CNN features + self.expanded_state_size = state_size + 512 # Original state + CNN features + + # LSTM layers + self.lstm = nn.LSTM( + input_size=self.expanded_state_size, + hidden_size=hidden_size, + num_layers=lstm_layers, + batch_first=True + ) + + # Attention mechanism + self.attention = nn.MultiheadAttention( + embed_dim=hidden_size, + num_heads=attention_heads + ) + + # Output layers + self.advantage_stream = nn.Sequential( + nn.Linear(hidden_size, hidden_size // 2), + nn.ReLU(), + nn.Linear(hidden_size // 2, action_size) + ) + + self.value_stream = nn.Sequential( + nn.Linear(hidden_size, hidden_size // 2), + nn.ReLU(), + nn.Linear(hidden_size // 2, 1) + ) + + def forward(self, state, x_1s=None, x_1m=None, x_1h=None, x_1d=None): + # Handle different input shapes + if len(state.shape) == 1: + # Add batch dimension if missing + state = state.unsqueeze(0) + + if len(state.shape) == 2: + # Add sequence dimension if missing + state = state.unsqueeze(1) + + batch_size = state.size(0) + seq_len = state.size(1) + + # If CNN inputs are provided, process them and concatenate with state + if x_1s is not None and x_1m is not None and x_1h is not None and x_1d is not None: + cnn_features = self.cnn(x_1s, x_1m, x_1h, x_1d) + + # Expand CNN features to match sequence length of state + cnn_features = cnn_features.unsqueeze(1).expand(-1, seq_len, -1) + + # Concatenate state with CNN features + state = torch.cat([state, cnn_features], dim=2) + else: + # If CNN inputs not provided, pad with zeros + padding = torch.zeros(batch_size, seq_len, 512, device=state.device) + state = torch.cat([state, padding], dim=2) + + # Process through LSTM + lstm_out, _ = self.lstm(state) + + # Apply attention + # Reshape for attention: [seq_len, batch_size, hidden_size] + lstm_out_permuted = lstm_out.permute(1, 0, 2) + attn_output, _ = self.attention(lstm_out_permuted, lstm_out_permuted, lstm_out_permuted) + + # Reshape back: [batch_size, seq_len, hidden_size] + attn_output = attn_output.permute(1, 0, 2) + + # Use the output of the last timestep + features = attn_output[:, -1, :] + + # Dueling architecture + advantage = self.advantage_stream(features) + value = self.value_stream(features) + + # Combine value and advantage streams + q_values = value + advantage - advantage.mean(dim=1, keepdim=True) + + return q_values + if __name__ == "__main__": try: asyncio.run(main())