wip - better training loop; realtime scaffold
This commit is contained in:
parent
4de6352468
commit
5e9e6360af
@ -31,6 +31,7 @@ import matplotlib.gridspec as gridspec
|
||||
import datetime
|
||||
from datetime import datetime as dt
|
||||
from collections import defaultdict
|
||||
from gym.spaces import Discrete, Box
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@ -267,33 +268,50 @@ class PricePredictionModel(nn.Module):
|
||||
return total_loss / epochs
|
||||
|
||||
class TradingEnvironment:
|
||||
def __init__(self, initial_balance=INITIAL_BALANCE, window_size=30, demo=True):
|
||||
def __init__(self, data=None, features=None, feature_extractors=None, initial_balance=10000, leverage=50,
|
||||
window_size=100, commission=0.0004, api_key=None, api_secret=None, exchange_id='binance',
|
||||
symbol='ETH/USDT', timeframe='1m', init_length=5000, max_steps=10000):
|
||||
"""Initialize the trading environment"""
|
||||
self.initial_balance = initial_balance
|
||||
self.balance = initial_balance
|
||||
self.window_size = window_size
|
||||
self.demo = demo
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
self.exchange_id = exchange_id
|
||||
self.symbol = symbol
|
||||
self.timeframe = timeframe
|
||||
self.init_length = init_length
|
||||
|
||||
# TODO: For 1s/ticks timeframes, implement WebSocket API integration for real-time data
|
||||
|
||||
try:
|
||||
# Initialize exchange if API credentials are provided
|
||||
if api_key and api_secret:
|
||||
self.exchange = initialize_exchange(exchange_id, api_key, api_secret)
|
||||
logger.info(f"Exchange initialized: {exchange_id}")
|
||||
# Fetch historical data
|
||||
self.data = fetch_candles(self.exchange, self.symbol, self.timeframe, limit=self.init_length)
|
||||
if not self.data:
|
||||
raise ValueError(f"No data fetched for {self.symbol} on {self.exchange_id}")
|
||||
self.data_format_is_list = isinstance(self.data[0], list)
|
||||
logger.info(f"Loaded {len(self.data)} candles from exchange")
|
||||
elif data is not None: # Use provided data
|
||||
self.data = data
|
||||
self.data_format_is_list = isinstance(self.data[0], list)
|
||||
logger.info(f"Using provided data with {len(self.data)} candles")
|
||||
else:
|
||||
# Initialize with empty data, we'll load it later with fetch_initial_data
|
||||
logger.warning("No data provided, initializing with empty data")
|
||||
self.data = []
|
||||
self.position = 'flat' # 'flat', 'long', or 'short'
|
||||
self.position_size = 0
|
||||
self.entry_price = 0
|
||||
self.entry_index = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
self.trades = []
|
||||
self.win_count = 0
|
||||
self.loss_count = 0
|
||||
self.total_pnl = 0.0
|
||||
self.episode_pnl = 0.0
|
||||
self.peak_balance = initial_balance
|
||||
self.max_drawdown = 0.0
|
||||
self.current_step = 0
|
||||
self.current_price = 0
|
||||
self.data_format_is_list = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing environment: {e}")
|
||||
raise
|
||||
|
||||
# For tracking signals for visualization
|
||||
self.trade_signals = []
|
||||
|
||||
# Initialize features
|
||||
# Initialize features and feature extractors
|
||||
if features is not None:
|
||||
self.features = features
|
||||
# Create a dictionary of features
|
||||
self.features_dict = {f"feature_{i}": feature for i, feature in enumerate(features)}
|
||||
else:
|
||||
# Initialize features as a dictionary, not a list
|
||||
self.features = {
|
||||
'price': [],
|
||||
'volume': [],
|
||||
@ -310,27 +328,193 @@ class TradingEnvironment:
|
||||
'ema_21': [],
|
||||
'atr': []
|
||||
}
|
||||
self.features_dict = {}
|
||||
|
||||
# Initialize price predictor
|
||||
self.price_predictor = None
|
||||
self.predicted_prices = np.array([])
|
||||
if feature_extractors is None:
|
||||
feature_extractors = []
|
||||
self.feature_extractors = feature_extractors
|
||||
|
||||
# Initialize optimal trade tracking
|
||||
self.optimal_bottoms = []
|
||||
self.optimal_tops = []
|
||||
self.optimal_signals = np.array([])
|
||||
# Environment parameters
|
||||
self.initial_balance = initial_balance
|
||||
self.balance = initial_balance
|
||||
self.leverage = leverage
|
||||
self.position = 'flat' # 'flat', 'long', or 'short'
|
||||
self.position_size = 0
|
||||
self.entry_price = 0
|
||||
self.entry_index = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
self.commission = commission
|
||||
self.total_pnl = 0
|
||||
self.total_fees = 0.0 # Track total fees paid
|
||||
self.trades = []
|
||||
self.trade_signals = []
|
||||
self.current_step = 0
|
||||
self.window_size = window_size
|
||||
self.max_steps = max_steps
|
||||
self.peak_balance = initial_balance
|
||||
self.max_drawdown = 0
|
||||
self.current_price = 0
|
||||
self.win_count = 0
|
||||
self.loss_count = 0
|
||||
self.min_position_size = 100 # Minimum position size in USD
|
||||
|
||||
# Add these new attributes
|
||||
self.leverage = MAX_LEVERAGE
|
||||
self.futures_symbol = "ETH_USDT" # Example futures symbol
|
||||
self.position_mode = "hedge" # For simultaneous long/short positions
|
||||
self.margin_mode = "cross" # Cross margin mode
|
||||
# Track candle patterns and reversal points
|
||||
self.patterns = {}
|
||||
self.reversal_points = []
|
||||
|
||||
# Initialize data format indicator (list or dict)
|
||||
self.data_format_is_list = True
|
||||
# Define observation and action spaces
|
||||
num_features = len(self.features) if hasattr(self, 'features') and self.features else 0
|
||||
state_dim = window_size * 5 + 5 + num_features # OHLCV + position info + features
|
||||
|
||||
self.action_space = Discrete(4) # 0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE
|
||||
self.observation_space = Box(low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32)
|
||||
|
||||
# Check if we have enough data
|
||||
if len(self.data) < self.window_size:
|
||||
logger.warning(f"Data length {len(self.data)} is less than window size {self.window_size}")
|
||||
|
||||
def calculate_reward(self, action):
|
||||
"""Calculate reward based on the action taken"""
|
||||
reward = 0
|
||||
|
||||
# Base reward structure
|
||||
if self.position == 'flat':
|
||||
if action == 0: # HOLD when flat
|
||||
reward = 0.01 # Small reward for holding when no position
|
||||
elif action == 1: # BUY/LONG
|
||||
# Check for buy signal in CNN patterns
|
||||
if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns:
|
||||
buy_confidence = self.cnn_patterns['long_confidence']
|
||||
# Scale by confidence
|
||||
reward = 0.1 * buy_confidence * 10
|
||||
else:
|
||||
reward = 0.1 # Default reward for taking a position
|
||||
|
||||
# Apply fee penalty
|
||||
if self.position_size > 0:
|
||||
fee = (self.position_size / 1900) * 1
|
||||
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
||||
reward -= fee_penalty
|
||||
elif action == 2: # SELL/SHORT
|
||||
# Check for sell signal in CNN patterns
|
||||
if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns:
|
||||
sell_confidence = self.cnn_patterns['short_confidence']
|
||||
# Scale by confidence
|
||||
reward = 0.1 * sell_confidence * 10
|
||||
else:
|
||||
reward = 0.1 # Default reward for taking a position
|
||||
|
||||
# Apply fee penalty
|
||||
if self.position_size > 0:
|
||||
fee = (self.position_size / 1900) * 1
|
||||
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
||||
reward -= fee_penalty
|
||||
elif action == 3: # CLOSE when no position
|
||||
reward = -0.1 # Penalty for trying to close no position
|
||||
|
||||
elif self.position == 'long':
|
||||
if action == 0: # HOLD long position
|
||||
# Calculate price change since entry
|
||||
price_change = (self.current_price - self.entry_price) / self.entry_price
|
||||
|
||||
# Reward or penalize based on price movement
|
||||
if price_change > 0:
|
||||
reward = price_change * 10 # Reward for holding profitable position
|
||||
else:
|
||||
reward = price_change * 5 # Smaller penalty for holding losing position
|
||||
|
||||
elif action == 1: # BUY when already long
|
||||
reward = -0.1 # Penalty for redundant action
|
||||
|
||||
elif action == 2: # SELL when long (reversal)
|
||||
# Calculate PnL
|
||||
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
|
||||
|
||||
if pnl_percent > 0:
|
||||
reward = -0.5 # Penalty for closing profitable long position to go short
|
||||
else:
|
||||
# Check for sell signal in CNN patterns
|
||||
if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns:
|
||||
sell_confidence = self.cnn_patterns['short_confidence']
|
||||
reward = 0.2 * sell_confidence * 10 # Reward for correct reversal
|
||||
else:
|
||||
reward = 0.2 # Default reward for cutting loss
|
||||
|
||||
# Apply fee penalty
|
||||
if self.position_size > 0:
|
||||
fee = (self.position_size / 1900) * 1
|
||||
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
||||
reward -= fee_penalty
|
||||
|
||||
elif action == 3: # CLOSE long position
|
||||
# Calculate PnL
|
||||
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
|
||||
|
||||
if pnl_percent > 0:
|
||||
reward = pnl_percent * 15 # Higher reward for taking profit
|
||||
else:
|
||||
reward = pnl_percent * 5 # Smaller penalty for cutting loss
|
||||
|
||||
# Apply fee penalty
|
||||
if self.position_size > 0:
|
||||
fee = (self.position_size / 1900) * 1
|
||||
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
||||
reward -= fee_penalty
|
||||
|
||||
elif self.position == 'short':
|
||||
if action == 0: # HOLD short position
|
||||
# Calculate price change since entry
|
||||
price_change = (self.entry_price - self.current_price) / self.entry_price
|
||||
|
||||
# Reward or penalize based on price movement
|
||||
if price_change > 0:
|
||||
reward = price_change * 10 # Reward for holding profitable position
|
||||
else:
|
||||
reward = price_change * 5 # Smaller penalty for holding losing position
|
||||
|
||||
elif action == 1: # BUY when short (reversal)
|
||||
# Calculate PnL
|
||||
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
|
||||
|
||||
if pnl_percent > 0:
|
||||
reward = -0.5 # Penalty for closing profitable short position to go long
|
||||
else:
|
||||
# Check for buy signal in CNN patterns
|
||||
if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns:
|
||||
buy_confidence = self.cnn_patterns['long_confidence']
|
||||
reward = 0.2 * buy_confidence * 10 # Reward for correct reversal
|
||||
else:
|
||||
reward = 0.2 # Default reward for cutting loss
|
||||
|
||||
# Apply fee penalty
|
||||
if self.position_size > 0:
|
||||
fee = (self.position_size / 1900) * 1
|
||||
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
||||
reward -= fee_penalty
|
||||
|
||||
elif action == 2: # SELL when already short
|
||||
reward = -0.1 # Penalty for redundant action
|
||||
|
||||
elif action == 3: # CLOSE short position
|
||||
# Calculate PnL
|
||||
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
|
||||
|
||||
if pnl_percent > 0:
|
||||
reward = pnl_percent * 15 # Higher reward for taking profit
|
||||
else:
|
||||
reward = pnl_percent * 5 # Smaller penalty for cutting loss
|
||||
|
||||
# Apply fee penalty
|
||||
if self.position_size > 0:
|
||||
fee = (self.position_size / 1900) * 1
|
||||
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
||||
reward -= fee_penalty
|
||||
|
||||
return reward
|
||||
|
||||
def reset(self):
|
||||
"""Reset the environment to initial state"""
|
||||
"""Reset the environment to its initial state and return the initial observation"""
|
||||
self.balance = self.initial_balance
|
||||
self.position = 'flat'
|
||||
self.position_size = 0
|
||||
@ -338,24 +522,15 @@ class TradingEnvironment:
|
||||
self.entry_index = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
self.current_step = 0
|
||||
self.trades = []
|
||||
self.win_count = 0
|
||||
self.loss_count = 0
|
||||
self.episode_pnl = 0.0
|
||||
self.trade_signals = []
|
||||
self.total_pnl = 0.0
|
||||
self.total_fees = 0.0
|
||||
self.peak_balance = self.initial_balance
|
||||
self.max_drawdown = 0.0
|
||||
self.current_step = 0
|
||||
|
||||
# Keep data but reset current position
|
||||
if len(self.data) > self.window_size:
|
||||
self.current_step = self.window_size
|
||||
if self.data_format_is_list:
|
||||
self.current_price = self.data[self.current_step][4] # Close price is at index 4
|
||||
else:
|
||||
self.current_price = self.data[self.current_step]['close']
|
||||
|
||||
# Reset trade signals
|
||||
self.trade_signals = []
|
||||
self.win_count = 0
|
||||
self.loss_count = 0
|
||||
|
||||
return self.get_state()
|
||||
|
||||
@ -492,6 +667,227 @@ class TradingEnvironment:
|
||||
# Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE)
|
||||
reward = self.calculate_reward(action)
|
||||
|
||||
# Execute the action
|
||||
initial_balance = self.balance # Store initial balance to calculate PnL
|
||||
|
||||
# Open long position
|
||||
if action == 1 and self.position != 'long':
|
||||
if self.position == 'short':
|
||||
# Close short position first
|
||||
if self.position_size > 0:
|
||||
# Calculate PnL
|
||||
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
|
||||
pnl_dollar = pnl_percent * self.position_size * self.leverage
|
||||
|
||||
# Update balance and record trade
|
||||
self.balance += pnl_dollar
|
||||
self.total_pnl += pnl_dollar
|
||||
|
||||
# Apply trading fee (1 USD per 1.9k position)
|
||||
fee = (self.position_size / 1900) * 1
|
||||
self.balance -= fee
|
||||
self.total_fees += fee
|
||||
|
||||
# Record trade
|
||||
trade_duration = self.current_step - self.entry_index
|
||||
if self.data_format_is_list:
|
||||
timestamp = self.data[self.current_step][0] # Timestamp
|
||||
else:
|
||||
timestamp = self.data[self.current_step]['timestamp']
|
||||
|
||||
self.trades.append({
|
||||
'type': 'short',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.current_price,
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'fee': fee,
|
||||
'net_pnl': pnl_dollar - fee,
|
||||
'duration': trade_duration,
|
||||
'timestamp': timestamp,
|
||||
'reason': 'action_change'
|
||||
})
|
||||
|
||||
# Update win/loss count
|
||||
if pnl_dollar > 0:
|
||||
self.win_count += 1
|
||||
else:
|
||||
self.loss_count += 1
|
||||
|
||||
# Now open long position
|
||||
self.position = 'long'
|
||||
self.entry_price = self.current_price
|
||||
self.entry_index = self.current_step
|
||||
|
||||
# Calculate position size with risk management
|
||||
self.position_size = self.calculate_position_size()
|
||||
|
||||
# Apply trading fee (1 USD per 1.9k position)
|
||||
fee = (self.position_size / 1900) * 1
|
||||
self.balance -= fee
|
||||
self.total_fees += fee
|
||||
|
||||
# Set stop loss and take profit
|
||||
sl_percent = 0.02 # 2% stop loss
|
||||
tp_percent = 0.04 # 4% take profit
|
||||
|
||||
self.stop_loss = self.entry_price * (1 - sl_percent)
|
||||
self.take_profit = self.entry_price * (1 + tp_percent)
|
||||
|
||||
# Open short position
|
||||
elif action == 2 and self.position != 'short':
|
||||
if self.position == 'long':
|
||||
# Close long position first
|
||||
if self.position_size > 0:
|
||||
# Calculate PnL
|
||||
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
|
||||
pnl_dollar = pnl_percent * self.position_size * self.leverage
|
||||
|
||||
# Update balance and record trade
|
||||
self.balance += pnl_dollar
|
||||
self.total_pnl += pnl_dollar
|
||||
|
||||
# Apply trading fee (1 USD per 1.9k position)
|
||||
fee = (self.position_size / 1900) * 1
|
||||
self.balance -= fee
|
||||
self.total_fees += fee
|
||||
|
||||
# Record trade
|
||||
trade_duration = self.current_step - self.entry_index
|
||||
if self.data_format_is_list:
|
||||
timestamp = self.data[self.current_step][0] # Timestamp
|
||||
else:
|
||||
timestamp = self.data[self.current_step]['timestamp']
|
||||
|
||||
self.trades.append({
|
||||
'type': 'long',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.current_price,
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'fee': fee,
|
||||
'net_pnl': pnl_dollar - fee,
|
||||
'duration': trade_duration,
|
||||
'timestamp': timestamp,
|
||||
'reason': 'action_change'
|
||||
})
|
||||
|
||||
# Update win/loss count
|
||||
if pnl_dollar > 0:
|
||||
self.win_count += 1
|
||||
else:
|
||||
self.loss_count += 1
|
||||
|
||||
# Now open short position
|
||||
self.position = 'short'
|
||||
self.entry_price = self.current_price
|
||||
self.entry_index = self.current_step
|
||||
|
||||
# Calculate position size with risk management
|
||||
self.position_size = self.calculate_position_size()
|
||||
|
||||
# Apply trading fee (1 USD per 1.9k position)
|
||||
fee = (self.position_size / 1900) * 1
|
||||
self.balance -= fee
|
||||
self.total_fees += fee
|
||||
|
||||
# Set stop loss and take profit
|
||||
sl_percent = 0.02 # 2% stop loss
|
||||
tp_percent = 0.04 # 4% take profit
|
||||
|
||||
self.stop_loss = self.entry_price * (1 + sl_percent)
|
||||
self.take_profit = self.entry_price * (1 - tp_percent)
|
||||
|
||||
# Close position
|
||||
elif action == 3 and self.position != 'flat':
|
||||
if self.position == 'long':
|
||||
# Calculate PnL
|
||||
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
|
||||
pnl_dollar = pnl_percent * self.position_size * self.leverage
|
||||
|
||||
# Update balance and record trade
|
||||
self.balance += pnl_dollar
|
||||
self.total_pnl += pnl_dollar
|
||||
|
||||
# Apply trading fee (1 USD per 1.9k position)
|
||||
fee = (self.position_size / 1900) * 1
|
||||
self.balance -= fee
|
||||
self.total_fees += fee
|
||||
|
||||
# Record trade
|
||||
trade_duration = self.current_step - self.entry_index
|
||||
if self.data_format_is_list:
|
||||
timestamp = self.data[self.current_step][0] # Timestamp
|
||||
else:
|
||||
timestamp = self.data[self.current_step]['timestamp']
|
||||
|
||||
self.trades.append({
|
||||
'type': 'long',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.current_price,
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'fee': fee,
|
||||
'net_pnl': pnl_dollar - fee,
|
||||
'duration': trade_duration,
|
||||
'timestamp': timestamp,
|
||||
'reason': 'close_action'
|
||||
})
|
||||
|
||||
# Update win/loss count
|
||||
if pnl_dollar > 0:
|
||||
self.win_count += 1
|
||||
else:
|
||||
self.loss_count += 1
|
||||
|
||||
elif self.position == 'short':
|
||||
# Calculate PnL
|
||||
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
|
||||
pnl_dollar = pnl_percent * self.position_size * self.leverage
|
||||
|
||||
# Update balance and record trade
|
||||
self.balance += pnl_dollar
|
||||
self.total_pnl += pnl_dollar
|
||||
|
||||
# Apply trading fee (1 USD per 1.9k position)
|
||||
fee = (self.position_size / 1900) * 1
|
||||
self.balance -= fee
|
||||
self.total_fees += fee
|
||||
|
||||
# Record trade
|
||||
trade_duration = self.current_step - self.entry_index
|
||||
if self.data_format_is_list:
|
||||
timestamp = self.data[self.current_step][0] # Timestamp
|
||||
else:
|
||||
timestamp = self.data[self.current_step]['timestamp']
|
||||
|
||||
self.trades.append({
|
||||
'type': 'short',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.current_price,
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar,
|
||||
'fee': fee,
|
||||
'net_pnl': pnl_dollar - fee,
|
||||
'duration': trade_duration,
|
||||
'timestamp': timestamp,
|
||||
'reason': 'close_action'
|
||||
})
|
||||
|
||||
# Update win/loss count
|
||||
if pnl_dollar > 0:
|
||||
self.win_count += 1
|
||||
else:
|
||||
self.loss_count += 1
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.position_size = 0
|
||||
self.entry_price = 0
|
||||
self.entry_index = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
# Record trade signal for visualization
|
||||
if action > 0: # If not HOLD
|
||||
signal_type = None
|
||||
@ -529,13 +925,22 @@ class TradingEnvironment:
|
||||
# Get new state
|
||||
next_state = self.get_state()
|
||||
|
||||
# Update peak balance and drawdown
|
||||
if self.balance > self.peak_balance:
|
||||
self.peak_balance = self.balance
|
||||
|
||||
current_drawdown = (self.peak_balance - self.balance) / self.peak_balance if self.peak_balance > 0 else 0
|
||||
self.max_drawdown = max(self.max_drawdown, current_drawdown)
|
||||
|
||||
# Create info dictionary
|
||||
info = {
|
||||
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
|
||||
'price': self.current_price,
|
||||
'balance': self.balance,
|
||||
'position': self.position,
|
||||
'pnl': self.total_pnl
|
||||
'pnl': self.total_pnl,
|
||||
'fees': self.total_fees,
|
||||
'net_pnl': self.total_pnl - self.total_fees
|
||||
}
|
||||
|
||||
return next_state, reward, done, info
|
||||
@ -999,6 +1404,33 @@ class TradingEnvironment:
|
||||
else:
|
||||
# Small negative reward for holding in the wrong direction
|
||||
reward -= 0.1
|
||||
elif action == 1 or action == 2: # BUY or SELL
|
||||
# Apply trading fee as negative reward (1 USD per 1.9k position size)
|
||||
position_size = self.calculate_position_size()
|
||||
fee = (position_size / 1900) * 1 # Trading fee in USD
|
||||
|
||||
# Penalty for fee
|
||||
fee_penalty = fee / 10 # Scale down to make it a reasonable penalty
|
||||
reward -= fee_penalty
|
||||
|
||||
# Logging
|
||||
if hasattr(self, 'total_fees'):
|
||||
self.total_fees += fee
|
||||
else:
|
||||
self.total_fees = fee
|
||||
elif action == 3: # CLOSE
|
||||
# Apply trading fee as negative reward (1 USD per 1.9k position size)
|
||||
fee = (self.position_size / 1900) * 1 # Trading fee in USD
|
||||
|
||||
# Penalty for fee
|
||||
fee_penalty = fee / 10 # Scale down to make it a reasonable penalty
|
||||
reward -= fee_penalty
|
||||
|
||||
# Logging
|
||||
if hasattr(self, 'total_fees'):
|
||||
self.total_fees += fee
|
||||
else:
|
||||
self.total_fees = fee
|
||||
|
||||
# Add CNN pattern confidence to reward
|
||||
reward += pattern_confidence * 10
|
||||
@ -1360,6 +1792,74 @@ class TradingEnvironment:
|
||||
|
||||
return False
|
||||
|
||||
def add_chart_to_tensorboard(self, writer, step, title='Trading Chart'):
|
||||
"""
|
||||
Add a candlestick chart and metrics to TensorBoard
|
||||
|
||||
Parameters:
|
||||
- writer: TensorBoard writer
|
||||
- step: Current step
|
||||
- title: Title for the chart
|
||||
"""
|
||||
try:
|
||||
# Initialize writer if not provided
|
||||
if writer is None:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
writer = SummaryWriter()
|
||||
|
||||
# Log basic metrics
|
||||
writer.add_scalar('Balance', self.balance, step)
|
||||
writer.add_scalar('Total_PnL', self.total_pnl, step)
|
||||
|
||||
# Log total fees if available
|
||||
if hasattr(self, 'total_fees'):
|
||||
writer.add_scalar('Total_Fees', self.total_fees, step)
|
||||
writer.add_scalar('Net_PnL', self.total_pnl - self.total_fees, step)
|
||||
|
||||
# Log position info
|
||||
writer.add_scalar('Position_Size', self.position_size, step)
|
||||
|
||||
# Log drawdown and win rate
|
||||
writer.add_scalar('Max_Drawdown', self.max_drawdown, step)
|
||||
|
||||
win_rate = self.win_count / (self.win_count + self.loss_count) if (self.win_count + self.loss_count) > 0 else 0
|
||||
writer.add_scalar('Win_Rate', win_rate, step)
|
||||
|
||||
# Log trade count
|
||||
writer.add_scalar('Trade_Count', len(self.trades), step)
|
||||
|
||||
# Check if we have enough data for candlestick chart
|
||||
if len(self.data) <= 0:
|
||||
logger.warning("No data available for candlestick chart")
|
||||
return
|
||||
|
||||
# Create figure for candlestick chart (last 100 data points)
|
||||
start_idx = max(0, self.current_step - 100)
|
||||
end_idx = self.current_step
|
||||
|
||||
# Get recent trades for visualization (last 10 trades)
|
||||
recent_trades = self.trades[-10:] if self.trades else []
|
||||
|
||||
try:
|
||||
fig = create_candlestick_figure(
|
||||
self.data[start_idx:end_idx+1],
|
||||
title=title,
|
||||
trades=recent_trades
|
||||
)
|
||||
|
||||
# Add figure to TensorBoard
|
||||
writer.add_figure('Candlestick_Chart', fig, step)
|
||||
|
||||
# Close figure to free memory
|
||||
plt.close(fig)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating candlestick chart: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding chart to TensorBoard: {e}")
|
||||
# Continue execution even if chart fails
|
||||
|
||||
# Ensure GPU usage if available
|
||||
def get_device():
|
||||
"""Get the best available device (CUDA GPU or CPU)"""
|
||||
@ -1752,7 +2252,7 @@ class Agent:
|
||||
raise
|
||||
|
||||
def add_chart_to_tensorboard(self, env, step):
|
||||
"""Add candlestick chart to tensorboard"""
|
||||
"""Add candlestick chart to tensorboard and various metrics"""
|
||||
try:
|
||||
# Initialize writer if it doesn't exist
|
||||
if not hasattr(self, 'writer') or self.writer is None:
|
||||
@ -1791,22 +2291,37 @@ class Agent:
|
||||
if hasattr(env, 'trade_count'):
|
||||
self.writer.add_scalar('Trading/Trade_Count', env.trade_count, step)
|
||||
|
||||
# Log trading fees
|
||||
if hasattr(env, 'total_fees'):
|
||||
self.writer.add_scalar('Trading/Total_Fees', env.total_fees, step)
|
||||
# Also log net PnL (after fees)
|
||||
if hasattr(env, 'total_pnl'):
|
||||
self.writer.add_scalar('Trading/Net_PnL_After_Fees', env.total_pnl - env.total_fees, step)
|
||||
|
||||
# Add candlestick chart if we have enough data
|
||||
if len(env.data) >= 100:
|
||||
try:
|
||||
# Use the last 100 candles for the chart
|
||||
recent_data = env.data[-100:]
|
||||
|
||||
# Get recent trades if available
|
||||
recent_trades = []
|
||||
if hasattr(env, 'trades') and env.trades:
|
||||
recent_trades = None
|
||||
if hasattr(env, 'trades') and len(env.trades) > 0:
|
||||
recent_trades = env.trades[-10:] # Last 10 trades
|
||||
|
||||
# Create candlestick figure with the last 100 candles and recent trades
|
||||
fig = create_candlestick_figure(env.data[-100:], recent_trades)
|
||||
# Create candlestick figure
|
||||
fig = create_candlestick_figure(recent_data, recent_trades, f"Trading Chart - Step {step}")
|
||||
|
||||
# Add figure to tensorboard
|
||||
if fig:
|
||||
# Add to tensorboard
|
||||
self.writer.add_figure('Trading/Chart', fig, step)
|
||||
|
||||
# Close figure to free resources
|
||||
# Close figure to free memory
|
||||
plt.close(fig)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error adding chart to tensorboard: {e}")
|
||||
logger.warning(f"Error creating candlestick chart: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in add_chart_to_tensorboard: {e}")
|
||||
|
||||
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
||||
"""Get live price data using websockets"""
|
||||
@ -1888,12 +2403,15 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
||||
'cumulative_pnl': [],
|
||||
'drawdowns': [],
|
||||
'trade_counts': [],
|
||||
'loss_values': []
|
||||
'loss_values': [],
|
||||
'fees': [], # Track fees
|
||||
'net_pnl_after_fees': [] # Track net PnL after fees
|
||||
}
|
||||
|
||||
# Track best models
|
||||
best_reward = float('-inf')
|
||||
best_pnl = float('-inf')
|
||||
best_net_pnl = float('-inf') # Track best net PnL (after fees)
|
||||
|
||||
# Make directory for models if it doesn't exist
|
||||
os.makedirs('models', exist_ok=True)
|
||||
@ -1997,6 +2515,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
||||
# Calculate statistics from this episode
|
||||
balance = env.balance
|
||||
pnl = balance - env.initial_balance if hasattr(env, 'initial_balance') else 0
|
||||
fees = env.total_fees if hasattr(env, 'total_fees') else 0
|
||||
net_pnl = pnl - fees # Calculate net PnL after fees
|
||||
|
||||
# Get trading statistics
|
||||
trade_analysis = None
|
||||
@ -2015,6 +2535,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
||||
writer.add_scalar('Reward/episode', episode_reward, episode)
|
||||
writer.add_scalar('Balance/episode', balance, episode)
|
||||
writer.add_scalar('PnL/episode', pnl, episode)
|
||||
writer.add_scalar('NetPnL/episode', net_pnl, episode)
|
||||
writer.add_scalar('Fees/episode', fees, episode)
|
||||
writer.add_scalar('WinRate/episode', win_rate, episode)
|
||||
writer.add_scalar('TradeCount/episode', trade_count, episode)
|
||||
writer.add_scalar('Drawdown/episode', max_drawdown, episode)
|
||||
@ -2030,6 +2552,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
||||
stats['drawdowns'].append(max_drawdown)
|
||||
stats['trade_counts'].append(trade_count)
|
||||
stats['loss_values'].append(avg_loss)
|
||||
stats['fees'].append(fees)
|
||||
stats['net_pnl_after_fees'].append(net_pnl)
|
||||
|
||||
# Calculate and update cumulative PnL
|
||||
if len(stats['episode_pnls']) > 0:
|
||||
@ -2039,6 +2563,7 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
||||
stats['cumulative_pnl'].append(cumulative_pnl)
|
||||
if writer:
|
||||
writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode)
|
||||
writer.add_scalar('CumulativeNetPnL/episode', sum(stats['net_pnl_after_fees']), episode)
|
||||
|
||||
# Save model if this is the best reward or PnL
|
||||
if episode_reward > best_reward:
|
||||
@ -2051,6 +2576,12 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
||||
agent.save('models/trading_agent_best_pnl.pt')
|
||||
logging.info(f"New best PnL: ${best_pnl:.2f}")
|
||||
|
||||
# Save model if this is the best net PnL (after fees)
|
||||
if net_pnl > best_net_pnl:
|
||||
best_net_pnl = net_pnl
|
||||
agent.save('models/trading_agent_best_net_pnl.pt')
|
||||
logging.info(f"New best Net PnL: ${best_net_pnl:.2f}")
|
||||
|
||||
# Save checkpoint periodically
|
||||
if episode % 10 == 0:
|
||||
agent.save(f'models/trading_agent_checkpoint_{episode}.pt')
|
||||
@ -2063,6 +2594,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
||||
f"Reward: {episode_reward:.2f} | " +
|
||||
f"Balance: ${balance:.2f} | " +
|
||||
f"PnL: ${pnl:.2f} | " +
|
||||
f"Fees: ${fees:.2f} | " +
|
||||
f"Net PnL: ${net_pnl:.2f} | " +
|
||||
f"Win Rate: {win_rate:.2f} | " +
|
||||
f"Trades: {trade_count} | " +
|
||||
f"Loss: {avg_loss:.5f} | " +
|
||||
@ -2608,12 +3141,19 @@ async def main():
|
||||
# Initialize exchange
|
||||
exchange = await initialize_exchange()
|
||||
|
||||
# Create environment
|
||||
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode)
|
||||
# Create environment with updated parameters
|
||||
env = TradingEnvironment(
|
||||
initial_balance=INITIAL_BALANCE,
|
||||
window_size=30,
|
||||
leverage=args.leverage,
|
||||
exchange_id='mexc',
|
||||
symbol=args.symbol,
|
||||
timeframe=args.timeframe
|
||||
)
|
||||
|
||||
if args.mode == 'train':
|
||||
# Fetch initial data for training
|
||||
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
|
||||
await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000)
|
||||
|
||||
# Create agent with consistent parameters
|
||||
# Note: Using STATE_SIZE and action_size=4 for consistency
|
||||
@ -2957,6 +3497,9 @@ async def fetch_multi_timeframe_data(exchange, symbol, candle_cache):
|
||||
'1d': 86400 # Update every 1 day
|
||||
}
|
||||
|
||||
# TODO: For 1s/tick timeframes, we'll need to use the exchange's WebSocket API
|
||||
# for real-time data streaming instead of REST API. Implement this in the future.
|
||||
|
||||
limits = {
|
||||
'1s': 1000,
|
||||
'1m': 1000,
|
||||
|
Loading…
x
Reference in New Issue
Block a user