wip - better training loop; realtime scaffold

This commit is contained in:
Dobromir Popov 2025-03-17 19:17:56 +02:00
parent 4de6352468
commit 5e9e6360af

View File

@ -31,6 +31,7 @@ import matplotlib.gridspec as gridspec
import datetime import datetime
from datetime import datetime as dt from datetime import datetime as dt
from collections import defaultdict from collections import defaultdict
from gym.spaces import Discrete, Box
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
@ -267,70 +268,253 @@ class PricePredictionModel(nn.Module):
return total_loss / epochs return total_loss / epochs
class TradingEnvironment: class TradingEnvironment:
def __init__(self, initial_balance=INITIAL_BALANCE, window_size=30, demo=True): def __init__(self, data=None, features=None, feature_extractors=None, initial_balance=10000, leverage=50,
window_size=100, commission=0.0004, api_key=None, api_secret=None, exchange_id='binance',
symbol='ETH/USDT', timeframe='1m', init_length=5000, max_steps=10000):
"""Initialize the trading environment""" """Initialize the trading environment"""
self.api_key = api_key
self.api_secret = api_secret
self.exchange_id = exchange_id
self.symbol = symbol
self.timeframe = timeframe
self.init_length = init_length
# TODO: For 1s/ticks timeframes, implement WebSocket API integration for real-time data
try:
# Initialize exchange if API credentials are provided
if api_key and api_secret:
self.exchange = initialize_exchange(exchange_id, api_key, api_secret)
logger.info(f"Exchange initialized: {exchange_id}")
# Fetch historical data
self.data = fetch_candles(self.exchange, self.symbol, self.timeframe, limit=self.init_length)
if not self.data:
raise ValueError(f"No data fetched for {self.symbol} on {self.exchange_id}")
self.data_format_is_list = isinstance(self.data[0], list)
logger.info(f"Loaded {len(self.data)} candles from exchange")
elif data is not None: # Use provided data
self.data = data
self.data_format_is_list = isinstance(self.data[0], list)
logger.info(f"Using provided data with {len(self.data)} candles")
else:
# Initialize with empty data, we'll load it later with fetch_initial_data
logger.warning("No data provided, initializing with empty data")
self.data = []
self.data_format_is_list = True
except Exception as e:
logger.error(f"Error initializing environment: {e}")
raise
# Initialize features and feature extractors
if features is not None:
self.features = features
# Create a dictionary of features
self.features_dict = {f"feature_{i}": feature for i, feature in enumerate(features)}
else:
# Initialize features as a dictionary, not a list
self.features = {
'price': [],
'volume': [],
'rsi': [],
'macd': [],
'macd_signal': [],
'macd_hist': [],
'bollinger_upper': [],
'bollinger_mid': [],
'bollinger_lower': [],
'stoch_k': [],
'stoch_d': [],
'ema_9': [],
'ema_21': [],
'atr': []
}
self.features_dict = {}
if feature_extractors is None:
feature_extractors = []
self.feature_extractors = feature_extractors
# Environment parameters
self.initial_balance = initial_balance self.initial_balance = initial_balance
self.balance = initial_balance self.balance = initial_balance
self.window_size = window_size self.leverage = leverage
self.demo = demo
self.data = []
self.position = 'flat' # 'flat', 'long', or 'short' self.position = 'flat' # 'flat', 'long', or 'short'
self.position_size = 0 self.position_size = 0
self.entry_price = 0 self.entry_price = 0
self.entry_index = 0 self.entry_index = 0
self.stop_loss = 0 self.stop_loss = 0
self.take_profit = 0 self.take_profit = 0
self.commission = commission
self.total_pnl = 0
self.total_fees = 0.0 # Track total fees paid
self.trades = [] self.trades = []
self.trade_signals = []
self.current_step = 0
self.window_size = window_size
self.max_steps = max_steps
self.peak_balance = initial_balance
self.max_drawdown = 0
self.current_price = 0
self.win_count = 0 self.win_count = 0
self.loss_count = 0 self.loss_count = 0
self.total_pnl = 0.0 self.min_position_size = 100 # Minimum position size in USD
self.episode_pnl = 0.0
self.peak_balance = initial_balance
self.max_drawdown = 0.0
self.current_step = 0
self.current_price = 0
# For tracking signals for visualization # Track candle patterns and reversal points
self.trade_signals = [] self.patterns = {}
self.reversal_points = []
# Initialize features # Define observation and action spaces
self.features = { num_features = len(self.features) if hasattr(self, 'features') and self.features else 0
'price': [], state_dim = window_size * 5 + 5 + num_features # OHLCV + position info + features
'volume': [],
'rsi': [],
'macd': [],
'macd_signal': [],
'macd_hist': [],
'bollinger_upper': [],
'bollinger_mid': [],
'bollinger_lower': [],
'stoch_k': [],
'stoch_d': [],
'ema_9': [],
'ema_21': [],
'atr': []
}
# Initialize price predictor self.action_space = Discrete(4) # 0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE
self.price_predictor = None self.observation_space = Box(low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32)
self.predicted_prices = np.array([])
# Initialize optimal trade tracking # Check if we have enough data
self.optimal_bottoms = [] if len(self.data) < self.window_size:
self.optimal_tops = [] logger.warning(f"Data length {len(self.data)} is less than window size {self.window_size}")
self.optimal_signals = np.array([])
# Add these new attributes def calculate_reward(self, action):
self.leverage = MAX_LEVERAGE """Calculate reward based on the action taken"""
self.futures_symbol = "ETH_USDT" # Example futures symbol reward = 0
self.position_mode = "hedge" # For simultaneous long/short positions
self.margin_mode = "cross" # Cross margin mode
# Initialize data format indicator (list or dict) # Base reward structure
self.data_format_is_list = True if self.position == 'flat':
if action == 0: # HOLD when flat
reward = 0.01 # Small reward for holding when no position
elif action == 1: # BUY/LONG
# Check for buy signal in CNN patterns
if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns:
buy_confidence = self.cnn_patterns['long_confidence']
# Scale by confidence
reward = 0.1 * buy_confidence * 10
else:
reward = 0.1 # Default reward for taking a position
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
elif action == 2: # SELL/SHORT
# Check for sell signal in CNN patterns
if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns:
sell_confidence = self.cnn_patterns['short_confidence']
# Scale by confidence
reward = 0.1 * sell_confidence * 10
else:
reward = 0.1 # Default reward for taking a position
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
elif action == 3: # CLOSE when no position
reward = -0.1 # Penalty for trying to close no position
elif self.position == 'long':
if action == 0: # HOLD long position
# Calculate price change since entry
price_change = (self.current_price - self.entry_price) / self.entry_price
# Reward or penalize based on price movement
if price_change > 0:
reward = price_change * 10 # Reward for holding profitable position
else:
reward = price_change * 5 # Smaller penalty for holding losing position
elif action == 1: # BUY when already long
reward = -0.1 # Penalty for redundant action
elif action == 2: # SELL when long (reversal)
# Calculate PnL
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
if pnl_percent > 0:
reward = -0.5 # Penalty for closing profitable long position to go short
else:
# Check for sell signal in CNN patterns
if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns:
sell_confidence = self.cnn_patterns['short_confidence']
reward = 0.2 * sell_confidence * 10 # Reward for correct reversal
else:
reward = 0.2 # Default reward for cutting loss
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
elif action == 3: # CLOSE long position
# Calculate PnL
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
if pnl_percent > 0:
reward = pnl_percent * 15 # Higher reward for taking profit
else:
reward = pnl_percent * 5 # Smaller penalty for cutting loss
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
elif self.position == 'short':
if action == 0: # HOLD short position
# Calculate price change since entry
price_change = (self.entry_price - self.current_price) / self.entry_price
# Reward or penalize based on price movement
if price_change > 0:
reward = price_change * 10 # Reward for holding profitable position
else:
reward = price_change * 5 # Smaller penalty for holding losing position
elif action == 1: # BUY when short (reversal)
# Calculate PnL
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
if pnl_percent > 0:
reward = -0.5 # Penalty for closing profitable short position to go long
else:
# Check for buy signal in CNN patterns
if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns:
buy_confidence = self.cnn_patterns['long_confidence']
reward = 0.2 * buy_confidence * 10 # Reward for correct reversal
else:
reward = 0.2 # Default reward for cutting loss
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
elif action == 2: # SELL when already short
reward = -0.1 # Penalty for redundant action
elif action == 3: # CLOSE short position
# Calculate PnL
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
if pnl_percent > 0:
reward = pnl_percent * 15 # Higher reward for taking profit
else:
reward = pnl_percent * 5 # Smaller penalty for cutting loss
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
return reward
def reset(self): def reset(self):
"""Reset the environment to initial state""" """Reset the environment to its initial state and return the initial observation"""
self.balance = self.initial_balance self.balance = self.initial_balance
self.position = 'flat' self.position = 'flat'
self.position_size = 0 self.position_size = 0
@ -338,24 +522,15 @@ class TradingEnvironment:
self.entry_index = 0 self.entry_index = 0
self.stop_loss = 0 self.stop_loss = 0
self.take_profit = 0 self.take_profit = 0
self.current_step = 0
self.trades = [] self.trades = []
self.win_count = 0 self.trade_signals = []
self.loss_count = 0 self.total_pnl = 0.0
self.episode_pnl = 0.0 self.total_fees = 0.0
self.peak_balance = self.initial_balance self.peak_balance = self.initial_balance
self.max_drawdown = 0.0 self.max_drawdown = 0.0
self.current_step = 0 self.win_count = 0
self.loss_count = 0
# Keep data but reset current position
if len(self.data) > self.window_size:
self.current_step = self.window_size
if self.data_format_is_list:
self.current_price = self.data[self.current_step][4] # Close price is at index 4
else:
self.current_price = self.data[self.current_step]['close']
# Reset trade signals
self.trade_signals = []
return self.get_state() return self.get_state()
@ -492,6 +667,227 @@ class TradingEnvironment:
# Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE) # Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE)
reward = self.calculate_reward(action) reward = self.calculate_reward(action)
# Execute the action
initial_balance = self.balance # Store initial balance to calculate PnL
# Open long position
if action == 1 and self.position != 'long':
if self.position == 'short':
# Close short position first
if self.position_size > 0:
# Calculate PnL
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
pnl_dollar = pnl_percent * self.position_size * self.leverage
# Update balance and record trade
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Record trade
trade_duration = self.current_step - self.entry_index
if self.data_format_is_list:
timestamp = self.data[self.current_step][0] # Timestamp
else:
timestamp = self.data[self.current_step]['timestamp']
self.trades.append({
'type': 'short',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'fee': fee,
'net_pnl': pnl_dollar - fee,
'duration': trade_duration,
'timestamp': timestamp,
'reason': 'action_change'
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
# Now open long position
self.position = 'long'
self.entry_price = self.current_price
self.entry_index = self.current_step
# Calculate position size with risk management
self.position_size = self.calculate_position_size()
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Set stop loss and take profit
sl_percent = 0.02 # 2% stop loss
tp_percent = 0.04 # 4% take profit
self.stop_loss = self.entry_price * (1 - sl_percent)
self.take_profit = self.entry_price * (1 + tp_percent)
# Open short position
elif action == 2 and self.position != 'short':
if self.position == 'long':
# Close long position first
if self.position_size > 0:
# Calculate PnL
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
pnl_dollar = pnl_percent * self.position_size * self.leverage
# Update balance and record trade
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Record trade
trade_duration = self.current_step - self.entry_index
if self.data_format_is_list:
timestamp = self.data[self.current_step][0] # Timestamp
else:
timestamp = self.data[self.current_step]['timestamp']
self.trades.append({
'type': 'long',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'fee': fee,
'net_pnl': pnl_dollar - fee,
'duration': trade_duration,
'timestamp': timestamp,
'reason': 'action_change'
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
# Now open short position
self.position = 'short'
self.entry_price = self.current_price
self.entry_index = self.current_step
# Calculate position size with risk management
self.position_size = self.calculate_position_size()
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Set stop loss and take profit
sl_percent = 0.02 # 2% stop loss
tp_percent = 0.04 # 4% take profit
self.stop_loss = self.entry_price * (1 + sl_percent)
self.take_profit = self.entry_price * (1 - tp_percent)
# Close position
elif action == 3 and self.position != 'flat':
if self.position == 'long':
# Calculate PnL
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
pnl_dollar = pnl_percent * self.position_size * self.leverage
# Update balance and record trade
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Record trade
trade_duration = self.current_step - self.entry_index
if self.data_format_is_list:
timestamp = self.data[self.current_step][0] # Timestamp
else:
timestamp = self.data[self.current_step]['timestamp']
self.trades.append({
'type': 'long',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'fee': fee,
'net_pnl': pnl_dollar - fee,
'duration': trade_duration,
'timestamp': timestamp,
'reason': 'close_action'
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
elif self.position == 'short':
# Calculate PnL
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
pnl_dollar = pnl_percent * self.position_size * self.leverage
# Update balance and record trade
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Record trade
trade_duration = self.current_step - self.entry_index
if self.data_format_is_list:
timestamp = self.data[self.current_step][0] # Timestamp
else:
timestamp = self.data[self.current_step]['timestamp']
self.trades.append({
'type': 'short',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'fee': fee,
'net_pnl': pnl_dollar - fee,
'duration': trade_duration,
'timestamp': timestamp,
'reason': 'close_action'
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
# Reset position
self.position = 'flat'
self.position_size = 0
self.entry_price = 0
self.entry_index = 0
self.stop_loss = 0
self.take_profit = 0
# Record trade signal for visualization # Record trade signal for visualization
if action > 0: # If not HOLD if action > 0: # If not HOLD
signal_type = None signal_type = None
@ -529,13 +925,22 @@ class TradingEnvironment:
# Get new state # Get new state
next_state = self.get_state() next_state = self.get_state()
# Update peak balance and drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
current_drawdown = (self.peak_balance - self.balance) / self.peak_balance if self.peak_balance > 0 else 0
self.max_drawdown = max(self.max_drawdown, current_drawdown)
# Create info dictionary # Create info dictionary
info = { info = {
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close', 'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
'price': self.current_price, 'price': self.current_price,
'balance': self.balance, 'balance': self.balance,
'position': self.position, 'position': self.position,
'pnl': self.total_pnl 'pnl': self.total_pnl,
'fees': self.total_fees,
'net_pnl': self.total_pnl - self.total_fees
} }
return next_state, reward, done, info return next_state, reward, done, info
@ -999,6 +1404,33 @@ class TradingEnvironment:
else: else:
# Small negative reward for holding in the wrong direction # Small negative reward for holding in the wrong direction
reward -= 0.1 reward -= 0.1
elif action == 1 or action == 2: # BUY or SELL
# Apply trading fee as negative reward (1 USD per 1.9k position size)
position_size = self.calculate_position_size()
fee = (position_size / 1900) * 1 # Trading fee in USD
# Penalty for fee
fee_penalty = fee / 10 # Scale down to make it a reasonable penalty
reward -= fee_penalty
# Logging
if hasattr(self, 'total_fees'):
self.total_fees += fee
else:
self.total_fees = fee
elif action == 3: # CLOSE
# Apply trading fee as negative reward (1 USD per 1.9k position size)
fee = (self.position_size / 1900) * 1 # Trading fee in USD
# Penalty for fee
fee_penalty = fee / 10 # Scale down to make it a reasonable penalty
reward -= fee_penalty
# Logging
if hasattr(self, 'total_fees'):
self.total_fees += fee
else:
self.total_fees = fee
# Add CNN pattern confidence to reward # Add CNN pattern confidence to reward
reward += pattern_confidence * 10 reward += pattern_confidence * 10
@ -1360,6 +1792,74 @@ class TradingEnvironment:
return False return False
def add_chart_to_tensorboard(self, writer, step, title='Trading Chart'):
"""
Add a candlestick chart and metrics to TensorBoard
Parameters:
- writer: TensorBoard writer
- step: Current step
- title: Title for the chart
"""
try:
# Initialize writer if not provided
if writer is None:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
# Log basic metrics
writer.add_scalar('Balance', self.balance, step)
writer.add_scalar('Total_PnL', self.total_pnl, step)
# Log total fees if available
if hasattr(self, 'total_fees'):
writer.add_scalar('Total_Fees', self.total_fees, step)
writer.add_scalar('Net_PnL', self.total_pnl - self.total_fees, step)
# Log position info
writer.add_scalar('Position_Size', self.position_size, step)
# Log drawdown and win rate
writer.add_scalar('Max_Drawdown', self.max_drawdown, step)
win_rate = self.win_count / (self.win_count + self.loss_count) if (self.win_count + self.loss_count) > 0 else 0
writer.add_scalar('Win_Rate', win_rate, step)
# Log trade count
writer.add_scalar('Trade_Count', len(self.trades), step)
# Check if we have enough data for candlestick chart
if len(self.data) <= 0:
logger.warning("No data available for candlestick chart")
return
# Create figure for candlestick chart (last 100 data points)
start_idx = max(0, self.current_step - 100)
end_idx = self.current_step
# Get recent trades for visualization (last 10 trades)
recent_trades = self.trades[-10:] if self.trades else []
try:
fig = create_candlestick_figure(
self.data[start_idx:end_idx+1],
title=title,
trades=recent_trades
)
# Add figure to TensorBoard
writer.add_figure('Candlestick_Chart', fig, step)
# Close figure to free memory
plt.close(fig)
except Exception as e:
logger.error(f"Error creating candlestick chart: {e}")
except Exception as e:
logger.error(f"Error adding chart to TensorBoard: {e}")
# Continue execution even if chart fails
# Ensure GPU usage if available # Ensure GPU usage if available
def get_device(): def get_device():
"""Get the best available device (CUDA GPU or CPU)""" """Get the best available device (CUDA GPU or CPU)"""
@ -1752,7 +2252,7 @@ class Agent:
raise raise
def add_chart_to_tensorboard(self, env, step): def add_chart_to_tensorboard(self, env, step):
"""Add candlestick chart to tensorboard""" """Add candlestick chart to tensorboard and various metrics"""
try: try:
# Initialize writer if it doesn't exist # Initialize writer if it doesn't exist
if not hasattr(self, 'writer') or self.writer is None: if not hasattr(self, 'writer') or self.writer is None:
@ -1791,22 +2291,37 @@ class Agent:
if hasattr(env, 'trade_count'): if hasattr(env, 'trade_count'):
self.writer.add_scalar('Trading/Trade_Count', env.trade_count, step) self.writer.add_scalar('Trading/Trade_Count', env.trade_count, step)
# Get recent trades if available # Log trading fees
recent_trades = [] if hasattr(env, 'total_fees'):
if hasattr(env, 'trades') and env.trades: self.writer.add_scalar('Trading/Total_Fees', env.total_fees, step)
recent_trades = env.trades[-10:] # Last 10 trades # Also log net PnL (after fees)
if hasattr(env, 'total_pnl'):
self.writer.add_scalar('Trading/Net_PnL_After_Fees', env.total_pnl - env.total_fees, step)
# Create candlestick figure with the last 100 candles and recent trades # Add candlestick chart if we have enough data
fig = create_candlestick_figure(env.data[-100:], recent_trades) if len(env.data) >= 100:
try:
# Use the last 100 candles for the chart
recent_data = env.data[-100:]
# Add figure to tensorboard # Get recent trades if available
self.writer.add_figure('Trading/Chart', fig, step) recent_trades = None
if hasattr(env, 'trades') and len(env.trades) > 0:
recent_trades = env.trades[-10:] # Last 10 trades
# Close figure to free resources # Create candlestick figure
plt.close(fig) fig = create_candlestick_figure(recent_data, recent_trades, f"Trading Chart - Step {step}")
if fig:
# Add to tensorboard
self.writer.add_figure('Trading/Chart', fig, step)
# Close figure to free memory
plt.close(fig)
except Exception as e:
logger.warning(f"Error creating candlestick chart: {e}")
except Exception as e: except Exception as e:
logger.warning(f"Error adding chart to tensorboard: {e}") logger.error(f"Error in add_chart_to_tensorboard: {e}")
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"): async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
"""Get live price data using websockets""" """Get live price data using websockets"""
@ -1888,12 +2403,15 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
'cumulative_pnl': [], 'cumulative_pnl': [],
'drawdowns': [], 'drawdowns': [],
'trade_counts': [], 'trade_counts': [],
'loss_values': [] 'loss_values': [],
'fees': [], # Track fees
'net_pnl_after_fees': [] # Track net PnL after fees
} }
# Track best models # Track best models
best_reward = float('-inf') best_reward = float('-inf')
best_pnl = float('-inf') best_pnl = float('-inf')
best_net_pnl = float('-inf') # Track best net PnL (after fees)
# Make directory for models if it doesn't exist # Make directory for models if it doesn't exist
os.makedirs('models', exist_ok=True) os.makedirs('models', exist_ok=True)
@ -1997,6 +2515,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
# Calculate statistics from this episode # Calculate statistics from this episode
balance = env.balance balance = env.balance
pnl = balance - env.initial_balance if hasattr(env, 'initial_balance') else 0 pnl = balance - env.initial_balance if hasattr(env, 'initial_balance') else 0
fees = env.total_fees if hasattr(env, 'total_fees') else 0
net_pnl = pnl - fees # Calculate net PnL after fees
# Get trading statistics # Get trading statistics
trade_analysis = None trade_analysis = None
@ -2015,6 +2535,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
writer.add_scalar('Reward/episode', episode_reward, episode) writer.add_scalar('Reward/episode', episode_reward, episode)
writer.add_scalar('Balance/episode', balance, episode) writer.add_scalar('Balance/episode', balance, episode)
writer.add_scalar('PnL/episode', pnl, episode) writer.add_scalar('PnL/episode', pnl, episode)
writer.add_scalar('NetPnL/episode', net_pnl, episode)
writer.add_scalar('Fees/episode', fees, episode)
writer.add_scalar('WinRate/episode', win_rate, episode) writer.add_scalar('WinRate/episode', win_rate, episode)
writer.add_scalar('TradeCount/episode', trade_count, episode) writer.add_scalar('TradeCount/episode', trade_count, episode)
writer.add_scalar('Drawdown/episode', max_drawdown, episode) writer.add_scalar('Drawdown/episode', max_drawdown, episode)
@ -2030,6 +2552,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
stats['drawdowns'].append(max_drawdown) stats['drawdowns'].append(max_drawdown)
stats['trade_counts'].append(trade_count) stats['trade_counts'].append(trade_count)
stats['loss_values'].append(avg_loss) stats['loss_values'].append(avg_loss)
stats['fees'].append(fees)
stats['net_pnl_after_fees'].append(net_pnl)
# Calculate and update cumulative PnL # Calculate and update cumulative PnL
if len(stats['episode_pnls']) > 0: if len(stats['episode_pnls']) > 0:
@ -2039,6 +2563,7 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
stats['cumulative_pnl'].append(cumulative_pnl) stats['cumulative_pnl'].append(cumulative_pnl)
if writer: if writer:
writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode) writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode)
writer.add_scalar('CumulativeNetPnL/episode', sum(stats['net_pnl_after_fees']), episode)
# Save model if this is the best reward or PnL # Save model if this is the best reward or PnL
if episode_reward > best_reward: if episode_reward > best_reward:
@ -2051,6 +2576,12 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
agent.save('models/trading_agent_best_pnl.pt') agent.save('models/trading_agent_best_pnl.pt')
logging.info(f"New best PnL: ${best_pnl:.2f}") logging.info(f"New best PnL: ${best_pnl:.2f}")
# Save model if this is the best net PnL (after fees)
if net_pnl > best_net_pnl:
best_net_pnl = net_pnl
agent.save('models/trading_agent_best_net_pnl.pt')
logging.info(f"New best Net PnL: ${best_net_pnl:.2f}")
# Save checkpoint periodically # Save checkpoint periodically
if episode % 10 == 0: if episode % 10 == 0:
agent.save(f'models/trading_agent_checkpoint_{episode}.pt') agent.save(f'models/trading_agent_checkpoint_{episode}.pt')
@ -2063,6 +2594,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
f"Reward: {episode_reward:.2f} | " + f"Reward: {episode_reward:.2f} | " +
f"Balance: ${balance:.2f} | " + f"Balance: ${balance:.2f} | " +
f"PnL: ${pnl:.2f} | " + f"PnL: ${pnl:.2f} | " +
f"Fees: ${fees:.2f} | " +
f"Net PnL: ${net_pnl:.2f} | " +
f"Win Rate: {win_rate:.2f} | " + f"Win Rate: {win_rate:.2f} | " +
f"Trades: {trade_count} | " + f"Trades: {trade_count} | " +
f"Loss: {avg_loss:.5f} | " + f"Loss: {avg_loss:.5f} | " +
@ -2608,12 +3141,19 @@ async def main():
# Initialize exchange # Initialize exchange
exchange = await initialize_exchange() exchange = await initialize_exchange()
# Create environment # Create environment with updated parameters
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode) env = TradingEnvironment(
initial_balance=INITIAL_BALANCE,
window_size=30,
leverage=args.leverage,
exchange_id='mexc',
symbol=args.symbol,
timeframe=args.timeframe
)
if args.mode == 'train': if args.mode == 'train':
# Fetch initial data for training # Fetch initial data for training
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000) await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000)
# Create agent with consistent parameters # Create agent with consistent parameters
# Note: Using STATE_SIZE and action_size=4 for consistency # Note: Using STATE_SIZE and action_size=4 for consistency
@ -2957,6 +3497,9 @@ async def fetch_multi_timeframe_data(exchange, symbol, candle_cache):
'1d': 86400 # Update every 1 day '1d': 86400 # Update every 1 day
} }
# TODO: For 1s/tick timeframes, we'll need to use the exchange's WebSocket API
# for real-time data streaming instead of REST API. Implement this in the future.
limits = { limits = {
'1s': 1000, '1s': 1000,
'1m': 1000, '1m': 1000,