wip - better training loop; realtime scaffold

This commit is contained in:
Dobromir Popov 2025-03-17 19:17:56 +02:00
parent 4de6352468
commit 5e9e6360af

View File

@ -31,6 +31,7 @@ import matplotlib.gridspec as gridspec
import datetime
from datetime import datetime as dt
from collections import defaultdict
from gym.spaces import Discrete, Box
# Configure logging
logging.basicConfig(
@ -267,70 +268,253 @@ class PricePredictionModel(nn.Module):
return total_loss / epochs
class TradingEnvironment:
def __init__(self, initial_balance=INITIAL_BALANCE, window_size=30, demo=True):
def __init__(self, data=None, features=None, feature_extractors=None, initial_balance=10000, leverage=50,
window_size=100, commission=0.0004, api_key=None, api_secret=None, exchange_id='binance',
symbol='ETH/USDT', timeframe='1m', init_length=5000, max_steps=10000):
"""Initialize the trading environment"""
self.api_key = api_key
self.api_secret = api_secret
self.exchange_id = exchange_id
self.symbol = symbol
self.timeframe = timeframe
self.init_length = init_length
# TODO: For 1s/ticks timeframes, implement WebSocket API integration for real-time data
try:
# Initialize exchange if API credentials are provided
if api_key and api_secret:
self.exchange = initialize_exchange(exchange_id, api_key, api_secret)
logger.info(f"Exchange initialized: {exchange_id}")
# Fetch historical data
self.data = fetch_candles(self.exchange, self.symbol, self.timeframe, limit=self.init_length)
if not self.data:
raise ValueError(f"No data fetched for {self.symbol} on {self.exchange_id}")
self.data_format_is_list = isinstance(self.data[0], list)
logger.info(f"Loaded {len(self.data)} candles from exchange")
elif data is not None: # Use provided data
self.data = data
self.data_format_is_list = isinstance(self.data[0], list)
logger.info(f"Using provided data with {len(self.data)} candles")
else:
# Initialize with empty data, we'll load it later with fetch_initial_data
logger.warning("No data provided, initializing with empty data")
self.data = []
self.data_format_is_list = True
except Exception as e:
logger.error(f"Error initializing environment: {e}")
raise
# Initialize features and feature extractors
if features is not None:
self.features = features
# Create a dictionary of features
self.features_dict = {f"feature_{i}": feature for i, feature in enumerate(features)}
else:
# Initialize features as a dictionary, not a list
self.features = {
'price': [],
'volume': [],
'rsi': [],
'macd': [],
'macd_signal': [],
'macd_hist': [],
'bollinger_upper': [],
'bollinger_mid': [],
'bollinger_lower': [],
'stoch_k': [],
'stoch_d': [],
'ema_9': [],
'ema_21': [],
'atr': []
}
self.features_dict = {}
if feature_extractors is None:
feature_extractors = []
self.feature_extractors = feature_extractors
# Environment parameters
self.initial_balance = initial_balance
self.balance = initial_balance
self.window_size = window_size
self.demo = demo
self.data = []
self.leverage = leverage
self.position = 'flat' # 'flat', 'long', or 'short'
self.position_size = 0
self.entry_price = 0
self.entry_index = 0
self.stop_loss = 0
self.take_profit = 0
self.commission = commission
self.total_pnl = 0
self.total_fees = 0.0 # Track total fees paid
self.trades = []
self.trade_signals = []
self.current_step = 0
self.window_size = window_size
self.max_steps = max_steps
self.peak_balance = initial_balance
self.max_drawdown = 0
self.current_price = 0
self.win_count = 0
self.loss_count = 0
self.total_pnl = 0.0
self.episode_pnl = 0.0
self.peak_balance = initial_balance
self.max_drawdown = 0.0
self.current_step = 0
self.current_price = 0
self.min_position_size = 100 # Minimum position size in USD
# For tracking signals for visualization
self.trade_signals = []
# Track candle patterns and reversal points
self.patterns = {}
self.reversal_points = []
# Initialize features
self.features = {
'price': [],
'volume': [],
'rsi': [],
'macd': [],
'macd_signal': [],
'macd_hist': [],
'bollinger_upper': [],
'bollinger_mid': [],
'bollinger_lower': [],
'stoch_k': [],
'stoch_d': [],
'ema_9': [],
'ema_21': [],
'atr': []
}
# Define observation and action spaces
num_features = len(self.features) if hasattr(self, 'features') and self.features else 0
state_dim = window_size * 5 + 5 + num_features # OHLCV + position info + features
# Initialize price predictor
self.price_predictor = None
self.predicted_prices = np.array([])
self.action_space = Discrete(4) # 0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE
self.observation_space = Box(low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32)
# Initialize optimal trade tracking
self.optimal_bottoms = []
self.optimal_tops = []
self.optimal_signals = np.array([])
# Check if we have enough data
if len(self.data) < self.window_size:
logger.warning(f"Data length {len(self.data)} is less than window size {self.window_size}")
# Add these new attributes
self.leverage = MAX_LEVERAGE
self.futures_symbol = "ETH_USDT" # Example futures symbol
self.position_mode = "hedge" # For simultaneous long/short positions
self.margin_mode = "cross" # Cross margin mode
def calculate_reward(self, action):
"""Calculate reward based on the action taken"""
reward = 0
# Initialize data format indicator (list or dict)
self.data_format_is_list = True
# Base reward structure
if self.position == 'flat':
if action == 0: # HOLD when flat
reward = 0.01 # Small reward for holding when no position
elif action == 1: # BUY/LONG
# Check for buy signal in CNN patterns
if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns:
buy_confidence = self.cnn_patterns['long_confidence']
# Scale by confidence
reward = 0.1 * buy_confidence * 10
else:
reward = 0.1 # Default reward for taking a position
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
elif action == 2: # SELL/SHORT
# Check for sell signal in CNN patterns
if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns:
sell_confidence = self.cnn_patterns['short_confidence']
# Scale by confidence
reward = 0.1 * sell_confidence * 10
else:
reward = 0.1 # Default reward for taking a position
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
elif action == 3: # CLOSE when no position
reward = -0.1 # Penalty for trying to close no position
elif self.position == 'long':
if action == 0: # HOLD long position
# Calculate price change since entry
price_change = (self.current_price - self.entry_price) / self.entry_price
# Reward or penalize based on price movement
if price_change > 0:
reward = price_change * 10 # Reward for holding profitable position
else:
reward = price_change * 5 # Smaller penalty for holding losing position
elif action == 1: # BUY when already long
reward = -0.1 # Penalty for redundant action
elif action == 2: # SELL when long (reversal)
# Calculate PnL
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
if pnl_percent > 0:
reward = -0.5 # Penalty for closing profitable long position to go short
else:
# Check for sell signal in CNN patterns
if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns:
sell_confidence = self.cnn_patterns['short_confidence']
reward = 0.2 * sell_confidence * 10 # Reward for correct reversal
else:
reward = 0.2 # Default reward for cutting loss
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
elif action == 3: # CLOSE long position
# Calculate PnL
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
if pnl_percent > 0:
reward = pnl_percent * 15 # Higher reward for taking profit
else:
reward = pnl_percent * 5 # Smaller penalty for cutting loss
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
elif self.position == 'short':
if action == 0: # HOLD short position
# Calculate price change since entry
price_change = (self.entry_price - self.current_price) / self.entry_price
# Reward or penalize based on price movement
if price_change > 0:
reward = price_change * 10 # Reward for holding profitable position
else:
reward = price_change * 5 # Smaller penalty for holding losing position
elif action == 1: # BUY when short (reversal)
# Calculate PnL
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
if pnl_percent > 0:
reward = -0.5 # Penalty for closing profitable short position to go long
else:
# Check for buy signal in CNN patterns
if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns:
buy_confidence = self.cnn_patterns['long_confidence']
reward = 0.2 * buy_confidence * 10 # Reward for correct reversal
else:
reward = 0.2 # Default reward for cutting loss
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
elif action == 2: # SELL when already short
reward = -0.1 # Penalty for redundant action
elif action == 3: # CLOSE short position
# Calculate PnL
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
if pnl_percent > 0:
reward = pnl_percent * 15 # Higher reward for taking profit
else:
reward = pnl_percent * 5 # Smaller penalty for cutting loss
# Apply fee penalty
if self.position_size > 0:
fee = (self.position_size / 1900) * 1
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
reward -= fee_penalty
return reward
def reset(self):
"""Reset the environment to initial state"""
"""Reset the environment to its initial state and return the initial observation"""
self.balance = self.initial_balance
self.position = 'flat'
self.position_size = 0
@ -338,24 +522,15 @@ class TradingEnvironment:
self.entry_index = 0
self.stop_loss = 0
self.take_profit = 0
self.current_step = 0
self.trades = []
self.win_count = 0
self.loss_count = 0
self.episode_pnl = 0.0
self.trade_signals = []
self.total_pnl = 0.0
self.total_fees = 0.0
self.peak_balance = self.initial_balance
self.max_drawdown = 0.0
self.current_step = 0
# Keep data but reset current position
if len(self.data) > self.window_size:
self.current_step = self.window_size
if self.data_format_is_list:
self.current_price = self.data[self.current_step][4] # Close price is at index 4
else:
self.current_price = self.data[self.current_step]['close']
# Reset trade signals
self.trade_signals = []
self.win_count = 0
self.loss_count = 0
return self.get_state()
@ -492,6 +667,227 @@ class TradingEnvironment:
# Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE)
reward = self.calculate_reward(action)
# Execute the action
initial_balance = self.balance # Store initial balance to calculate PnL
# Open long position
if action == 1 and self.position != 'long':
if self.position == 'short':
# Close short position first
if self.position_size > 0:
# Calculate PnL
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
pnl_dollar = pnl_percent * self.position_size * self.leverage
# Update balance and record trade
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Record trade
trade_duration = self.current_step - self.entry_index
if self.data_format_is_list:
timestamp = self.data[self.current_step][0] # Timestamp
else:
timestamp = self.data[self.current_step]['timestamp']
self.trades.append({
'type': 'short',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'fee': fee,
'net_pnl': pnl_dollar - fee,
'duration': trade_duration,
'timestamp': timestamp,
'reason': 'action_change'
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
# Now open long position
self.position = 'long'
self.entry_price = self.current_price
self.entry_index = self.current_step
# Calculate position size with risk management
self.position_size = self.calculate_position_size()
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Set stop loss and take profit
sl_percent = 0.02 # 2% stop loss
tp_percent = 0.04 # 4% take profit
self.stop_loss = self.entry_price * (1 - sl_percent)
self.take_profit = self.entry_price * (1 + tp_percent)
# Open short position
elif action == 2 and self.position != 'short':
if self.position == 'long':
# Close long position first
if self.position_size > 0:
# Calculate PnL
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
pnl_dollar = pnl_percent * self.position_size * self.leverage
# Update balance and record trade
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Record trade
trade_duration = self.current_step - self.entry_index
if self.data_format_is_list:
timestamp = self.data[self.current_step][0] # Timestamp
else:
timestamp = self.data[self.current_step]['timestamp']
self.trades.append({
'type': 'long',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'fee': fee,
'net_pnl': pnl_dollar - fee,
'duration': trade_duration,
'timestamp': timestamp,
'reason': 'action_change'
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
# Now open short position
self.position = 'short'
self.entry_price = self.current_price
self.entry_index = self.current_step
# Calculate position size with risk management
self.position_size = self.calculate_position_size()
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Set stop loss and take profit
sl_percent = 0.02 # 2% stop loss
tp_percent = 0.04 # 4% take profit
self.stop_loss = self.entry_price * (1 + sl_percent)
self.take_profit = self.entry_price * (1 - tp_percent)
# Close position
elif action == 3 and self.position != 'flat':
if self.position == 'long':
# Calculate PnL
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
pnl_dollar = pnl_percent * self.position_size * self.leverage
# Update balance and record trade
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Record trade
trade_duration = self.current_step - self.entry_index
if self.data_format_is_list:
timestamp = self.data[self.current_step][0] # Timestamp
else:
timestamp = self.data[self.current_step]['timestamp']
self.trades.append({
'type': 'long',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'fee': fee,
'net_pnl': pnl_dollar - fee,
'duration': trade_duration,
'timestamp': timestamp,
'reason': 'close_action'
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
elif self.position == 'short':
# Calculate PnL
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
pnl_dollar = pnl_percent * self.position_size * self.leverage
# Update balance and record trade
self.balance += pnl_dollar
self.total_pnl += pnl_dollar
# Apply trading fee (1 USD per 1.9k position)
fee = (self.position_size / 1900) * 1
self.balance -= fee
self.total_fees += fee
# Record trade
trade_duration = self.current_step - self.entry_index
if self.data_format_is_list:
timestamp = self.data[self.current_step][0] # Timestamp
else:
timestamp = self.data[self.current_step]['timestamp']
self.trades.append({
'type': 'short',
'entry': self.entry_price,
'exit': self.current_price,
'pnl_percent': pnl_percent,
'pnl_dollar': pnl_dollar,
'fee': fee,
'net_pnl': pnl_dollar - fee,
'duration': trade_duration,
'timestamp': timestamp,
'reason': 'close_action'
})
# Update win/loss count
if pnl_dollar > 0:
self.win_count += 1
else:
self.loss_count += 1
# Reset position
self.position = 'flat'
self.position_size = 0
self.entry_price = 0
self.entry_index = 0
self.stop_loss = 0
self.take_profit = 0
# Record trade signal for visualization
if action > 0: # If not HOLD
signal_type = None
@ -529,13 +925,22 @@ class TradingEnvironment:
# Get new state
next_state = self.get_state()
# Update peak balance and drawdown
if self.balance > self.peak_balance:
self.peak_balance = self.balance
current_drawdown = (self.peak_balance - self.balance) / self.peak_balance if self.peak_balance > 0 else 0
self.max_drawdown = max(self.max_drawdown, current_drawdown)
# Create info dictionary
info = {
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
'price': self.current_price,
'balance': self.balance,
'position': self.position,
'pnl': self.total_pnl
'pnl': self.total_pnl,
'fees': self.total_fees,
'net_pnl': self.total_pnl - self.total_fees
}
return next_state, reward, done, info
@ -999,6 +1404,33 @@ class TradingEnvironment:
else:
# Small negative reward for holding in the wrong direction
reward -= 0.1
elif action == 1 or action == 2: # BUY or SELL
# Apply trading fee as negative reward (1 USD per 1.9k position size)
position_size = self.calculate_position_size()
fee = (position_size / 1900) * 1 # Trading fee in USD
# Penalty for fee
fee_penalty = fee / 10 # Scale down to make it a reasonable penalty
reward -= fee_penalty
# Logging
if hasattr(self, 'total_fees'):
self.total_fees += fee
else:
self.total_fees = fee
elif action == 3: # CLOSE
# Apply trading fee as negative reward (1 USD per 1.9k position size)
fee = (self.position_size / 1900) * 1 # Trading fee in USD
# Penalty for fee
fee_penalty = fee / 10 # Scale down to make it a reasonable penalty
reward -= fee_penalty
# Logging
if hasattr(self, 'total_fees'):
self.total_fees += fee
else:
self.total_fees = fee
# Add CNN pattern confidence to reward
reward += pattern_confidence * 10
@ -1360,6 +1792,74 @@ class TradingEnvironment:
return False
def add_chart_to_tensorboard(self, writer, step, title='Trading Chart'):
"""
Add a candlestick chart and metrics to TensorBoard
Parameters:
- writer: TensorBoard writer
- step: Current step
- title: Title for the chart
"""
try:
# Initialize writer if not provided
if writer is None:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
# Log basic metrics
writer.add_scalar('Balance', self.balance, step)
writer.add_scalar('Total_PnL', self.total_pnl, step)
# Log total fees if available
if hasattr(self, 'total_fees'):
writer.add_scalar('Total_Fees', self.total_fees, step)
writer.add_scalar('Net_PnL', self.total_pnl - self.total_fees, step)
# Log position info
writer.add_scalar('Position_Size', self.position_size, step)
# Log drawdown and win rate
writer.add_scalar('Max_Drawdown', self.max_drawdown, step)
win_rate = self.win_count / (self.win_count + self.loss_count) if (self.win_count + self.loss_count) > 0 else 0
writer.add_scalar('Win_Rate', win_rate, step)
# Log trade count
writer.add_scalar('Trade_Count', len(self.trades), step)
# Check if we have enough data for candlestick chart
if len(self.data) <= 0:
logger.warning("No data available for candlestick chart")
return
# Create figure for candlestick chart (last 100 data points)
start_idx = max(0, self.current_step - 100)
end_idx = self.current_step
# Get recent trades for visualization (last 10 trades)
recent_trades = self.trades[-10:] if self.trades else []
try:
fig = create_candlestick_figure(
self.data[start_idx:end_idx+1],
title=title,
trades=recent_trades
)
# Add figure to TensorBoard
writer.add_figure('Candlestick_Chart', fig, step)
# Close figure to free memory
plt.close(fig)
except Exception as e:
logger.error(f"Error creating candlestick chart: {e}")
except Exception as e:
logger.error(f"Error adding chart to TensorBoard: {e}")
# Continue execution even if chart fails
# Ensure GPU usage if available
def get_device():
"""Get the best available device (CUDA GPU or CPU)"""
@ -1752,7 +2252,7 @@ class Agent:
raise
def add_chart_to_tensorboard(self, env, step):
"""Add candlestick chart to tensorboard"""
"""Add candlestick chart to tensorboard and various metrics"""
try:
# Initialize writer if it doesn't exist
if not hasattr(self, 'writer') or self.writer is None:
@ -1791,22 +2291,37 @@ class Agent:
if hasattr(env, 'trade_count'):
self.writer.add_scalar('Trading/Trade_Count', env.trade_count, step)
# Get recent trades if available
recent_trades = []
if hasattr(env, 'trades') and env.trades:
recent_trades = env.trades[-10:] # Last 10 trades
# Log trading fees
if hasattr(env, 'total_fees'):
self.writer.add_scalar('Trading/Total_Fees', env.total_fees, step)
# Also log net PnL (after fees)
if hasattr(env, 'total_pnl'):
self.writer.add_scalar('Trading/Net_PnL_After_Fees', env.total_pnl - env.total_fees, step)
# Create candlestick figure with the last 100 candles and recent trades
fig = create_candlestick_figure(env.data[-100:], recent_trades)
# Add candlestick chart if we have enough data
if len(env.data) >= 100:
try:
# Use the last 100 candles for the chart
recent_data = env.data[-100:]
# Add figure to tensorboard
self.writer.add_figure('Trading/Chart', fig, step)
# Get recent trades if available
recent_trades = None
if hasattr(env, 'trades') and len(env.trades) > 0:
recent_trades = env.trades[-10:] # Last 10 trades
# Close figure to free resources
plt.close(fig)
# Create candlestick figure
fig = create_candlestick_figure(recent_data, recent_trades, f"Trading Chart - Step {step}")
if fig:
# Add to tensorboard
self.writer.add_figure('Trading/Chart', fig, step)
# Close figure to free memory
plt.close(fig)
except Exception as e:
logger.warning(f"Error creating candlestick chart: {e}")
except Exception as e:
logger.warning(f"Error adding chart to tensorboard: {e}")
logger.error(f"Error in add_chart_to_tensorboard: {e}")
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
"""Get live price data using websockets"""
@ -1888,12 +2403,15 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
'cumulative_pnl': [],
'drawdowns': [],
'trade_counts': [],
'loss_values': []
'loss_values': [],
'fees': [], # Track fees
'net_pnl_after_fees': [] # Track net PnL after fees
}
# Track best models
best_reward = float('-inf')
best_pnl = float('-inf')
best_net_pnl = float('-inf') # Track best net PnL (after fees)
# Make directory for models if it doesn't exist
os.makedirs('models', exist_ok=True)
@ -1997,6 +2515,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
# Calculate statistics from this episode
balance = env.balance
pnl = balance - env.initial_balance if hasattr(env, 'initial_balance') else 0
fees = env.total_fees if hasattr(env, 'total_fees') else 0
net_pnl = pnl - fees # Calculate net PnL after fees
# Get trading statistics
trade_analysis = None
@ -2015,6 +2535,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
writer.add_scalar('Reward/episode', episode_reward, episode)
writer.add_scalar('Balance/episode', balance, episode)
writer.add_scalar('PnL/episode', pnl, episode)
writer.add_scalar('NetPnL/episode', net_pnl, episode)
writer.add_scalar('Fees/episode', fees, episode)
writer.add_scalar('WinRate/episode', win_rate, episode)
writer.add_scalar('TradeCount/episode', trade_count, episode)
writer.add_scalar('Drawdown/episode', max_drawdown, episode)
@ -2030,6 +2552,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
stats['drawdowns'].append(max_drawdown)
stats['trade_counts'].append(trade_count)
stats['loss_values'].append(avg_loss)
stats['fees'].append(fees)
stats['net_pnl_after_fees'].append(net_pnl)
# Calculate and update cumulative PnL
if len(stats['episode_pnls']) > 0:
@ -2039,6 +2563,7 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
stats['cumulative_pnl'].append(cumulative_pnl)
if writer:
writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode)
writer.add_scalar('CumulativeNetPnL/episode', sum(stats['net_pnl_after_fees']), episode)
# Save model if this is the best reward or PnL
if episode_reward > best_reward:
@ -2051,6 +2576,12 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
agent.save('models/trading_agent_best_pnl.pt')
logging.info(f"New best PnL: ${best_pnl:.2f}")
# Save model if this is the best net PnL (after fees)
if net_pnl > best_net_pnl:
best_net_pnl = net_pnl
agent.save('models/trading_agent_best_net_pnl.pt')
logging.info(f"New best Net PnL: ${best_net_pnl:.2f}")
# Save checkpoint periodically
if episode % 10 == 0:
agent.save(f'models/trading_agent_checkpoint_{episode}.pt')
@ -2063,6 +2594,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
f"Reward: {episode_reward:.2f} | " +
f"Balance: ${balance:.2f} | " +
f"PnL: ${pnl:.2f} | " +
f"Fees: ${fees:.2f} | " +
f"Net PnL: ${net_pnl:.2f} | " +
f"Win Rate: {win_rate:.2f} | " +
f"Trades: {trade_count} | " +
f"Loss: {avg_loss:.5f} | " +
@ -2608,12 +3141,19 @@ async def main():
# Initialize exchange
exchange = await initialize_exchange()
# Create environment
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode)
# Create environment with updated parameters
env = TradingEnvironment(
initial_balance=INITIAL_BALANCE,
window_size=30,
leverage=args.leverage,
exchange_id='mexc',
symbol=args.symbol,
timeframe=args.timeframe
)
if args.mode == 'train':
# Fetch initial data for training
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000)
# Create agent with consistent parameters
# Note: Using STATE_SIZE and action_size=4 for consistency
@ -2957,6 +3497,9 @@ async def fetch_multi_timeframe_data(exchange, symbol, candle_cache):
'1d': 86400 # Update every 1 day
}
# TODO: For 1s/tick timeframes, we'll need to use the exchange's WebSocket API
# for real-time data streaming instead of REST API. Implement this in the future.
limits = {
'1s': 1000,
'1m': 1000,