wip - better training loop; realtime scaffold
This commit is contained in:
parent
4de6352468
commit
5e9e6360af
@ -31,6 +31,7 @@ import matplotlib.gridspec as gridspec
|
|||||||
import datetime
|
import datetime
|
||||||
from datetime import datetime as dt
|
from datetime import datetime as dt
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from gym.spaces import Discrete, Box
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -267,70 +268,253 @@ class PricePredictionModel(nn.Module):
|
|||||||
return total_loss / epochs
|
return total_loss / epochs
|
||||||
|
|
||||||
class TradingEnvironment:
|
class TradingEnvironment:
|
||||||
def __init__(self, initial_balance=INITIAL_BALANCE, window_size=30, demo=True):
|
def __init__(self, data=None, features=None, feature_extractors=None, initial_balance=10000, leverage=50,
|
||||||
|
window_size=100, commission=0.0004, api_key=None, api_secret=None, exchange_id='binance',
|
||||||
|
symbol='ETH/USDT', timeframe='1m', init_length=5000, max_steps=10000):
|
||||||
"""Initialize the trading environment"""
|
"""Initialize the trading environment"""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.api_secret = api_secret
|
||||||
|
self.exchange_id = exchange_id
|
||||||
|
self.symbol = symbol
|
||||||
|
self.timeframe = timeframe
|
||||||
|
self.init_length = init_length
|
||||||
|
|
||||||
|
# TODO: For 1s/ticks timeframes, implement WebSocket API integration for real-time data
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize exchange if API credentials are provided
|
||||||
|
if api_key and api_secret:
|
||||||
|
self.exchange = initialize_exchange(exchange_id, api_key, api_secret)
|
||||||
|
logger.info(f"Exchange initialized: {exchange_id}")
|
||||||
|
# Fetch historical data
|
||||||
|
self.data = fetch_candles(self.exchange, self.symbol, self.timeframe, limit=self.init_length)
|
||||||
|
if not self.data:
|
||||||
|
raise ValueError(f"No data fetched for {self.symbol} on {self.exchange_id}")
|
||||||
|
self.data_format_is_list = isinstance(self.data[0], list)
|
||||||
|
logger.info(f"Loaded {len(self.data)} candles from exchange")
|
||||||
|
elif data is not None: # Use provided data
|
||||||
|
self.data = data
|
||||||
|
self.data_format_is_list = isinstance(self.data[0], list)
|
||||||
|
logger.info(f"Using provided data with {len(self.data)} candles")
|
||||||
|
else:
|
||||||
|
# Initialize with empty data, we'll load it later with fetch_initial_data
|
||||||
|
logger.warning("No data provided, initializing with empty data")
|
||||||
|
self.data = []
|
||||||
|
self.data_format_is_list = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error initializing environment: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Initialize features and feature extractors
|
||||||
|
if features is not None:
|
||||||
|
self.features = features
|
||||||
|
# Create a dictionary of features
|
||||||
|
self.features_dict = {f"feature_{i}": feature for i, feature in enumerate(features)}
|
||||||
|
else:
|
||||||
|
# Initialize features as a dictionary, not a list
|
||||||
|
self.features = {
|
||||||
|
'price': [],
|
||||||
|
'volume': [],
|
||||||
|
'rsi': [],
|
||||||
|
'macd': [],
|
||||||
|
'macd_signal': [],
|
||||||
|
'macd_hist': [],
|
||||||
|
'bollinger_upper': [],
|
||||||
|
'bollinger_mid': [],
|
||||||
|
'bollinger_lower': [],
|
||||||
|
'stoch_k': [],
|
||||||
|
'stoch_d': [],
|
||||||
|
'ema_9': [],
|
||||||
|
'ema_21': [],
|
||||||
|
'atr': []
|
||||||
|
}
|
||||||
|
self.features_dict = {}
|
||||||
|
|
||||||
|
if feature_extractors is None:
|
||||||
|
feature_extractors = []
|
||||||
|
self.feature_extractors = feature_extractors
|
||||||
|
|
||||||
|
# Environment parameters
|
||||||
self.initial_balance = initial_balance
|
self.initial_balance = initial_balance
|
||||||
self.balance = initial_balance
|
self.balance = initial_balance
|
||||||
self.window_size = window_size
|
self.leverage = leverage
|
||||||
self.demo = demo
|
|
||||||
self.data = []
|
|
||||||
self.position = 'flat' # 'flat', 'long', or 'short'
|
self.position = 'flat' # 'flat', 'long', or 'short'
|
||||||
self.position_size = 0
|
self.position_size = 0
|
||||||
self.entry_price = 0
|
self.entry_price = 0
|
||||||
self.entry_index = 0
|
self.entry_index = 0
|
||||||
self.stop_loss = 0
|
self.stop_loss = 0
|
||||||
self.take_profit = 0
|
self.take_profit = 0
|
||||||
|
self.commission = commission
|
||||||
|
self.total_pnl = 0
|
||||||
|
self.total_fees = 0.0 # Track total fees paid
|
||||||
self.trades = []
|
self.trades = []
|
||||||
|
self.trade_signals = []
|
||||||
|
self.current_step = 0
|
||||||
|
self.window_size = window_size
|
||||||
|
self.max_steps = max_steps
|
||||||
|
self.peak_balance = initial_balance
|
||||||
|
self.max_drawdown = 0
|
||||||
|
self.current_price = 0
|
||||||
self.win_count = 0
|
self.win_count = 0
|
||||||
self.loss_count = 0
|
self.loss_count = 0
|
||||||
self.total_pnl = 0.0
|
self.min_position_size = 100 # Minimum position size in USD
|
||||||
self.episode_pnl = 0.0
|
|
||||||
self.peak_balance = initial_balance
|
|
||||||
self.max_drawdown = 0.0
|
|
||||||
self.current_step = 0
|
|
||||||
self.current_price = 0
|
|
||||||
|
|
||||||
# For tracking signals for visualization
|
# Track candle patterns and reversal points
|
||||||
self.trade_signals = []
|
self.patterns = {}
|
||||||
|
self.reversal_points = []
|
||||||
|
|
||||||
# Initialize features
|
# Define observation and action spaces
|
||||||
self.features = {
|
num_features = len(self.features) if hasattr(self, 'features') and self.features else 0
|
||||||
'price': [],
|
state_dim = window_size * 5 + 5 + num_features # OHLCV + position info + features
|
||||||
'volume': [],
|
|
||||||
'rsi': [],
|
|
||||||
'macd': [],
|
|
||||||
'macd_signal': [],
|
|
||||||
'macd_hist': [],
|
|
||||||
'bollinger_upper': [],
|
|
||||||
'bollinger_mid': [],
|
|
||||||
'bollinger_lower': [],
|
|
||||||
'stoch_k': [],
|
|
||||||
'stoch_d': [],
|
|
||||||
'ema_9': [],
|
|
||||||
'ema_21': [],
|
|
||||||
'atr': []
|
|
||||||
}
|
|
||||||
|
|
||||||
# Initialize price predictor
|
self.action_space = Discrete(4) # 0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE
|
||||||
self.price_predictor = None
|
self.observation_space = Box(low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32)
|
||||||
self.predicted_prices = np.array([])
|
|
||||||
|
|
||||||
# Initialize optimal trade tracking
|
# Check if we have enough data
|
||||||
self.optimal_bottoms = []
|
if len(self.data) < self.window_size:
|
||||||
self.optimal_tops = []
|
logger.warning(f"Data length {len(self.data)} is less than window size {self.window_size}")
|
||||||
self.optimal_signals = np.array([])
|
|
||||||
|
|
||||||
# Add these new attributes
|
def calculate_reward(self, action):
|
||||||
self.leverage = MAX_LEVERAGE
|
"""Calculate reward based on the action taken"""
|
||||||
self.futures_symbol = "ETH_USDT" # Example futures symbol
|
reward = 0
|
||||||
self.position_mode = "hedge" # For simultaneous long/short positions
|
|
||||||
self.margin_mode = "cross" # Cross margin mode
|
|
||||||
|
|
||||||
# Initialize data format indicator (list or dict)
|
# Base reward structure
|
||||||
self.data_format_is_list = True
|
if self.position == 'flat':
|
||||||
|
if action == 0: # HOLD when flat
|
||||||
|
reward = 0.01 # Small reward for holding when no position
|
||||||
|
elif action == 1: # BUY/LONG
|
||||||
|
# Check for buy signal in CNN patterns
|
||||||
|
if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns:
|
||||||
|
buy_confidence = self.cnn_patterns['long_confidence']
|
||||||
|
# Scale by confidence
|
||||||
|
reward = 0.1 * buy_confidence * 10
|
||||||
|
else:
|
||||||
|
reward = 0.1 # Default reward for taking a position
|
||||||
|
|
||||||
|
# Apply fee penalty
|
||||||
|
if self.position_size > 0:
|
||||||
|
fee = (self.position_size / 1900) * 1
|
||||||
|
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
||||||
|
reward -= fee_penalty
|
||||||
|
elif action == 2: # SELL/SHORT
|
||||||
|
# Check for sell signal in CNN patterns
|
||||||
|
if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns:
|
||||||
|
sell_confidence = self.cnn_patterns['short_confidence']
|
||||||
|
# Scale by confidence
|
||||||
|
reward = 0.1 * sell_confidence * 10
|
||||||
|
else:
|
||||||
|
reward = 0.1 # Default reward for taking a position
|
||||||
|
|
||||||
|
# Apply fee penalty
|
||||||
|
if self.position_size > 0:
|
||||||
|
fee = (self.position_size / 1900) * 1
|
||||||
|
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
||||||
|
reward -= fee_penalty
|
||||||
|
elif action == 3: # CLOSE when no position
|
||||||
|
reward = -0.1 # Penalty for trying to close no position
|
||||||
|
|
||||||
|
elif self.position == 'long':
|
||||||
|
if action == 0: # HOLD long position
|
||||||
|
# Calculate price change since entry
|
||||||
|
price_change = (self.current_price - self.entry_price) / self.entry_price
|
||||||
|
|
||||||
|
# Reward or penalize based on price movement
|
||||||
|
if price_change > 0:
|
||||||
|
reward = price_change * 10 # Reward for holding profitable position
|
||||||
|
else:
|
||||||
|
reward = price_change * 5 # Smaller penalty for holding losing position
|
||||||
|
|
||||||
|
elif action == 1: # BUY when already long
|
||||||
|
reward = -0.1 # Penalty for redundant action
|
||||||
|
|
||||||
|
elif action == 2: # SELL when long (reversal)
|
||||||
|
# Calculate PnL
|
||||||
|
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
|
||||||
|
|
||||||
|
if pnl_percent > 0:
|
||||||
|
reward = -0.5 # Penalty for closing profitable long position to go short
|
||||||
|
else:
|
||||||
|
# Check for sell signal in CNN patterns
|
||||||
|
if hasattr(self, 'cnn_patterns') and 'short_confidence' in self.cnn_patterns:
|
||||||
|
sell_confidence = self.cnn_patterns['short_confidence']
|
||||||
|
reward = 0.2 * sell_confidence * 10 # Reward for correct reversal
|
||||||
|
else:
|
||||||
|
reward = 0.2 # Default reward for cutting loss
|
||||||
|
|
||||||
|
# Apply fee penalty
|
||||||
|
if self.position_size > 0:
|
||||||
|
fee = (self.position_size / 1900) * 1
|
||||||
|
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
||||||
|
reward -= fee_penalty
|
||||||
|
|
||||||
|
elif action == 3: # CLOSE long position
|
||||||
|
# Calculate PnL
|
||||||
|
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
|
||||||
|
|
||||||
|
if pnl_percent > 0:
|
||||||
|
reward = pnl_percent * 15 # Higher reward for taking profit
|
||||||
|
else:
|
||||||
|
reward = pnl_percent * 5 # Smaller penalty for cutting loss
|
||||||
|
|
||||||
|
# Apply fee penalty
|
||||||
|
if self.position_size > 0:
|
||||||
|
fee = (self.position_size / 1900) * 1
|
||||||
|
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
||||||
|
reward -= fee_penalty
|
||||||
|
|
||||||
|
elif self.position == 'short':
|
||||||
|
if action == 0: # HOLD short position
|
||||||
|
# Calculate price change since entry
|
||||||
|
price_change = (self.entry_price - self.current_price) / self.entry_price
|
||||||
|
|
||||||
|
# Reward or penalize based on price movement
|
||||||
|
if price_change > 0:
|
||||||
|
reward = price_change * 10 # Reward for holding profitable position
|
||||||
|
else:
|
||||||
|
reward = price_change * 5 # Smaller penalty for holding losing position
|
||||||
|
|
||||||
|
elif action == 1: # BUY when short (reversal)
|
||||||
|
# Calculate PnL
|
||||||
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
|
||||||
|
|
||||||
|
if pnl_percent > 0:
|
||||||
|
reward = -0.5 # Penalty for closing profitable short position to go long
|
||||||
|
else:
|
||||||
|
# Check for buy signal in CNN patterns
|
||||||
|
if hasattr(self, 'cnn_patterns') and 'long_confidence' in self.cnn_patterns:
|
||||||
|
buy_confidence = self.cnn_patterns['long_confidence']
|
||||||
|
reward = 0.2 * buy_confidence * 10 # Reward for correct reversal
|
||||||
|
else:
|
||||||
|
reward = 0.2 # Default reward for cutting loss
|
||||||
|
|
||||||
|
# Apply fee penalty
|
||||||
|
if self.position_size > 0:
|
||||||
|
fee = (self.position_size / 1900) * 1
|
||||||
|
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
||||||
|
reward -= fee_penalty
|
||||||
|
|
||||||
|
elif action == 2: # SELL when already short
|
||||||
|
reward = -0.1 # Penalty for redundant action
|
||||||
|
|
||||||
|
elif action == 3: # CLOSE short position
|
||||||
|
# Calculate PnL
|
||||||
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
|
||||||
|
|
||||||
|
if pnl_percent > 0:
|
||||||
|
reward = pnl_percent * 15 # Higher reward for taking profit
|
||||||
|
else:
|
||||||
|
reward = pnl_percent * 5 # Smaller penalty for cutting loss
|
||||||
|
|
||||||
|
# Apply fee penalty
|
||||||
|
if self.position_size > 0:
|
||||||
|
fee = (self.position_size / 1900) * 1
|
||||||
|
fee_penalty = min(0.05, fee / 100) # Scale fee to a small penalty, max 0.05
|
||||||
|
reward -= fee_penalty
|
||||||
|
|
||||||
|
return reward
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the environment to initial state"""
|
"""Reset the environment to its initial state and return the initial observation"""
|
||||||
self.balance = self.initial_balance
|
self.balance = self.initial_balance
|
||||||
self.position = 'flat'
|
self.position = 'flat'
|
||||||
self.position_size = 0
|
self.position_size = 0
|
||||||
@ -338,24 +522,15 @@ class TradingEnvironment:
|
|||||||
self.entry_index = 0
|
self.entry_index = 0
|
||||||
self.stop_loss = 0
|
self.stop_loss = 0
|
||||||
self.take_profit = 0
|
self.take_profit = 0
|
||||||
|
self.current_step = 0
|
||||||
self.trades = []
|
self.trades = []
|
||||||
self.win_count = 0
|
self.trade_signals = []
|
||||||
self.loss_count = 0
|
self.total_pnl = 0.0
|
||||||
self.episode_pnl = 0.0
|
self.total_fees = 0.0
|
||||||
self.peak_balance = self.initial_balance
|
self.peak_balance = self.initial_balance
|
||||||
self.max_drawdown = 0.0
|
self.max_drawdown = 0.0
|
||||||
self.current_step = 0
|
self.win_count = 0
|
||||||
|
self.loss_count = 0
|
||||||
# Keep data but reset current position
|
|
||||||
if len(self.data) > self.window_size:
|
|
||||||
self.current_step = self.window_size
|
|
||||||
if self.data_format_is_list:
|
|
||||||
self.current_price = self.data[self.current_step][4] # Close price is at index 4
|
|
||||||
else:
|
|
||||||
self.current_price = self.data[self.current_step]['close']
|
|
||||||
|
|
||||||
# Reset trade signals
|
|
||||||
self.trade_signals = []
|
|
||||||
|
|
||||||
return self.get_state()
|
return self.get_state()
|
||||||
|
|
||||||
@ -492,6 +667,227 @@ class TradingEnvironment:
|
|||||||
# Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE)
|
# Process action (0: HOLD, 1: BUY/LONG, 2: SELL/SHORT, 3: CLOSE)
|
||||||
reward = self.calculate_reward(action)
|
reward = self.calculate_reward(action)
|
||||||
|
|
||||||
|
# Execute the action
|
||||||
|
initial_balance = self.balance # Store initial balance to calculate PnL
|
||||||
|
|
||||||
|
# Open long position
|
||||||
|
if action == 1 and self.position != 'long':
|
||||||
|
if self.position == 'short':
|
||||||
|
# Close short position first
|
||||||
|
if self.position_size > 0:
|
||||||
|
# Calculate PnL
|
||||||
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
|
||||||
|
pnl_dollar = pnl_percent * self.position_size * self.leverage
|
||||||
|
|
||||||
|
# Update balance and record trade
|
||||||
|
self.balance += pnl_dollar
|
||||||
|
self.total_pnl += pnl_dollar
|
||||||
|
|
||||||
|
# Apply trading fee (1 USD per 1.9k position)
|
||||||
|
fee = (self.position_size / 1900) * 1
|
||||||
|
self.balance -= fee
|
||||||
|
self.total_fees += fee
|
||||||
|
|
||||||
|
# Record trade
|
||||||
|
trade_duration = self.current_step - self.entry_index
|
||||||
|
if self.data_format_is_list:
|
||||||
|
timestamp = self.data[self.current_step][0] # Timestamp
|
||||||
|
else:
|
||||||
|
timestamp = self.data[self.current_step]['timestamp']
|
||||||
|
|
||||||
|
self.trades.append({
|
||||||
|
'type': 'short',
|
||||||
|
'entry': self.entry_price,
|
||||||
|
'exit': self.current_price,
|
||||||
|
'pnl_percent': pnl_percent,
|
||||||
|
'pnl_dollar': pnl_dollar,
|
||||||
|
'fee': fee,
|
||||||
|
'net_pnl': pnl_dollar - fee,
|
||||||
|
'duration': trade_duration,
|
||||||
|
'timestamp': timestamp,
|
||||||
|
'reason': 'action_change'
|
||||||
|
})
|
||||||
|
|
||||||
|
# Update win/loss count
|
||||||
|
if pnl_dollar > 0:
|
||||||
|
self.win_count += 1
|
||||||
|
else:
|
||||||
|
self.loss_count += 1
|
||||||
|
|
||||||
|
# Now open long position
|
||||||
|
self.position = 'long'
|
||||||
|
self.entry_price = self.current_price
|
||||||
|
self.entry_index = self.current_step
|
||||||
|
|
||||||
|
# Calculate position size with risk management
|
||||||
|
self.position_size = self.calculate_position_size()
|
||||||
|
|
||||||
|
# Apply trading fee (1 USD per 1.9k position)
|
||||||
|
fee = (self.position_size / 1900) * 1
|
||||||
|
self.balance -= fee
|
||||||
|
self.total_fees += fee
|
||||||
|
|
||||||
|
# Set stop loss and take profit
|
||||||
|
sl_percent = 0.02 # 2% stop loss
|
||||||
|
tp_percent = 0.04 # 4% take profit
|
||||||
|
|
||||||
|
self.stop_loss = self.entry_price * (1 - sl_percent)
|
||||||
|
self.take_profit = self.entry_price * (1 + tp_percent)
|
||||||
|
|
||||||
|
# Open short position
|
||||||
|
elif action == 2 and self.position != 'short':
|
||||||
|
if self.position == 'long':
|
||||||
|
# Close long position first
|
||||||
|
if self.position_size > 0:
|
||||||
|
# Calculate PnL
|
||||||
|
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
|
||||||
|
pnl_dollar = pnl_percent * self.position_size * self.leverage
|
||||||
|
|
||||||
|
# Update balance and record trade
|
||||||
|
self.balance += pnl_dollar
|
||||||
|
self.total_pnl += pnl_dollar
|
||||||
|
|
||||||
|
# Apply trading fee (1 USD per 1.9k position)
|
||||||
|
fee = (self.position_size / 1900) * 1
|
||||||
|
self.balance -= fee
|
||||||
|
self.total_fees += fee
|
||||||
|
|
||||||
|
# Record trade
|
||||||
|
trade_duration = self.current_step - self.entry_index
|
||||||
|
if self.data_format_is_list:
|
||||||
|
timestamp = self.data[self.current_step][0] # Timestamp
|
||||||
|
else:
|
||||||
|
timestamp = self.data[self.current_step]['timestamp']
|
||||||
|
|
||||||
|
self.trades.append({
|
||||||
|
'type': 'long',
|
||||||
|
'entry': self.entry_price,
|
||||||
|
'exit': self.current_price,
|
||||||
|
'pnl_percent': pnl_percent,
|
||||||
|
'pnl_dollar': pnl_dollar,
|
||||||
|
'fee': fee,
|
||||||
|
'net_pnl': pnl_dollar - fee,
|
||||||
|
'duration': trade_duration,
|
||||||
|
'timestamp': timestamp,
|
||||||
|
'reason': 'action_change'
|
||||||
|
})
|
||||||
|
|
||||||
|
# Update win/loss count
|
||||||
|
if pnl_dollar > 0:
|
||||||
|
self.win_count += 1
|
||||||
|
else:
|
||||||
|
self.loss_count += 1
|
||||||
|
|
||||||
|
# Now open short position
|
||||||
|
self.position = 'short'
|
||||||
|
self.entry_price = self.current_price
|
||||||
|
self.entry_index = self.current_step
|
||||||
|
|
||||||
|
# Calculate position size with risk management
|
||||||
|
self.position_size = self.calculate_position_size()
|
||||||
|
|
||||||
|
# Apply trading fee (1 USD per 1.9k position)
|
||||||
|
fee = (self.position_size / 1900) * 1
|
||||||
|
self.balance -= fee
|
||||||
|
self.total_fees += fee
|
||||||
|
|
||||||
|
# Set stop loss and take profit
|
||||||
|
sl_percent = 0.02 # 2% stop loss
|
||||||
|
tp_percent = 0.04 # 4% take profit
|
||||||
|
|
||||||
|
self.stop_loss = self.entry_price * (1 + sl_percent)
|
||||||
|
self.take_profit = self.entry_price * (1 - tp_percent)
|
||||||
|
|
||||||
|
# Close position
|
||||||
|
elif action == 3 and self.position != 'flat':
|
||||||
|
if self.position == 'long':
|
||||||
|
# Calculate PnL
|
||||||
|
pnl_percent = (self.current_price - self.entry_price) / self.entry_price
|
||||||
|
pnl_dollar = pnl_percent * self.position_size * self.leverage
|
||||||
|
|
||||||
|
# Update balance and record trade
|
||||||
|
self.balance += pnl_dollar
|
||||||
|
self.total_pnl += pnl_dollar
|
||||||
|
|
||||||
|
# Apply trading fee (1 USD per 1.9k position)
|
||||||
|
fee = (self.position_size / 1900) * 1
|
||||||
|
self.balance -= fee
|
||||||
|
self.total_fees += fee
|
||||||
|
|
||||||
|
# Record trade
|
||||||
|
trade_duration = self.current_step - self.entry_index
|
||||||
|
if self.data_format_is_list:
|
||||||
|
timestamp = self.data[self.current_step][0] # Timestamp
|
||||||
|
else:
|
||||||
|
timestamp = self.data[self.current_step]['timestamp']
|
||||||
|
|
||||||
|
self.trades.append({
|
||||||
|
'type': 'long',
|
||||||
|
'entry': self.entry_price,
|
||||||
|
'exit': self.current_price,
|
||||||
|
'pnl_percent': pnl_percent,
|
||||||
|
'pnl_dollar': pnl_dollar,
|
||||||
|
'fee': fee,
|
||||||
|
'net_pnl': pnl_dollar - fee,
|
||||||
|
'duration': trade_duration,
|
||||||
|
'timestamp': timestamp,
|
||||||
|
'reason': 'close_action'
|
||||||
|
})
|
||||||
|
|
||||||
|
# Update win/loss count
|
||||||
|
if pnl_dollar > 0:
|
||||||
|
self.win_count += 1
|
||||||
|
else:
|
||||||
|
self.loss_count += 1
|
||||||
|
|
||||||
|
elif self.position == 'short':
|
||||||
|
# Calculate PnL
|
||||||
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price
|
||||||
|
pnl_dollar = pnl_percent * self.position_size * self.leverage
|
||||||
|
|
||||||
|
# Update balance and record trade
|
||||||
|
self.balance += pnl_dollar
|
||||||
|
self.total_pnl += pnl_dollar
|
||||||
|
|
||||||
|
# Apply trading fee (1 USD per 1.9k position)
|
||||||
|
fee = (self.position_size / 1900) * 1
|
||||||
|
self.balance -= fee
|
||||||
|
self.total_fees += fee
|
||||||
|
|
||||||
|
# Record trade
|
||||||
|
trade_duration = self.current_step - self.entry_index
|
||||||
|
if self.data_format_is_list:
|
||||||
|
timestamp = self.data[self.current_step][0] # Timestamp
|
||||||
|
else:
|
||||||
|
timestamp = self.data[self.current_step]['timestamp']
|
||||||
|
|
||||||
|
self.trades.append({
|
||||||
|
'type': 'short',
|
||||||
|
'entry': self.entry_price,
|
||||||
|
'exit': self.current_price,
|
||||||
|
'pnl_percent': pnl_percent,
|
||||||
|
'pnl_dollar': pnl_dollar,
|
||||||
|
'fee': fee,
|
||||||
|
'net_pnl': pnl_dollar - fee,
|
||||||
|
'duration': trade_duration,
|
||||||
|
'timestamp': timestamp,
|
||||||
|
'reason': 'close_action'
|
||||||
|
})
|
||||||
|
|
||||||
|
# Update win/loss count
|
||||||
|
if pnl_dollar > 0:
|
||||||
|
self.win_count += 1
|
||||||
|
else:
|
||||||
|
self.loss_count += 1
|
||||||
|
|
||||||
|
# Reset position
|
||||||
|
self.position = 'flat'
|
||||||
|
self.position_size = 0
|
||||||
|
self.entry_price = 0
|
||||||
|
self.entry_index = 0
|
||||||
|
self.stop_loss = 0
|
||||||
|
self.take_profit = 0
|
||||||
|
|
||||||
# Record trade signal for visualization
|
# Record trade signal for visualization
|
||||||
if action > 0: # If not HOLD
|
if action > 0: # If not HOLD
|
||||||
signal_type = None
|
signal_type = None
|
||||||
@ -529,13 +925,22 @@ class TradingEnvironment:
|
|||||||
# Get new state
|
# Get new state
|
||||||
next_state = self.get_state()
|
next_state = self.get_state()
|
||||||
|
|
||||||
|
# Update peak balance and drawdown
|
||||||
|
if self.balance > self.peak_balance:
|
||||||
|
self.peak_balance = self.balance
|
||||||
|
|
||||||
|
current_drawdown = (self.peak_balance - self.balance) / self.peak_balance if self.peak_balance > 0 else 0
|
||||||
|
self.max_drawdown = max(self.max_drawdown, current_drawdown)
|
||||||
|
|
||||||
# Create info dictionary
|
# Create info dictionary
|
||||||
info = {
|
info = {
|
||||||
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
|
'action': 'hold' if action == 0 else 'buy' if action == 1 else 'sell' if action == 2 else 'close',
|
||||||
'price': self.current_price,
|
'price': self.current_price,
|
||||||
'balance': self.balance,
|
'balance': self.balance,
|
||||||
'position': self.position,
|
'position': self.position,
|
||||||
'pnl': self.total_pnl
|
'pnl': self.total_pnl,
|
||||||
|
'fees': self.total_fees,
|
||||||
|
'net_pnl': self.total_pnl - self.total_fees
|
||||||
}
|
}
|
||||||
|
|
||||||
return next_state, reward, done, info
|
return next_state, reward, done, info
|
||||||
@ -999,6 +1404,33 @@ class TradingEnvironment:
|
|||||||
else:
|
else:
|
||||||
# Small negative reward for holding in the wrong direction
|
# Small negative reward for holding in the wrong direction
|
||||||
reward -= 0.1
|
reward -= 0.1
|
||||||
|
elif action == 1 or action == 2: # BUY or SELL
|
||||||
|
# Apply trading fee as negative reward (1 USD per 1.9k position size)
|
||||||
|
position_size = self.calculate_position_size()
|
||||||
|
fee = (position_size / 1900) * 1 # Trading fee in USD
|
||||||
|
|
||||||
|
# Penalty for fee
|
||||||
|
fee_penalty = fee / 10 # Scale down to make it a reasonable penalty
|
||||||
|
reward -= fee_penalty
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
if hasattr(self, 'total_fees'):
|
||||||
|
self.total_fees += fee
|
||||||
|
else:
|
||||||
|
self.total_fees = fee
|
||||||
|
elif action == 3: # CLOSE
|
||||||
|
# Apply trading fee as negative reward (1 USD per 1.9k position size)
|
||||||
|
fee = (self.position_size / 1900) * 1 # Trading fee in USD
|
||||||
|
|
||||||
|
# Penalty for fee
|
||||||
|
fee_penalty = fee / 10 # Scale down to make it a reasonable penalty
|
||||||
|
reward -= fee_penalty
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
if hasattr(self, 'total_fees'):
|
||||||
|
self.total_fees += fee
|
||||||
|
else:
|
||||||
|
self.total_fees = fee
|
||||||
|
|
||||||
# Add CNN pattern confidence to reward
|
# Add CNN pattern confidence to reward
|
||||||
reward += pattern_confidence * 10
|
reward += pattern_confidence * 10
|
||||||
@ -1360,6 +1792,74 @@ class TradingEnvironment:
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def add_chart_to_tensorboard(self, writer, step, title='Trading Chart'):
|
||||||
|
"""
|
||||||
|
Add a candlestick chart and metrics to TensorBoard
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- writer: TensorBoard writer
|
||||||
|
- step: Current step
|
||||||
|
- title: Title for the chart
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Initialize writer if not provided
|
||||||
|
if writer is None:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
writer = SummaryWriter()
|
||||||
|
|
||||||
|
# Log basic metrics
|
||||||
|
writer.add_scalar('Balance', self.balance, step)
|
||||||
|
writer.add_scalar('Total_PnL', self.total_pnl, step)
|
||||||
|
|
||||||
|
# Log total fees if available
|
||||||
|
if hasattr(self, 'total_fees'):
|
||||||
|
writer.add_scalar('Total_Fees', self.total_fees, step)
|
||||||
|
writer.add_scalar('Net_PnL', self.total_pnl - self.total_fees, step)
|
||||||
|
|
||||||
|
# Log position info
|
||||||
|
writer.add_scalar('Position_Size', self.position_size, step)
|
||||||
|
|
||||||
|
# Log drawdown and win rate
|
||||||
|
writer.add_scalar('Max_Drawdown', self.max_drawdown, step)
|
||||||
|
|
||||||
|
win_rate = self.win_count / (self.win_count + self.loss_count) if (self.win_count + self.loss_count) > 0 else 0
|
||||||
|
writer.add_scalar('Win_Rate', win_rate, step)
|
||||||
|
|
||||||
|
# Log trade count
|
||||||
|
writer.add_scalar('Trade_Count', len(self.trades), step)
|
||||||
|
|
||||||
|
# Check if we have enough data for candlestick chart
|
||||||
|
if len(self.data) <= 0:
|
||||||
|
logger.warning("No data available for candlestick chart")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create figure for candlestick chart (last 100 data points)
|
||||||
|
start_idx = max(0, self.current_step - 100)
|
||||||
|
end_idx = self.current_step
|
||||||
|
|
||||||
|
# Get recent trades for visualization (last 10 trades)
|
||||||
|
recent_trades = self.trades[-10:] if self.trades else []
|
||||||
|
|
||||||
|
try:
|
||||||
|
fig = create_candlestick_figure(
|
||||||
|
self.data[start_idx:end_idx+1],
|
||||||
|
title=title,
|
||||||
|
trades=recent_trades
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add figure to TensorBoard
|
||||||
|
writer.add_figure('Candlestick_Chart', fig, step)
|
||||||
|
|
||||||
|
# Close figure to free memory
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating candlestick chart: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error adding chart to TensorBoard: {e}")
|
||||||
|
# Continue execution even if chart fails
|
||||||
|
|
||||||
# Ensure GPU usage if available
|
# Ensure GPU usage if available
|
||||||
def get_device():
|
def get_device():
|
||||||
"""Get the best available device (CUDA GPU or CPU)"""
|
"""Get the best available device (CUDA GPU or CPU)"""
|
||||||
@ -1752,7 +2252,7 @@ class Agent:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def add_chart_to_tensorboard(self, env, step):
|
def add_chart_to_tensorboard(self, env, step):
|
||||||
"""Add candlestick chart to tensorboard"""
|
"""Add candlestick chart to tensorboard and various metrics"""
|
||||||
try:
|
try:
|
||||||
# Initialize writer if it doesn't exist
|
# Initialize writer if it doesn't exist
|
||||||
if not hasattr(self, 'writer') or self.writer is None:
|
if not hasattr(self, 'writer') or self.writer is None:
|
||||||
@ -1791,22 +2291,37 @@ class Agent:
|
|||||||
if hasattr(env, 'trade_count'):
|
if hasattr(env, 'trade_count'):
|
||||||
self.writer.add_scalar('Trading/Trade_Count', env.trade_count, step)
|
self.writer.add_scalar('Trading/Trade_Count', env.trade_count, step)
|
||||||
|
|
||||||
# Get recent trades if available
|
# Log trading fees
|
||||||
recent_trades = []
|
if hasattr(env, 'total_fees'):
|
||||||
if hasattr(env, 'trades') and env.trades:
|
self.writer.add_scalar('Trading/Total_Fees', env.total_fees, step)
|
||||||
recent_trades = env.trades[-10:] # Last 10 trades
|
# Also log net PnL (after fees)
|
||||||
|
if hasattr(env, 'total_pnl'):
|
||||||
|
self.writer.add_scalar('Trading/Net_PnL_After_Fees', env.total_pnl - env.total_fees, step)
|
||||||
|
|
||||||
# Create candlestick figure with the last 100 candles and recent trades
|
# Add candlestick chart if we have enough data
|
||||||
fig = create_candlestick_figure(env.data[-100:], recent_trades)
|
if len(env.data) >= 100:
|
||||||
|
try:
|
||||||
|
# Use the last 100 candles for the chart
|
||||||
|
recent_data = env.data[-100:]
|
||||||
|
|
||||||
# Add figure to tensorboard
|
# Get recent trades if available
|
||||||
self.writer.add_figure('Trading/Chart', fig, step)
|
recent_trades = None
|
||||||
|
if hasattr(env, 'trades') and len(env.trades) > 0:
|
||||||
|
recent_trades = env.trades[-10:] # Last 10 trades
|
||||||
|
|
||||||
# Close figure to free resources
|
# Create candlestick figure
|
||||||
plt.close(fig)
|
fig = create_candlestick_figure(recent_data, recent_trades, f"Trading Chart - Step {step}")
|
||||||
|
|
||||||
|
if fig:
|
||||||
|
# Add to tensorboard
|
||||||
|
self.writer.add_figure('Trading/Chart', fig, step)
|
||||||
|
|
||||||
|
# Close figure to free memory
|
||||||
|
plt.close(fig)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error creating candlestick chart: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error adding chart to tensorboard: {e}")
|
logger.error(f"Error in add_chart_to_tensorboard: {e}")
|
||||||
|
|
||||||
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
||||||
"""Get live price data using websockets"""
|
"""Get live price data using websockets"""
|
||||||
@ -1888,12 +2403,15 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
|||||||
'cumulative_pnl': [],
|
'cumulative_pnl': [],
|
||||||
'drawdowns': [],
|
'drawdowns': [],
|
||||||
'trade_counts': [],
|
'trade_counts': [],
|
||||||
'loss_values': []
|
'loss_values': [],
|
||||||
|
'fees': [], # Track fees
|
||||||
|
'net_pnl_after_fees': [] # Track net PnL after fees
|
||||||
}
|
}
|
||||||
|
|
||||||
# Track best models
|
# Track best models
|
||||||
best_reward = float('-inf')
|
best_reward = float('-inf')
|
||||||
best_pnl = float('-inf')
|
best_pnl = float('-inf')
|
||||||
|
best_net_pnl = float('-inf') # Track best net PnL (after fees)
|
||||||
|
|
||||||
# Make directory for models if it doesn't exist
|
# Make directory for models if it doesn't exist
|
||||||
os.makedirs('models', exist_ok=True)
|
os.makedirs('models', exist_ok=True)
|
||||||
@ -1997,6 +2515,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
|||||||
# Calculate statistics from this episode
|
# Calculate statistics from this episode
|
||||||
balance = env.balance
|
balance = env.balance
|
||||||
pnl = balance - env.initial_balance if hasattr(env, 'initial_balance') else 0
|
pnl = balance - env.initial_balance if hasattr(env, 'initial_balance') else 0
|
||||||
|
fees = env.total_fees if hasattr(env, 'total_fees') else 0
|
||||||
|
net_pnl = pnl - fees # Calculate net PnL after fees
|
||||||
|
|
||||||
# Get trading statistics
|
# Get trading statistics
|
||||||
trade_analysis = None
|
trade_analysis = None
|
||||||
@ -2015,6 +2535,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
|||||||
writer.add_scalar('Reward/episode', episode_reward, episode)
|
writer.add_scalar('Reward/episode', episode_reward, episode)
|
||||||
writer.add_scalar('Balance/episode', balance, episode)
|
writer.add_scalar('Balance/episode', balance, episode)
|
||||||
writer.add_scalar('PnL/episode', pnl, episode)
|
writer.add_scalar('PnL/episode', pnl, episode)
|
||||||
|
writer.add_scalar('NetPnL/episode', net_pnl, episode)
|
||||||
|
writer.add_scalar('Fees/episode', fees, episode)
|
||||||
writer.add_scalar('WinRate/episode', win_rate, episode)
|
writer.add_scalar('WinRate/episode', win_rate, episode)
|
||||||
writer.add_scalar('TradeCount/episode', trade_count, episode)
|
writer.add_scalar('TradeCount/episode', trade_count, episode)
|
||||||
writer.add_scalar('Drawdown/episode', max_drawdown, episode)
|
writer.add_scalar('Drawdown/episode', max_drawdown, episode)
|
||||||
@ -2030,6 +2552,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
|||||||
stats['drawdowns'].append(max_drawdown)
|
stats['drawdowns'].append(max_drawdown)
|
||||||
stats['trade_counts'].append(trade_count)
|
stats['trade_counts'].append(trade_count)
|
||||||
stats['loss_values'].append(avg_loss)
|
stats['loss_values'].append(avg_loss)
|
||||||
|
stats['fees'].append(fees)
|
||||||
|
stats['net_pnl_after_fees'].append(net_pnl)
|
||||||
|
|
||||||
# Calculate and update cumulative PnL
|
# Calculate and update cumulative PnL
|
||||||
if len(stats['episode_pnls']) > 0:
|
if len(stats['episode_pnls']) > 0:
|
||||||
@ -2039,6 +2563,7 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
|||||||
stats['cumulative_pnl'].append(cumulative_pnl)
|
stats['cumulative_pnl'].append(cumulative_pnl)
|
||||||
if writer:
|
if writer:
|
||||||
writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode)
|
writer.add_scalar('CumulativePnL/episode', cumulative_pnl, episode)
|
||||||
|
writer.add_scalar('CumulativeNetPnL/episode', sum(stats['net_pnl_after_fees']), episode)
|
||||||
|
|
||||||
# Save model if this is the best reward or PnL
|
# Save model if this is the best reward or PnL
|
||||||
if episode_reward > best_reward:
|
if episode_reward > best_reward:
|
||||||
@ -2051,6 +2576,12 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
|||||||
agent.save('models/trading_agent_best_pnl.pt')
|
agent.save('models/trading_agent_best_pnl.pt')
|
||||||
logging.info(f"New best PnL: ${best_pnl:.2f}")
|
logging.info(f"New best PnL: ${best_pnl:.2f}")
|
||||||
|
|
||||||
|
# Save model if this is the best net PnL (after fees)
|
||||||
|
if net_pnl > best_net_pnl:
|
||||||
|
best_net_pnl = net_pnl
|
||||||
|
agent.save('models/trading_agent_best_net_pnl.pt')
|
||||||
|
logging.info(f"New best Net PnL: ${best_net_pnl:.2f}")
|
||||||
|
|
||||||
# Save checkpoint periodically
|
# Save checkpoint periodically
|
||||||
if episode % 10 == 0:
|
if episode % 10 == 0:
|
||||||
agent.save(f'models/trading_agent_checkpoint_{episode}.pt')
|
agent.save(f'models/trading_agent_checkpoint_{episode}.pt')
|
||||||
@ -2063,6 +2594,8 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
|||||||
f"Reward: {episode_reward:.2f} | " +
|
f"Reward: {episode_reward:.2f} | " +
|
||||||
f"Balance: ${balance:.2f} | " +
|
f"Balance: ${balance:.2f} | " +
|
||||||
f"PnL: ${pnl:.2f} | " +
|
f"PnL: ${pnl:.2f} | " +
|
||||||
|
f"Fees: ${fees:.2f} | " +
|
||||||
|
f"Net PnL: ${net_pnl:.2f} | " +
|
||||||
f"Win Rate: {win_rate:.2f} | " +
|
f"Win Rate: {win_rate:.2f} | " +
|
||||||
f"Trades: {trade_count} | " +
|
f"Trades: {trade_count} | " +
|
||||||
f"Loss: {avg_loss:.5f} | " +
|
f"Loss: {avg_loss:.5f} | " +
|
||||||
@ -2608,12 +3141,19 @@ async def main():
|
|||||||
# Initialize exchange
|
# Initialize exchange
|
||||||
exchange = await initialize_exchange()
|
exchange = await initialize_exchange()
|
||||||
|
|
||||||
# Create environment
|
# Create environment with updated parameters
|
||||||
env = TradingEnvironment(initial_balance=INITIAL_BALANCE, window_size=30, demo=demo_mode)
|
env = TradingEnvironment(
|
||||||
|
initial_balance=INITIAL_BALANCE,
|
||||||
|
window_size=30,
|
||||||
|
leverage=args.leverage,
|
||||||
|
exchange_id='mexc',
|
||||||
|
symbol=args.symbol,
|
||||||
|
timeframe=args.timeframe
|
||||||
|
)
|
||||||
|
|
||||||
if args.mode == 'train':
|
if args.mode == 'train':
|
||||||
# Fetch initial data for training
|
# Fetch initial data for training
|
||||||
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 1000)
|
await env.fetch_initial_data(exchange, args.symbol, args.timeframe, 1000)
|
||||||
|
|
||||||
# Create agent with consistent parameters
|
# Create agent with consistent parameters
|
||||||
# Note: Using STATE_SIZE and action_size=4 for consistency
|
# Note: Using STATE_SIZE and action_size=4 for consistency
|
||||||
@ -2957,6 +3497,9 @@ async def fetch_multi_timeframe_data(exchange, symbol, candle_cache):
|
|||||||
'1d': 86400 # Update every 1 day
|
'1d': 86400 # Update every 1 day
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# TODO: For 1s/tick timeframes, we'll need to use the exchange's WebSocket API
|
||||||
|
# for real-time data streaming instead of REST API. Implement this in the future.
|
||||||
|
|
||||||
limits = {
|
limits = {
|
||||||
'1s': 1000,
|
'1s': 1000,
|
||||||
'1m': 1000,
|
'1m': 1000,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user