diff --git a/main.py b/main.py index 74d672a..b5c40fc 100644 --- a/main.py +++ b/main.py @@ -30,14 +30,6 @@ import matplotlib.pyplot as mpf import matplotlib.gridspec as gridspec import datetime from datetime import datetime as dt -from collections import defaultdict -from gym.spaces import Discrete, Box -import csv -import gc -import shutil -import math -import platform -import ctypes # Configure logging logging.basicConfig( @@ -72,73 +64,25 @@ Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state' # Add this function near the top of the file, after the imports but before any classes def find_local_extrema(prices, window=5): - """ - Find local extrema (tops and bottoms) in price series. - - Args: - prices: Array of price values - window: Window size for finding extrema - - Returns: - Tuple of (tops, bottoms) indices - """ - tops = [] + """Find local minima (bottoms) and maxima (tops) in price data""" bottoms = [] + tops = [] if len(prices) < window * 2 + 1: - return tops, bottoms + return bottoms, tops - try: - # Use peak detection algorithms from scipy if available - from scipy.signal import find_peaks + for i in range(window, len(prices) - window): + # Check if this is a local minimum (bottom) + if all(prices[i] <= prices[i-j] for j in range(1, window+1)) and \ + all(prices[i] <= prices[i+j] for j in range(1, window+1)): + bottoms.append(i) - # Find peaks (tops) - peaks, _ = find_peaks(prices, distance=window) - tops = list(peaks) - - # Find valleys (bottoms) by inverting the prices - valleys, _ = find_peaks(-prices, distance=window) - bottoms = list(valleys) - - # Optional: Filter extrema for significance - if len(tops) > 0 and len(bottoms) > 0: - # Calculate average price move - avg_move = np.mean(np.abs(np.diff(prices))) - - # Filter tops and bottoms for significant moves - filtered_tops = [] - for top in tops: - # Check if this top is significantly higher than surrounding points - if top > window and top < len(prices) - window: - surrounding_min = min(prices[top-window:top+window]) - if prices[top] - surrounding_min > avg_move * 1.5: # 1.5x average move - filtered_tops.append(top) - - filtered_bottoms = [] - for bottom in bottoms: - # Check if this bottom is significantly lower than surrounding points - if bottom > window and bottom < len(prices) - window: - surrounding_max = max(prices[bottom-window:bottom+window]) - if surrounding_max - prices[bottom] > avg_move * 1.5: # 1.5x average move - filtered_bottoms.append(bottom) - - tops = filtered_tops - bottoms = filtered_bottoms + # Check if this is a local maximum (top) + if all(prices[i] >= prices[i-j] for j in range(1, window+1)) and \ + all(prices[i] >= prices[i+j] for j in range(1, window+1)): + tops.append(i) - except ImportError: - # Fallback to manual detection if scipy is not available - for i in range(window, len(prices) - window): - # Check if this point is a local maximum - if all(prices[i] >= prices[i - j] for j in range(1, window + 1)) and \ - all(prices[i] >= prices[i + j] for j in range(1, window + 1)): - tops.append(i) - - # Check if this point is a local minimum - if all(prices[i] <= prices[i - j] for j in range(1, window + 1)) and \ - all(prices[i] <= prices[i + j] for j in range(1, window + 1)): - bottoms.append(i) - - return tops, bottoms + return bottoms, tops class ReplayMemory: def __init__(self, capacity): @@ -154,21 +98,88 @@ class ReplayMemory: return len(self.memory) class DQN(nn.Module): - """Deep Q-Network with enhanced architecture""" - def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4): super(DQN, self).__init__() - self.network = LSTMAttentionDQN(state_size, action_size, hidden_size, lstm_layers, attention_heads) + + self.state_size = state_size self.hidden_size = hidden_size self.lstm_layers = lstm_layers - self.attention_heads = attention_heads - - def forward(self, state, x_1s=None, x_1m=None, x_1h=None, x_1d=None): - # Pass through to LSTMAttentionDQN - if x_1m is not None and x_1h is not None and x_1d is not None: - return self.network(state, x_1s, x_1m, x_1h, x_1d) - else: - return self.network(state) + + # Initial feature extraction + self.fc1 = nn.Linear(state_size, hidden_size) + # Use LayerNorm instead of BatchNorm for more stability with varying batch sizes + self.ln1 = nn.LayerNorm(hidden_size) + self.dropout1 = nn.Dropout(0.2) + + # LSTM layer for sequential data + self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=lstm_layers, batch_first=True, dropout=0.2) + + # Attention mechanism + self.attention = nn.MultiheadAttention(hidden_size, attention_heads) + + # Output layers with increased capacity + self.fc2 = nn.Linear(hidden_size, hidden_size) + self.ln2 = nn.LayerNorm(hidden_size) # LayerNorm instead of BatchNorm + self.dropout2 = nn.Dropout(0.2) + self.fc3 = nn.Linear(hidden_size, hidden_size // 2) + + # Dueling DQN architecture + self.value_stream = nn.Linear(hidden_size // 2, 1) + self.advantage_stream = nn.Linear(hidden_size // 2, action_size) + + # Transformer encoder for more complex pattern recognition + encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=attention_heads, dropout=0.1) + self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2) + + def forward(self, x): + batch_size = x.size(0) if x.dim() > 1 else 1 + + # Ensure input has correct shape + if x.dim() == 1: + x = x.unsqueeze(0) # Add batch dimension + + # Check if state size matches expected input size + if x.size(1) != self.state_size: + # Handle mismatched input by either truncating or padding + if x.size(1) > self.state_size: + x = x[:, :self.state_size] # Truncate + else: + # Pad with zeros + padding = torch.zeros(batch_size, self.state_size - x.size(1), device=x.device) + x = torch.cat([x, padding], dim=1) + + # Initial feature extraction + x = self.fc1(x) + x = F.relu(self.ln1(x)) # LayerNorm works with any batch size + x = self.dropout1(x) + + # Reshape for LSTM + x_lstm = x.unsqueeze(1) if x.dim() == 2 else x + + # Process through LSTM + lstm_out, _ = self.lstm(x_lstm) + lstm_out = lstm_out.squeeze(1) if lstm_out.size(1) == 1 else lstm_out[:, -1] + + # Process through transformer for more complex patterns + transformer_input = x.unsqueeze(1) if x.dim() == 2 else x + transformer_out = self.transformer_encoder(transformer_input.transpose(0, 1)) + transformer_out = transformer_out.transpose(0, 1).mean(dim=1) + + # Combine LSTM and transformer outputs + x = lstm_out + transformer_out + + # Final layers + x = self.fc2(x) + x = F.relu(self.ln2(x)) # LayerNorm works with any batch size + x = self.dropout2(x) + x = F.relu(self.fc3(x)) + + # Dueling architecture + value = self.value_stream(x) + advantages = self.advantage_stream(x) + qvals = value + (advantages - advantages.mean(dim=1, keepdim=True)) + + return qvals class PricePredictionModel(nn.Module): def __init__(self, input_size=30, hidden_size=128, output_size=5, num_layers=2): @@ -267,253 +278,67 @@ class PricePredictionModel(nn.Module): return total_loss / epochs class TradingEnvironment: - def __init__(self, data=None, features=None, feature_extractors=None, initial_balance=10000, leverage=50, - window_size=100, commission=0.0004, api_key=None, api_secret=None, exchange_id='binance', - symbol='ETH/USDT', timeframe='1m', init_length=5000, max_steps=10000): + def __init__(self, initial_balance=INITIAL_BALANCE, window_size=30, demo=True): """Initialize the trading environment""" - self.api_key = api_key - self.api_secret = api_secret - self.exchange_id = exchange_id - self.symbol = symbol - self.timeframe = timeframe - self.init_length = init_length - - # TODO: For 1s/ticks timeframes, implement WebSocket API integration for real-time data - - try: - # Initialize exchange if API credentials are provided - if api_key and api_secret: - self.exchange = initialize_exchange(exchange_id, api_key, api_secret) - logger.info(f"Exchange initialized: {exchange_id}") - # Fetch historical data - self.data = fetch_candles(self.exchange, self.symbol, self.timeframe, limit=self.init_length) - if not self.data: - raise ValueError(f"No data fetched for {self.symbol} on {self.exchange_id}") - self.data_format_is_list = isinstance(self.data[0], list) - logger.info(f"Loaded {len(self.data)} candles from exchange") - elif data is not None: # Use provided data - self.data = data - self.data_format_is_list = isinstance(self.data[0], list) - logger.info(f"Using provided data with {len(self.data)} candles") - else: - # Initialize with empty data, we'll load it later with fetch_initial_data - logger.warning("No data provided, initializing with empty data") - self.data = [] - self.data_format_is_list = True - except Exception as e: - logger.error(f"Error initializing environment: {e}") - raise - - # Initialize features and feature extractors - if features is not None: - self.features = features - # Create a dictionary of features - self.features_dict = {f"feature_{i}": feature for i, feature in enumerate(features)} - else: - # Initialize features as a dictionary, not a list - self.features = { - 'price': [], - 'volume': [], - 'rsi': [], - 'macd': [], - 'macd_signal': [], - 'macd_hist': [], - 'bollinger_upper': [], - 'bollinger_mid': [], - 'bollinger_lower': [], - 'stoch_k': [], - 'stoch_d': [], - 'ema_9': [], - 'ema_21': [], - 'atr': [] - } - self.features_dict = {} - - if feature_extractors is None: - feature_extractors = [] - self.feature_extractors = feature_extractors - - # Environment parameters self.initial_balance = initial_balance self.balance = initial_balance - self.leverage = leverage + self.window_size = window_size + self.demo = demo + self.data = [] self.position = 'flat' # 'flat', 'long', or 'short' self.position_size = 0 self.entry_price = 0 self.entry_index = 0 self.stop_loss = 0 self.take_profit = 0 - self.commission = commission - self.total_pnl = 0 - self.total_fees = 0.0 # Track total fees paid self.trades = [] - self.trade_signals = [] - self.current_step = 0 - self.window_size = window_size - self.max_steps = max_steps - self.peak_balance = initial_balance - self.max_drawdown = 0 - self.current_price = 0 self.win_count = 0 self.loss_count = 0 - self.min_position_size = 100 # Minimum position size in USD + self.total_pnl = 0.0 + self.episode_pnl = 0.0 + self.peak_balance = initial_balance + self.max_drawdown = 0.0 + self.current_step = 0 + self.current_price = 0 - # Track candle patterns and reversal points - self.patterns = {} - self.reversal_points = [] + # For tracking signals for visualization + self.trade_signals = [] - # Define observation and action spaces - num_features = len(self.features) if hasattr(self, 'features') and self.features else 0 - state_dim = window_size * 5 + 5 + num_features # OHLCV + position info + features + # Initialize features + self.features = { + 'price': [], + 'volume': [], + 'rsi': [], + 'macd': [], + 'macd_signal': [], + 'macd_hist': [], + 'bollinger_upper': [], + 'bollinger_mid': [], + 'bollinger_lower': [], + 'stoch_k': [], + 'stoch_d': [], + 'ema_9': [], + 'ema_21': [], + 'atr': [] + } - self.action_space = Discrete(4) # 0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE - self.observation_space = Box(low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32) + # Initialize price predictor + self.price_predictor = None + self.predicted_prices = np.array([]) - # Check if we have enough data - if len(self.data) < self.window_size: - logger.warning(f"Data length {len(self.data)} is less than window size {self.window_size}") - - def calculate_reward(self, action): - """Calculate reward based on the action taken""" - reward = 0 + # Initialize optimal trade tracking + self.optimal_bottoms = [] + self.optimal_tops = [] + self.optimal_signals = np.array([]) - # Base reward structure - if self.position == 'flat': - if action == 0: # HOLD when flat - reward = 0.01 # Small reward for holding when no position - elif action == 1: # BUY/LONG - # Check for buy signal in CNN patterns - if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns: - buy_confidence = self.cnn_patterns['long_confidence'] - # Scale by confidence - reward = 0.1 * buy_confidence * 10 - else: - reward = 0.1 # Default reward for taking a position - - # Apply fee penalty - if self.position_size > 0: - fee = (self.position_size / 1900) * 1 - fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 - reward -= fee_penalty - elif action == 2: # SELL/SHORT - # Check for sell signal in CNN patterns - if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns: - sell_confidence = self.cnn_patterns['short_confidence'] - # Scale by confidence - reward = 0.1 * sell_confidence * 10 - else: - reward = 0.1 # Default reward for taking a position - - # Apply fee penalty - if self.position_size > 0: - fee = (self.position_size / 1900) * 1 - fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 - reward -= fee_penalty - elif action == 3: # CLOSE when no position - reward = -0.1 # Penalty for trying to close no position + # Add these new attributes + self.leverage = MAX_LEVERAGE + self.futures_symbol = "ETH_USDT" # Example futures symbol + self.position_mode = "hedge" # For simultaneous long/short positions + self.margin_mode = "cross" # Cross margin mode - elif self.position == 'long': - if action == 0: # HOLD long position - # Calculate price change since entry - price_change = (self.current_price - self.entry_price) / self.entry_price - - # Reward or penalize based on price movement - if price_change > 0: - reward = price_change * 10 # Reward for holding profitable position - else: - reward = price_change * 5 # Smaller penalty for holding losing position - - elif action == 1: # BUY when already long - reward = -0.1 # Penalty for redundant action - - elif action == 2: # SELL when long (reversal) - # Calculate PnL - pnl_percent = (self.current_price - self.entry_price) / self.entry_price - - if pnl_percent > 0: - reward = -0.5 # Penalty for closing profitable long position to go short - else: - # Check for sell signal in CNN patterns - if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns: - sell_confidence = self.cnn_patterns['short_confidence'] - reward = 0.2 * sell_confidence * 10 # Reward for correct reversal - else: - reward = 0.2 # Default reward for cutting loss - - # Apply fee penalty - if self.position_size > 0: - fee = (self.position_size / 1900) * 1 - fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 - reward -= fee_penalty - - elif action == 3: # CLOSE long position - # Calculate PnL - pnl_percent = (self.current_price - self.entry_price) / self.entry_price - - if pnl_percent > 0: - reward = pnl_percent * 15 # Higher reward for taking profit - else: - reward = pnl_percent * 5 # Smaller penalty for cutting loss - - # Apply fee penalty - if self.position_size > 0: - fee = (self.position_size / 1900) * 1 - fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 - reward -= fee_penalty - - elif self.position == 'short': - if action == 0: # HOLD short position - # Calculate price change since entry - price_change = (self.entry_price - self.current_price) / self.entry_price - - # Reward or penalize based on price movement - if price_change > 0: - reward = price_change * 10 # Reward for holding profitable position - else: - reward = price_change * 5 # Smaller penalty for holding losing position - - elif action == 1: # BUY when short (reversal) - # Calculate PnL - pnl_percent = (self.entry_price - self.current_price) / self.entry_price - - if pnl_percent > 0: - reward = -0.5 # Penalty for closing profitable short position to go long - else: - # Check for buy signal in CNN patterns - if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns: - buy_confidence = self.cnn_patterns['long_confidence'] - reward = 0.2 * buy_confidence * 10 # Reward for correct reversal - else: - reward = 0.2 # Default reward for cutting loss - - # Apply fee penalty - if self.position_size > 0: - fee = (self.position_size / 1900) * 1 - fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 - reward -= fee_penalty - - elif action == 2: # SELL when already short - reward = -0.1 # Penalty for redundant action - - elif action == 3: # CLOSE short position - # Calculate PnL - pnl_percent = (self.entry_price - self.current_price) / self.entry_price - - if pnl_percent > 0: - reward = pnl_percent * 15 # Higher reward for taking profit - else: - reward = pnl_percent * 5 # Smaller penalty for cutting loss - - # Apply fee penalty - if self.position_size > 0: - fee = (self.position_size / 1900) * 1 - fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 - reward -= fee_penalty - - return reward - def reset(self): - """Reset the environment to its initial state and return the initial observation""" + """Reset the environment to initial state""" self.balance = self.initial_balance self.position = 'flat' self.position_size = 0 @@ -521,31 +346,29 @@ class TradingEnvironment: self.entry_index = 0 self.stop_loss = 0 self.take_profit = 0 - self.current_step = 0 self.trades = [] - self.trade_signals = [] - self.total_pnl = 0.0 - self.total_fees = 0.0 - self.peak_balance = self.initial_balance - self.max_drawdown = 0.0 self.win_count = 0 self.loss_count = 0 + self.episode_pnl = 0.0 + self.peak_balance = self.initial_balance + self.max_drawdown = 0.0 + self.current_step = 0 + + # Keep data but reset current position + if len(self.data) > self.window_size: + self.current_step = self.window_size + self.current_price = self.data[self.current_step]['close'] + + # Reset trade signals + self.trade_signals = [] return self.get_state() def add_data(self, candle): """Add a new candle to the data""" - # Check if candle is a list or dictionary - if isinstance(candle, list): - self.data_format_is_list = True - self.data.append(candle) - self.current_price = candle[4] # Close price is at index 4 - else: - self.data_format_is_list = False - self.data.append(candle) - self.current_price = candle['close'] - + self.data.append(candle) self._update_features() + self.current_price = candle['close'] def _initialize_features(self): """Initialize technical indicators and features""" @@ -553,12 +376,7 @@ class TradingEnvironment: return # Convert data to pandas DataFrame for easier calculation - if self.data_format_is_list: - # Convert list format to DataFrame - df = pd.DataFrame(self.data, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']) - else: - # Dictionary format - df = pd.DataFrame(self.data) + df = pd.DataFrame(self.data) # Basic price and volume self.features['price'] = df['close'].values @@ -646,247 +464,12 @@ class TradingEnvironment: } return next_state, 0, done, info - # Circuit breaker after consecutive losses - if self.count_consecutive_losses() >= 5: - logger.warning("Circuit breaker triggered after 5 consecutive losses") - return self.get_state(), -1, True, {'action': 'circuit_breaker_triggered'} - - # Reduce leverage in volatile markets - if self.is_volatile_market(): - self.leverage = MAX_LEVERAGE * 0.5 # Half leverage in volatile markets - else: - self.leverage = MAX_LEVERAGE - # Store current price before taking action - if self.data_format_is_list: - self.current_price = self.data[self.current_step][4] # Close price - else: - self.current_price = self.data[self.current_step]['close'] + self.current_price = self.data[self.current_step]['close'] # Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE) reward = self.calculate_reward(action) - # Execute the action - initial_balance = self.balance # Store initial balance to calculate PnL - - # Open long position - if action == 1 and self.position != 'long': - if self.position == 'short': - # Close short position first - if self.position_size > 0: - # Calculate PnL - pnl_percent = (self.entry_price - self.current_price) / self.entry_price - pnl_dollar = pnl_percent * self.position_size * self.leverage - - # Update balance and record trade - self.balance += pnl_dollar - self.total_pnl += pnl_dollar - - # Apply trading fee (1 USD per 1.9k position) - fee = (self.position_size / 1900) * 1 - self.balance -= fee - self.total_fees += fee - - # Record trade - trade_duration = self.current_step - self.entry_index - if self.data_format_is_list: - timestamp = self.data[self.current_step][0] # Timestamp - else: - timestamp = self.data[self.current_step]['timestamp'] - - self.trades.append({ - 'type': 'short', - 'entry': self.entry_price, - 'exit': self.current_price, - 'pnl_percent': pnl_percent, - 'pnl_dollar': pnl_dollar, - 'fee': fee, - 'net_pnl': pnl_dollar - fee, - 'duration': trade_duration, - 'timestamp': timestamp, - 'reason': 'action_change' - }) - - # Update win/loss count - if pnl_dollar > 0: - self.win_count += 1 - else: - self.loss_count += 1 - - # Now open long position - self.position = 'long' - self.entry_price = self.current_price - self.entry_index = self.current_step - - # Calculate position size with risk management - self.position_size = self.calculate_position_size() - - # Apply trading fee (1 USD per 1.9k position) - fee = (self.position_size / 1900) * 1 - self.balance -= fee - self.total_fees += fee - - # Set stop loss and take profit - sl_percent = 0.02 # 2% stop loss - tp_percent = 0.04 # 4% take profit - - self.stop_loss = self.entry_price * (1 - sl_percent) - self.take_profit = self.entry_price * (1 + tp_percent) - - # Open short position - elif action == 2 and self.position != 'short': - if self.position == 'long': - # Close long position first - if self.position_size > 0: - # Calculate PnL - pnl_percent = (self.current_price - self.entry_price) / self.entry_price - pnl_dollar = pnl_percent * self.position_size * self.leverage - - # Update balance and record trade - self.balance += pnl_dollar - self.total_pnl += pnl_dollar - - # Apply trading fee (1 USD per 1.9k position) - fee = (self.position_size / 1900) * 1 - self.balance -= fee - self.total_fees += fee - - # Record trade - trade_duration = self.current_step - self.entry_index - if self.data_format_is_list: - timestamp = self.data[self.current_step][0] # Timestamp - else: - timestamp = self.data[self.current_step]['timestamp'] - - self.trades.append({ - 'type': 'long', - 'entry': self.entry_price, - 'exit': self.current_price, - 'pnl_percent': pnl_percent, - 'pnl_dollar': pnl_dollar, - 'fee': fee, - 'net_pnl': pnl_dollar - fee, - 'duration': trade_duration, - 'timestamp': timestamp, - 'reason': 'action_change' - }) - - # Update win/loss count - if pnl_dollar > 0: - self.win_count += 1 - else: - self.loss_count += 1 - - # Now open short position - self.position = 'short' - self.entry_price = self.current_price - self.entry_index = self.current_step - - # Calculate position size with risk management - self.position_size = self.calculate_position_size() - - # Apply trading fee (1 USD per 1.9k position) - fee = (self.position_size / 1900) * 1 - self.balance -= fee - self.total_fees += fee - - # Set stop loss and take profit - sl_percent = 0.02 # 2% stop loss - tp_percent = 0.04 # 4% take profit - - self.stop_loss = self.entry_price * (1 + sl_percent) - self.take_profit = self.entry_price * (1 - tp_percent) - - # Close position - elif action == 3 and self.position != 'flat': - if self.position == 'long': - # Calculate PnL - pnl_percent = (self.current_price - self.entry_price) / self.entry_price - pnl_dollar = pnl_percent * self.position_size * self.leverage - - # Update balance and record trade - self.balance += pnl_dollar - self.total_pnl += pnl_dollar - - # Apply trading fee (1 USD per 1.9k position) - fee = (self.position_size / 1900) * 1 - self.balance -= fee - self.total_fees += fee - - # Record trade - trade_duration = self.current_step - self.entry_index - if self.data_format_is_list: - timestamp = self.data[self.current_step][0] # Timestamp - else: - timestamp = self.data[self.current_step]['timestamp'] - - self.trades.append({ - 'type': 'long', - 'entry': self.entry_price, - 'exit': self.current_price, - 'pnl_percent': pnl_percent, - 'pnl_dollar': pnl_dollar, - 'fee': fee, - 'net_pnl': pnl_dollar - fee, - 'duration': trade_duration, - 'timestamp': timestamp, - 'reason': 'close_action' - }) - - # Update win/loss count - if pnl_dollar > 0: - self.win_count += 1 - else: - self.loss_count += 1 - - elif self.position == 'short': - # Calculate PnL - pnl_percent = (self.entry_price - self.current_price) / self.entry_price - pnl_dollar = pnl_percent * self.position_size * self.leverage - - # Update balance and record trade - self.balance += pnl_dollar - self.total_pnl += pnl_dollar - - # Apply trading fee (1 USD per 1.9k position) - fee = (self.position_size / 1900) * 1 - self.balance -= fee - self.total_fees += fee - - # Record trade - trade_duration = self.current_step - self.entry_index - if self.data_format_is_list: - timestamp = self.data[self.current_step][0] # Timestamp - else: - timestamp = self.data[self.current_step]['timestamp'] - - self.trades.append({ - 'type': 'short', - 'entry': self.entry_price, - 'exit': self.current_price, - 'pnl_percent': pnl_percent, - 'pnl_dollar': pnl_dollar, - 'fee': fee, - 'net_pnl': pnl_dollar - fee, - 'duration': trade_duration, - 'timestamp': timestamp, - 'reason': 'close_action' - }) - - # Update win/loss count - if pnl_dollar > 0: - self.win_count += 1 - else: - self.loss_count += 1 - - # Reset position - self.position = 'flat' - self.position_size = 0 - self.entry_price = 0 - self.entry_index = 0 - self.stop_loss = 0 - self.take_profit = 0 - # Record trade signal for visualization if action > 0: # If not HOLD signal_type = None @@ -901,13 +484,8 @@ class TradingEnvironment: signal_type = 'close_short' if signal_type: - if self.data_format_is_list: - timestamp = self.data[self.current_step][0] # Timestamp - else: - timestamp = self.data[self.current_step]['timestamp'] - self.trade_signals.append({ - 'timestamp': timestamp, + 'timestamp': self.data[self.current_step]['timestamp'], 'price': self.current_price, 'type': signal_type, 'balance': self.balance, @@ -924,36 +502,23 @@ class TradingEnvironment: # Get new state next_state = self.get_state() - # Update peak balance and drawdown - if self.balance > self.peak_balance: - self.peak_balance = self.balance - - current_drawdown = (self.peak_balance - self.balance) / self.peak_balance if self.peak_balance > 0 else 0 - self.max_drawdown = max(self.max_drawdown, current_drawdown) - # Create info dictionary info = { 'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close', 'price': self.current_price, 'balance': self.balance, 'position': self.position, - 'pnl': self.total_pnl, - 'fees': self.total_fees, - 'net_pnl': self.total_pnl - self.total_fees + 'pnl': self.total_pnl } return next_state, reward, done, info def check_sl_tp(self): - """Check if stop loss or take profit has been hit with improved trailing stop""" + """Check if stop loss or take profit has been hit""" if self.position == 'flat': return if self.position == 'long': - # Implement trailing stop loss if in profit - if self.current_price > self.entry_price * 1.01: - self.stop_loss = max(self.stop_loss, self.current_price * 0.995) # Trail at 0.5% below current price - # Check stop loss if self.current_price <= self.stop_loss: # Stop loss hit @@ -966,32 +531,45 @@ class TradingEnvironment: # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar + self.episode_pnl += pnl_dollar + + # Update max drawdown + if self.balance > self.peak_balance: + self.peak_balance = self.balance + drawdown = (self.peak_balance - self.balance) / self.peak_balance + self.max_drawdown = max(self.max_drawdown, drawdown) # Record trade - trade_duration = self.current_step - self.entry_index self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.stop_loss, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, - 'duration': trade_duration, - 'timestamp': self.data[self.current_step]['timestamp'], + 'duration': self.current_step - self.entry_index, + 'market_direction': self.get_market_direction(), 'reason': 'stop_loss' }) - if pnl_dollar > 0: - self.win_count += 1 - else: - self.loss_count += 1 + # Update win/loss count + self.loss_count += 1 logger.info(f"STOP LOSS hit for long at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") + # Record signal for visualization + self.trade_signals.append({ + 'timestamp': self.data[self.current_step]['timestamp'], + 'price': self.stop_loss, + 'type': 'stop_loss_long', + 'balance': self.balance, + 'pnl': self.total_pnl + }) + # Reset position self.position = 'flat' - self.position_size = 0 self.entry_price = 0 self.entry_index = 0 + self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 @@ -1007,37 +585,47 @@ class TradingEnvironment: # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar + self.episode_pnl += pnl_dollar + + # Update max drawdown + if self.balance > self.peak_balance: + self.peak_balance = self.balance # Record trade - trade_duration = self.current_step - self.entry_index self.trades.append({ 'type': 'long', 'entry': self.entry_price, 'exit': self.take_profit, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, - 'duration': trade_duration, - 'timestamp': self.data[self.current_step]['timestamp'], + 'duration': self.current_step - self.entry_index, + 'market_direction': self.get_market_direction(), 'reason': 'take_profit' }) + # Update win/loss count self.win_count += 1 logger.info(f"TAKE PROFIT hit for long at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") + # Record signal for visualization + self.trade_signals.append({ + 'timestamp': self.data[self.current_step]['timestamp'], + 'price': self.take_profit, + 'type': 'take_profit_long', + 'balance': self.balance, + 'pnl': self.total_pnl + }) + # Reset position self.position = 'flat' - self.position_size = 0 self.entry_price = 0 self.entry_index = 0 + self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 elif self.position == 'short': - # Implement trailing stop loss if in profit - if self.current_price < self.entry_price * 0.99: - self.stop_loss = min(self.stop_loss, self.current_price * 1.005) # Trail at 0.5% above current price - # Check stop loss if self.current_price >= self.stop_loss: # Stop loss hit @@ -1050,32 +638,45 @@ class TradingEnvironment: # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar + self.episode_pnl += pnl_dollar + + # Update max drawdown + if self.balance > self.peak_balance: + self.peak_balance = self.balance + drawdown = (self.peak_balance - self.balance) / self.peak_balance + self.max_drawdown = max(self.max_drawdown, drawdown) # Record trade - trade_duration = self.current_step - self.entry_index self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.stop_loss, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, - 'duration': trade_duration, - 'timestamp': self.data[self.current_step]['timestamp'], + 'duration': self.current_step - self.entry_index, + 'market_direction': self.get_market_direction(), 'reason': 'stop_loss' }) - if pnl_dollar > 0: - self.win_count += 1 - else: - self.loss_count += 1 + # Update win/loss count + self.loss_count += 1 logger.info(f"STOP LOSS hit for short at {self.stop_loss} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") + # Record signal for visualization + self.trade_signals.append({ + 'timestamp': self.data[self.current_step]['timestamp'], + 'price': self.stop_loss, + 'type': 'stop_loss_short', + 'balance': self.balance, + 'pnl': self.total_pnl + }) + # Reset position self.position = 'flat' - self.position_size = 0 self.entry_price = 0 self.entry_index = 0 + self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 @@ -1091,29 +692,43 @@ class TradingEnvironment: # Update balance self.balance += pnl_dollar self.total_pnl += pnl_dollar + self.episode_pnl += pnl_dollar + + # Update max drawdown + if self.balance > self.peak_balance: + self.peak_balance = self.balance # Record trade - trade_duration = self.current_step - self.entry_index self.trades.append({ 'type': 'short', 'entry': self.entry_price, 'exit': self.take_profit, 'pnl_percent': pnl_percent, 'pnl_dollar': pnl_dollar, - 'duration': trade_duration, - 'timestamp': self.data[self.current_step]['timestamp'], + 'duration': self.current_step - self.entry_index, + 'market_direction': self.get_market_direction(), 'reason': 'take_profit' }) + # Update win/loss count self.win_count += 1 logger.info(f"TAKE PROFIT hit for short at {self.take_profit} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") + # Record signal for visualization + self.trade_signals.append({ + 'timestamp': self.data[self.current_step]['timestamp'], + 'price': self.take_profit, + 'type': 'take_profit_short', + 'balance': self.balance, + 'pnl': self.total_pnl + }) + # Reset position self.position = 'flat' - self.position_size = 0 self.entry_price = 0 self.entry_index = 0 + self.position_size = 0 self.stop_loss = 0 self.take_profit = 0 @@ -1225,7 +840,7 @@ class TradingEnvironment: # NEW FEATURES START HERE - # 1. Price momentum features (rate of change) + # 1. Price momentum features (rate of change over different periods) if len(self.features['price']) >= 20: roc_5 = (latest_price / self.features['price'][-5] - 1.0) if self.features['price'][-5] != 0 else 0 roc_10 = (latest_price / self.features['price'][-10] - 1.0) if self.features['price'][-10] != 0 else 0 @@ -1241,21 +856,11 @@ class TradingEnvironment: returns = np.diff(self.features['price'][-21:]) / self.features['price'][-21:-1] # Calculate volatility (standard deviation of returns) volatility = np.std(returns) - - # Calculate normalized high-low range based on data format - if self.data_format_is_list: - # List format: high at index 2, low at index 3, close at index 4 - high_low_range = np.mean([ - (self.data[i][2] - self.data[i][3]) / self.data[i][4] - for i in range(max(0, self.current_step-5), min(len(self.data), self.current_step+1)) - ]) if len(self.data) > 0 else 0 - else: - # Dictionary format - high_low_range = np.mean([ - (self.data[i]['high'] - self.data[i]['low']) / self.data[i]['close'] - for i in range(max(0, self.current_step-5), min(len(self.data), self.current_step+1)) - ]) if len(self.data) > 0 else 0 - + # Calculate normalized high-low range + high_low_range = np.mean([ + (self.data[i]['high'] - self.data[i]['low']) / self.data[i]['close'] + for i in range(max(0, len(self.data)-5), len(self.data)) + ]) if len(self.data) > 0 else 0 # ATR normalized by price atr_norm = self.features['atr'][-1] / latest_price if len(self.features['atr']) > 0 else 0 @@ -1358,84 +963,588 @@ class TradingEnvironment: def calculate_reward(self, action): - """ - Calculate reward for taking the given action. - - Args: - action: The action taken (0=hold, 1=buy, 2=sell, 3=close) - - Returns: - The calculated reward - """ + """Calculate reward for the given action with improved penalties for losing trades""" reward = 0 - # Get current price - if self.data_format_is_list: - current_price = self.data[self.current_step][4] # Close price - else: - current_price = self.data[self.current_step]['close'] - - # Base reward component based on price movement - price_change_pct = 0 - if self.current_step > 0: - if self.data_format_is_list: - prev_price = self.data[self.current_step-1][4] # Previous close - else: - prev_price = self.data[self.current_step-1]['close'] - - price_change_pct = (current_price - prev_price) / prev_price - - # Check if we have CNN patterns available - pattern_confidence = 0 - if hasattr(self, 'cnn_patterns'): - if action == 1 and 'long_confidence' in self.cnn_patterns: # Buy action - pattern_confidence = self.cnn_patterns['long_confidence'] - elif action == 2 and 'short_confidence' in self.cnn_patterns: # Sell action - pattern_confidence = self.cnn_patterns['short_confidence'] - - # Action-specific rewards + # Base reward for actions if action == 0: # HOLD - # Small positive reward for holding in the right direction of market movement - if self.position == 'long' and price_change_pct > 0: - reward += 0.1 + price_change_pct * 10 - elif self.position == 'short' and price_change_pct < 0: - reward += 0.1 + abs(price_change_pct) * 10 - else: - # Small negative reward for holding in the wrong direction - reward -= 0.1 - elif action == 1 or action == 2: # BUY or SELL - # Apply trading fee as negative reward (1 USD per 1.9k position size) - position_size = self.calculate_position_size() - fee = (position_size / 1900) * 1 # Trading fee in USD - - # Penalty for fee - fee_penalty = fee / 10 # Scale down to make it a reasonable penalty - reward -= fee_penalty - - # Logging - if hasattr(self, 'total_fees'): - self.total_fees += fee - else: - self.total_fees = fee - elif action == 3: # CLOSE - # Apply trading fee as negative reward (1 USD per 1.9k position size) - fee = (self.position_size / 1900) * 1 # Trading fee in USD - - # Penalty for fee - fee_penalty = fee / 10 # Scale down to make it a reasonable penalty - reward -= fee_penalty - - # Logging - if hasattr(self, 'total_fees'): - self.total_fees += fee - else: - self.total_fees = fee + reward = -0.01 # Small penalty for doing nothing - # Add CNN pattern confidence to reward - reward += pattern_confidence * 10 + elif action == 1: # BUY/LONG + if self.position == 'flat': + # Opening a long position + self.position = 'long' + self.entry_price = self.current_price + self.position_size = self.calculate_position_size() + self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT/100) + self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT/100) + + # Check if this is an optimal buy point (bottom) + current_idx = len(self.features['price']) - 1 + if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms: + reward += 2.0 # Bonus for buying at a bottom + else: + # Check if we're buying in a downtrend (bad) + if self.is_downtrend(): + reward -= 0.5 # Penalty for buying in downtrend + else: + reward += 0.1 # Small reward for opening a position + + logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") + + elif self.position == 'short': + # Close short and open long + pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100 + pnl_dollar = pnl_percent / 100 * self.position_size + + # Apply fees + pnl_dollar -= self.calculate_fees(self.position_size) + + # Update balance + self.balance += pnl_dollar + self.total_pnl += pnl_dollar + + # Record trade + trade_duration = len(self.features['price']) - self.entry_index + self.trades.append({ + 'type': 'short', + 'entry': self.entry_price, + 'exit': self.current_price, + 'pnl_percent': pnl_percent, + 'pnl_dollar': pnl_dollar, + 'duration': trade_duration, + 'market_direction': self.get_market_direction() + }) + + # Reward based on PnL with stronger penalties for losses + if pnl_dollar > 0: + reward += 1.0 + pnl_dollar / 10 # Positive reward for profit + self.win_count += 1 + else: + # Stronger penalty for losses, scaled by the size of the loss + loss_penalty = 1.0 + abs(pnl_dollar) / 5 + reward -= loss_penalty + self.loss_count += 1 + + # Extra penalty for closing a losing trade too quickly + if trade_duration < 5: + reward -= 0.5 # Penalty for very short losing trades + + logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") + + # Now open long + self.position = 'long' + self.entry_price = self.current_price + self.entry_index = len(self.features['price']) - 1 + self.position_size = self.calculate_position_size() + self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT/100) + self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT/100) + + # Check if this is an optimal buy point + if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms: + reward += 2.0 # Bonus for buying at a bottom + + logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") + + elif action == 2: # SELL/SHORT + if self.position == 'flat': + # Opening a short position + self.position = 'short' + self.entry_price = self.current_price + self.position_size = self.calculate_position_size() + self.stop_loss = self.entry_price * (1 + STOP_LOSS_PERCENT/100) + self.take_profit = self.entry_price * (1 - TAKE_PROFIT_PERCENT/100) + + # Check if this is an optimal sell point (top) + current_idx = len(self.features['price']) - 1 + if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops: + reward += 2.0 # Bonus for selling at a top + else: + reward += 0.1 # Small reward for opening a position + + logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") + + elif self.position == 'long': + # Close long and open short + pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100 + pnl_dollar = pnl_percent / 100 * self.position_size + + # Apply fees + pnl_dollar -= self.calculate_fees(self.position_size) + + # Update balance + self.balance += pnl_dollar + self.total_pnl += pnl_dollar + + # Record trade + self.trades.append({ + 'type': 'long', + 'entry': self.entry_price, + 'exit': self.current_price, + 'pnl_percent': pnl_percent, + 'pnl_dollar': pnl_dollar + }) + + # Reward based on PnL + if pnl_dollar > 0: + reward += 1.0 + pnl_dollar / 10 # Positive reward for profit + self.win_count += 1 + else: + reward -= 1.0 # Negative reward for loss + self.loss_count += 1 + + logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") + + # Now open short + self.position = 'short' + self.entry_price = self.current_price + self.position_size = self.calculate_position_size() + self.stop_loss = self.entry_price * (1 + STOP_LOSS_PERCENT/100) + self.take_profit = self.entry_price * (1 - TAKE_PROFIT_PERCENT/100) + + # Check if this is an optimal sell point + current_idx = len(self.features['price']) - 1 + if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops: + reward += 2.0 # Bonus for selling at a top + + logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}") + + elif action == 3: # CLOSE + if self.position == 'long': + # Close long position + pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100 + pnl_dollar = pnl_percent / 100 * self.position_size + + # Apply fees + pnl_dollar -= self.calculate_fees(self.position_size) + + # Update balance + self.balance += pnl_dollar + self.total_pnl += pnl_dollar + self.episode_pnl += pnl_dollar + + # Update max drawdown + if self.balance > self.peak_balance: + self.peak_balance = self.balance + drawdown = (self.peak_balance - self.balance) / self.peak_balance + self.max_drawdown = max(self.max_drawdown, drawdown) + + # Record trade + self.trades.append({ + 'type': 'long', + 'entry': self.entry_price, + 'exit': self.current_price, + 'pnl_percent': pnl_percent, + 'pnl_dollar': pnl_dollar + }) + + # Reward based on PnL + if pnl_dollar > 0: + reward += 1.0 + pnl_dollar / 10 # Positive reward for profit + self.win_count += 1 + else: + reward -= 1.0 # Negative reward for loss + self.loss_count += 1 + + logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") + + # Reset position + self.position = 'flat' + self.entry_price = 0 + self.position_size = 0 + self.stop_loss = 0 + self.take_profit = 0 + + elif self.position == 'short': + # Close short position + pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100 + pnl_dollar = pnl_percent / 100 * self.position_size + + # Apply fees + pnl_dollar -= self.calculate_fees(self.position_size) + + # Update balance + self.balance += pnl_dollar + self.total_pnl += pnl_dollar + self.episode_pnl += pnl_dollar + + # Update max drawdown + if self.balance > self.peak_balance: + self.peak_balance = self.balance + drawdown = (self.peak_balance - self.balance) / self.peak_balance + self.max_drawdown = max(self.max_drawdown, drawdown) + + # Record trade + self.trades.append({ + 'type': 'short', + 'entry': self.entry_price, + 'exit': self.current_price, + 'pnl_percent': pnl_percent, + 'pnl_dollar': pnl_dollar + }) + + # Reward based on PnL + if pnl_dollar > 0: + reward += 1.0 + pnl_dollar / 10 # Positive reward for profit + self.win_count += 1 + else: + reward -= 1.0 # Negative reward for loss + self.loss_count += 1 + + logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}") + + # Reset position + self.position = 'flat' + self.entry_price = 0 + self.position_size = 0 + self.stop_loss = 0 + self.take_profit = 0 + + # Add prediction accuracy component to reward + if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0: + # Compare the first prediction with actual price + if len(self.data) > 1: + actual_price = self.data[-1]['close'] + predicted_price = self.predicted_prices[0] + prediction_error = abs(predicted_price - actual_price) / actual_price + + # Reward accurate predictions, penalize bad ones + if prediction_error < 0.005: # Less than 0.5% error + reward += 0.5 + elif prediction_error > 0.02: # More than 2% error + reward -= 0.5 return reward + def is_downtrend(self): + """Check if the market is in a downtrend""" + if len(self.features['price']) < 20: + return False + + # Use EMA to determine trend + short_ema = self.features['ema_9'][-1] + long_ema = self.features['ema_21'][-1] + + # Downtrend if short EMA is below long EMA + return short_ema < long_ema + + def is_uptrend(self): + """Check if the market is in an uptrend""" + if len(self.features['price']) < 20: + return False + + # Use EMA to determine trend + short_ema = self.features['ema_9'][-1] + long_ema = self.features['ema_21'][-1] + + # Uptrend if short EMA is above long EMA + return short_ema > long_ema + + def get_market_direction(self): + """Get the current market direction""" + if self.is_uptrend(): + return "uptrend" + elif self.is_downtrend(): + return "downtrend" + else: + return "sideways" + + def analyze_trades(self): + """Analyze completed trades to identify patterns""" + if not self.trades: + return {} + + analysis = { + 'total_trades': len(self.trades), + 'winning_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) > 0), + 'losing_trades': sum(1 for t in self.trades if t.get('pnl_dollar', 0) <= 0), + 'avg_win': 0, + 'avg_loss': 0, + 'avg_duration': 0, + 'uptrend_win_rate': 0, + 'downtrend_win_rate': 0, + 'sideways_win_rate': 0 + } + + # Calculate averages + wins = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) > 0] + losses = [t.get('pnl_dollar', 0) for t in self.trades if t.get('pnl_dollar', 0) <= 0] + durations = [t.get('duration', 0) for t in self.trades] + + analysis['avg_win'] = sum(wins) / len(wins) if wins else 0 + analysis['avg_loss'] = sum(losses) / len(losses) if losses else 0 + analysis['avg_duration'] = sum(durations) / len(durations) if durations else 0 + + # Calculate win rates by market direction + for direction in ['uptrend', 'downtrend', 'sideways']: + direction_trades = [t for t in self.trades if t.get('market_direction') == direction] + if direction_trades: + wins_in_direction = sum(1 for t in direction_trades if t.get('pnl_dollar', 0) > 0) + analysis[f'{direction}_win_rate'] = wins_in_direction / len(direction_trades) * 100 + + return analysis + + def initialize_price_predictor(self, device="cpu"): + """Initialize the price prediction model""" + self.price_predictor = PricePredictionModel(input_size=30, hidden_size=128, output_size=5) + self.price_predictor.to(device) + self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-3) + self.predicted_prices = np.array([]) + + def train_price_predictor(self): + """Train the price prediction model on recent data""" + if len(self.features['price']) < 35: + return 0.0 + + # Get price history + price_history = self.features['price'] + + # Train the model + loss = self.price_predictor.train_on_new_data( + price_history, + self.price_predictor_optimizer, + epochs=5 + ) + + return loss + + def update_price_predictions(self): + """Update price predictions""" + if len(self.features['price']) < 30 or not hasattr(self, 'price_predictor') or self.price_predictor is None: + self.predicted_prices = np.array([]) + return + + # Get price history + price_history = self.features['price'] + + try: + # Get predictions + self.predicted_prices = self.price_predictor.predict_next_candles(price_history, num_candles=5) + except Exception as e: + logger.warning(f"Error updating predictions: {e}") + self.predicted_prices = np.array([]) + + def identify_optimal_trades(self): + """Identify optimal entry and exit points based on local extrema""" + if len(self.features['price']) < 20: + return + + # Find local bottoms and tops + bottoms, tops = find_local_extrema(self.features['price'], window=5) + + # Store optimal trade points + self.optimal_bottoms = bottoms # Buy points + self.optimal_tops = tops # Sell points + + # Create optimal trade signals + self.optimal_signals = np.zeros(len(self.features['price'])) + for i in bottoms: + if 0 <= i < len(self.optimal_signals): # Ensure index is valid + self.optimal_signals[i] = 1 # Buy signal + for i in tops: + if 0 <= i < len(self.optimal_signals): # Ensure index is valid + self.optimal_signals[i] = -1 # Sell signal + + logger.info(f"Identified {len(bottoms)} optimal buy points and {len(tops)} optimal sell points") + + def calculate_position_size(self): + """Calculate position size based on current balance and risk parameters""" + # Use a fixed percentage of balance for each trade + risk_percent = 5.0 # Risk 5% of balance per trade + + # Calculate position size with leverage + position_size = self.balance * (risk_percent / 100) * MAX_LEVERAGE + + # Apply a safety factor to avoid liquidation + safety_factor = 0.8 + position_size *= safety_factor + + # Ensure minimum position size + min_position = 10.0 # Minimum position size in USD + position_size = max(position_size, min(min_position, self.balance * 0.5)) + + # Ensure position size doesn't exceed balance * leverage + max_position = self.balance * MAX_LEVERAGE + position_size = min(position_size, max_position) + + return position_size + + def calculate_fees(self, position_size): + """Calculate trading fees for a given position size""" + # Typical fee rate for crypto exchanges (0.1%) + fee_rate = 0.001 + + # Calculate fee + fee = position_size * fee_rate + + return fee + + def is_uncertain_market(self): + """Check if the market is in an uncertain/sideways state""" + if len(self.features['price']) < 20: + return True + + # Check if price is within a narrow range + recent_prices = self.features['price'][-20:] + price_range = (max(recent_prices) - min(recent_prices)) / np.mean(recent_prices) + + # Check if EMAs are close to each other + if len(self.features['ema_9']) > 0 and len(self.features['ema_21']) > 0: + short_ema = self.features['ema_9'][-1] + long_ema = self.features['ema_21'][-1] + ema_diff = abs(short_ema - long_ema) / long_ema + + # Return True if price range is small and EMAs are close + return price_range < 0.02 and ema_diff < 0.005 + + return price_range < 0.015 # Very narrow range + + def is_near_support(self): + """Check if current price is near a support level""" + if not hasattr(self, 'features') or len(self.features['price']) < 30: + return False + + # Find recent lows + prices = self.features['price'][-30:] + lows = [] + + for i in range(1, len(prices)-1): + if prices[i] < prices[i-1] and prices[i] < prices[i+1]: + lows.append(prices[i]) + + if not lows: + return False + + # Check if current price is near any of these lows + current_price = self.current_price + for low in lows: + if abs(current_price - low) / low < 0.01: # Within 1% of a recent low + return True + + return False + + def is_near_resistance(self): + """Check if current price is near a resistance level""" + if not hasattr(self, 'features') or len(self.features['price']) < 30: + return False + + # Find recent highs + prices = self.features['price'][-30:] + highs = [] + + for i in range(1, len(prices)-1): + if prices[i] > prices[i-1] and prices[i] > prices[i+1]: + highs.append(prices[i]) + + if not highs: + return False + + # Check if current price is near any of these highs + current_price = self.current_price + for high in highs: + if abs(current_price - high) / high < 0.01: # Within 1% of a recent high + return True + + return False + + def is_market_turning(self): + """Check if the market is potentially changing direction""" + if len(self.features['price']) < 20: + return False + + # Check for divergence between price and momentum indicators + if len(self.features['rsi']) > 5: + # Price making higher highs but RSI making lower highs (bearish divergence) + price_trend = self.features['price'][-1] > self.features['price'][-5] + rsi_trend = self.features['rsi'][-1] < self.features['rsi'][-5] + + if price_trend != rsi_trend: + return True + + # Check for EMA crossover + if len(self.features['ema_9']) > 1 and len(self.features['ema_21']) > 1: + short_ema_prev = self.features['ema_9'][-2] + long_ema_prev = self.features['ema_21'][-2] + short_ema_curr = self.features['ema_9'][-1] + long_ema_curr = self.features['ema_21'][-1] + + # Check if EMAs just crossed + if (short_ema_prev < long_ema_prev and short_ema_curr > long_ema_curr) or \ + (short_ema_prev > long_ema_prev and short_ema_curr < long_ema_curr): + return True + + return False + + def is_market_against_position(self, position_type): + """Check if market conditions have turned against the current position""" + if position_type == 'long': + # For long positions, check if market has turned bearish + return self.is_downtrend() and not self.is_near_support() + elif position_type == 'short': + # For short positions, check if market has turned bullish + return self.is_uptrend() and not self.is_near_resistance() + + return False + + def is_near_optimal_exit(self, position_type): + """Check if current price is near an optimal exit point for the position""" + current_idx = len(self.features['price']) - 1 + + if position_type == 'long' and hasattr(self, 'optimal_tops'): + # For long positions, optimal exit is near tops + for top_idx in self.optimal_tops: + if abs(current_idx - top_idx) < 3: # Within 3 candles of a top + return True + elif position_type == 'short' and hasattr(self, 'optimal_bottoms'): + # For short positions, optimal exit is near bottoms + for bottom_idx in self.optimal_bottoms: + if abs(current_idx - bottom_idx) < 3: # Within 3 candles of a bottom + return True + + return False + + def calculate_future_profit_potential(self, position_type, lookahead=20): + """ + Calculate potential profit if position is held for a certain period + This is used for retrospective backtesting rewards + + Args: + position_type: 'long' or 'short' + lookahead: Number of candles to look ahead + + Returns: + Potential profit percentage + """ + if len(self.data) <= 1 or self.current_step >= len(self.data): + return 0 + + # Get current price + current_price = self.current_price + + # Get future prices (if available in historical data) + future_prices = [] + current_idx = self.current_step + + # Safely get future prices + for i in range(1, min(lookahead + 1, len(self.data) - current_idx)): + if current_idx + i < len(self.data): + future_prices.append(self.data[current_idx + i]['close']) + + if not future_prices: + return 0 + + # Calculate potential profit + if position_type == 'long': + # For long positions, find the maximum price in the future + max_future_price = max(future_prices) + potential_profit = (max_future_price - current_price) / current_price * 100 + else: # short + # For short positions, find the minimum price in the future + min_future_price = min(future_prices) + potential_profit = (current_price - min_future_price) / current_price * 100 + + return potential_profit + async def initialize_futures(self, exchange): """Initialize futures trading parameters""" if not self.demo: @@ -1491,398 +1600,6 @@ class TradingEnvironment: logger.error(f"Trade execution failed: {e}") return None - def trades_in_last_n_candles(self, n=20): - """Count the number of trades in the last n candles""" - if len(self.trades) == 0: - return 0 - - if self.data_format_is_list: - # List format: timestamp at index 0 - current_time = self.data[self.current_step][0] - n_candles_ago = self.data[max(0, self.current_step - n)][0] - else: - # Dictionary format - current_time = self.data[self.current_step]['timestamp'] - n_candles_ago = self.data[max(0, self.current_step - n)]['timestamp'] - - count = 0 - for trade in reversed(self.trades): - if 'timestamp' in trade and trade['timestamp'] >= n_candles_ago and trade['timestamp'] <= current_time: - count += 1 - else: - # Older trades, we can stop counting - break - - return count - - def count_consecutive_losses(self): - """Count the number of consecutive losing trades""" - count = 0 - for trade in reversed(self.trades): - if trade.get('pnl_dollar', 0) < 0: - count += 1 - else: - break - return count - - def is_volatile_market(self): - """Determine if the current market is volatile""" - if len(self.features['price']) < 20: - return False - - recent_prices = self.features['price'][-20:] - avg_price = sum(recent_prices) / len(recent_prices) - volatility = sum([abs(p - avg_price) / avg_price for p in recent_prices]) / len(recent_prices) - - return volatility > 0.01 # 1% average deviation is considered volatile - - def is_uptrend(self): - """Determine if the market is in an uptrend""" - if len(self.features['ema_9']) < 2 or len(self.features['ema_21']) < 2: - return False - - # Short-term trend - short_trend = self.features['ema_9'][-1] > self.features['ema_9'][-2] - - # Medium-term trend - medium_trend = self.features['ema_9'][-1] > self.features['ema_21'][-1] - - return short_trend and medium_trend - - def is_downtrend(self): - """Determine if the market is in a downtrend""" - if len(self.features['ema_9']) < 2 or len(self.features['ema_21']) < 2: - return False - - # Short-term trend - short_trend = self.features['ema_9'][-1] < self.features['ema_9'][-2] - - # Medium-term trend - medium_trend = self.features['ema_9'][-1] < self.features['ema_21'][-1] - - return short_trend and medium_trend - - def calculate_position_size(self): - """Calculate position size based on risk management rules""" - # Reduce position size after losses - consecutive_losses = self.count_consecutive_losses() - risk_factor = max(0.3, 1.0 - (consecutive_losses * 0.1)) # Reduce by 10% per loss, min 30% - - # Calculate position size based on available balance and risk - max_risk_amount = self.balance * 0.02 # Risk 2% per trade - position_size = max_risk_amount / (STOP_LOSS_PERCENT / 100 * self.current_price) - - # Apply leverage - position_size = position_size * self.leverage - - # Cap at available balance - position_size = min(position_size, self.balance * self.leverage) - - return position_size * risk_factor - - # ... existing identify_optimal_trades method ... - - def update_cnn_patterns(self, candle_data=None): - """ - Update CNN patterns using multi-timeframe data. - - Args: - candle_data: Dictionary containing candle data for different timeframes - """ - if not candle_data: - return - - try: - # Check if we have the necessary timeframes - required_timeframes = ['1m', '1h', '1d'] - if not all(tf in candle_data for tf in required_timeframes): - logging.warning(f"Missing required timeframes for CNN pattern detection") - return - - # Initialize patterns if not already done - if not hasattr(self, 'cnn_patterns'): - self.cnn_patterns = {} - - # Extract features from candle data - features = {} - - # Process each timeframe - for tf in required_timeframes: - candles = candle_data[tf] - if not candles or len(candles) < 30: - continue - - # Convert to numpy arrays for easier processing - closes = np.array([c[4] for c in candles[-100:]]) - highs = np.array([c[2] for c in candles[-100:]]) - lows = np.array([c[3] for c in candles[-100:]]) - - # Simple feature extraction - # 1. Detect trends - ema20 = self._calculate_ema(closes, 20) - ema50 = self._calculate_ema(closes, 50) - - uptrend = ema20[-1] > ema50[-1] and closes[-1] > ema20[-1] - downtrend = ema20[-1] < ema50[-1] and closes[-1] < ema20[-1] - - # 2. Detect potential reversal patterns - # Find local extrema - tops, bottoms = find_local_extrema(closes, window=10) - - # Check if we're near a bottom (potential buy) - near_bottom = False - bottom_confidence = 0 - if bottoms and len(bottoms) > 0: - last_bottom = bottoms[-1] - if len(closes) - last_bottom < 5: # Recent bottom - bottom_dist = abs(closes[-1] - closes[last_bottom]) / closes[last_bottom] - if bottom_dist < 0.01: # Within 1% of the bottom - near_bottom = True - # Higher confidence if volume is increasing - bottom_confidence = 0.8 - bottom_dist * 50 # 0.8 to 0.3 range - - # Check if we're near a top (potential sell) - near_top = False - top_confidence = 0 - if tops and len(tops) > 0: - last_top = tops[-1] - if len(closes) - last_top < 5: # Recent top - top_dist = abs(closes[-1] - closes[last_top]) / closes[last_top] - if top_dist < 0.01: # Within 1% of the top - near_top = True - # Higher confidence if volume is increasing - top_confidence = 0.8 - top_dist * 50 # 0.8 to 0.3 range - - # Store features for this timeframe - features[tf] = { - 'uptrend': uptrend, - 'downtrend': downtrend, - 'near_bottom': near_bottom, - 'bottom_confidence': bottom_confidence, - 'near_top': near_top, - 'top_confidence': top_confidence - } - - # Combine features across timeframes to get overall pattern confidence - long_confidence = 0 - short_confidence = 0 - - # Weight each timeframe (higher weight for longer timeframes) - weights = {'1m': 0.2, '1h': 0.3, '1d': 0.5} - - for tf, tf_features in features.items(): - weight = weights.get(tf, 0.2) - - # Add to long confidence - if tf_features['uptrend'] or tf_features['near_bottom']: - long_confidence += weight * (0.6 if tf_features['uptrend'] else 0) + \ - weight * (tf_features['bottom_confidence'] if tf_features['near_bottom'] else 0) - - # Add to short confidence - if tf_features['downtrend'] or tf_features['near_top']: - short_confidence += weight * (0.6 if tf_features['downtrend'] else 0) + \ - weight * (tf_features['top_confidence'] if tf_features['near_top'] else 0) - - # Normalize confidence scores to [0, 1] - long_confidence = min(1.0, long_confidence) - short_confidence = min(1.0, short_confidence) - - # Update patterns - self.cnn_patterns = { - 'long_confidence': long_confidence, - 'short_confidence': short_confidence, - 'features': features - } - - logging.debug(f"Updated CNN patterns - Long: {long_confidence:.2f}, Short: {short_confidence:.2f}") - - except Exception as e: - logging.error(f"Error updating CNN patterns: {e}") - - def _calculate_ema(self, data, span): - """Calculate exponential moving average""" - alpha = 2 / (span + 1) - alpha_rev = 1 - alpha - - ema = np.zeros_like(data) - ema[0] = data[0] - - for i in range(1, len(data)): - ema[i] = alpha * data[i] + alpha_rev * ema[i-1] - - return ema - - def is_uncertain_market(self): - """Determine if the market is in an uncertain/range-bound state""" - if len(self.features['price']) < 30: - return False - - # Check if EMAs are close to each other (no clear trend) - if len(self.features['ema_9']) > 0 and len(self.features['ema_21']) > 0: - ema9 = self.features['ema_9'][-1] - ema21 = self.features['ema_21'][-1] - - # If EMAs are within 0.2% of each other, market is uncertain - if abs(ema9 - ema21) / ema21 < 0.002: - return True - - # Check if price is oscillating without clear direction - if len(self.features['price']) >= 10: - recent_prices = self.features['price'][-10:] - ups = downs = 0 - for i in range(1, len(recent_prices)): - if recent_prices[i] > recent_prices[i-1]: - ups += 1 - else: - downs += 1 - - # If there's a mix of ups and downs (neither dominates heavily) - return abs(ups - downs) < 3 - - return False - - def is_near_support(self): - """Determine if the current price is near a support level""" - if len(self.features['price']) < 30: - return False - - # Use Bollinger lower band as support - if len(self.features['bollinger_lower']) > 0 and len(self.features['price']) > 0: - current_price = self.features['price'][-1] - lower_band = self.features['bollinger_lower'][-1] - - # If price is within 0.5% of the lower band - if (current_price - lower_band) / current_price < 0.005: - return True - - # Check if we're near recent lows - if len(self.features['price']) >= 20: - current_price = self.features['price'][-1] - min_price = min(self.features['price'][-20:]) - - # If within 1% of recent lows - if (current_price - min_price) / current_price < 0.01: - return True - - return False - - def is_near_resistance(self): - """Determine if the current price is near a resistance level""" - if len(self.features['price']) < 30: - return False - - # Use Bollinger upper band as resistance - if len(self.features['bollinger_upper']) > 0 and len(self.features['price']) > 0: - current_price = self.features['price'][-1] - upper_band = self.features['bollinger_upper'][-1] - - # If price is within 0.5% of the upper band - if (upper_band - current_price) / current_price < 0.005: - return True - - # Check if we're near recent highs - if len(self.features['price']) >= 20: - current_price = self.features['price'][-1] - max_price = max(self.features['price'][-20:]) - - # If within 1% of recent highs - if (max_price - current_price) / current_price < 0.01: - return True - - return False - - def add_chart_to_tensorboard(self, writer, step, title='Trading Chart'): - """ - Add a candlestick chart and metrics to TensorBoard - - Parameters: - - writer: TensorBoard writer - - step: Current step - - title: Title for the chart - """ - try: - # Initialize writer if not provided - if writer is None: - from torch.utils.tensorboard import SummaryWriter - writer = SummaryWriter() - - # Log basic metrics - writer.add_scalar('Balance', self.balance, step) - writer.add_scalar('Total_PnL', self.total_pnl, step) - - # Log total fees if available - if hasattr(self, 'total_fees'): - writer.add_scalar('Total_Fees', self.total_fees, step) - writer.add_scalar('Net_PnL', self.total_pnl - self.total_fees, step) - - # Log position info - writer.add_scalar('Position_Size', self.position_size, step) - - # Log drawdown and win rate - writer.add_scalar('Max_Drawdown', self.max_drawdown, step) - - win_rate = self.win_count / (self.win_count + self.loss_count) if (self.win_count + self.loss_count) > 0 else 0 - writer.add_scalar('Win_Rate', win_rate, step) - - # Log trade count - writer.add_scalar('Trade_Count', len(self.trades), step) - - # Check if we have enough data for candlestick chart - if len(self.data) <= 0: - logger.warning("No data available for candlestick chart") - return - - # Create figure for candlestick chart (last 100 data points) - start_idx = max(0, self.current_step - 100) - end_idx = self.current_step - - # Get recent trades for visualization (last 10 trades) - recent_trades = self.trades[-10:] if self.trades else [] - - try: - fig = create_candlestick_figure( - self.data[start_idx:end_idx+1], - title=title, - trades=recent_trades - ) - - # Add figure to TensorBoard - writer.add_figure('Candlestick_Chart', fig, step) - - # Close figure to free memory - plt.close(fig) - - except Exception as e: - logger.error(f"Error creating candlestick chart: {e}") - - except Exception as e: - logger.error(f"Error adding chart to TensorBoard: {e}") - # Continue execution even if chart fails - - def get_realtime_state(self, tick_data): - """ - Create a state representation optimized for real-time processing. - This is a streamlined version of get_state() designed for minimal latency. - - TODO: Implement optimized state creation from tick data - """ - # This would be a simplified version of get_state that processes only - # the most important features needed for real-time decision making - - # Example implementation: - # realtime_features = { - # 'price': tick_data['price'], - # 'volume': tick_data['volume'], - # 'ema_short': self._calculate_ema(tick_data['price'], 9), - # 'ema_long': self._calculate_ema(tick_data['price'], 21), - # } - - # Convert to tensor or numpy array in the required format - # return torch.tensor([...], dtype=torch.float32) - - # Placeholder - return np.zeros((self.observation_space.shape[0],), dtype=np.float32) - # Ensure GPU usage if available def get_device(): """Get the best available device (CUDA GPU or CPU)""" @@ -1909,9 +1626,9 @@ class Agent: # Set device self.device = device if device is not None else get_device() - # Initialize networks - use LSTMAttentionDQN instead of DQN - self.policy_net = LSTMAttentionDQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device) - self.target_net = LSTMAttentionDQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device) + # Initialize networks + self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device) + self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device) self.target_net.load_state_dict(self.policy_net.state_dict()) # Initialize optimizer @@ -1922,8 +1639,6 @@ class Agent: # Initialize exploration parameters self.epsilon = EPSILON_START - self.epsilon_start = EPSILON_START - self.epsilon_end = EPSILON_END self.epsilon_decay = EPSILON_DECAY self.epsilon_min = EPSILON_END @@ -1934,16 +1649,9 @@ class Agent: self.writer = None # Initialize GradScaler for mixed precision training - self.scaler = torch.amp.GradScaler('cuda') if self.device.type == "cuda" else None + self.scaler = torch.cuda.amp.GradScaler() if self.device.type == "cuda" else None - # Initialize candle cache for multi-timeframe data - self.candle_cache = CandleCache() - - # Store model name for logging - self.model_name = f"LSTM_Attention_DQN_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}" - - logger.info(f"Initialized agent with state_size={state_size}, action_size={action_size}, hidden_size={hidden_size}") - logger.info(f"Using device: {self.device}") + # Rest of the initialization code... def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8): """Expand the model to handle more features or increase capacity""" @@ -1954,9 +1662,9 @@ class Agent: old_state_dict = self.policy_net.state_dict() # Create new larger networks - new_policy_net = LSTMAttentionDQN(new_state_size, self.action_size, + new_policy_net = DQN(new_state_size, self.action_size, new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device) - new_target_net = LSTMAttentionDQN(new_state_size, self.action_size, + new_target_net = DQN(new_state_size, self.action_size, new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device) # Transfer weights for common layers @@ -1994,78 +1702,25 @@ class Agent: return True - def select_action(self, state, training=True, candle_data=None): - """ - Select an action using the policy network. + def select_action(self, state, training=True): + sample = random.random() - Args: - state: The current state - training: Whether we're in training mode (for epsilon-greedy) - candle_data: Dictionary with ['1s'-later], '1m', '1h', '1d' candle data - - Returns: - The selected action - """ - # ... existing code ... + if training: + # Epsilon decay + self.epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \ + np.exp(-1. * self.steps_done / EPSILON_DECAY) + self.steps_done += 1 - # Add CNN processing if candle data is available - cnn_inputs = None - if candle_data and all(k in candle_data for k in [ '1m', '1h', '1d']): - # Process candle data into tensors - # x_1s = self.prepare_candle_tensor(candle_data['1s']) - x_1m = self.prepare_candle_tensor(candle_data['1m']) - x_1h = self.prepare_candle_tensor(candle_data['1h']) - x_1d = self.prepare_candle_tensor(candle_data['1d']) - - cnn_inputs = (x_1m, x_1h, x_1d) - - # Use epsilon-greedy strategy during training - if training and random.random() < self.epsilon: - return random.randrange(self.action_size) - - with torch.no_grad(): - state_tensor = torch.FloatTensor(state).to(self.device) - - if cnn_inputs: - q_values = self.policy_net(state_tensor, *cnn_inputs) - else: - q_values = self.policy_net(state_tensor) - - return q_values.max(1)[1].item() - - def prepare_candle_tensor(self, candles, max_candles=300): - """Convert candle data to tensors for CNN input""" - if not candles: - # Return zeros if no candles available - return torch.zeros((1, 5, max_candles), device=self.device) - - # Limit to the most recent candles - candles = candles[-max_candles:] - - # Extract OHLCV data - ohlcv = np.array([[c[1], c[2], c[3], c[4], c[5]] for c in candles], dtype=np.float32) - - # Normalize the data - if len(ohlcv) > 0: - # Simple min-max normalization per column - min_vals = ohlcv.min(axis=0, keepdims=True) - max_vals = ohlcv.max(axis=0, keepdims=True) - range_vals = max_vals - min_vals - range_vals[range_vals == 0] = 1 # Avoid division by zero - ohlcv = (ohlcv - min_vals) / range_vals - - # Pad if needed - padded = np.zeros((max_candles, 5), dtype=np.float32) - padded[-len(ohlcv):] = ohlcv - - # Convert to tensor [batch, channels, sequence] - tensor = torch.FloatTensor(padded.transpose(1, 0)).unsqueeze(0).to(self.device) - return tensor + if sample > self.epsilon or not training: + with torch.no_grad(): + state_tensor = torch.FloatTensor(state).to(self.device) + action_values = self.policy_net(state_tensor) + return action_values.max(1)[1].item() else: - return torch.zeros((1, 5, max_candles), device=self.device) + return random.randrange(self.action_size) def learn(self): - """Learn from a batch of experiences with GPU acceleration and CNN features""" + """Learn from a batch of experiences""" if len(self.memory) < BATCH_SIZE: return None @@ -2080,32 +1735,31 @@ class Agent: next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device) dones = torch.FloatTensor([e.done for e in experiences]).to(self.device) - # Use mixed precision for forward/backward passes if on GPU + # Use mixed precision for forward/backward passes if self.device.type == "cuda" and self.scaler is not None: - with torch.cuda.amp.autocast(): - # Compute current Q values + with torch.amp.autocast('cuda'): + # Compute Q values current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)) - # Compute next Q values + # Compute next Q values with target network with torch.no_grad(): next_q_values = self.target_net(next_states).max(1)[0] + target_q_values = rewards + (GAMMA * next_q_values * (1 - dones)) - # Compute target Q values - target_q_values = rewards + (GAMMA * next_q_values * (1 - dones)) + # Reshape target values to match current_q_values target_q_values = target_q_values.unsqueeze(1) # Compute loss loss = F.smooth_l1_loss(current_q_values, target_q_values) - - # Backward pass with gradient scaling + + # Backward pass with mixed precision self.optimizer.zero_grad() self.scaler.scale(loss).backward() - # Clip gradients + # Gradient clipping to prevent exploding gradients self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0) - # Update weights self.scaler.step(self.optimizer) self.scaler.update() else: @@ -2146,41 +1800,53 @@ class Agent: logger.error(f"Error during learning: {e}") logger.error(f"Traceback: {traceback.format_exc()}") return None - - def update_epsilon(self, episode): - """Update epsilon value based on episode number""" - # Calculate epsilon using a linear decay formula - epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \ - max(0, (self.epsilon_decay - episode)) / self.epsilon_decay - - # Update self.epsilon with the calculated value - self.epsilon = max(self.epsilon_min, epsilon) - - return self.epsilon def update_target_network(self): self.target_net.load_state_dict(self.policy_net.state_dict()) def save(self, path="models/trading_agent_best_pnl.pt"): - """Save the model using a robust saving approach with multiple fallbacks""" + """Save the model in a format compatible with PyTorch 2.6+""" try: # Create directory if it doesn't exist os.makedirs(os.path.dirname(path), exist_ok=True) - # Call robust save function - success = robust_save(self, path) - - if success: - logger.info(f"Model saved successfully to {path}") - return True - else: - logger.error(f"All save attempts failed for path: {path}") - return False + # Ensure architecture parameters are set + if not hasattr(self, 'hidden_size'): + self.hidden_size = 256 # Default value + logger.warning("Setting default hidden_size=256 for saving") + if not hasattr(self, 'lstm_layers'): + self.lstm_layers = 2 # Default value + logger.warning("Setting default lstm_layers=2 for saving") + + if not hasattr(self, 'attention_heads'): + self.attention_heads = 4 # Default value + logger.warning("Setting default attention_heads=4 for saving") + + # Save model state + checkpoint = { + 'policy_net': self.policy_net.state_dict(), + 'target_net': self.target_net.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'epsilon': self.epsilon, + 'state_size': self.state_size, + 'action_size': self.action_size, + 'hidden_size': self.hidden_size, + 'lstm_layers': self.lstm_layers, + 'attention_heads': self.attention_heads + } + + # Save scaler state if it exists + if hasattr(self, 'scaler') and self.scaler is not None: + checkpoint['scaler'] = self.scaler.state_dict() + + # Save with pickle_protocol=4 for better compatibility + torch.save(checkpoint, path, _use_new_zipfile_serialization=True, pickle_protocol=4) + logger.info(f"Model saved to {path}") except Exception as e: - logger.error(f"Error in save method: {e}") + logger.error(f"Error saving model: {e}") + import traceback logger.error(traceback.format_exc()) - return False def load(self, path="models/trading_agent_best_pnl.pt"): """Load a trained model with improved error handling for PyTorch 2.6 compatibility""" @@ -2259,110 +1925,40 @@ class Agent: logger.error(traceback.format_exc()) raise - def add_chart_to_tensorboard(self, env, step): - """Add candlestick chart to tensorboard and various metrics""" + def add_chart_to_tensorboard(self, env, global_step): + """Add trading chart to TensorBoard""" try: - # Initialize writer if it doesn't exist - if not hasattr(self, 'writer') or self.writer is None: - self.writer = SummaryWriter(log_dir=f'runs/{self.model_name}') - - # Check if we have enough data - if not hasattr(env, 'data') or len(env.data) < 20: - logger.warning("Not enough data for chart in TensorBoard") + if len(env.data) < 10: return - # Get position value (convert from string if needed) - position_value = 0 # Default to flat - if hasattr(env, 'position'): - if isinstance(env.position, str): - # Map string positions to numeric values - position_map = {'flat': 0, 'long': 1, 'short': -1} - position_value = position_map.get(env.position.lower(), 0) - else: - position_value = float(env.position) - - # Log metrics to tensorboard - self.writer.add_scalar('Trading/Position', position_value, step) + # Create chart image + chart_img = create_candlestick_figure( + env.data, + env.trade_signals, + window_size=100, + title=f"Trading Chart - Step {global_step}" + ) - if hasattr(env, 'balance'): - self.writer.add_scalar('Trading/Balance', env.balance, step) + if chart_img is not None: + # Convert PIL image to numpy array for TensorBoard + chart_array = np.array(chart_img) + # TensorBoard expects [C, H, W] format + chart_array = np.transpose(chart_array, (2, 0, 1)) + self.writer.add_image('Trading Chart', chart_array, global_step) - if hasattr(env, 'total_pnl'): - self.writer.add_scalar('Trading/Total_PnL', env.total_pnl, step) - - if hasattr(env, 'max_drawdown'): - self.writer.add_scalar('Trading/Drawdown', env.max_drawdown, step) - - if hasattr(env, 'win_rate'): - self.writer.add_scalar('Trading/Win_Rate', env.win_rate, step) - - if hasattr(env, 'trade_count'): - self.writer.add_scalar('Trading/Trade_Count', env.trade_count, step) - - # Log trading fees - if hasattr(env, 'total_fees'): - self.writer.add_scalar('Trading/Total_Fees', env.total_fees, step) - # Also log net PnL (after fees) - if hasattr(env, 'total_pnl'): - self.writer.add_scalar('Trading/Net_PnL_After_Fees', env.total_pnl - env.total_fees, step) - - # Add candlestick chart if we have enough data - if len(env.data) >= 100: - try: - # Use the last 100 candles for the chart - recent_data = env.data[-100:] - - # Get recent trades if available - recent_trades = None - if hasattr(env, 'trades') and len(env.trades) > 0: - recent_trades = env.trades[-10:] # Last 10 trades - - # Create candlestick figure - fig = create_candlestick_figure(recent_data, recent_trades, f"Trading Chart - Step {step}") - - if fig: - # Add to tensorboard - self.writer.add_figure('Trading/Chart', fig, step) - - # Close figure to free memory - plt.close(fig) - except Exception as e: - logger.warning(f"Error creating candlestick chart: {e}") + # Add position information as text + entry_price = env.entry_price if env.entry_price else 0.00 + position_info = f""" + **Current Position**: {env.position.upper()} + **Entry Price**: ${entry_price:.2f} + **Current Price**: ${env.data[-1]['close']:.2f} + **Position Size**: ${env.position_size:.2f} + **Unrealized PnL**: ${env.total_pnl:.2f} + """ + self.writer.add_text('Position', position_info, global_step) except Exception as e: - logger.error(f"Error in add_chart_to_tensorboard: {e}") - - def select_action_realtime(self, state): - """ - Select action with minimal latency for real-time trading. - Optimized version of select_action for ultra-low latency requirements. - - TODO: Implement optimized action selection for real-time trading - """ - # Convert to tensor if needed - state_tensor = torch.tensor(state, dtype=torch.float32) - - # Fast forward pass through the network - with torch.no_grad(): - q_values = self.policy_net.forward_realtime(state_tensor.unsqueeze(0)) - - # Get the action with highest Q-value - action = q_values.max(1)[1].item() - - return action - - def forward_realtime(self, state): - """ - Optimized forward pass for real-time trading with minimal latency. - - TODO: Implement streamlined forward pass that prioritizes speed - """ - # For now, just use the regular forward pass - # This could be optimized later with techniques like: - # - Using a smaller model for real-time decisions - # - Skipping certain layers or calculations - # - Using quantized weights or other optimizations - - return self.forward(state) + logger.error(f"Error adding chart to TensorBoard: {str(e)}") + # Continue without visualization rather than crashing async def get_live_prices(symbol="ETH/USDT", timeframe="1m"): """Get live price data using websockets""" @@ -2402,39 +1998,8 @@ async def get_live_prices(symbol="ETH/USDT", timeframe="1m"): await asyncio.sleep(5) break -async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, use_compact_save=False): - """ - Train the agent in the environment. - - Args: - agent: The agent to train - env: The trading environment - num_episodes: Number of episodes to train for - max_steps_per_episode: Maximum steps per episode - use_compact_save: Whether to use compact save (for low disk space) - - Returns: - Training statistics - """ - # Initialize TensorBoard writer if not already done - try: - if agent.writer is None: - from torch.utils.tensorboard import SummaryWriter - agent.writer = SummaryWriter(log_dir=f'runs/{agent.model_name}') - - writer = agent.writer - except Exception as e: - logging.error(f"Failed to initialize TensorBoard: {e}") - writer = None - - # Initialize exchange for data fetching - try: - exchange = await initialize_exchange() - logging.info("Initialized exchange for data fetching") - except Exception as e: - logging.error(f"Failed to initialize exchange: {e}") - exchange = None - +async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000): + """Train the agent using historical and live data with GPU acceleration""" # Initialize statistics tracking stats = { 'episode_rewards': [], @@ -2444,409 +2009,225 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, 'episode_pnls': [], 'cumulative_pnl': [], 'drawdowns': [], - 'trade_counts': [], - 'loss_values': [], - 'fees': [], # Track fees - 'net_pnl_after_fees': [] # Track net PnL after fees + 'prediction_accuracy': [], + 'trade_analysis': [] } # Track best models best_reward = float('-inf') best_pnl = float('-inf') - best_net_pnl = float('-inf') # Track best net PnL (after fees) - # Make directory for models if it doesn't exist - os.makedirs('models', exist_ok=True) + # Initialize TensorBoard writer if not already initialized + if not hasattr(agent, 'writer') or agent.writer is None: + agent.writer = SummaryWriter('runs/training') - # Memory management function - def clean_memory(): - """Clean up memory to avoid memory leaks""" - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - # Start training loop + # Training loop for episode in range(num_episodes): try: - # Clean up memory before starting a new episode - clean_memory() - # Reset environment state = env.reset() episode_reward = 0 - episode_losses = [] - - # Fetch multi-timeframe data at the start of the episode - candle_data = None - if exchange: - try: - candle_data = await fetch_multi_timeframe_data( - exchange, "ETH/USDT", agent.candle_cache - ) - # Update CNN patterns - env.update_cnn_patterns(candle_data) - logging.info(f"Fetched multi-timeframe data for episode {episode+1}") - except Exception as e: - logging.error(f"Failed to fetch candle data: {e}") - - # Track consecutive errors - consecutive_errors = 0 - max_consecutive_errors = 5 + prediction_loss = 0 # Episode loop for step in range(max_steps_per_episode): + # Select action + action = agent.select_action(state) + + # Take action try: - # Select action using CNN-enhanced policy - action = agent.select_action(state, training=True, candle_data=candle_data) - - # Take action next_state, reward, done, info = env.step(action) - - # Store transition in replay memory - agent.memory.push(state, action, reward, next_state, done) - - # Move to the next state - state = next_state - - # Update episode reward - episode_reward += reward - - # Learn from experience - if len(agent.memory) > BATCH_SIZE: - try: - loss = agent.learn() - if loss is not None: - episode_losses.append(loss) - # Log loss to TensorBoard - global_step = episode * max_steps_per_episode + step - if writer: - writer.add_scalar('Loss/step', loss, global_step) - - # Reset consecutive errors counter on successful learning - consecutive_errors = 0 - except Exception as e: - logging.error(f"Error during learning: {e}") - consecutive_errors += 1 - if consecutive_errors >= max_consecutive_errors: - logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors") - break - - # Update target network periodically - if step % TARGET_UPDATE == 0: - agent.update_target_network() - - # Update price predictions and CNN patterns periodically - if step % 50 == 0: - try: - # Update internal environment predictions - if hasattr(env, 'update_price_predictions'): - env.update_price_predictions() - if hasattr(env, 'identify_optimal_trades'): - env.identify_optimal_trades() - - # Fetch fresh candle data periodically - if exchange: - try: - candle_data = await fetch_multi_timeframe_data( - exchange, "ETH/USDT", agent.candle_cache - ) - - # Update CNN patterns with the new candle data - env.update_cnn_patterns(candle_data) - logging.info(f"Updated multi-timeframe data at step {step}") - except Exception as e: - logging.error(f"Failed to fetch candle data: {e}") - except Exception as e: - logging.warning(f"Error updating predictions: {e}") - - # Clean memory periodically during long episodes - if step % 200 == 0 and step > 0: - clean_memory() - - # Add chart to TensorBoard periodically - if step % 100 == 0 or (step == max_steps_per_episode - 1) or done: - try: - global_step = episode * max_steps_per_episode + step - if writer: - agent.add_chart_to_tensorboard(env, global_step) - except Exception as e: - logging.warning(f"Error adding chart to TensorBoard: {e}") - - if done: - break - except Exception as e: - logging.error(f"Error in training step: {e}") - consecutive_errors += 1 - if consecutive_errors >= max_consecutive_errors: - logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors") - break + logger.error(f"Error in step function: {e}") + break + + # Store transition in replay memory + agent.memory.push(state, action, reward, next_state, done) + + # Move to the next state + state = next_state + + # Update episode reward + episode_reward += reward + + # Learn from experience + if len(agent.memory) > BATCH_SIZE: + agent.learn() + + # Update price predictions periodically + if step % 50 == 0: + try: + env.update_price_predictions() + env.identify_optimal_trades() + except Exception as e: + logger.warning(f"Error updating predictions: {e}") + + # Add chart to TensorBoard periodically + if step % 50 == 0 or (step == max_steps_per_episode - 1) or done: + try: + global_step = episode * max_steps_per_episode + step + agent.add_chart_to_tensorboard(env, global_step) + except Exception as e: + logger.warning(f"Error adding chart to TensorBoard: {e}") + + # End episode if done + if done: + break - # Calculate statistics from this episode - balance = env.balance - pnl = balance - env.initial_balance if hasattr(env, 'initial_balance') else 0 - fees = env.total_fees if hasattr(env, 'total_fees') else 0 - net_pnl = pnl - fees # Calculate net PnL after fees + # Update target network periodically + if episode % TARGET_UPDATE == 0: + agent.update_target_network() - # Get trading statistics - trade_analysis = None - if hasattr(env, 'analyze_trades'): + # Calculate win rate + total_trades = env.win_count + env.loss_count + win_rate = (env.win_count / total_trades * 100) if total_trades > 0 else 0 + + # Train price predictor + try: + if episode % 5 == 0 and len(env.data) > 50: + prediction_loss = env.train_price_predictor() + except Exception as e: + logger.warning(f"Error training price predictor: {e}") + prediction_loss = 0 + + # Analyze trades + try: trade_analysis = env.analyze_trades() + stats['trade_analysis'].append(trade_analysis) + except Exception as e: + logger.warning(f"Error analyzing trades: {e}") + trade_analysis = {} + stats['trade_analysis'].append({}) - win_rate = trade_analysis['win_rate'] if trade_analysis and 'win_rate' in trade_analysis else 0 - trade_count = trade_analysis['total_trades'] if trade_analysis and 'total_trades' in trade_analysis else 0 - max_drawdown = trade_analysis['max_drawdown'] if trade_analysis and 'max_drawdown' in trade_analysis else 0 + # Calculate prediction accuracy + prediction_accuracy = 0.0 + try: + if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0: + if len(env.data) > 5: + actual_prices = [candle['close'] for candle in env.data[-5:]] + predicted = env.predicted_prices[:min(5, len(actual_prices))] + errors = [abs(p - a) / a for p, a in zip(predicted, actual_prices[:len(predicted)])] + prediction_accuracy = 100 * (1 - sum(errors) / len(errors)) + except Exception as e: + logger.warning(f"Error calculating prediction accuracy: {e}") - # Calculate average loss for this episode - avg_loss = sum(episode_losses) / len(episode_losses) if episode_losses else 0 - - # Log episode metrics to TensorBoard - if writer: - writer.add_scalar('Reward/episode', episode_reward, episode) - writer.add_scalar('Balance/episode', balance, episode) - writer.add_scalar('PnL/episode', pnl, episode) - writer.add_scalar('NetPnL/episode', net_pnl, episode) - writer.add_scalar('Fees/episode', fees, episode) - writer.add_scalar('WinRate/episode', win_rate, episode) - writer.add_scalar('TradeCount/episode', trade_count, episode) - writer.add_scalar('Drawdown/episode', max_drawdown, episode) - writer.add_scalar('Loss/episode', avg_loss, episode) - writer.add_scalar('Epsilon/episode', agent.epsilon, episode) - - # Update stats dictionary + # Log statistics stats['episode_rewards'].append(episode_reward) stats['episode_lengths'].append(step + 1) - stats['balances'].append(balance) + stats['balances'].append(env.balance) stats['win_rates'].append(win_rate) - stats['episode_pnls'].append(pnl) - stats['drawdowns'].append(max_drawdown) - stats['trade_counts'].append(trade_count) - stats['loss_values'].append(avg_loss) - stats['fees'].append(fees) - stats['net_pnl_after_fees'].append(net_pnl) + stats['episode_pnls'].append(env.episode_pnl) + stats['cumulative_pnl'].append(env.total_pnl) + stats['drawdowns'].append(env.max_drawdown * 100) + stats['prediction_accuracy'].append(prediction_accuracy) - # Calculate and update cumulative PnL - if len(stats['episode_pnls']) > 0: - cumulative_pnl = sum(stats['episode_pnls']) - if 'cumulative_pnl' not in stats: - stats['cumulative_pnl'] = [] - stats['cumulative_pnl'].append(cumulative_pnl) - if writer: - writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode) - writer.add_scalar('CumulativeNetPnL/episode', sum(stats['net_pnl_after_fees']), episode) + # Log detailed trade analysis + if trade_analysis: + logger.info(f"Trade Analysis: Win Rate={trade_analysis.get('uptrend_win_rate', 0):.1f}% in uptrends, " + f"{trade_analysis.get('downtrend_win_rate', 0):.1f}% in downtrends | " + f"Avg Win=${trade_analysis.get('avg_win', 0):.2f}, Avg Loss=${trade_analysis.get('avg_loss', 0):.2f}") - # Save model if this is the best reward or PnL + # Log to TensorBoard + agent.writer.add_scalar('Reward/train', episode_reward, episode) + agent.writer.add_scalar('Balance/train', env.balance, episode) + agent.writer.add_scalar('WinRate/train', win_rate, episode) + agent.writer.add_scalar('PnL/episode', env.episode_pnl, episode) + agent.writer.add_scalar('PnL/cumulative', env.total_pnl, episode) + agent.writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode) + agent.writer.add_scalar('PredictionLoss', prediction_loss, episode) + agent.writer.add_scalar('PredictionAccuracy', prediction_accuracy, episode) + + # Add final chart for this episode + try: + agent.add_chart_to_tensorboard(env, (episode + 1) * max_steps_per_episode) + except Exception as e: + logger.warning(f"Error adding final chart: {e}") + + logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, " + f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}, " + f"Episode PnL=${env.episode_pnl:.2f}, Total PnL=${env.total_pnl:.2f}, " + f"Max Drawdown={env.max_drawdown*100:.1f}%, Pred Accuracy={prediction_accuracy:.1f}%") + + # Save best model by reward if episode_reward > best_reward: best_reward = episode_reward - try: - if use_compact_save: - success = compact_save(agent, 'models/trading_agent_best_reward.pt') - else: - success = agent.save('models/trading_agent_best_reward.pt') - if success: - logging.info(f"New best reward: {best_reward:.2f}") - except Exception as e: - logging.error(f"Error saving best reward model: {e}") + agent.save("models/trading_agent_best_reward.pt") - if pnl > best_pnl: - best_pnl = pnl - try: - if use_compact_save: - success = compact_save(agent, 'models/trading_agent_best_pnl.pt') - else: - success = agent.save('models/trading_agent_best_pnl.pt') - if success: - logging.info(f"New best PnL: ${best_pnl:.2f}") - except Exception as e: - logging.error(f"Error saving best PnL model: {e}") - - # Save model if this is the best net PnL (after fees) - if net_pnl > best_net_pnl: - best_net_pnl = net_pnl - try: - if use_compact_save: - success = compact_save(agent, 'models/trading_agent_best_net_pnl.pt') - else: - success = agent.save('models/trading_agent_best_net_pnl.pt') - if success: - logging.info(f"New best Net PnL: ${best_net_pnl:.2f}") - except Exception as e: - logging.error(f"Error saving best net PnL model: {e}") + # Save best model by PnL + if env.episode_pnl > best_pnl: + best_pnl = env.episode_pnl + agent.save("models/trading_agent_best_pnl.pt") - # Save checkpoint periodically + # Save checkpoint if episode % 10 == 0: - try: - if use_compact_save: - compact_save(agent, f'models/trading_agent_checkpoint_{episode}.pt') - else: - agent.save(f'models/trading_agent_checkpoint_{episode}.pt') - except Exception as e: - logging.error(f"Error saving checkpoint model: {e}") - - # Update epsilon - agent.update_epsilon(episode) - - # Log training progress - logging.info(f"Episode {episode+1}/{num_episodes} | " + - f"Reward: {episode_reward:.2f} | " + - f"Balance: ${balance:.2f} | " + - f"PnL: ${pnl:.2f} | " + - f"Fees: ${fees:.2f} | " + - f"Net PnL: ${net_pnl:.2f} | " + - f"Win Rate: {win_rate:.2f} | " + - f"Trades: {trade_count} | " + - f"Loss: {avg_loss:.5f} | " + - f"Epsilon: {agent.epsilon:.4f}") - + agent.save(f"models/trading_agent_episode_{episode}.pt") + except Exception as e: - logging.error(f"Error in episode {episode}: {e}") - logging.error(traceback.format_exc()) + logger.error(f"Error in episode {episode}: {e}") continue - # Clean memory before saving final model - clean_memory() - # Save final model - try: - if use_compact_save: - compact_save(agent, 'models/trading_agent_final.pt') - else: - agent.save('models/trading_agent_final.pt') - except Exception as e: - logging.error(f"Error saving final model: {e}") + agent.save("models/trading_agent_final.pt") - # Save training statistics to file - try: - import pandas as pd - - # Make sure all arrays in stats are the same length by padding with NaN - max_length = max(len(v) for k, v in stats.items() if isinstance(v, list)) - for k, v in stats.items(): - if isinstance(v, list) and len(v) < max_length: - stats[k] = v + [float('nan')] * (max_length - len(v)) - - # Create dataframe and save - stats_df = pd.DataFrame(stats) - stats_df.to_csv('training_stats.csv', index=False) - logging.info(f"Training statistics saved to training_stats.csv") - except Exception as e: - logging.error(f"Failed to save training statistics: {e}") - logging.error(traceback.format_exc()) - - # Close exchange if it's still open - if exchange: - try: - # Check if exchange has the close method (ccxt.async_support) - if hasattr(exchange, 'close'): - await exchange.close() - logging.info("Closed exchange connection") - else: - logging.info("Exchange doesn't have close method (standard ccxt), skipping close") - except Exception as e: - logging.error(f"Error closing exchange: {e}") + # Plot training results + plot_training_results(stats) return stats def plot_training_results(stats): - """Plot training results and save to file""" - try: - # Check if we have data to plot - if not stats or len(stats.get('episode_rewards', [])) == 0: - logger.warning("No training data to plot") - return - - # Create a DataFrame with consistent lengths - max_len = max(len(stats.get(key, [])) for key in stats) - - # Ensure all arrays have the same length by padding with the last value or zeros - processed_stats = {} - for key, values in stats.items(): - if not values: # Skip empty lists - continue - - # Pad arrays to the same length - if len(values) < max_len: - if len(values) > 0: - # Pad with the last value - values = values + [values[-1]] * (max_len - len(values)) - else: - # Pad with zeros - values = [0] * max_len - - processed_stats[key] = values[:max_len] # Trim if longer - - # Create DataFrame - df = pd.DataFrame(processed_stats) - - # Add episode column - df['episode'] = range(1, len(df) + 1) - - # Create figure with subplots - fig, axes = plt.subplots(3, 2, figsize=(15, 15)) - - # Plot episode rewards - if 'episode_rewards' in df.columns: - axes[0, 0].plot(df['episode'], df['episode_rewards']) - axes[0, 0].set_title('Episode Rewards') - axes[0, 0].set_xlabel('Episode') - axes[0, 0].set_ylabel('Reward') - axes[0, 0].grid(True) - - # Plot account balance - if 'balances' in df.columns: - axes[0, 1].plot(df['episode'], df['balances']) - axes[0, 1].set_title('Account Balance') - axes[0, 1].set_xlabel('Episode') - axes[0, 1].set_ylabel('Balance ($)') - axes[0, 1].grid(True) - - # Plot win rate - if 'win_rates' in df.columns: - axes[1, 0].plot(df['episode'], df['win_rates']) - axes[1, 0].set_title('Win Rate') - axes[1, 0].set_xlabel('Episode') - axes[1, 0].set_ylabel('Win Rate') - axes[1, 0].set_ylim([0, 1]) - axes[1, 0].grid(True) - - # Plot episode PnL - if 'episode_pnls' in df.columns: - axes[1, 1].plot(df['episode'], df['episode_pnls']) - axes[1, 1].set_title('Episode PnL') - axes[1, 1].set_xlabel('Episode') - axes[1, 1].set_ylabel('PnL ($)') - axes[1, 1].grid(True) - - # Plot cumulative PnL - if 'cumulative_pnl' in df.columns: - axes[2, 0].plot(df['episode'], df['cumulative_pnl']) - axes[2, 0].set_title('Cumulative PnL') - axes[2, 0].set_xlabel('Episode') - axes[2, 0].set_ylabel('Cumulative PnL ($)') - axes[2, 0].grid(True) - - # Plot maximum drawdown - if 'drawdowns' in df.columns: - axes[2, 1].plot(df['episode'], df['drawdowns']) - axes[2, 1].set_title('Maximum Drawdown') - axes[2, 1].set_xlabel('Episode') - axes[2, 1].set_ylabel('Drawdown') - axes[2, 1].grid(True) - - # Adjust layout - plt.tight_layout() - - # Save figure - plt.savefig('training_results.png') - logger.info("Training results saved to training_results.png") - - # Save statistics to CSV - df.to_csv('training_stats.csv', index=False) - logger.info("Training statistics saved to training_stats.csv") - - except Exception as e: - logger.error(f"Error plotting training results: {e}") - logger.error(traceback.format_exc()) + """Plot detailed training results""" + plt.figure(figsize=(20, 15)) + + # Plot rewards + plt.subplot(3, 2, 1) + plt.plot(stats['episode_rewards']) + plt.title('Episode Rewards') + plt.xlabel('Episode') + plt.ylabel('Reward') + + # Plot balance + plt.subplot(3, 2, 2) + plt.plot(stats['balances']) + plt.title('Account Balance') + plt.xlabel('Episode') + plt.ylabel('Balance ($)') + + # Plot win rate + plt.subplot(3, 2, 3) + plt.plot(stats['win_rates']) + plt.title('Win Rate') + plt.xlabel('Episode') + plt.ylabel('Win Rate (%)') + + # Plot episode PnL + plt.subplot(3, 2, 4) + plt.plot(stats['episode_pnls']) + plt.title('Episode PnL') + plt.xlabel('Episode') + plt.ylabel('PnL ($)') + + # Plot cumulative PnL + plt.subplot(3, 2, 5) + plt.plot(stats['cumulative_pnl']) + plt.title('Cumulative PnL') + plt.xlabel('Episode') + plt.ylabel('Cumulative PnL ($)') + + # Plot drawdown + plt.subplot(3, 2, 6) + plt.plot(stats['drawdowns']) + plt.title('Maximum Drawdown') + plt.xlabel('Episode') + plt.ylabel('Drawdown (%)') + + plt.tight_layout() + plt.savefig('training_results.png') + + # Save statistics to CSV + df = pd.DataFrame(stats) + df.to_csv('training_stats.csv', index=False) + + logger.info("Training statistics saved to training_stats.csv and training_results.png") def evaluate_agent(agent, env, num_episodes=10): """Evaluate the agent on test data""" @@ -3008,437 +2389,211 @@ async def get_historical_data(exchange, symbol="ETH/USDT", timeframe="1m", limit logger.error(f"Failed to fetch historical data: {e}") return [] -async def live_trading( - symbol="ETH/USDT", - timeframe="1m", - model_path=None, - demo=False, - leverage=50, - initial_balance=1000, - max_position_size=0.1, - commission=0.0004, - window_size=30, - update_interval=60, - stop_loss_pct=0.02, - take_profit_pct=0.04, - max_trades_per_day=10, - risk_per_trade=0.02, - use_trailing_stop=False, - trailing_stop_callback=0.005, - use_dynamic_sizing=True, - use_volatility_sizing=True, - use_multi_timeframe=True, - use_sentiment=False, - use_limit_orders=False, - use_dollar_cost_avg=False, - use_grid_trading=False, - use_martingale=False, - use_anti_martingale=False, - use_custom_indicators=True, - use_ml_predictions=True, - use_ensemble=True, - use_reinforcement=True, - use_risk_management=True, - use_portfolio_management=False, - use_position_sizing=True, - use_stop_loss=True, - use_take_profit=True, - use_trailing_stop_loss=False, - use_dynamic_stop_loss=True, - use_dynamic_take_profit=True, - use_dynamic_trailing_stop=False, - use_dynamic_position_sizing=True, - use_dynamic_leverage=False, - use_dynamic_risk_per_trade=True, - use_dynamic_max_trades_per_day=False, - use_dynamic_update_interval=False, - use_dynamic_window_size=False, - use_dynamic_commission=False, - use_dynamic_timeframe=False, - use_dynamic_symbol=False, - use_dynamic_model_path=False, - use_dynamic_demo=False, - use_dynamic_leverage_value=False, - use_dynamic_initial_balance=False, - use_dynamic_max_position_size=False, - use_dynamic_stop_loss_pct=False, - use_dynamic_take_profit_pct=False, - use_dynamic_risk_per_trade_value=False, - use_dynamic_trailing_stop_callback=False, - use_dynamic_use_trailing_stop=False, - use_dynamic_use_dynamic_sizing=False, - use_dynamic_use_volatility_sizing=False, - use_dynamic_use_multi_timeframe=False, - use_dynamic_use_sentiment=False, - use_dynamic_use_limit_orders=False, - use_dynamic_use_dollar_cost_avg=False, - use_dynamic_use_grid_trading=False, - use_dynamic_use_martingale=False, - use_dynamic_use_anti_martingale=False, - use_dynamic_use_custom_indicators=False, - use_dynamic_use_ml_predictions=False, - use_dynamic_use_ensemble=False, - use_dynamic_use_reinforcement=False, - use_dynamic_use_risk_management=False, - use_dynamic_use_portfolio_management=False, - use_dynamic_use_position_sizing=False, - use_dynamic_use_stop_loss=False, - use_dynamic_use_take_profit=False, - use_dynamic_use_trailing_stop_loss=False, - use_dynamic_use_dynamic_stop_loss=False, - use_dynamic_use_dynamic_take_profit=False, - use_dynamic_use_dynamic_trailing_stop=False, - use_dynamic_use_dynamic_position_sizing=False, - use_dynamic_use_dynamic_leverage=False, - use_dynamic_use_dynamic_risk_per_trade=False, - use_dynamic_use_dynamic_max_trades_per_day=False, - use_dynamic_use_dynamic_update_interval=False, - use_dynamic_use_dynamic_window_size=False, - use_dynamic_use_dynamic_commission=False, - use_dynamic_use_dynamic_timeframe=False, - use_dynamic_use_dynamic_symbol=False, - use_dynamic_use_dynamic_model_path=False, - use_dynamic_use_dynamic_demo=False, - use_dynamic_use_dynamic_leverage_value=False, - use_dynamic_use_dynamic_initial_balance=False, - use_dynamic_use_dynamic_max_position_size=False, - use_dynamic_use_dynamic_stop_loss_pct=False, - use_dynamic_use_dynamic_take_profit_pct=False, - use_dynamic_use_dynamic_risk_per_trade_value=False, - use_dynamic_use_dynamic_trailing_stop_callback=False, -): - """ - Live trading function that connects to the exchange and trades in real-time. - - Args: - symbol: Trading pair symbol - timeframe: Timeframe for trading - model_path: Path to the trained model - demo: Whether to use demo mode (sandbox) - leverage: Leverage to use - initial_balance: Initial balance - max_position_size: Maximum position size as a percentage of balance - commission: Commission rate - window_size: Window size for the environment - update_interval: Interval to update data in seconds - stop_loss_pct: Stop loss percentage - take_profit_pct: Take profit percentage - max_trades_per_day: Maximum trades per day - risk_per_trade: Risk per trade as a percentage of balance - use_trailing_stop: Whether to use trailing stop - trailing_stop_callback: Trailing stop callback percentage - use_dynamic_sizing: Whether to use dynamic position sizing - use_volatility_sizing: Whether to use volatility-based position sizing - use_multi_timeframe: Whether to use multi-timeframe analysis - use_sentiment: Whether to use sentiment analysis - use_limit_orders: Whether to use limit orders - use_dollar_cost_avg: Whether to use dollar cost averaging - use_grid_trading: Whether to use grid trading - use_martingale: Whether to use martingale strategy - use_anti_martingale: Whether to use anti-martingale strategy - use_custom_indicators: Whether to use custom indicators - use_ml_predictions: Whether to use ML predictions - use_ensemble: Whether to use ensemble methods - use_reinforcement: Whether to use reinforcement learning - use_risk_management: Whether to use risk management - use_portfolio_management: Whether to use portfolio management - use_position_sizing: Whether to use position sizing - use_stop_loss: Whether to use stop loss - use_take_profit: Whether to use take profit - use_trailing_stop_loss: Whether to use trailing stop loss - use_dynamic_stop_loss: Whether to use dynamic stop loss - use_dynamic_take_profit: Whether to use dynamic take profit - use_dynamic_trailing_stop: Whether to use dynamic trailing stop - use_dynamic_position_sizing: Whether to use dynamic position sizing - use_dynamic_leverage: Whether to use dynamic leverage - use_dynamic_risk_per_trade: Whether to use dynamic risk per trade - use_dynamic_max_trades_per_day: Whether to use dynamic max trades per day - use_dynamic_update_interval: Whether to use dynamic update interval - use_dynamic_window_size: Whether to use dynamic window size - use_dynamic_commission: Whether to use dynamic commission - use_dynamic_timeframe: Whether to use dynamic timeframe - use_dynamic_symbol: Whether to use dynamic symbol - use_dynamic_model_path: Whether to use dynamic model path - use_dynamic_demo: Whether to use dynamic demo - use_dynamic_leverage_value: Whether to use dynamic leverage value - use_dynamic_initial_balance: Whether to use dynamic initial balance - use_dynamic_max_position_size: Whether to use dynamic max position size - use_dynamic_stop_loss_pct: Whether to use dynamic stop loss percentage - use_dynamic_take_profit_pct: Whether to use dynamic take profit percentage - use_dynamic_risk_per_trade_value: Whether to use dynamic risk per trade value - use_dynamic_trailing_stop_callback: Whether to use dynamic trailing stop callback - """ +async def live_trading(agent, env, exchange, symbol="ETH/USDT", timeframe="1m", demo=True, leverage=50): + """Run the trading bot in live mode with enhanced error handling""" logger.info(f"Starting live trading for {symbol} on {timeframe} timeframe") - logger.info(f"Demo mode: {demo}, Leverage: {leverage}x") + logger.info(f"Mode: {'DEMO (paper trading)' if demo else 'LIVE TRADING'}") + + # Verify agent is properly initialized + try: + # Ensure agent has all required attributes + if not hasattr(agent, 'hidden_size'): + agent.hidden_size = 256 # Default value + logger.warning("Agent missing hidden_size attribute, using default: 256") + + if not hasattr(agent, 'lstm_layers'): + agent.lstm_layers = 2 # Default value + logger.warning("Agent missing lstm_layers attribute, using default: 2") + + if not hasattr(agent, 'attention_heads'): + agent.attention_heads = 4 # Default value + logger.warning("Agent missing attention_heads attribute, using default: 4") + + logger.info(f"Agent configuration: state_size={agent.state_size}, action_size={agent.action_size}, hidden_size={agent.hidden_size}") + except Exception as e: + logger.error(f"Error checking agent configuration: {e}") + # Continue anyway, as these are just informational attributes + + if not demo: + # Confirm with user before starting live trading + confirmation = input(f"⚠️ WARNING: You are about to start LIVE TRADING with real funds on {symbol}. Type 'CONFIRM' to continue: ") + if confirmation != "CONFIRM": + logger.info("Live trading canceled by user") + return + + # Initialize futures trading if not in demo mode + try: + await env.initialize_futures(exchange) + logger.info(f"Futures trading initialized with {leverage}x leverage") + except Exception as e: + logger.error(f"Failed to initialize futures trading: {str(e)}") + logger.info("Falling back to demo mode for safety") + demo = True + + # Initialize TensorBoard for monitoring + if not hasattr(agent, 'writer') or agent.writer is None: + from torch.utils.tensorboard import SummaryWriter + # Fix the datetime usage here + current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + agent.writer = SummaryWriter(f'runs/live_{symbol.replace("/", "_")}_{current_time}') + + # Track performance metrics + trades_count = 0 + winning_trades = 0 + total_profit = 0 + max_drawdown = 0 + peak_balance = env.balance + step_counter = 0 + prev_position = 'flat' + + # Create directory for trade logs + os.makedirs('trade_logs', exist_ok=True) + # Fix the datetime usage here + current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + trade_log_path = f'trade_logs/trades_{current_time}.csv' + with open(trade_log_path, 'w') as f: + f.write("timestamp,action,price,position_size,balance,pnl\n") + + logger.info("Entering live trading loop...") - # Flag to track if we're using mock trading - using_mock_trading = False - - # Initialize exchange try: - exchange = await initialize_exchange() - - # Try to set sandbox mode if demo is True - if demo: - try: - exchange.set_sandbox_mode(demo) - logger.info(f"Sandbox mode set to {demo}") - except Exception as e: - logger.warning(f"Exchange doesn't support sandbox mode: {e}") - logger.info("Continuing in mock trading mode instead") - using_mock_trading = True - - # Set leverage - if not demo or using_mock_trading: - try: - await exchange.set_leverage(leverage, symbol) - logger.info(f"Leverage set to {leverage}x") - except Exception as e: - logger.warning(f"Failed to set leverage: {e}") - - # Initialize environment - env = TradingEnvironment( - initial_balance=initial_balance, - leverage=leverage, - window_size=window_size, - commission=commission, - symbol=symbol, - timeframe=timeframe, - max_position_size=max_position_size, - stop_loss_pct=stop_loss_pct, - take_profit_pct=take_profit_pct, - max_trades_per_day=max_trades_per_day, - risk_per_trade=risk_per_trade, - use_trailing_stop=use_trailing_stop, - trailing_stop_callback=trailing_stop_callback, - use_dynamic_sizing=use_dynamic_sizing, - use_volatility_sizing=use_volatility_sizing, - use_multi_timeframe=use_multi_timeframe, - use_sentiment=use_sentiment, - use_limit_orders=use_limit_orders, - use_dollar_cost_avg=use_dollar_cost_avg, - use_grid_trading=use_grid_trading, - use_martingale=use_martingale, - use_anti_martingale=use_anti_martingale, - use_custom_indicators=use_custom_indicators, - use_ml_predictions=use_ml_predictions, - use_ensemble=use_ensemble, - use_reinforcement=use_reinforcement, - use_risk_management=use_risk_management, - use_portfolio_management=use_portfolio_management, - use_position_sizing=use_position_sizing, - use_stop_loss=use_stop_loss, - use_take_profit=use_take_profit, - use_trailing_stop_loss=use_trailing_stop_loss, - use_dynamic_stop_loss=use_dynamic_stop_loss, - use_dynamic_take_profit=use_dynamic_take_profit, - use_dynamic_trailing_stop=use_dynamic_trailing_stop, - use_dynamic_position_sizing=use_dynamic_position_sizing, - use_dynamic_leverage=use_dynamic_leverage, - use_dynamic_risk_per_trade=use_dynamic_risk_per_trade, - use_dynamic_max_trades_per_day=use_dynamic_max_trades_per_day, - use_dynamic_update_interval=use_dynamic_update_interval, - use_dynamic_window_size=use_dynamic_window_size, - use_dynamic_commission=use_dynamic_commission, - use_dynamic_timeframe=use_dynamic_timeframe, - use_dynamic_symbol=use_dynamic_symbol, - use_dynamic_model_path=use_dynamic_model_path, - use_dynamic_demo=use_dynamic_demo, - use_dynamic_leverage_value=use_dynamic_leverage_value, - use_dynamic_initial_balance=use_dynamic_initial_balance, - use_dynamic_max_position_size=use_dynamic_max_position_size, - use_dynamic_stop_loss_pct=use_dynamic_stop_loss_pct, - use_dynamic_take_profit_pct=use_dynamic_take_profit_pct, - use_dynamic_risk_per_trade_value=use_dynamic_risk_per_trade_value, - use_dynamic_trailing_stop_callback=use_dynamic_trailing_stop_callback, - use_dynamic_use_trailing_stop=use_dynamic_use_trailing_stop, - use_dynamic_use_dynamic_sizing=use_dynamic_use_dynamic_sizing, - use_dynamic_use_volatility_sizing=use_dynamic_use_volatility_sizing, - use_dynamic_use_multi_timeframe=use_dynamic_use_multi_timeframe, - use_dynamic_use_sentiment=use_dynamic_use_sentiment, - use_dynamic_use_limit_orders=use_dynamic_use_limit_orders, - use_dynamic_use_dollar_cost_avg=use_dynamic_use_dollar_cost_avg, - use_dynamic_use_grid_trading=use_dynamic_use_grid_trading, - use_dynamic_use_martingale=use_dynamic_use_martingale, - use_dynamic_use_anti_martingale=use_dynamic_use_anti_martingale, - use_dynamic_use_custom_indicators=use_dynamic_use_custom_indicators, - use_dynamic_use_ml_predictions=use_dynamic_use_ml_predictions, - use_dynamic_use_ensemble=use_dynamic_use_ensemble, - use_dynamic_use_reinforcement=use_dynamic_use_reinforcement, - use_dynamic_use_risk_management=use_dynamic_use_risk_management, - use_dynamic_use_portfolio_management=use_dynamic_use_portfolio_management, - use_dynamic_use_position_sizing=use_dynamic_use_position_sizing, - use_dynamic_use_stop_loss=use_dynamic_use_stop_loss, - use_dynamic_use_take_profit=use_dynamic_use_take_profit, - use_dynamic_use_trailing_stop_loss=use_dynamic_use_trailing_stop_loss, - use_dynamic_use_dynamic_stop_loss=use_dynamic_use_dynamic_stop_loss, - use_dynamic_use_dynamic_take_profit=use_dynamic_use_dynamic_take_profit, - use_dynamic_use_dynamic_trailing_stop=use_dynamic_use_dynamic_trailing_stop, - use_dynamic_use_dynamic_position_sizing=use_dynamic_use_dynamic_position_sizing, - use_dynamic_use_dynamic_leverage=use_dynamic_use_dynamic_leverage, - use_dynamic_use_dynamic_risk_per_trade=use_dynamic_use_dynamic_risk_per_trade, - use_dynamic_use_dynamic_max_trades_per_day=use_dynamic_use_dynamic_max_trades_per_day, - use_dynamic_use_dynamic_update_interval=use_dynamic_use_dynamic_update_interval, - use_dynamic_use_dynamic_window_size=use_dynamic_use_dynamic_window_size, - use_dynamic_use_dynamic_commission=use_dynamic_use_dynamic_commission, - use_dynamic_use_dynamic_timeframe=use_dynamic_use_dynamic_timeframe, - use_dynamic_use_dynamic_symbol=use_dynamic_use_dynamic_symbol, - use_dynamic_use_dynamic_model_path=use_dynamic_use_dynamic_model_path, - use_dynamic_use_dynamic_demo=use_dynamic_use_dynamic_demo, - use_dynamic_use_dynamic_leverage_value=use_dynamic_use_dynamic_leverage_value, - use_dynamic_use_dynamic_initial_balance=use_dynamic_use_dynamic_initial_balance, - use_dynamic_use_dynamic_max_position_size=use_dynamic_use_dynamic_max_position_size, - use_dynamic_use_dynamic_stop_loss_pct=use_dynamic_use_dynamic_stop_loss_pct, - use_dynamic_use_dynamic_take_profit_pct=use_dynamic_use_dynamic_take_profit_pct, - use_dynamic_use_dynamic_risk_per_trade_value=use_dynamic_use_dynamic_risk_per_trade_value, - use_dynamic_use_dynamic_trailing_stop_callback=use_dynamic_use_dynamic_trailing_stop_callback, - ) - - # Fetch initial data - logger.info(f"Fetching initial data for {symbol}") - await fetch_and_update_data(exchange, env, symbol, timeframe) - - # Initialize agent - STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64 - ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4 - agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE, hidden_size=384) - - # Load model if provided - if model_path: - agent.load(model_path) - logger.info(f"Model loaded successfully from {model_path}") - - # Initialize TensorBoard writer - agent.writer = SummaryWriter(log_dir=f"runs/live_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - - # Initialize trading statistics - trades = [] - total_pnl = 0 - win_count = 0 - loss_count = 0 - - # Initialize trading log file - log_file = f"live_trading_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" - with open(log_file, 'w') as f: - f.write("timestamp,action,price,position_size,balance,pnl\n") - - # Start live trading loop - logger.info(f"Starting live trading with {symbol} on {timeframe} timeframe") - - # Main trading loop - step_counter = 0 - last_update_time = time.time() - while True: - # Get current state - state = env.get_state() - - # Select action - action = agent.select_action(state, training=False) - - # Take action - next_state, reward, done, info = env.step(action) - - # Log action and results - if info.get('trade_executed', False): - trade_data = { - 'timestamp': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - 'action': info['action'], - 'price': env.current_price, - 'position_size': env.position_size, - 'balance': env.balance, - 'pnl': env.last_trade_profit - } + try: + # Fetch latest candle data + candle = await get_latest_candle(exchange, symbol) + if candle is None: + logger.warning("Failed to fetch latest candle, retrying in 5 seconds...") + await asyncio.sleep(5) + continue - trades.append(trade_data) + # Add new data to environment + env.add_data(candle) - # Update statistics - if env.last_trade_profit > 0: - win_count += 1 - total_pnl += env.last_trade_profit - else: - loss_count += 1 + # Get current state and select action + state = env.get_state() - # Log trade to file - with open(log_file, 'a') as f: - f.write(f"{trade_data['timestamp']},{trade_data['action']},{trade_data['price']},{trade_data['position_size']},{trade_data['balance']},{trade_data['pnl']}\n") + # Verify state shape matches agent's expected input + if state.shape[0] != agent.state_size: + logger.warning(f"State size mismatch: got {state.shape[0]}, expected {agent.state_size}") + # Pad or truncate state to match expected size + if state.shape[0] < agent.state_size: + state = np.pad(state, (0, agent.state_size - state.shape[0])) + else: + state = state[:agent.state_size] - logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}") - - # Update TensorBoard metrics - if step_counter % 10 == 0: + action = agent.select_action(state, training=False) + + # Ensure action is valid + if action >= agent.action_size: + logger.warning(f"Invalid action {action}, clipping to {agent.action_size-1}") + action = agent.action_size - 1 + + # Log action + action_name = "HOLD" if action == 0 else "BUY" if action == 1 else "SELL" if action == 2 else "CLOSE" + logger.info(f"Step {step_counter}: Action selected: {action_name}, Price: ${env.data[-1]['close']:.2f}") + + # Execute action + if not demo: + # Execute real trade on exchange + current_price = env.data[-1]['close'] + trade_result = await env.execute_real_trade(exchange, action, current_price) + if trade_result is None or not isinstance(trade_result, dict) or not trade_result.get('success', False): + error_msg = trade_result.get('error', 'Unknown error') if isinstance(trade_result, dict) else 'Trade execution failed' + logger.error(f"Trade execution failed: {error_msg}") + # Continue with simulated trade for tracking purposes + + # Update environment with action (simulated in demo mode) + try: + next_state, reward, done, info = env.step(action) + except ValueError as e: + # Handle case where step returns 3 values instead of 4 + if "not enough values to unpack" in str(e): + logger.warning("Step function returned 3 values instead of 4, creating info dict") + next_state, reward, done = env.step(action) + info = { + 'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close', + 'price': env.current_price, + 'balance': env.balance, + 'position': env.position, + 'pnl': env.total_pnl + } + else: + raise + + # Log trade if position changed + if env.position != prev_position: + trades_count += 1 + if env.last_trade_profit > 0: + winning_trades += 1 + total_profit += env.last_trade_profit + + # Log trade details + with open(trade_log_path, 'a') as f: + f.write(f"{datetime.datetime.now().isoformat()},{info['action']},{env.data[-1]['close']},{env.position_size},{env.balance},{env.last_trade_profit}\n") + + logger.info(f"Trade executed: {info['action']} at ${env.data[-1]['close']:.2f}, PnL: ${env.last_trade_profit:.2f}") + + # Update performance metrics + if env.balance > peak_balance: + peak_balance = env.balance + current_drawdown = (peak_balance - env.balance) / peak_balance if peak_balance > 0 else 0 + if current_drawdown > max_drawdown: + max_drawdown = current_drawdown + + # Update TensorBoard metrics + step_counter += 1 agent.writer.add_scalar('Live/Balance', env.balance, step_counter) agent.writer.add_scalar('Live/PnL', env.total_pnl, step_counter) - agent.writer.add_scalar('Live/Reward', reward, step_counter) - - # Check if it's time to update data - current_time = time.time() - if current_time - last_update_time > update_interval: - await fetch_and_update_data(exchange, env, symbol, timeframe) - last_update_time = current_time + agent.writer.add_scalar('Live/Drawdown', current_drawdown * 100, step_counter) - # Print status update - win_rate = win_count / (win_count + loss_count) if (win_count + loss_count) > 0 else 0 - logger.info(f""" - Step: {step_counter} - Balance: ${env.balance:.2f} - Total PnL: ${env.total_pnl:.2f} - Win Rate: {win_rate:.2f} - Trades: {len(trades)} - """) - - # Move to next state - state = next_state - step_counter += 1 - - # Sleep to avoid excessive API calls - await asyncio.sleep(1) - - # Check for manual stop - if done: - break + # Update chart visualization + if step_counter % 5 == 0 or env.position != prev_position: + agent.add_chart_to_tensorboard(env, step_counter) + + # Log performance summary + if trades_count > 0: + win_rate = (winning_trades / trades_count) * 100 + agent.writer.add_scalar('Live/WinRate', win_rate, step_counter) + + performance_text = f""" + **Live Trading Performance** + Balance: ${env.balance:.2f} + Total PnL: ${env.total_pnl:.2f} + Trades: {trades_count} + Win Rate: {win_rate:.1f}% + Max Drawdown: {max_drawdown*100:.1f}% + """ + agent.writer.add_text('Performance', performance_text, step_counter) + + prev_position = env.position + + # Wait for next candle + logger.info(f"Waiting for next candle... (Step {step_counter})") + await asyncio.sleep(10) # Check every 10 seconds + + except Exception as e: + logger.error(f"Error in live trading loop: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + logger.info("Continuing after error...") + await asyncio.sleep(30) # Wait longer after an error + + except KeyboardInterrupt: + logger.info("Live trading stopped by user") - # Close TensorBoard writer - agent.writer.close() - - # Save final statistics - win_rate = win_count / (win_count + loss_count) if (win_count + loss_count) > 0 else 0 - logger.info(f""" - Live Trading Summary: - Total Steps: {step_counter} - Final Balance: ${env.balance:.2f} - Total PnL: ${env.total_pnl:.2f} - Win Rate: {win_rate:.2f} - Total Trades: {len(trades)} - """) - - # Close exchange connection - try: - await exchange.close() - logger.info("Exchange connection closed") - except Exception as e: - logger.warning(f"Error closing exchange connection: {e}") - - except Exception as e: - logger.error(f"Error in live trading: {e}") - logger.error(traceback.format_exc()) - try: - await exchange.close() - except: - pass - logger.info("Exchange connection closed") + # Final performance report + if trades_count > 0: + win_rate = (winning_trades / trades_count) * 100 + logger.info(f"Trading session summary:") + logger.info(f"Total trades: {trades_count}") + logger.info(f"Win rate: {win_rate:.1f}%") + logger.info(f"Final balance: ${env.balance:.2f}") + logger.info(f"Total profit: ${total_profit:.2f}") + logger.info(f"Maximum drawdown: {max_drawdown*100:.1f}%") + logger.info(f"Trade log saved to: {trade_log_path}") async def get_latest_candle(exchange, symbol): - """ - Get the latest candle for a symbol. - - Args: - exchange: Exchange instance - symbol: Trading pair symbol - - Returns: - Latest candle data or None on failure - """ + """Get the latest candle data""" try: # Use the refactored fetch method with limit=1 data = await fetch_ohlcv_data(exchange, symbol, "1m", 1) @@ -3452,68 +2607,45 @@ async def get_latest_candle(exchange, symbol): logger.error(f"Failed to fetch latest candle: {e}") return None -async def fetch_ohlcv_data(exchange, symbol="ETH/USDT", timeframe="1m", limit=1000): - """ - Fetch OHLCV data from exchange with error handling and retry logic. - - Args: - exchange: The exchange instance - symbol: Trading pair symbol - timeframe: Candle timeframe - limit: Number of candles to fetch +async def fetch_ohlcv_data(exchange, symbol, timeframe, limit): + """Fetch OHLCV data with proper handling for both async and standard CCXT""" + try: + # Check if exchange has fetchOHLCV method + if not hasattr(exchange, 'fetchOHLCV'): + logger.error("Exchange does not support OHLCV data fetching") + return [] + + # Handle different CCXT versions + if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False): + # Use async method if available + ohlcv = await exchange.fetchOHLCV(symbol, timeframe, limit=limit) + else: + # Use synchronous method with run_in_executor + loop = asyncio.get_event_loop() + ohlcv = await loop.run_in_executor( + None, + lambda: exchange.fetch_ohlcv(symbol, timeframe, limit=limit) + ) + + # Convert to list of dictionaries + data = [] + for candle in ohlcv: + timestamp, open_price, high, low, close, volume = candle + data.append({ + 'timestamp': timestamp, + 'open': open_price, + 'high': high, + 'low': low, + 'close': close, + 'volume': volume + }) + + logger.info(f"Fetched {len(data)} candles for {symbol} ({timeframe})") + return data - Returns: - List of candle data or empty list on failure - """ - max_retries = 3 - retry_delay = 5 - - for attempt in range(max_retries): - try: - logging.info(f"Fetching {limit} {timeframe} candles for {symbol} (attempt {attempt+1}/{max_retries})") - - # Check if exchange has fetch_ohlcv method - if not hasattr(exchange, 'fetch_ohlcv'): - logging.error("Exchange does not support OHLCV data fetching") - return [] - - # Fetch OHLCV data from exchange using asyncio if available, otherwise use run_in_executor - try: - if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False): - ohlcv = await exchange.fetchOHLCVAsync(symbol, timeframe, limit=limit) - else: - # Run in executor to avoid blocking - loop = asyncio.get_event_loop() - ohlcv = await loop.run_in_executor( - None, - lambda: exchange.fetch_ohlcv(symbol, timeframe, limit=limit) - ) - except Exception as e: - logging.error(f"Failed to fetch OHLCV data: {e}") - await asyncio.sleep(retry_delay) - continue - - if not ohlcv or len(ohlcv) == 0: - logging.warning(f"No data returned from exchange (attempt {attempt+1}/{max_retries})") - await asyncio.sleep(retry_delay) - continue - - # Convert to list of lists format - data = [] - for candle in ohlcv: - timestamp, open_price, high, low, close, volume = candle - data.append([timestamp, open_price, high, low, close, volume]) - - logging.info(f"Successfully fetched {len(data)} candles") - return data - - except Exception as e: - logging.error(f"Error fetching OHLCV data (attempt {attempt+1}/{max_retries}): {e}") - if attempt < max_retries - 1: - await asyncio.sleep(retry_delay) - - logging.error(f"Failed to fetch OHLCV data after {max_retries} attempts") - return [] + except Exception as e: + logger.error(f"Error fetching OHLCV data: {e}") + return [] # Add this near the top of the file, after imports def ensure_pytorch_compatibility(): @@ -3540,8 +2672,6 @@ async def main(): help='Operation mode: train, eval, or live') parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes for training or evaluation') - parser.add_argument('--max_steps', type=int, default=1000, - help='Maximum steps per episode for training') parser.add_argument('--demo', type=str, choices=['true', 'false'], default='true', help='Run in demo mode (paper trading) if true') parser.add_argument('--symbol', type=str, default='ETH/USDT', @@ -3550,10 +2680,8 @@ async def main(): help='Candle timeframe (1m, 5m, 15m, 1h, etc.)') parser.add_argument('--leverage', type=int, default=50, help='Leverage for futures trading') - parser.add_argument('--model', type=str, default='models/trading_agent_best_net_pnl.pt', + parser.add_argument('--model', type=str, default=None, help='Path to model file for evaluation or live trading') - parser.add_argument('--compact_save', action='store_true', - help='Use compact model saving (for low disk space)') args = parser.parse_args() @@ -3569,19 +2697,12 @@ async def main(): # Initialize exchange exchange = await initialize_exchange() - # Create environment with updated parameters - env = TradingEnvironment( - initial_balance=INITIAL_BALANCE, - window_size=30, - leverage=args.leverage, - exchange_id='mexc', - symbol=args.symbol, - timeframe=args.timeframe - ) + # Create environment + env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode) if args.mode == 'train': # Fetch initial data for training - await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000) + await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000) # Create agent with consistent parameters # Note: Using STATE_SIZE and action_size=4 for consistency @@ -3589,9 +2710,7 @@ async def main(): # Train the agent logger.info(f"Starting training for {args.episodes} episodes...") - stats = await train_agent(agent, env, num_episodes=args.episodes, - max_steps_per_episode=args.max_steps, - use_compact_save=args.compact_save) + stats = await train_agent(agent, env, num_episodes=args.episodes) elif args.mode == 'eval' or args.mode == 'live': # Fetch initial data for the specified symbol and timeframe @@ -3668,1498 +2787,82 @@ async def main(): logger.warning(f"Could not properly close exchange connection: {e}") # Add this function near the top with other utility functions -def create_candlestick_figure(data, trades=None, title="Trading Chart"): - """Create a candlestick chart with trades marked""" - try: - if data is None or len(data) < 5: - logger.warning("Not enough data for candlestick chart") - return None - - # Convert data to DataFrame if it's not already - if not isinstance(data, pd.DataFrame): - df = pd.DataFrame(data) - else: - df = data.copy() - - # Ensure required columns exist - required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume'] - for col in required_columns: - if col not in df.columns: - logger.warning(f"Missing required column {col} for candlestick chart") - return None - - # Format dates - if 'timestamp' in df.columns: - if isinstance(df['timestamp'].iloc[0], (int, float)): - # Convert timestamp to datetime if it's numeric - df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') - - # Set timestamp as index if it's not already - if df.index.name != 'timestamp': - df.set_index('timestamp', inplace=True) - - # Rename columns for mplfinance - df_mpf = df.copy() - if 'open' in df_mpf.columns: - df_mpf = df_mpf.rename(columns={ - 'open': 'Open', - 'high': 'High', - 'low': 'Low', - 'close': 'Close', - 'volume': 'Volume' - }) - - # Create a simple matplotlib figure instead - fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8), - gridspec_kw={'height_ratios': [3, 1]}) - - # Plot candlesticks manually - for i in range(len(df_mpf)): - # Get date and prices - date = df_mpf.index[i] - open_price = df_mpf['Open'].iloc[i] - high_price = df_mpf['High'].iloc[i] - low_price = df_mpf['Low'].iloc[i] - close_price = df_mpf['Close'].iloc[i] - - # Determine color based on price movement - color = 'green' if close_price >= open_price else 'red' - - # Plot candle body - body_height = abs(close_price - open_price) - body_bottom = min(close_price, open_price) - ax1.bar(date, body_height, bottom=body_bottom, width=0.6, - color=color, alpha=0.6) - - # Plot wick - ax1.plot([date, date], [low_price, high_price], color=color, linewidth=1) - - # Plot volume - ax2.bar(df_mpf.index, df_mpf['Volume'], width=0.6, color='blue', alpha=0.5) - - # Mark trades if available - if trades and len(trades) > 0: - for trade in trades: - if 'timestamp' not in trade or 'type' not in trade or 'price' not in trade: - continue - - # Convert timestamp to datetime if needed - if isinstance(trade['timestamp'], (int, float)): - trade_time = pd.to_datetime(trade['timestamp'], unit='ms') - else: - trade_time = trade['timestamp'] - - # Determine marker color based on trade type - marker_color = 'green' if trade['type'].lower() == 'buy' else 'red' - - # Add marker at trade price - ax1.scatter(trade_time, trade['price'], color=marker_color, - marker='^' if trade['type'].lower() == 'buy' else 'v', - s=100, zorder=5) - - # Set title and labels - ax1.set_title(title) - ax1.set_ylabel('Price') - ax2.set_ylabel('Volume') - ax1.grid(True) - ax2.grid(True) - - # Format x-axis - plt.setp(ax1.get_xticklabels(), visible=False) - - # Adjust layout - plt.tight_layout() - - return fig - - except Exception as e: - logger.error(f"Error creating candlestick figure: {e}") +def create_candlestick_figure(data, trade_signals, window_size=100, title=""): + """Create a candlestick chart with trade signals for TensorBoard visualization""" + if len(data) < 10: return None - -class CandlePatternCNN(nn.Module): - """Convolutional neural network for detecting candlestick patterns""" - - def __init__(self, input_channels=5, feature_dimension=512): - super(CandlePatternCNN, self).__init__() - self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1) - self.relu1 = nn.ReLU() - self.pool1 = nn.MaxPool2d(kernel_size=2) - self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) - self.relu2 = nn.ReLU() - self.pool2 = nn.MaxPool2d(kernel_size=2) - self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) - self.relu3 = nn.ReLU() - self.pool3 = nn.MaxPool2d(kernel_size=2) - # Projection layers - self.fc1 = nn.Linear(128 * 4 * 4, 1024) - self.relu4 = nn.ReLU() - self.fc2 = nn.Linear(1024, feature_dimension) + try: + # Create figure + fig = plt.figure(figsize=(12, 8)) - # Initialize intermediate features as empty tensors, not as a dict - # This makes the model TorchScript compatible - self.feature_1m = torch.zeros(1, feature_dimension) - self.feature_1h = torch.zeros(1, feature_dimension) - self.feature_1d = torch.zeros(1, feature_dimension) - - def forward(self, x_1m, x_1h, x_1d): - # Process 1m data - feat_1m = self.process_timeframe(x_1m) + # Prepare data for plotting + df = pd.DataFrame(data[-window_size:]) + df['date'] = pd.to_datetime(df['timestamp'], unit='ms') + df.set_index('date', inplace=True) - # Process 1h data - feat_1h = self.process_timeframe(x_1h) + # Create subplot grid + gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1]) + price_ax = plt.subplot(gs[0]) + volume_ax = plt.subplot(gs[1], sharex=price_ax) - # Process 1d data - feat_1d = self.process_timeframe(x_1d) + # Plot candlesticks - use a simpler approach if mplfinance fails + try: + # Use a different style or approach that doesn't use 'type' parameter + mpf.plot(df, type='candle', ax=price_ax, volume=volume_ax, style='yahoo') + except Exception as e: + logger.warning(f"Error plotting with mplfinance: {e}, falling back to simple plot") + # Fallback to simple plot + price_ax.plot(df.index, df['close'], label='Price') + volume_ax.bar(df.index, df['volume'], color='blue', alpha=0.5) - # Store features as attributes instead of in a dictionary - self.feature_1m = feat_1m - self.feature_1h = feat_1h - self.feature_1d = feat_1d - - # Concatenate features from different timeframes - combined_features = torch.cat([feat_1m, feat_1h, feat_1d], dim=1) - - return combined_features - - def process_timeframe(self, x): - """Process a single timeframe batch of data""" - # Ensure proper shape for input, handle both batched and single inputs - if len(x.shape) == 3: # Single input, shape: [channels, height, width] - x = x.unsqueeze(0) # Add batch dimension - - x = self.pool1(self.relu1(self.conv1(x))) - x = self.pool2(self.relu2(self.conv2(x))) - x = self.pool3(self.relu3(self.conv3(x))) - - # Flatten the spatial dimensions for the fully connected layer - x = x.view(x.size(0), -1) - - x = self.relu4(self.fc1(x)) - x = self.fc2(x) - - return x - - def get_features(self): - """Return features for each timeframe""" - # Use properties instead of dict for TorchScript compatibility - return self.feature_1m, self.feature_1h, self.feature_1d - -# Add candle cache system -class CandleCache: - """ - Cache system for candles of different timeframes. - Reduces API calls by storing and updating candle data. - """ - def __init__(self): - self.candles = { - '1m': [], - '1h': [], - '1d': [] - } - self.last_updated = { - '1m': None, - '1h': None, - '1d': None - } - # Add ticks channel for real-time data (WebSocket) - self.ticks = [] - self.last_tick_time = None - - def add_candles(self, timeframe, new_candles): - """Add new candles to the cache""" - if not self.candles[timeframe]: - self.candles[timeframe] = new_candles - else: - # Find the last timestamp in our current cache - last_timestamp = self.candles[timeframe][-1][0] - - # Add only candles newer than our last cached one - for candle in new_candles: - if candle[0] > last_timestamp: - self.candles[timeframe].append(candle) - - self.last_updated[timeframe] = datetime.datetime.now() - - def add_tick(self, tick_data): - """Add a new tick to the ticks buffer""" - self.ticks.append(tick_data) - self.last_tick_time = datetime.datetime.now() - - # Keep only the most recent 1000 ticks to prevent memory issues - if len(self.ticks) > 1000: - self.ticks = self.ticks[-1000:] - - def get_ticks(self, limit=None): - """Get the most recent ticks from the buffer""" - if not self.ticks: - return [] - - if limit and limit > 0: - return self.ticks[-limit:] - return self.ticks - - def get_candles(self, timeframe, limit=300): - """Get the most recent candles for a timeframe""" - if not self.candles[timeframe]: - return [] - - return self.candles[timeframe][-limit:] - - def needs_update(self, timeframe, max_age_seconds): - """Check if the cache needs to be updated""" - if not self.last_updated[timeframe]: - return True - - age = (datetime.datetime.now() - self.last_updated[timeframe]).total_seconds() - return age > max_age_seconds - -async def fetch_multi_timeframe_data(exchange, symbol, candle_cache): - """Fetch candle data for multiple timeframes, using cache when possible""" - update_intervals = { - '1m': 60, # Update every 1 minute - '1h': 3600, # Update every 1 hour - '1d': 86400 # Update every 1 day - } - - # TODO: For 1s/tick timeframes, we'll implement the exchange's WebSocket API - # for real-time data streaming in the future. This will enable ultra-low latency - # trading signals with minimal delay between market data reception and action execution. - # A WebSocket implementation is already prepared in the RealTimeDataStream class. - - limits = { - '1m': 1000, - '1h': 500, - '1d': 300 - } - - for timeframe, interval in update_intervals.items(): - if candle_cache.needs_update(timeframe, interval): + # Add trade signals + for signal in trade_signals: try: - logging.info(f"Fetching {timeframe} candle data for {symbol}") - candles = await fetch_ohlcv_data(exchange, symbol, timeframe, limits[timeframe]) - candle_cache.add_candles(timeframe, candles) - logging.info(f"Fetched {len(candles)} {timeframe} candles") + timestamp = pd.to_datetime(signal['timestamp'], unit='ms') + price = signal['price'] + + if signal['type'] == 'buy': + price_ax.plot(timestamp, price, '^', color='green', markersize=10) + elif signal['type'] == 'sell': + price_ax.plot(timestamp, price, 'v', color='red', markersize=10) + elif signal['type'] == 'close_long': + price_ax.plot(timestamp, price, 'x', color='gold', markersize=10) + elif signal['type'] == 'close_short': + price_ax.plot(timestamp, price, 'x', color='black', markersize=10) + elif 'stop_loss' in signal['type']: + price_ax.plot(timestamp, price, 'X', color='purple', markersize=10) + elif 'take_profit' in signal['type']: + price_ax.plot(timestamp, price, '*', color='cyan', markersize=10) except Exception as e: - logging.error(f"Error fetching {timeframe} candle data: {e}") - - return { - '1m': candle_cache.get_candles('1m'), - '1h': candle_cache.get_candles('1h'), - '1d': candle_cache.get_candles('1d') - } - -# Modify the LSTMAttentionDQN class to incorporate the CNN features -class LSTMAttentionDQN(nn.Module): - def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4): - super(LSTMAttentionDQN, self).__init__() - self.state_size = state_size - self.action_size = action_size - self.hidden_size = hidden_size - self.lstm_layers = lstm_layers - self.attention_heads = attention_heads - - # LSTM layer - self.lstm = nn.LSTM( - input_size=state_size, - hidden_size=hidden_size, - num_layers=lstm_layers, - batch_first=True, - dropout=0.2 if lstm_layers > 1 else 0 - ) - - # Multi-head self-attention - self.attention = nn.MultiheadAttention( - embed_dim=hidden_size, - num_heads=attention_heads, - dropout=0.1 - ) - - # Value stream - self.value_stream = nn.Sequential( - nn.Linear(hidden_size, 128), - nn.ReLU(), - nn.Linear(128, 1) - ) - - # Advantage stream - self.advantage_stream = nn.Sequential( - nn.Linear(hidden_size, 128), - nn.ReLU(), - nn.Linear(128, action_size) - ) - - # Fusion for multi-timeframe data - self.cnn_fusion = nn.Sequential( - nn.Linear(512 * 3, 1024), # 512 features from each of the 3 timeframes - nn.ReLU(), - nn.Dropout(0.3), - nn.Linear(1024, hidden_size) - ) - - # Initialize weights - self.apply(self._init_weights) - - def _init_weights(self, module): - if isinstance(module, nn.Linear): - nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - elif isinstance(module, nn.LSTM): - for name, param in module.named_parameters(): - if 'weight' in name: - nn.init.xavier_uniform_(param) - elif 'bias' in name: - nn.init.constant_(param, 0) - - def forward(self, state, x_1m=None, x_1h=None, x_1d=None): - """ - Forward pass handling different input shapes and optional CNN features - - Args: - state: Primary state vector (batch_size, sequence_length, state_size) - x_1m, x_1h, x_1d: Optional CNN features from different timeframes - - Returns: - Q-values for each action - """ - batch_size = state.size(0) - - # Handle CNN features if provided - if x_1m is not None and x_1h is not None and x_1d is not None: - # Ensure all CNN features have batch dimension - if len(x_1m.shape) == 2: - x_1m = x_1m.unsqueeze(0) - if len(x_1h.shape) == 2: - x_1h = x_1h.unsqueeze(0) - if len(x_1d.shape) == 2: - x_1d = x_1d.unsqueeze(0) - - # Ensure batch dimensions match - if x_1m.size(0) != batch_size: - x_1m = x_1m.expand(batch_size, -1, -1) if x_1m.size(0) == 1 else x_1m[:batch_size] - if x_1h.size(0) != batch_size: - x_1h = x_1h.expand(batch_size, -1, -1) if x_1h.size(0) == 1 else x_1h[:batch_size] - if x_1d.size(0) != batch_size: - x_1d = x_1d.expand(batch_size, -1, -1) if x_1d.size(0) == 1 else x_1d[:batch_size] - - # Check dimensions before concatenation - if x_1m.dim() == 3 and x_1m.size(1) == 512 and x_1h.size(1) == 512 and x_1d.size(1) == 512: - # Already in correct format [batch, features] - cnn_combined = torch.cat([x_1m, x_1h, x_1d], dim=1) - elif x_1m.dim() == 2 and x_1m.size(1) == 512 and x_1h.size(1) == 512 and x_1d.size(1) == 512: - # Dimensions correct but missing batch dimension - cnn_combined = torch.cat([x_1m, x_1h, x_1d], dim=1).unsqueeze(0) - else: - # Reshape to ensure correct dimensions - x_1m_flat = x_1m.reshape(batch_size, -1) - x_1h_flat = x_1h.reshape(batch_size, -1) - x_1d_flat = x_1d.reshape(batch_size, -1) - - # Handle variable dimensions more gracefully - needed_features = 512 - if x_1m_flat.size(1) < needed_features: - x_1m_flat = F.pad(x_1m_flat, (0, needed_features - x_1m_flat.size(1))) - else: - x_1m_flat = x_1m_flat[:, :needed_features] - - if x_1h_flat.size(1) < needed_features: - x_1h_flat = F.pad(x_1h_flat, (0, needed_features - x_1h_flat.size(1))) - else: - x_1h_flat = x_1h_flat[:, :needed_features] - - if x_1d_flat.size(1) < needed_features: - x_1d_flat = F.pad(x_1d_flat, (0, needed_features - x_1d_flat.size(1))) - else: - x_1d_flat = x_1d_flat[:, :needed_features] - - # Concatenate - cnn_combined = torch.cat([x_1m_flat, x_1h_flat, x_1d_flat], dim=1) - - # Use CNN fusion network to reduce dimension - cnn_features = self.cnn_fusion(cnn_combined) - - # Reshape to match LSTM input shape - cnn_features = cnn_features.view(batch_size, 1, self.hidden_size) - - # Combine with state input by concatenating along sequence dimension - if state.dim() < 3: - # If state is 2D [batch, features], reshape to 3D [batch, 1, features] - state = state.unsqueeze(1) - - # Ensure state has proper dimensions - if state.size(2) != self.state_size: - # If state dimension doesn't match, reshape or pad - if state.size(2) > self.state_size: - state = state[:, :, :self.state_size] - else: - state = F.pad(state, (0, self.state_size - state.size(2))) - - # Concatenate along sequence dimension - combined_input = torch.cat([state, cnn_features], dim=1) - else: - # Use only state input if CNN features not provided - combined_input = state - if combined_input.dim() < 3: - # If state is 2D [batch, features], reshape to 3D [batch, 1, features] - combined_input = combined_input.unsqueeze(1) - - # Ensure state has proper dimensions - if combined_input.size(2) != self.state_size: - # If state dimension doesn't match, reshape or pad - if combined_input.size(2) > self.state_size: - combined_input = combined_input[:, :, :self.state_size] - else: - combined_input = F.pad(combined_input, (0, self.state_size - combined_input.size(2))) - - # Pass through LSTM - lstm_out, _ = self.lstm(combined_input) - - # Apply self-attention to LSTM output - # Transform to shape required by MultiheadAttention (seq_len, batch, hidden) - attn_input = lstm_out.transpose(0, 1) - attn_output, _ = self.attention(attn_input, attn_input, attn_input) - - # Transform back to (batch, seq_len, hidden) - attn_output = attn_output.transpose(0, 1) - - # Use last output after attention - attn_out = attn_output[:, -1] - - # Value and advantage streams (dueling architecture) - value = self.value_stream(attn_out) - advantage = self.advantage_stream(attn_out) - - # Combine value and advantage for Q-values - q_values = value + advantage - advantage.mean(dim=1, keepdim=True) - - return q_values - - def forward_realtime(self, x): - """Simplified forward pass for realtime inference""" - # Adapt x to the right format if needed - if isinstance(x, np.ndarray): - x = torch.FloatTensor(x) - - # Add batch dimension if not present - if x.dim() == 1: - x = x.unsqueeze(0) - - # Add sequence dimension if not present - if x.dim() == 2: - x = x.unsqueeze(1) - - # Basic forward pass - lstm_out, _ = self.lstm(x) - - # Apply attention - attn_input = lstm_out.transpose(0, 1) - attn_output, _ = self.attention(attn_input, attn_input, attn_input) - attn_output = attn_output.transpose(0, 1) - - # Get last output after attention - features = attn_output[:, -1] - - # Value and advantage streams - value = self.value_stream(features) - advantage = self.advantage_stream(features) - - # Combine for Q-values - q_values = value + advantage - advantage.mean(dim=1, keepdim=True) - - return q_values - -# Add this class after the CandleCache class - -class RealTimeDataStream: - """ - Class for handling WebSocket API connections for ultra-low latency trading signals. - Provides real-time data streaming at 1-second intervals or faster for immediate trading decisions. - """ - - def __init__(self, exchange, symbol, callback_fn=None): - """ - Initialize the real-time data stream with WebSocket connection - - Args: - exchange: The exchange API client - symbol: Trading pair symbol (e.g. 'ETH/USDT') - callback_fn: Function to call when new data is received - """ - self.exchange = exchange - self.symbol = symbol - self.callback_fn = callback_fn - self.websocket = None - self.connected = False - self.last_tick_time = None - self.tick_buffer = [] - self.latency_stats = [] - self.logger = logging.getLogger(__name__) - - # Statistics for monitoring performance - self.total_ticks = 0 - self.avg_latency_ms = 0 - self.max_latency_ms = 0 - - # Candle cache for storing processed data - self.candle_cache = CandleCache() - - async def connect(self): - """Connect to the exchange WebSocket API""" - # TODO: Implement actual WebSocket connection logic - self.logger.info(f"Connecting to WebSocket for {self.symbol}...") - try: - # This will be replaced with actual WebSocket connection code - self.websocket = None # Placeholder - self.connected = True - self.logger.info(f"Connected to WebSocket for {self.symbol}") - return True - except Exception as e: - self.logger.error(f"WebSocket connection error: {e}") - return False - - async def subscribe(self): - """Subscribe to relevant data channels""" - # TODO: Implement actual WebSocket subscription logic - self.logger.info(f"Subscribing to {self.symbol} ticks...") - try: - # This will be replaced with actual subscription code - return True - except Exception as e: - self.logger.error(f"WebSocket subscription error: {e}") - return False - - async def process_message(self, message): - """ - Process incoming WebSocket message - - Args: - message: The raw WebSocket message - - Returns: - Processed tick data - """ - # TODO: Implement actual WebSocket message processing logic - try: - # Track tick receipt time for latency calculations - receive_time = time.time() * 1000 # milliseconds - - # This is a placeholder - actual implementation will parse the message - # Example tick data structure (will vary by exchange): - tick_data = { - 'timestamp': receive_time, - 'price': 0.0, # Will be replaced with actual price - 'volume': 0.0, # Will be replaced with actual volume - 'side': 'buy', # or 'sell' - 'exchange_time': 0, # Will be replaced with exchange timestamp - 'latency_ms': 0 # Will be calculated - } - - # Calculate latency (difference between our receive time and exchange time) - if 'exchange_time' in tick_data and tick_data['exchange_time'] > 0: - latency = receive_time - tick_data['exchange_time'] - tick_data['latency_ms'] = latency - - # Update latency statistics - self.latency_stats.append(latency) - if len(self.latency_stats) > 1000: - self.latency_stats = self.latency_stats[-1000:] - - self.total_ticks += 1 - self.avg_latency_ms = sum(self.latency_stats) / len(self.latency_stats) - self.max_latency_ms = max(self.max_latency_ms, latency) - - # Store tick in buffer - self.tick_buffer.append(tick_data) - self.candle_cache.add_tick(tick_data) - self.last_tick_time = datetime.datetime.now() - - # Keep buffer size reasonable - if len(self.tick_buffer) > 1000: - self.tick_buffer = self.tick_buffer[-1000:] - - # Call callback function if provided - if self.callback_fn: - await self.callback_fn(tick_data) - - return tick_data - except Exception as e: - self.logger.error(f"Error processing WebSocket message: {e}") - return None - - def prepare_nn_input(self, model=None, state=None): - """ - Prepare network inputs from tick data for real-time inference - - Args: - model: The neural network model - state: Current state representation - - Returns: - Prepared tensors for model input - """ - # Get the most recent ticks - ticks = self.candle_cache.get_ticks(limit=300) - - if not ticks or len(ticks) < 10: - # Not enough ticks for meaningful processing - return None - - try: - # Extract price and volume data from ticks - prices = np.array([t['price'] for t in ticks if 'price' in t]) - volumes = np.array([t['volume'] for t in ticks if 'volume' in t]) - - if len(prices) < 10: - return None - - # Normalize data - min_price, max_price = prices.min(), prices.max() - price_range = max_price - min_price - if price_range == 0: - price_range = 1 - - normalized_prices = (prices - min_price) / price_range - - # Create tick tensor - this is flexible-length data - # Format as sequence for time-series analysis - tick_data = torch.FloatTensor(normalized_prices).unsqueeze(0).unsqueeze(0) - - return { - 'state': state, - 'ticks': tick_data - } - except Exception as e: - self.logger.error(f"Error preparing neural network input: {e}") - return None - - def get_latency_stats(self): - """Get statistics about WebSocket connection latency""" - return { - 'total_ticks': self.total_ticks, - 'avg_latency_ms': self.avg_latency_ms, - 'max_latency_ms': self.max_latency_ms, - 'last_update': self.last_tick_time.isoformat() if self.last_tick_time else None - } - - async def close(self): - """Close the WebSocket connection""" - if self.connected and self.websocket: - try: - # This will be replaced with actual close logic - self.connected = False - self.logger.info(f"Closed WebSocket connection for {self.symbol}") - return True - except Exception as e: - self.logger.error(f"Error closing WebSocket connection: {e}") - return False - -class BacktestCandles(CandleCache): - """ - Special cache for backtesting that retrieves historical data from specific time periods - without contaminating the main cache. Used for running simulations "as if" we were - at a different point in time. - """ - def __init__(self, since_timestamp=None, until_timestamp=None): - """ - Initialize backtesting candle cache. - - Args: - since_timestamp: Start timestamp for backtesting (milliseconds) - until_timestamp: End timestamp for backtesting (milliseconds) - """ - super().__init__() - # Since and until timestamps for backtesting - self.since_timestamp = since_timestamp - self.until_timestamp = until_timestamp - # Flag to indicate this is a backtesting cache - self.is_backtesting = True - # Optional name for backtesting period (e.g., "Day 1 - 24h ago") - self.period_name = None - - async def fetch_historical_timeframe(self, exchange, symbol, timeframe, limit=1000): - """ - Fetch historical data for a specific timeframe and time period. - - Args: - exchange: The exchange instance - symbol: Trading pair symbol - timeframe: Candle timeframe - limit: Number of candles to fetch - - Returns: - Dictionary with candle data for the timeframe - """ - try: - logging.info(f"Fetching historical {timeframe} candles for {symbol} " + - f"(since: {self.format_timestamp(self.since_timestamp) if self.since_timestamp else 'None'}, " + - f"until: {self.format_timestamp(self.until_timestamp) if self.until_timestamp else 'None'})") - - candles = await self.fetch_ohlcv_with_timerange(exchange, symbol, timeframe, - limit, self.since_timestamp, self.until_timestamp) - - if candles: - # Store in the appropriate timeframe - self.candles[timeframe] = candles - self.last_updated[timeframe] = datetime.datetime.now() - logging.info(f"Fetched {len(candles)} historical {timeframe} candles for backtesting") - else: - logging.warning(f"No historical {timeframe} candles found for the specified time period") - - return candles - except Exception as e: - logging.error(f"Error fetching historical {timeframe} data: {e}") - return [] - - async def fetch_all_timeframes(self, exchange, symbol): - """ - Fetch historical data for all timeframes. - - Args: - exchange: The exchange instance - symbol: Trading pair symbol - - Returns: - Dictionary with candle data for all timeframes - """ - # Define limits for each timeframe - limits = { - '1m': 1000, - '1h': 500, - '1d': 300 - } - - # Fetch data for each timeframe - for timeframe, limit in limits.items(): - await self.fetch_historical_timeframe(exchange, symbol, timeframe, limit) - - # Return the candles dictionary - return { - '1m': self.get_candles('1m'), - '1h': self.get_candles('1h'), - '1d': self.get_candles('1d') - } - - async def fetch_ohlcv_with_timerange(self, exchange, symbol, timeframe, limit, since=None, until=None): - """ - Fetch OHLCV data within a specific time range. - - Args: - exchange: The exchange instance - symbol: Trading pair symbol - timeframe: Candle timeframe - limit: Number of candles to fetch - since: Start timestamp (milliseconds) - until: End timestamp (milliseconds) - - Returns: - List of candle data - """ - max_retries = 3 - retry_delay = 5 - - for attempt in range(max_retries): - try: - logging.info(f"Fetching {limit} {timeframe} candles for {symbol} " + - f"(since: {self.format_timestamp(since) if since else 'None'}, " + - f"until: {self.format_timestamp(until) if until else 'None'}) " + - f"(attempt {attempt+1}/{max_retries})") - - # Check if exchange has fetch_ohlcv method - if not hasattr(exchange, 'fetch_ohlcv'): - logging.error("Exchange does not support OHLCV data fetching") - return [] - - # Fetch OHLCV data from exchange using asyncio if available, otherwise use run_in_executor - try: - if hasattr(exchange, 'has') and exchange.has.get('fetchOHLCVAsync', False): - ohlcv = await exchange.fetchOHLCVAsync(symbol, timeframe, since=since, limit=limit) - else: - # Run in executor to avoid blocking - loop = asyncio.get_event_loop() - ohlcv = await loop.run_in_executor( - None, - lambda: exchange.fetch_ohlcv(symbol, timeframe, since=since, limit=limit) - ) - except Exception as e: - logging.error(f"Failed to fetch OHLCV data: {e}") - await asyncio.sleep(retry_delay) - continue - - if not ohlcv or len(ohlcv) == 0: - logging.warning(f"No data returned from exchange (attempt {attempt+1}/{max_retries})") - await asyncio.sleep(retry_delay) - continue - - # Filter candles if until timestamp is provided - if until is not None: - ohlcv = [candle for candle in ohlcv if candle[0] <= until] - - # Convert to list of lists format - data = [] - for candle in ohlcv: - timestamp, open_price, high, low, close, volume = candle - data.append([timestamp, open_price, high, low, close, volume]) - - logging.info(f"Successfully fetched {len(data)} historical candles") - return data - - except Exception as e: - logging.error(f"Error fetching historical OHLCV data (attempt {attempt+1}/{max_retries}): {e}") - if attempt < max_retries - 1: - await asyncio.sleep(retry_delay) - - logging.error(f"Failed to fetch historical OHLCV data after {max_retries} attempts") - return [] - - def format_timestamp(self, timestamp): - """Format a timestamp for readable logging""" - if timestamp is None: - return "None" - - try: - dt = datetime.datetime.fromtimestamp(timestamp / 1000.0) - return dt.strftime('%Y-%m-%d %H:%M:%S') - except: - return str(timestamp) - -async def train_with_backtesting(agent, env, symbol="ETH/USDT", - since_timestamp=None, until_timestamp=None, - num_episodes=10, max_steps_per_episode=1000, - period_name=None): - """ - Train agent with backtesting on historical data. - - Args: - agent: The agent to train - env: Trading environment - symbol: Trading pair symbol - since_timestamp: Start timestamp for backtesting - until_timestamp: End timestamp for backtesting - num_episodes: Number of episodes to train - max_steps_per_episode: Maximum steps per episode - period_name: Name of the backtest period - - Returns: - Training statistics dictionary - """ - # Create a backtesting candle cache - backtest_cache = BacktestCandles(since_timestamp, until_timestamp) - if period_name: - backtest_cache.period_name = period_name - logging.info(f"Starting backtesting for period: {period_name}") - - # Initialize exchange for data fetching - exchange = None - try: - exchange = await initialize_exchange() - logging.info("Initialized exchange for backtesting") - except Exception as e: - logging.error(f"Failed to initialize exchange: {e}") - return None - - # Initialize statistics tracking - stats = { - 'period': period_name, - 'since_timestamp': since_timestamp, - 'until_timestamp': until_timestamp, - 'episode_rewards': [], - 'episode_lengths': [], - 'balances': [], - 'win_rates': [], - 'episode_pnls': [], - 'cumulative_pnl': [], - 'drawdowns': [], - 'trade_counts': [], - 'loss_values': [], - 'fees': [], - 'net_pnl_after_fees': [] - } - - # Memory management function - def clean_memory(): - """Clean up memory to avoid memory leaks""" - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - # Fetch historical data for all timeframes - try: - clean_memory() # Clean memory before fetching data - candle_data = await backtest_cache.fetch_all_timeframes(exchange, symbol) - if not candle_data or not candle_data['1m']: - logging.error(f"No historical data available for backtesting period: {period_name}") - try: - await exchange.close() - except Exception as e: - logging.error(f"Error closing exchange: {e}") - return None - - logging.info(f"Fetched historical data for backtesting: {len(candle_data['1m'])} minute candles") - except Exception as e: - logging.error(f"Failed to fetch historical data for backtesting: {e}") - try: - await exchange.close() - except Exception as exchange_err: - logging.error(f"Error closing exchange: {exchange_err}") - return None - - # Track best models - best_reward = float('-inf') - best_pnl = float('-inf') - best_net_pnl = float('-inf') - - # Make directory for backtesting models if it doesn't exist - os.makedirs('models/backtest', exist_ok=True) - - # Start backtesting training loop - for episode in range(num_episodes): - try: - # Clean memory before starting a new episode - clean_memory() - - # Reset environment - state = env.reset() - episode_reward = 0 - episode_losses = [] - - # Update CNN patterns with historical data - env.update_cnn_patterns(candle_data) - - # Track consecutive errors for circuit breaker - consecutive_errors = 0 - max_consecutive_errors = 5 - - # Episode loop - for step in range(max_steps_per_episode): - try: - # Select action using CNN-enhanced policy - action = agent.select_action(state, training=True, candle_data=candle_data) - - # Take action - next_state, reward, done, info = env.step(action) - - # Store transition in replay memory - agent.memory.push(state, action, reward, next_state, done) - - # Move to the next state - state = next_state - - # Update episode reward - episode_reward += reward - - # Learn from experience - if len(agent.memory) > BATCH_SIZE: - try: - loss = agent.learn() - if loss is not None: - episode_losses.append(loss) - # Reset consecutive errors counter on successful learning - consecutive_errors = 0 - except Exception as e: - logging.error(f"Error during learning: {e}") - consecutive_errors += 1 - if consecutive_errors >= max_consecutive_errors: - logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors") - break - - # Update target network periodically - if step % TARGET_UPDATE == 0: - agent.update_target_network() - - # Clean memory periodically during long episodes - if step % 200 == 0 and step > 0: - clean_memory() - - # End episode if done - if done: - break - - except Exception as e: - logging.error(f"Error in training step: {e}") - consecutive_errors += 1 - if consecutive_errors >= max_consecutive_errors: - logging.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors") - break - - # Calculate statistics - mean_loss = np.mean(episode_losses) if episode_losses else 0 - balance = env.balance - pnl = balance - env.initial_balance - fees = env.total_fees - net_pnl = pnl - fees - win_rate = env.win_rate if hasattr(env, 'win_rate') else 0 - trade_count = env.trade_count if hasattr(env, 'trade_count') else 0 - - # Update epsilon for exploration - epsilon = agent.update_epsilon(episode) - - # Update statistics - stats['episode_rewards'].append(episode_reward) - stats['episode_lengths'].append(step + 1) - stats['balances'].append(balance) - stats['win_rates'].append(win_rate) - stats['episode_pnls'].append(pnl) - stats['drawdowns'].append(env.max_drawdown) - stats['trade_counts'].append(trade_count) - stats['loss_values'].append(mean_loss) - stats['fees'].append(fees) - stats['net_pnl_after_fees'].append(net_pnl) - - # Calculate and update cumulative PnL - if len(stats['episode_pnls']) > 0: - cumulative_pnl = sum(stats['episode_pnls']) - if 'cumulative_pnl' not in stats: - stats['cumulative_pnl'] = [] - stats['cumulative_pnl'].append(cumulative_pnl) - if writer: - writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode) - writer.add_scalar('CumulativeNetPnL/episode', sum(stats['net_pnl_after_fees']), episode) - - # Save model if this is the best reward or PnL - if episode_reward > best_reward: - best_reward = episode_reward - model_path = f"models/backtest/{period_name}_best_reward.pt" if period_name else "models/backtest/best_reward.pt" - try: - agent.save(model_path) - logging.info(f"New best reward: {best_reward:.2f}") - except Exception as e: - logging.error(f"Error saving best reward model: {e}") - logging.info(f"New best reward: {best_reward:.2f} (model not saved)") - - if pnl > best_pnl: - best_pnl = pnl - model_path = f"models/backtest/{period_name}_best_pnl.pt" if period_name else "models/backtest/best_pnl.pt" - try: - agent.save(model_path) - logging.info(f"New best PnL: ${best_pnl:.2f}") - except Exception as e: - logging.error(f"Error saving best PnL model: {e}") - logging.info(f"New best PnL: ${best_pnl:.2f} (model not saved)") - - # Save model if this is the best net PnL (after fees) - if net_pnl > best_net_pnl: - best_net_pnl = net_pnl - model_path = f"models/backtest/{period_name}_best_net_pnl.pt" if period_name else "models/backtest/best_net_pnl.pt" - try: - agent.save(model_path) - logging.info(f"New best Net PnL: ${best_net_pnl:.2f}") - except Exception as e: - logging.error(f"Error saving best net PnL model: {e}") - logging.info(f"New best Net PnL: ${best_net_pnl:.2f} (model not saved)") - - # Save checkpoint periodically - if episode % 10 == 0: - try: - if use_compact_save: - compact_save(agent, f'models/trading_agent_checkpoint_{episode}.pt') - else: - agent.save(f'models/trading_agent_checkpoint_{episode}.pt') - except Exception as e: - logging.error(f"Error saving checkpoint model: {e}") - - # Update epsilon - agent.update_epsilon(episode) - - # Log training progress - logging.info(f"Episode {episode+1}/{num_episodes} | " + - f"Reward: {episode_reward:.2f} | " + - f"Balance: ${balance:.2f} | " + - f"PnL: ${pnl:.2f} | " + - f"Fees: ${fees:.2f} | " + - f"Net PnL: ${net_pnl:.2f} | " + - f"Win Rate: {win_rate:.2f} | " + - f"Trades: {trade_count} | " + - f"Loss: {mean_loss:.5f} | " + - f"Epsilon: {agent.epsilon:.4f}") - - except Exception as e: - logging.error(f"Error in episode {episode}: {e}") - logging.error(traceback.format_exc()) - continue - - # Clean memory before saving final model - clean_memory() - - # Save final model - if period_name: - try: - agent.save(f"models/backtest/{period_name}_final.pt") - logging.info(f"Saved final model for period: {period_name}") - except Exception as e: - logging.error(f"Error saving final model: {e}") - - # Save backtesting statistics - stats_file = f"backtest_stats_{period_name}.csv" if period_name else "backtest_stats.csv" - try: - with open(stats_file, 'w', newline='') as f: - writer = csv.writer(f) - writer.writerow(['Episode', 'Reward', 'Balance', 'PnL', 'Fees', 'Net PnL', 'Win Rate', 'Trades', 'Loss']) - for i in range(len(stats['episode_rewards'])): - writer.writerow([ - i+1, - stats['episode_rewards'][i], - stats['balances'][i], - stats['episode_pnls'][i], - stats['fees'][i], - stats['net_pnl_after_fees'][i], - stats['win_rates'][i], - stats['trade_counts'][i], - stats['loss_values'][i] - ]) - logging.info(f"Backtesting statistics saved to {stats_file}") - except Exception as e: - logging.error(f"Error saving backtesting statistics: {e}") - - # Close exchange connection - if exchange: - try: - await exchange.close() - logging.info("Exchange connection closed successfully") - except AttributeError: - # Some exchanges don't have a close method - logging.info("Exchange doesn't have a close method, skipping") - except Exception as e: - logging.error(f"Error closing exchange connection: {e}") - - return stats - -# Implement a robust save function to handle PyTorch serialization errors -def robust_save(model, path): - """ - Save a model with multiple fallback approaches to ensure file is saved - even in low disk space conditions. - """ - logger.info(f"Saving model to {path}.backup (attempt 1)") - backup_path = f"{path}.backup" - - # Attempt 1: Regular save to backup file - try: - checkpoint = { - 'policy_net': model.policy_net.state_dict(), - 'target_net': model.target_net.state_dict(), - 'optimizer': model.optimizer.state_dict(), - 'epsilon': model.epsilon - } - torch.save(checkpoint, backup_path) - logger.info(f"Successfully saved to {backup_path}") - - # If successful, copy to final path - try: - shutil.copy2(backup_path, path) - logger.info(f"Copied backup to {path}") - logger.info(f"Model saved successfully to {path}") - return True - except Exception as e: - logger.warning(f"Failed to copy backup to main file: {str(e)}") - logger.info(f"Using backup file as the main save") - return True - except Exception as e: - logger.warning(f"First save attempt failed: {str(e)}") - - # Attempt 2: Try with older pickle protocol - logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)") - try: - checkpoint = { - 'policy_net': model.policy_net.state_dict(), - 'target_net': model.target_net.state_dict(), - 'optimizer': model.optimizer.state_dict(), - 'epsilon': model.epsilon - } - torch.save(checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2) - logger.info(f"Successfully saved to {path} with protocol 2") - return True - except Exception as e: - logger.warning(f"Second save attempt failed: {str(e)}") - - # Attempt 3: Try without optimizer - logger.info(f"Saving model to {path} (attempt 3 - without optimizer)") - try: - checkpoint = { - 'policy_net': model.policy_net.state_dict(), - 'target_net': model.target_net.state_dict(), - 'epsilon': model.epsilon - } - torch.save(checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2) - logger.info(f"Successfully saved to {path} without optimizer") - return True - except Exception as e: - logger.warning(f"Third save attempt failed: {str(e)}") - - # Attempt 4: Save model structure (as JSON) and parameters separately - logger.info(f"Saving model to {path} (attempt 4 - model structure as JSON)") - try: - # Save only essential model parameters as JSON - model_params = { - 'epsilon': float(model.epsilon), - 'state_size': model.state_size, - 'action_size': model.action_size, - 'hidden_size': model.hidden_size, - 'lstm_layers': model.policy_net.lstm_layers if hasattr(model.policy_net, 'lstm_layers') else 2, - 'attention_heads': model.policy_net.attention_heads if hasattr(model.policy_net, 'attention_heads') else 4 - } - - params_path = f"{path}.params.json" - with open(params_path, 'w') as f: - json.dump(model_params, f) - logger.info(f"Successfully saved model parameters to {params_path}") - - # Now try to save a smaller version of the model without CNN components - # This is a more minimal save for recovery purposes - try: - # Create stripped down checkpoint with minimal components - minimal_checkpoint = { - 'epsilon': model.epsilon, - 'state_size': model.state_size, - 'action_size': model.action_size, - 'hidden_size': model.hidden_size - } - - minimal_path = f"{path}.minimal" - torch.save(minimal_checkpoint, minimal_path, _use_new_zipfile_serialization=False, pickle_protocol=2) - logger.info(f"Successfully saved minimal checkpoint to {minimal_path}") - except Exception as e: - logger.warning(f"Minimal checkpoint save failed: {str(e)}") - - logger.info(f"Model saved successfully to {path}") - return True - except Exception as e: - logger.error(f"All save attempts failed for {path}: {str(e)}") - return False - -def cleanup_model_files(keep_best=True, keep_latest_n=5, aggressive=False): - """ - Delete old model files to free up disk space. - - Args: - keep_best (bool): Whether to keep the best model files (reward, pnl, net_pnl) - keep_latest_n (int): Number of latest checkpoint files to keep - aggressive (bool): If True, apply more aggressive cleanup in very low disk scenarios - """ - try: - logging.info(f"Running model file cleanup: keep_best={keep_best}, keep_latest_n={keep_latest_n}, aggressive={aggressive}") - models_dir = "models" - - # Get all files in the models directory - all_files = os.listdir(models_dir) - - # Files to potentially delete - checkpoint_files = [] - backup_files = [] - params_files = [] - dated_files = [] - - # Best files to keep if keep_best is True - best_patterns = [ - "trading_agent_best_reward.pt", - "trading_agent_best_pnl.pt", - "trading_agent_best_net_pnl.pt", - "trading_agent_final.pt" - ] - - # Categorize files for potential deletion - for filename in all_files: - file_path = os.path.join(models_dir, filename) - - # Skip directories - if os.path.isdir(file_path): + logger.warning(f"Error plotting signal: {e}") continue - - # Skip current best files if keep_best is True - if keep_best and any(filename == pattern for pattern in best_patterns): - continue - - # Check for different file types - if "checkpoint" in filename and filename.endswith(".pt"): - checkpoint_files.append((filename, os.path.getmtime(file_path), file_path)) - elif filename.endswith(".backup"): - backup_files.append((filename, os.path.getmtime(file_path), file_path)) - elif filename.endswith(".params.json"): - params_files.append((filename, os.path.getmtime(file_path), file_path)) - elif "_2025" in filename or "_2024" in filename: # Files with date stamps - dated_files.append((filename, os.path.getmtime(file_path), file_path)) - bytes_freed = 0 - files_deleted = 0 + # Add balance and PnL annotation + if trade_signals and 'balance' in trade_signals[-1] and 'pnl' in trade_signals[-1]: + balance = trade_signals[-1]['balance'] + pnl = trade_signals[-1]['pnl'] + price_ax.annotate(f"Balance: ${balance:.2f}\nPnL: ${pnl:.2f}", + xy=(0.02, 0.95), xycoords='axes fraction', + bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8)) - # Process checkpoint files - keep the newest N - if len(checkpoint_files) > keep_latest_n: - # Sort by modification time (newest first) - checkpoint_files.sort(key=lambda x: x[1], reverse=True) - - # Keep the newest N files - files_to_delete = checkpoint_files[keep_latest_n:] - - # Delete old checkpoint files - for _, _, file_path in files_to_delete: - try: - file_size = os.path.getsize(file_path) - os.remove(file_path) - bytes_freed += file_size - files_deleted += 1 - logging.info(f"Deleted old checkpoint file: {file_path}") - except Exception as e: - logging.error(f"Failed to delete file {file_path}: {str(e)}") + # Set title and format + price_ax.set_title(title) + fig.tight_layout() - # If aggressive cleanup is enabled, remove more files - if aggressive: - # Delete all backup files except the newest one - if backup_files: - backup_files.sort(key=lambda x: x[1], reverse=True) - for _, _, file_path in backup_files[1:]: # Keep only newest backup - try: - file_size = os.path.getsize(file_path) - os.remove(file_path) - bytes_freed += file_size - files_deleted += 1 - logging.info(f"Deleted old backup file: {file_path}") - except Exception as e: - logging.error(f"Failed to delete file {file_path}: {str(e)}") - - # Delete all dated files (these are typically archived models) - for _, _, file_path in dated_files: - try: - file_size = os.path.getsize(file_path) - os.remove(file_path) - bytes_freed += file_size - files_deleted += 1 - logging.info(f"Deleted dated model file: {file_path}") - except Exception as e: - logging.error(f"Failed to delete file {file_path}: {str(e)}") + # Convert to image + buf = io.BytesIO() + fig.savefig(buf, format='png') + buf.seek(0) + plt.close(fig) + img = Image.open(buf) + return img - logging.info(f"Cleanup complete. Deleted {files_deleted} files, freed {bytes_freed / (1024*1024):.2f} MB") - - # Check available disk space after cleanup - try: - if platform.system() == 'Windows': - free_bytes = ctypes.c_ulonglong(0) - ctypes.windll.kernel32.GetDiskFreeSpaceExW(ctypes.c_wchar_p(os.path.abspath(models_dir)), None, None, ctypes.pointer(free_bytes)) - free_mb = free_bytes.value / (1024 * 1024) - else: - st = os.statvfs(os.path.abspath(models_dir)) - free_mb = (st.f_bavail * st.f_frsize) / (1024 * 1024) - - logging.info(f"Available disk space after cleanup: {free_mb:.2f} MB") - - # If space is still low, recommend aggressive cleanup - if free_mb < 200 and not aggressive: # Less than 200MB available - logging.warning("Disk space still critically low. Consider using aggressive cleanup.") - except Exception as e: - logging.error(f"Error checking disk space: {str(e)}") - except Exception as e: - logging.error(f"Error during file cleanup: {str(e)}") - logging.error(traceback.format_exc()) - -def compact_save(model, optimizer, reward, epsilon, state_size, action_size, hidden_size, path, use_quantization=False): - """ - Save a model in a compact format suitable for low disk space environments. - Includes fallbacks if the primary save method fails. - - Args: - model: The model to save - optimizer: The optimizer to save - reward: The current reward - epsilon: The current epsilon value - state_size: The state size - action_size: The action size - hidden_size: The hidden size - path: The path to save to - use_quantization: Whether to use quantization to reduce model size - - Returns: - bool: Whether the save was successful - """ - try: - # Create minimal checkpoint with essential data only - checkpoint = { - 'model_state_dict': model.state_dict(), - 'epsilon': epsilon, - 'state_size': state_size, - 'action_size': action_size, - 'hidden_size': hidden_size - } - - # Apply quantization if requested - if use_quantization: - try: - logging.info(f"Attempting quantized save to {path}") - # Quantize model to int8 - quantized_model = torch.quantization.quantize_dynamic( - model, # the original model - {torch.nn.Linear}, # a set of layers to dynamically quantize - dtype=torch.qint8 # the target dtype for quantized weights - ) - - # Create quantized checkpoint - quantized_checkpoint = { - 'model_state_dict': quantized_model.state_dict(), - 'epsilon': epsilon, - 'state_size': state_size, - 'action_size': action_size, - 'hidden_size': hidden_size, - 'is_quantized': True - } - - # Save with older pickle protocol and disable new zipfile serialization - torch.save(quantized_checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2) - logging.info(f"Quantized compact save successful to {path}") - return True - except Exception as e: - logging.warning(f"Quantized save failed, falling back to regular save: {str(e)}") - # Fall back to regular save if quantization fails - - # Regular save with older pickle protocol and no zipfile serialization - torch.save(checkpoint, path, _use_new_zipfile_serialization=False, pickle_protocol=2) - logging.info(f"Compact save successful to {path}") - return True - except Exception as e: - logging.error(f"Compact save failed: {str(e)}") - logging.error(traceback.format_exc()) - - # Fallback: Save just the parameters as JSON if we can't save the full model - try: - params = { - 'epsilon': epsilon, - 'state_size': state_size, - 'action_size': action_size, - 'hidden_size': hidden_size - } - json_path = f"{path}.params.json" - with open(json_path, 'w') as f: - json.dump(params, f) - logging.info(f"Saved minimal parameters to {json_path}") - return False - except Exception as json_e: - logging.error(f"JSON parameter save failed: {str(json_e)}") - return False + logger.error(f"Error creating chart: {str(e)}") + return None if __name__ == "__main__": - # Parse command line arguments - parser = argparse.ArgumentParser(description='Trading Bot') - parser.add_argument('--mode', type=str, default='train', help='Mode: train, test, live') - parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train') - parser.add_argument('--max_steps', type=int, default=1000, help='Maximum steps per episode') - parser.add_argument('--update_interval', type=int, default=10, help='Target network update interval') - parser.add_argument('--training_iterations', type=int, default=10, help='Number of training iterations per step') - parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading symbol') - parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for candlestick data') - parser.add_argument('--compact_save', action='store_true', help='Use compact save to reduce disk usage') - parser.add_argument('--use_quantization', action='store_true', help='Use model quantization for even smaller file sizes') - parser.add_argument('--cleanup', action='store_true', help='Clean up old model files before training') - parser.add_argument('--aggressive_cleanup', action='store_true', help='Perform aggressive cleanup to free more space') - parser.add_argument('--keep_latest', type=int, default=5, help='Number of latest checkpoint files to keep when cleaning up') - - args = parser.parse_args() - - # Import platform and ctypes for disk space checking - import platform - import ctypes - - # Run cleanup if requested - if args.cleanup: - cleanup_model_files(keep_best=True, keep_latest_n=args.keep_latest, aggressive=args.aggressive_cleanup) - try: asyncio.run(main()) except KeyboardInterrupt: