diff --git a/crypto/gogo2/main.py b/crypto/gogo2/main.py index 354e5dd..5f0c0a5 100644 --- a/crypto/gogo2/main.py +++ b/crypto/gogo2/main.py @@ -31,6 +31,7 @@ import matplotlib.gridspec as gridspec import datetime from datetime import datetime as dt from collections import defaultdict +from gym.spaces import Discrete, Box # Configure logging logging.basicConfig( @@ -267,70 +268,253 @@ class PricePredictionModel(nn.Module): return total_loss / epochs class TradingEnvironment: - def __init__(self, initial_balance=INITIAL_BALANCE, window_size=30, demo=True): + def __init__(self, data=None, features=None, feature_extractors=None, initial_balance=10000, leverage=50, + window_size=100, commission=0.0004, api_key=None, api_secret=None, exchange_id='binance', + symbol='ETH/USDT', timeframe='1m', init_length=5000, max_steps=10000): """Initialize the trading environment""" + self.api_key = api_key + self.api_secret = api_secret + self.exchange_id = exchange_id + self.symbol = symbol + self.timeframe = timeframe + self.init_length = init_length + + # TODO: For 1s/ticks timeframes, implement WebSocket API integration for real-time data + + try: + # Initialize exchange if API credentials are provided + if api_key and api_secret: + self.exchange = initialize_exchange(exchange_id, api_key, api_secret) + logger.info(f"Exchange initialized: {exchange_id}") + # Fetch historical data + self.data = fetch_candles(self.exchange, self.symbol, self.timeframe, limit=self.init_length) + if not self.data: + raise ValueError(f"No data fetched for {self.symbol} on {self.exchange_id}") + self.data_format_is_list = isinstance(self.data[0], list) + logger.info(f"Loaded {len(self.data)} candles from exchange") + elif data is not None: # Use provided data + self.data = data + self.data_format_is_list = isinstance(self.data[0], list) + logger.info(f"Using provided data with {len(self.data)} candles") + else: + # Initialize with empty data, we'll load it later with fetch_initial_data + logger.warning("No data provided, initializing with empty data") + self.data = [] + self.data_format_is_list = True + except Exception as e: + logger.error(f"Error initializing environment: {e}") + raise + + # Initialize features and feature extractors + if features is not None: + self.features = features + # Create a dictionary of features + self.features_dict = {f"feature_{i}": feature for i, feature in enumerate(features)} + else: + # Initialize features as a dictionary, not a list + self.features = { + 'price': [], + 'volume': [], + 'rsi': [], + 'macd': [], + 'macd_signal': [], + 'macd_hist': [], + 'bollinger_upper': [], + 'bollinger_mid': [], + 'bollinger_lower': [], + 'stoch_k': [], + 'stoch_d': [], + 'ema_9': [], + 'ema_21': [], + 'atr': [] + } + self.features_dict = {} + + if feature_extractors is None: + feature_extractors = [] + self.feature_extractors = feature_extractors + + # Environment parameters self.initial_balance = initial_balance self.balance = initial_balance - self.window_size = window_size - self.demo = demo - self.data = [] + self.leverage = leverage self.position = 'flat' # 'flat', 'long', or 'short' self.position_size = 0 self.entry_price = 0 self.entry_index = 0 self.stop_loss = 0 self.take_profit = 0 + self.commission = commission + self.total_pnl = 0 + self.total_fees = 0.0 # Track total fees paid self.trades = [] + self.trade_signals = [] + self.current_step = 0 + self.window_size = window_size + self.max_steps = max_steps + self.peak_balance = initial_balance + self.max_drawdown = 0 + self.current_price = 0 self.win_count = 0 self.loss_count = 0 - self.total_pnl = 0.0 - self.episode_pnl = 0.0 - self.peak_balance = initial_balance - self.max_drawdown = 0.0 - self.current_step = 0 - self.current_price = 0 + self.min_position_size = 100 # Minimum position size in USD - # For tracking signals for visualization - self.trade_signals = [] + # Track candle patterns and reversal points + self.patterns = {} + self.reversal_points = [] - # Initialize features - self.features = { - 'price': [], - 'volume': [], - 'rsi': [], - 'macd': [], - 'macd_signal': [], - 'macd_hist': [], - 'bollinger_upper': [], - 'bollinger_mid': [], - 'bollinger_lower': [], - 'stoch_k': [], - 'stoch_d': [], - 'ema_9': [], - 'ema_21': [], - 'atr': [] - } + # Define observation and action spaces + num_features = len(self.features) if hasattr(self, 'features') and self.features else 0 + state_dim = window_size * 5 + 5 + num_features # OHLCV + position info + features - # Initialize price predictor - self.price_predictor = None - self.predicted_prices = np.array([]) + self.action_space = Discrete(4) # 0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE + self.observation_space = Box(low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32) - # Initialize optimal trade tracking - self.optimal_bottoms = [] - self.optimal_tops = [] - self.optimal_signals = np.array([]) + # Check if we have enough data + if len(self.data) < self.window_size: + logger.warning(f"Data length {len(self.data)} is less than window size {self.window_size}") + + def calculate_reward(self, action): + """Calculate reward based on the action taken""" + reward = 0 - # Add these new attributes - self.leverage = MAX_LEVERAGE - self.futures_symbol = "ETH_USDT" # Example futures symbol - self.position_mode = "hedge" # For simultaneous long/short positions - self.margin_mode = "cross" # Cross margin mode + # Base reward structure + if self.position == 'flat': + if action == 0: # HOLD when flat + reward = 0.01 # Small reward for holding when no position + elif action == 1: # BUY/LONG + # Check for buy signal in CNN patterns + if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns: + buy_confidence = self.cnn_patterns['long_confidence'] + # Scale by confidence + reward = 0.1 * buy_confidence * 10 + else: + reward = 0.1 # Default reward for taking a position + + # Apply fee penalty + if self.position_size > 0: + fee = (self.position_size / 1900) * 1 + fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 + reward -= fee_penalty + elif action == 2: # SELL/SHORT + # Check for sell signal in CNN patterns + if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns: + sell_confidence = self.cnn_patterns['short_confidence'] + # Scale by confidence + reward = 0.1 * sell_confidence * 10 + else: + reward = 0.1 # Default reward for taking a position + + # Apply fee penalty + if self.position_size > 0: + fee = (self.position_size / 1900) * 1 + fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 + reward -= fee_penalty + elif action == 3: # CLOSE when no position + reward = -0.1 # Penalty for trying to close no position - # Initialize data format indicator (list or dict) - self.data_format_is_list = True + elif self.position == 'long': + if action == 0: # HOLD long position + # Calculate price change since entry + price_change = (self.current_price - self.entry_price) / self.entry_price + + # Reward or penalize based on price movement + if price_change > 0: + reward = price_change * 10 # Reward for holding profitable position + else: + reward = price_change * 5 # Smaller penalty for holding losing position + + elif action == 1: # BUY when already long + reward = -0.1 # Penalty for redundant action + + elif action == 2: # SELL when long (reversal) + # Calculate PnL + pnl_percent = (self.current_price - self.entry_price) / self.entry_price + + if pnl_percent > 0: + reward = -0.5 # Penalty for closing profitable long position to go short + else: + # Check for sell signal in CNN patterns + if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns: + sell_confidence = self.cnn_patterns['short_confidence'] + reward = 0.2 * sell_confidence * 10 # Reward for correct reversal + else: + reward = 0.2 # Default reward for cutting loss + + # Apply fee penalty + if self.position_size > 0: + fee = (self.position_size / 1900) * 1 + fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 + reward -= fee_penalty + + elif action == 3: # CLOSE long position + # Calculate PnL + pnl_percent = (self.current_price - self.entry_price) / self.entry_price + + if pnl_percent > 0: + reward = pnl_percent * 15 # Higher reward for taking profit + else: + reward = pnl_percent * 5 # Smaller penalty for cutting loss + + # Apply fee penalty + if self.position_size > 0: + fee = (self.position_size / 1900) * 1 + fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 + reward -= fee_penalty + elif self.position == 'short': + if action == 0: # HOLD short position + # Calculate price change since entry + price_change = (self.entry_price - self.current_price) / self.entry_price + + # Reward or penalize based on price movement + if price_change > 0: + reward = price_change * 10 # Reward for holding profitable position + else: + reward = price_change * 5 # Smaller penalty for holding losing position + + elif action == 1: # BUY when short (reversal) + # Calculate PnL + pnl_percent = (self.entry_price - self.current_price) / self.entry_price + + if pnl_percent > 0: + reward = -0.5 # Penalty for closing profitable short position to go long + else: + # Check for buy signal in CNN patterns + if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns: + buy_confidence = self.cnn_patterns['long_confidence'] + reward = 0.2 * buy_confidence * 10 # Reward for correct reversal + else: + reward = 0.2 # Default reward for cutting loss + + # Apply fee penalty + if self.position_size > 0: + fee = (self.position_size / 1900) * 1 + fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 + reward -= fee_penalty + + elif action == 2: # SELL when already short + reward = -0.1 # Penalty for redundant action + + elif action == 3: # CLOSE short position + # Calculate PnL + pnl_percent = (self.entry_price - self.current_price) / self.entry_price + + if pnl_percent > 0: + reward = pnl_percent * 15 # Higher reward for taking profit + else: + reward = pnl_percent * 5 # Smaller penalty for cutting loss + + # Apply fee penalty + if self.position_size > 0: + fee = (self.position_size / 1900) * 1 + fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05 + reward -= fee_penalty + + return reward + def reset(self): - """Reset the environment to initial state""" + """Reset the environment to its initial state and return the initial observation""" self.balance = self.initial_balance self.position = 'flat' self.position_size = 0 @@ -338,24 +522,15 @@ class TradingEnvironment: self.entry_index = 0 self.stop_loss = 0 self.take_profit = 0 + self.current_step = 0 self.trades = [] - self.win_count = 0 - self.loss_count = 0 - self.episode_pnl = 0.0 + self.trade_signals = [] + self.total_pnl = 0.0 + self.total_fees = 0.0 self.peak_balance = self.initial_balance self.max_drawdown = 0.0 - self.current_step = 0 - - # Keep data but reset current position - if len(self.data) > self.window_size: - self.current_step = self.window_size - if self.data_format_is_list: - self.current_price = self.data[self.current_step][4] # Close price is at index 4 - else: - self.current_price = self.data[self.current_step]['close'] - - # Reset trade signals - self.trade_signals = [] + self.win_count = 0 + self.loss_count = 0 return self.get_state() @@ -492,6 +667,227 @@ class TradingEnvironment: # Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE) reward = self.calculate_reward(action) + # Execute the action + initial_balance = self.balance # Store initial balance to calculate PnL + + # Open long position + if action == 1 and self.position != 'long': + if self.position == 'short': + # Close short position first + if self.position_size > 0: + # Calculate PnL + pnl_percent = (self.entry_price - self.current_price) / self.entry_price + pnl_dollar = pnl_percent * self.position_size * self.leverage + + # Update balance and record trade + self.balance += pnl_dollar + self.total_pnl += pnl_dollar + + # Apply trading fee (1 USD per 1.9k position) + fee = (self.position_size / 1900) * 1 + self.balance -= fee + self.total_fees += fee + + # Record trade + trade_duration = self.current_step - self.entry_index + if self.data_format_is_list: + timestamp = self.data[self.current_step][0] # Timestamp + else: + timestamp = self.data[self.current_step]['timestamp'] + + self.trades.append({ + 'type': 'short', + 'entry': self.entry_price, + 'exit': self.current_price, + 'pnl_percent': pnl_percent, + 'pnl_dollar': pnl_dollar, + 'fee': fee, + 'net_pnl': pnl_dollar - fee, + 'duration': trade_duration, + 'timestamp': timestamp, + 'reason': 'action_change' + }) + + # Update win/loss count + if pnl_dollar > 0: + self.win_count += 1 + else: + self.loss_count += 1 + + # Now open long position + self.position = 'long' + self.entry_price = self.current_price + self.entry_index = self.current_step + + # Calculate position size with risk management + self.position_size = self.calculate_position_size() + + # Apply trading fee (1 USD per 1.9k position) + fee = (self.position_size / 1900) * 1 + self.balance -= fee + self.total_fees += fee + + # Set stop loss and take profit + sl_percent = 0.02 # 2% stop loss + tp_percent = 0.04 # 4% take profit + + self.stop_loss = self.entry_price * (1 - sl_percent) + self.take_profit = self.entry_price * (1 + tp_percent) + + # Open short position + elif action == 2 and self.position != 'short': + if self.position == 'long': + # Close long position first + if self.position_size > 0: + # Calculate PnL + pnl_percent = (self.current_price - self.entry_price) / self.entry_price + pnl_dollar = pnl_percent * self.position_size * self.leverage + + # Update balance and record trade + self.balance += pnl_dollar + self.total_pnl += pnl_dollar + + # Apply trading fee (1 USD per 1.9k position) + fee = (self.position_size / 1900) * 1 + self.balance -= fee + self.total_fees += fee + + # Record trade + trade_duration = self.current_step - self.entry_index + if self.data_format_is_list: + timestamp = self.data[self.current_step][0] # Timestamp + else: + timestamp = self.data[self.current_step]['timestamp'] + + self.trades.append({ + 'type': 'long', + 'entry': self.entry_price, + 'exit': self.current_price, + 'pnl_percent': pnl_percent, + 'pnl_dollar': pnl_dollar, + 'fee': fee, + 'net_pnl': pnl_dollar - fee, + 'duration': trade_duration, + 'timestamp': timestamp, + 'reason': 'action_change' + }) + + # Update win/loss count + if pnl_dollar > 0: + self.win_count += 1 + else: + self.loss_count += 1 + + # Now open short position + self.position = 'short' + self.entry_price = self.current_price + self.entry_index = self.current_step + + # Calculate position size with risk management + self.position_size = self.calculate_position_size() + + # Apply trading fee (1 USD per 1.9k position) + fee = (self.position_size / 1900) * 1 + self.balance -= fee + self.total_fees += fee + + # Set stop loss and take profit + sl_percent = 0.02 # 2% stop loss + tp_percent = 0.04 # 4% take profit + + self.stop_loss = self.entry_price * (1 + sl_percent) + self.take_profit = self.entry_price * (1 - tp_percent) + + # Close position + elif action == 3 and self.position != 'flat': + if self.position == 'long': + # Calculate PnL + pnl_percent = (self.current_price - self.entry_price) / self.entry_price + pnl_dollar = pnl_percent * self.position_size * self.leverage + + # Update balance and record trade + self.balance += pnl_dollar + self.total_pnl += pnl_dollar + + # Apply trading fee (1 USD per 1.9k position) + fee = (self.position_size / 1900) * 1 + self.balance -= fee + self.total_fees += fee + + # Record trade + trade_duration = self.current_step - self.entry_index + if self.data_format_is_list: + timestamp = self.data[self.current_step][0] # Timestamp + else: + timestamp = self.data[self.current_step]['timestamp'] + + self.trades.append({ + 'type': 'long', + 'entry': self.entry_price, + 'exit': self.current_price, + 'pnl_percent': pnl_percent, + 'pnl_dollar': pnl_dollar, + 'fee': fee, + 'net_pnl': pnl_dollar - fee, + 'duration': trade_duration, + 'timestamp': timestamp, + 'reason': 'close_action' + }) + + # Update win/loss count + if pnl_dollar > 0: + self.win_count += 1 + else: + self.loss_count += 1 + + elif self.position == 'short': + # Calculate PnL + pnl_percent = (self.entry_price - self.current_price) / self.entry_price + pnl_dollar = pnl_percent * self.position_size * self.leverage + + # Update balance and record trade + self.balance += pnl_dollar + self.total_pnl += pnl_dollar + + # Apply trading fee (1 USD per 1.9k position) + fee = (self.position_size / 1900) * 1 + self.balance -= fee + self.total_fees += fee + + # Record trade + trade_duration = self.current_step - self.entry_index + if self.data_format_is_list: + timestamp = self.data[self.current_step][0] # Timestamp + else: + timestamp = self.data[self.current_step]['timestamp'] + + self.trades.append({ + 'type': 'short', + 'entry': self.entry_price, + 'exit': self.current_price, + 'pnl_percent': pnl_percent, + 'pnl_dollar': pnl_dollar, + 'fee': fee, + 'net_pnl': pnl_dollar - fee, + 'duration': trade_duration, + 'timestamp': timestamp, + 'reason': 'close_action' + }) + + # Update win/loss count + if pnl_dollar > 0: + self.win_count += 1 + else: + self.loss_count += 1 + + # Reset position + self.position = 'flat' + self.position_size = 0 + self.entry_price = 0 + self.entry_index = 0 + self.stop_loss = 0 + self.take_profit = 0 + # Record trade signal for visualization if action > 0: # If not HOLD signal_type = None @@ -529,13 +925,22 @@ class TradingEnvironment: # Get new state next_state = self.get_state() + # Update peak balance and drawdown + if self.balance > self.peak_balance: + self.peak_balance = self.balance + + current_drawdown = (self.peak_balance - self.balance) / self.peak_balance if self.peak_balance > 0 else 0 + self.max_drawdown = max(self.max_drawdown, current_drawdown) + # Create info dictionary info = { 'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close', 'price': self.current_price, 'balance': self.balance, 'position': self.position, - 'pnl': self.total_pnl + 'pnl': self.total_pnl, + 'fees': self.total_fees, + 'net_pnl': self.total_pnl - self.total_fees } return next_state, reward, done, info @@ -999,6 +1404,33 @@ class TradingEnvironment: else: # Small negative reward for holding in the wrong direction reward -= 0.1 + elif action == 1 or action == 2: # BUY or SELL + # Apply trading fee as negative reward (1 USD per 1.9k position size) + position_size = self.calculate_position_size() + fee = (position_size / 1900) * 1 # Trading fee in USD + + # Penalty for fee + fee_penalty = fee / 10 # Scale down to make it a reasonable penalty + reward -= fee_penalty + + # Logging + if hasattr(self, 'total_fees'): + self.total_fees += fee + else: + self.total_fees = fee + elif action == 3: # CLOSE + # Apply trading fee as negative reward (1 USD per 1.9k position size) + fee = (self.position_size / 1900) * 1 # Trading fee in USD + + # Penalty for fee + fee_penalty = fee / 10 # Scale down to make it a reasonable penalty + reward -= fee_penalty + + # Logging + if hasattr(self, 'total_fees'): + self.total_fees += fee + else: + self.total_fees = fee # Add CNN pattern confidence to reward reward += pattern_confidence * 10 @@ -1360,6 +1792,74 @@ class TradingEnvironment: return False + def add_chart_to_tensorboard(self, writer, step, title='Trading Chart'): + """ + Add a candlestick chart and metrics to TensorBoard + + Parameters: + - writer: TensorBoard writer + - step: Current step + - title: Title for the chart + """ + try: + # Initialize writer if not provided + if writer is None: + from torch.utils.tensorboard import SummaryWriter + writer = SummaryWriter() + + # Log basic metrics + writer.add_scalar('Balance', self.balance, step) + writer.add_scalar('Total_PnL', self.total_pnl, step) + + # Log total fees if available + if hasattr(self, 'total_fees'): + writer.add_scalar('Total_Fees', self.total_fees, step) + writer.add_scalar('Net_PnL', self.total_pnl - self.total_fees, step) + + # Log position info + writer.add_scalar('Position_Size', self.position_size, step) + + # Log drawdown and win rate + writer.add_scalar('Max_Drawdown', self.max_drawdown, step) + + win_rate = self.win_count / (self.win_count + self.loss_count) if (self.win_count + self.loss_count) > 0 else 0 + writer.add_scalar('Win_Rate', win_rate, step) + + # Log trade count + writer.add_scalar('Trade_Count', len(self.trades), step) + + # Check if we have enough data for candlestick chart + if len(self.data) <= 0: + logger.warning("No data available for candlestick chart") + return + + # Create figure for candlestick chart (last 100 data points) + start_idx = max(0, self.current_step - 100) + end_idx = self.current_step + + # Get recent trades for visualization (last 10 trades) + recent_trades = self.trades[-10:] if self.trades else [] + + try: + fig = create_candlestick_figure( + self.data[start_idx:end_idx+1], + title=title, + trades=recent_trades + ) + + # Add figure to TensorBoard + writer.add_figure('Candlestick_Chart', fig, step) + + # Close figure to free memory + plt.close(fig) + + except Exception as e: + logger.error(f"Error creating candlestick chart: {e}") + + except Exception as e: + logger.error(f"Error adding chart to TensorBoard: {e}") + # Continue execution even if chart fails + # Ensure GPU usage if available def get_device(): """Get the best available device (CUDA GPU or CPU)""" @@ -1752,7 +2252,7 @@ class Agent: raise def add_chart_to_tensorboard(self, env, step): - """Add candlestick chart to tensorboard""" + """Add candlestick chart to tensorboard and various metrics""" try: # Initialize writer if it doesn't exist if not hasattr(self, 'writer') or self.writer is None: @@ -1790,23 +2290,38 @@ class Agent: if hasattr(env, 'trade_count'): self.writer.add_scalar('Trading/Trade_Count', env.trade_count, step) - - # Get recent trades if available - recent_trades = [] - if hasattr(env, 'trades') and env.trades: - recent_trades = env.trades[-10:] # Last 10 trades - # Create candlestick figure with the last 100 candles and recent trades - fig = create_candlestick_figure(env.data[-100:], recent_trades) - - # Add figure to tensorboard - self.writer.add_figure('Trading/Chart', fig, step) - - # Close figure to free resources - plt.close(fig) - + # Log trading fees + if hasattr(env, 'total_fees'): + self.writer.add_scalar('Trading/Total_Fees', env.total_fees, step) + # Also log net PnL (after fees) + if hasattr(env, 'total_pnl'): + self.writer.add_scalar('Trading/Net_PnL_After_Fees', env.total_pnl - env.total_fees, step) + + # Add candlestick chart if we have enough data + if len(env.data) >= 100: + try: + # Use the last 100 candles for the chart + recent_data = env.data[-100:] + + # Get recent trades if available + recent_trades = None + if hasattr(env, 'trades') and len(env.trades) > 0: + recent_trades = env.trades[-10:] # Last 10 trades + + # Create candlestick figure + fig = create_candlestick_figure(recent_data, recent_trades, f"Trading Chart - Step {step}") + + if fig: + # Add to tensorboard + self.writer.add_figure('Trading/Chart', fig, step) + + # Close figure to free memory + plt.close(fig) + except Exception as e: + logger.warning(f"Error creating candlestick chart: {e}") except Exception as e: - logger.warning(f"Error adding chart to tensorboard: {e}") + logger.error(f"Error in add_chart_to_tensorboard: {e}") async def get_live_prices(symbol="ETH/USDT", timeframe="1m"): """Get live price data using websockets""" @@ -1888,12 +2403,15 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000) 'cumulative_pnl': [], 'drawdowns': [], 'trade_counts': [], - 'loss_values': [] + 'loss_values': [], + 'fees': [], # Track fees + 'net_pnl_after_fees': [] # Track net PnL after fees } # Track best models best_reward = float('-inf') best_pnl = float('-inf') + best_net_pnl = float('-inf') # Track best net PnL (after fees) # Make directory for models if it doesn't exist os.makedirs('models', exist_ok=True) @@ -1997,6 +2515,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000) # Calculate statistics from this episode balance = env.balance pnl = balance - env.initial_balance if hasattr(env, 'initial_balance') else 0 + fees = env.total_fees if hasattr(env, 'total_fees') else 0 + net_pnl = pnl - fees # Calculate net PnL after fees # Get trading statistics trade_analysis = None @@ -2015,6 +2535,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000) writer.add_scalar('Reward/episode', episode_reward, episode) writer.add_scalar('Balance/episode', balance, episode) writer.add_scalar('PnL/episode', pnl, episode) + writer.add_scalar('NetPnL/episode', net_pnl, episode) + writer.add_scalar('Fees/episode', fees, episode) writer.add_scalar('WinRate/episode', win_rate, episode) writer.add_scalar('TradeCount/episode', trade_count, episode) writer.add_scalar('Drawdown/episode', max_drawdown, episode) @@ -2030,6 +2552,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000) stats['drawdowns'].append(max_drawdown) stats['trade_counts'].append(trade_count) stats['loss_values'].append(avg_loss) + stats['fees'].append(fees) + stats['net_pnl_after_fees'].append(net_pnl) # Calculate and update cumulative PnL if len(stats['episode_pnls']) > 0: @@ -2039,6 +2563,7 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000) stats['cumulative_pnl'].append(cumulative_pnl) if writer: writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode) + writer.add_scalar('CumulativeNetPnL/episode', sum(stats['net_pnl_after_fees']), episode) # Save model if this is the best reward or PnL if episode_reward > best_reward: @@ -2050,6 +2575,12 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000) best_pnl = pnl agent.save('models/trading_agent_best_pnl.pt') logging.info(f"New best PnL: ${best_pnl:.2f}") + + # Save model if this is the best net PnL (after fees) + if net_pnl > best_net_pnl: + best_net_pnl = net_pnl + agent.save('models/trading_agent_best_net_pnl.pt') + logging.info(f"New best Net PnL: ${best_net_pnl:.2f}") # Save checkpoint periodically if episode % 10 == 0: @@ -2063,6 +2594,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000) f"Reward: {episode_reward:.2f} | " + f"Balance: ${balance:.2f} | " + f"PnL: ${pnl:.2f} | " + + f"Fees: ${fees:.2f} | " + + f"Net PnL: ${net_pnl:.2f} | " + f"Win Rate: {win_rate:.2f} | " + f"Trades: {trade_count} | " + f"Loss: {avg_loss:.5f} | " + @@ -2608,12 +3141,19 @@ async def main(): # Initialize exchange exchange = await initialize_exchange() - # Create environment - env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode) + # Create environment with updated parameters + env = TradingEnvironment( + initial_balance=INITIAL_BALANCE, + window_size=30, + leverage=args.leverage, + exchange_id='mexc', + symbol=args.symbol, + timeframe=args.timeframe + ) if args.mode == 'train': # Fetch initial data for training - await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000) + await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000) # Create agent with consistent parameters # Note: Using STATE_SIZE and action_size=4 for consistency @@ -2957,6 +3497,9 @@ async def fetch_multi_timeframe_data(exchange, symbol, candle_cache): '1d': 86400 # Update every 1 day } + # TODO: For 1s/tick timeframes, we'll need to use the exchange's WebSocket API + # for real-time data streaming instead of REST API. Implement this in the future. + limits = { '1s': 1000, '1m': 1000,