improvements -
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -32,3 +32,4 @@ crypto/sol/.vs/*
|
||||
crypto/brian/models/best/*
|
||||
crypto/brian/models/last/*
|
||||
crypto/brian/live_chart.html
|
||||
crypto/gogo2/models/*
|
||||
|
@ -18,6 +18,8 @@ import ccxt
|
||||
import websockets
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.cuda.amp as amp # Add this import at the top
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
import copy
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@ -73,8 +75,9 @@ class DQN(nn.Module):
|
||||
|
||||
# Initial feature extraction
|
||||
self.fc1 = nn.Linear(state_size, hidden_size)
|
||||
self.bn1 = nn.BatchNorm1d(hidden_size)
|
||||
self.dropout1 = nn.Dropout(0.2) # Add dropout for regularization
|
||||
# Use LayerNorm instead of BatchNorm for more stability with varying batch sizes
|
||||
self.ln1 = nn.LayerNorm(hidden_size)
|
||||
self.dropout1 = nn.Dropout(0.2)
|
||||
|
||||
# LSTM layer for sequential data
|
||||
self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=lstm_layers, batch_first=True, dropout=0.2)
|
||||
@ -84,7 +87,7 @@ class DQN(nn.Module):
|
||||
|
||||
# Output layers with increased capacity
|
||||
self.fc2 = nn.Linear(hidden_size, hidden_size)
|
||||
self.bn2 = nn.BatchNorm1d(hidden_size)
|
||||
self.ln2 = nn.LayerNorm(hidden_size) # LayerNorm instead of BatchNorm
|
||||
self.dropout2 = nn.Dropout(0.2)
|
||||
self.fc3 = nn.Linear(hidden_size, hidden_size // 2)
|
||||
|
||||
@ -115,7 +118,7 @@ class DQN(nn.Module):
|
||||
|
||||
# Initial feature extraction
|
||||
x = self.fc1(x)
|
||||
x = F.relu(self.bn1(x) if batch_size > 1 else self.bn1(x.unsqueeze(0)).squeeze(0))
|
||||
x = F.relu(self.ln1(x)) # LayerNorm works with any batch size
|
||||
x = self.dropout1(x)
|
||||
|
||||
# Reshape for LSTM
|
||||
@ -135,7 +138,7 @@ class DQN(nn.Module):
|
||||
|
||||
# Final layers
|
||||
x = self.fc2(x)
|
||||
x = F.relu(self.bn2(x) if batch_size > 1 else self.bn2(x.unsqueeze(0)).squeeze(0))
|
||||
x = F.relu(self.ln2(x)) # LayerNorm works with any batch size
|
||||
x = self.dropout2(x)
|
||||
x = F.relu(self.fc3(x))
|
||||
|
||||
@ -146,6 +149,96 @@ class DQN(nn.Module):
|
||||
|
||||
return qvals
|
||||
|
||||
class PricePredictionModel(nn.Module):
|
||||
def __init__(self, input_size=30, hidden_size=128, output_size=5, num_layers=2):
|
||||
super(PricePredictionModel, self).__init__()
|
||||
self.lstm = nn.LSTM(1, hidden_size, num_layers=num_layers, batch_first=True, dropout=0.2)
|
||||
self.fc = nn.Linear(hidden_size, output_size)
|
||||
self.scaler = MinMaxScaler(feature_range=(0, 1))
|
||||
self.is_fitted = False
|
||||
|
||||
def forward(self, x):
|
||||
# x shape: [batch_size, seq_len, 1]
|
||||
lstm_out, _ = self.lstm(x)
|
||||
# Use the last time step output
|
||||
predictions = self.fc(lstm_out[:, -1, :])
|
||||
return predictions
|
||||
|
||||
def preprocess(self, data):
|
||||
# Reshape data for scaler
|
||||
data_reshaped = np.array(data).reshape(-1, 1)
|
||||
|
||||
# Fit scaler if not already fitted
|
||||
if not self.is_fitted:
|
||||
self.scaler.fit(data_reshaped)
|
||||
self.is_fitted = True
|
||||
|
||||
# Transform data
|
||||
scaled_data = self.scaler.transform(data_reshaped)
|
||||
return scaled_data
|
||||
|
||||
def postprocess(self, scaled_predictions):
|
||||
# Inverse transform to get actual price values
|
||||
return self.scaler.inverse_transform(scaled_predictions.reshape(-1, 1)).flatten()
|
||||
|
||||
def predict_next_candles(self, price_history, num_candles=5):
|
||||
if len(price_history) < 30: # Need enough history
|
||||
return np.zeros(num_candles)
|
||||
|
||||
# Preprocess data
|
||||
scaled_data = self.preprocess(price_history)
|
||||
|
||||
# Create sequence
|
||||
sequence = scaled_data[-30:].reshape(1, 30, 1)
|
||||
sequence_tensor = torch.FloatTensor(sequence).to(next(self.parameters()).device)
|
||||
|
||||
# Get predictions
|
||||
with torch.no_grad():
|
||||
scaled_predictions = self(sequence_tensor).cpu().numpy()[0]
|
||||
|
||||
# Postprocess predictions
|
||||
predictions = self.postprocess(scaled_predictions)
|
||||
return predictions
|
||||
|
||||
def train_on_new_data(self, price_history, optimizer, epochs=10):
|
||||
if len(price_history) < 35: # Need enough history for training
|
||||
return 0.0
|
||||
|
||||
# Preprocess data
|
||||
scaled_data = self.preprocess(price_history)
|
||||
|
||||
# Create sequences and targets
|
||||
sequences = []
|
||||
targets = []
|
||||
|
||||
for i in range(len(scaled_data) - 35):
|
||||
seq = scaled_data[i:i+30]
|
||||
target = scaled_data[i+30:i+35]
|
||||
sequences.append(seq)
|
||||
targets.append(target)
|
||||
|
||||
if not sequences:
|
||||
return 0.0
|
||||
|
||||
sequences = np.array(sequences).reshape(-1, 30, 1)
|
||||
targets = np.array(targets).reshape(-1, 5)
|
||||
|
||||
# Convert to tensors
|
||||
sequences_tensor = torch.FloatTensor(sequences).to(next(self.parameters()).device)
|
||||
targets_tensor = torch.FloatTensor(targets).to(next(self.parameters()).device)
|
||||
|
||||
# Train
|
||||
total_loss = 0
|
||||
for _ in range(epochs):
|
||||
optimizer.zero_grad()
|
||||
predictions = self(sequences_tensor)
|
||||
loss = F.mse_loss(predictions, targets_tensor)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.item()
|
||||
|
||||
return total_loss / epochs
|
||||
|
||||
class TradingEnvironment:
|
||||
def __init__(self, exchange, symbol="ETH/USDT", timeframe="1m", leverage=MAX_LEVERAGE,
|
||||
initial_balance=INITIAL_BALANCE, window_size=60, is_demo=True):
|
||||
@ -173,6 +266,22 @@ class TradingEnvironment:
|
||||
|
||||
self._initialize_features()
|
||||
|
||||
# Add price prediction model
|
||||
self.price_predictor = None
|
||||
self.predicted_prices = []
|
||||
|
||||
# Add statistics tracking
|
||||
self.episode_pnl = 0.0
|
||||
self.total_pnl = 0.0
|
||||
self.win_count = 0
|
||||
self.loss_count = 0
|
||||
self.trade_count = 0
|
||||
self.max_drawdown = 0.0
|
||||
self.peak_balance = initial_balance
|
||||
|
||||
# For backtesting optimal trades
|
||||
self.optimal_trades = []
|
||||
|
||||
def _initialize_features(self):
|
||||
"""Initialize technical indicators and features"""
|
||||
self.features = {
|
||||
@ -397,6 +506,30 @@ class TradingEnvironment:
|
||||
state_components.append(self.features['stoch_k'][-3:] / 100.0)
|
||||
state_components.append(self.features['stoch_d'][-3:] / 100.0)
|
||||
|
||||
# Add predicted prices (if available)
|
||||
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
|
||||
# Normalize predictions relative to current price
|
||||
pred_norm = np.array(self.predicted_prices[:3]) / latest_price - 1.0
|
||||
# Pad if needed
|
||||
if len(pred_norm) < 3:
|
||||
pred_norm = np.pad(pred_norm, (0, 3 - len(pred_norm)), 'constant')
|
||||
state_components.append(pred_norm)
|
||||
else:
|
||||
# Add zeros if no predictions
|
||||
state_components.append(np.zeros(3))
|
||||
|
||||
# Add extrema signals (if available)
|
||||
if hasattr(self, 'optimal_signals'):
|
||||
# Get recent signals
|
||||
recent_signals = self.optimal_signals[-5:]
|
||||
# Pad if needed
|
||||
if len(recent_signals) < 5:
|
||||
recent_signals = np.pad(recent_signals, (0, 5 - len(recent_signals)), 'constant')
|
||||
state_components.append(recent_signals)
|
||||
else:
|
||||
# Add zeros if no signals
|
||||
state_components.append(np.zeros(5))
|
||||
|
||||
# Position info
|
||||
position_info = np.zeros(5)
|
||||
if self.position == 'long':
|
||||
@ -467,6 +600,26 @@ class TradingEnvironment:
|
||||
if self.current_step >= 10000 or self.balance <= 0.1 * self.initial_balance:
|
||||
done = True
|
||||
|
||||
# Update statistics
|
||||
if reward != 0: # If a trade was closed
|
||||
self.episode_pnl += reward
|
||||
self.total_pnl += reward
|
||||
|
||||
if reward > 0:
|
||||
self.win_count += 1
|
||||
else:
|
||||
self.loss_count += 1
|
||||
|
||||
self.trade_count += 1
|
||||
|
||||
# Update peak balance and drawdown
|
||||
if self.balance > self.peak_balance:
|
||||
self.peak_balance = self.balance
|
||||
|
||||
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
||||
if drawdown > self.max_drawdown:
|
||||
self.max_drawdown = drawdown
|
||||
|
||||
return self.get_state(), reward, done
|
||||
|
||||
def _open_long_position(self):
|
||||
@ -634,8 +787,319 @@ class TradingEnvironment:
|
||||
self.take_profit = 0.0
|
||||
self.current_step = 0
|
||||
|
||||
# Reset episode statistics
|
||||
self.episode_pnl = 0.0
|
||||
|
||||
# Find optimal trades for bootstrapping
|
||||
self.find_optimal_trades()
|
||||
|
||||
# Update price predictions
|
||||
self.update_price_predictions()
|
||||
|
||||
return self.get_state()
|
||||
|
||||
def initialize_price_predictor(self, device):
|
||||
"""Initialize the price prediction model"""
|
||||
self.price_predictor = PricePredictionModel().to(device)
|
||||
self.price_predictor_optimizer = optim.Adam(self.price_predictor.parameters(), lr=1e-4)
|
||||
|
||||
def update_price_predictions(self):
|
||||
"""Update price predictions based on current data"""
|
||||
if self.price_predictor is not None and len(self.features['price']) > 30:
|
||||
self.predicted_prices = self.price_predictor.predict_next_candles(
|
||||
self.features['price'][-100:], num_candles=5
|
||||
)
|
||||
|
||||
def train_price_predictor(self):
|
||||
"""Train the price prediction model on new data"""
|
||||
if self.price_predictor is not None and len(self.features['price']) > 35:
|
||||
loss = self.price_predictor.train_on_new_data(
|
||||
self.features['price'], self.price_predictor_optimizer
|
||||
)
|
||||
return loss
|
||||
return 0.0
|
||||
|
||||
def find_optimal_trades(self):
|
||||
"""Find optimal buy/sell points for bootstrapping"""
|
||||
if len(self.features['price']) < 30:
|
||||
return
|
||||
|
||||
prices = np.array(self.features['price'])
|
||||
window = 10 # Window to look for local minima/maxima
|
||||
|
||||
self.optimal_trades = np.zeros(len(prices))
|
||||
|
||||
for i in range(window, len(prices) - window):
|
||||
# Check for local minimum (buy signal)
|
||||
if prices[i] == min(prices[i-window:i+window+1]):
|
||||
self.optimal_trades[i] = 1 # Buy signal
|
||||
|
||||
# Check for local maximum (sell signal)
|
||||
if prices[i] == max(prices[i-window:i+window+1]):
|
||||
self.optimal_trades[i] = 2 # Sell signal
|
||||
|
||||
def identify_optimal_trades(self):
|
||||
"""Identify optimal entry and exit points based on local extrema"""
|
||||
if len(self.features['price']) < 20:
|
||||
return
|
||||
|
||||
# Find local bottoms and tops
|
||||
bottoms, tops = find_local_extrema(self.features['price'], window=5)
|
||||
|
||||
# Store optimal trade points
|
||||
self.optimal_bottoms = [i for i in bottoms] # Buy points
|
||||
self.optimal_tops = [i for i in tops] # Sell points
|
||||
|
||||
# Create optimal trade signals
|
||||
self.optimal_signals = np.zeros(len(self.features['price']))
|
||||
for i in self.optimal_bottoms:
|
||||
self.optimal_signals[i] = 1 # Buy signal
|
||||
for i in self.optimal_tops:
|
||||
self.optimal_signals[i] = -1 # Sell signal
|
||||
|
||||
logger.info(f"Identified {len(self.optimal_bottoms)} optimal buy points and {len(self.optimal_tops)} optimal sell points")
|
||||
|
||||
def calculate_reward(self, action):
|
||||
"""Calculate reward for the given action"""
|
||||
reward = 0
|
||||
|
||||
# Base reward for actions
|
||||
if action == 0: # HOLD
|
||||
reward = -0.01 # Small penalty for doing nothing
|
||||
|
||||
elif action == 1: # BUY/LONG
|
||||
if self.position == 'flat':
|
||||
# Opening a long position
|
||||
self.position = 'long'
|
||||
self.entry_price = self.current_price
|
||||
self.position_size = self.calculate_position_size()
|
||||
self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT/100)
|
||||
self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT/100)
|
||||
|
||||
# Check if this is an optimal buy point (bottom)
|
||||
current_idx = len(self.features['price']) - 1
|
||||
if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms:
|
||||
reward += 2.0 # Bonus for buying at a bottom
|
||||
else:
|
||||
reward += 0.1 # Small reward for opening a position
|
||||
|
||||
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
||||
|
||||
elif self.position == 'short':
|
||||
# Close short and open long
|
||||
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Apply fees
|
||||
pnl_dollar -= self.calculate_fees(self.position_size)
|
||||
|
||||
# Update balance
|
||||
self.balance += pnl_dollar
|
||||
self.total_pnl += pnl_dollar
|
||||
|
||||
# Record trade
|
||||
self.trades.append({
|
||||
'type': 'short',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.current_price,
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar
|
||||
})
|
||||
|
||||
# Reward based on PnL
|
||||
if pnl_dollar > 0:
|
||||
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
||||
self.win_count += 1
|
||||
else:
|
||||
reward -= 1.0 # Negative reward for loss
|
||||
self.loss_count += 1
|
||||
|
||||
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||
|
||||
# Now open long
|
||||
self.position = 'long'
|
||||
self.entry_price = self.current_price
|
||||
self.position_size = self.calculate_position_size()
|
||||
self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT/100)
|
||||
self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT/100)
|
||||
|
||||
# Check if this is an optimal buy point
|
||||
current_idx = len(self.features['price']) - 1
|
||||
if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms:
|
||||
reward += 2.0 # Bonus for buying at a bottom
|
||||
|
||||
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
||||
|
||||
elif action == 2: # SELL/SHORT
|
||||
if self.position == 'flat':
|
||||
# Opening a short position
|
||||
self.position = 'short'
|
||||
self.entry_price = self.current_price
|
||||
self.position_size = self.calculate_position_size()
|
||||
self.stop_loss = self.entry_price * (1 + STOP_LOSS_PERCENT/100)
|
||||
self.take_profit = self.entry_price * (1 - TAKE_PROFIT_PERCENT/100)
|
||||
|
||||
# Check if this is an optimal sell point (top)
|
||||
current_idx = len(self.features['price']) - 1
|
||||
if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops:
|
||||
reward += 2.0 # Bonus for selling at a top
|
||||
else:
|
||||
reward += 0.1 # Small reward for opening a position
|
||||
|
||||
logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
||||
|
||||
elif self.position == 'long':
|
||||
# Close long and open short
|
||||
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Apply fees
|
||||
pnl_dollar -= self.calculate_fees(self.position_size)
|
||||
|
||||
# Update balance
|
||||
self.balance += pnl_dollar
|
||||
self.total_pnl += pnl_dollar
|
||||
|
||||
# Record trade
|
||||
self.trades.append({
|
||||
'type': 'long',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.current_price,
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar
|
||||
})
|
||||
|
||||
# Reward based on PnL
|
||||
if pnl_dollar > 0:
|
||||
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
||||
self.win_count += 1
|
||||
else:
|
||||
reward -= 1.0 # Negative reward for loss
|
||||
self.loss_count += 1
|
||||
|
||||
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||
|
||||
# Now open short
|
||||
self.position = 'short'
|
||||
self.entry_price = self.current_price
|
||||
self.position_size = self.calculate_position_size()
|
||||
self.stop_loss = self.entry_price * (1 + STOP_LOSS_PERCENT/100)
|
||||
self.take_profit = self.entry_price * (1 - TAKE_PROFIT_PERCENT/100)
|
||||
|
||||
# Check if this is an optimal sell point
|
||||
current_idx = len(self.features['price']) - 1
|
||||
if hasattr(self, 'optimal_tops') and current_idx in self.optimal_tops:
|
||||
reward += 2.0 # Bonus for selling at a top
|
||||
|
||||
logger.info(f"OPENED SHORT at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
||||
|
||||
elif action == 3: # CLOSE
|
||||
if self.position == 'long':
|
||||
# Close long position
|
||||
pnl_percent = (self.current_price - self.entry_price) / self.entry_price * 100
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Apply fees
|
||||
pnl_dollar -= self.calculate_fees(self.position_size)
|
||||
|
||||
# Update balance
|
||||
self.balance += pnl_dollar
|
||||
self.total_pnl += pnl_dollar
|
||||
self.episode_pnl += pnl_dollar
|
||||
|
||||
# Update max drawdown
|
||||
if self.balance > self.peak_balance:
|
||||
self.peak_balance = self.balance
|
||||
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
||||
self.max_drawdown = max(self.max_drawdown, drawdown)
|
||||
|
||||
# Record trade
|
||||
self.trades.append({
|
||||
'type': 'long',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.current_price,
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar
|
||||
})
|
||||
|
||||
# Reward based on PnL
|
||||
if pnl_dollar > 0:
|
||||
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
||||
self.win_count += 1
|
||||
else:
|
||||
reward -= 1.0 # Negative reward for loss
|
||||
self.loss_count += 1
|
||||
|
||||
logger.info(f"CLOSED long at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.entry_price = 0
|
||||
self.position_size = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
elif self.position == 'short':
|
||||
# Close short position
|
||||
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||
|
||||
# Apply fees
|
||||
pnl_dollar -= self.calculate_fees(self.position_size)
|
||||
|
||||
# Update balance
|
||||
self.balance += pnl_dollar
|
||||
self.total_pnl += pnl_dollar
|
||||
self.episode_pnl += pnl_dollar
|
||||
|
||||
# Update max drawdown
|
||||
if self.balance > self.peak_balance:
|
||||
self.peak_balance = self.balance
|
||||
drawdown = (self.peak_balance - self.balance) / self.peak_balance
|
||||
self.max_drawdown = max(self.max_drawdown, drawdown)
|
||||
|
||||
# Record trade
|
||||
self.trades.append({
|
||||
'type': 'short',
|
||||
'entry': self.entry_price,
|
||||
'exit': self.current_price,
|
||||
'pnl_percent': pnl_percent,
|
||||
'pnl_dollar': pnl_dollar
|
||||
})
|
||||
|
||||
# Reward based on PnL
|
||||
if pnl_dollar > 0:
|
||||
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
||||
self.win_count += 1
|
||||
else:
|
||||
reward -= 1.0 # Negative reward for loss
|
||||
self.loss_count += 1
|
||||
|
||||
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||
|
||||
# Reset position
|
||||
self.position = 'flat'
|
||||
self.entry_price = 0
|
||||
self.position_size = 0
|
||||
self.stop_loss = 0
|
||||
self.take_profit = 0
|
||||
|
||||
# Add prediction accuracy component to reward
|
||||
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
|
||||
# Compare the first prediction with actual price
|
||||
if len(self.data) > 1:
|
||||
actual_price = self.data[-1]['close']
|
||||
predicted_price = self.predicted_prices[0]
|
||||
prediction_error = abs(predicted_price - actual_price) / actual_price
|
||||
|
||||
# Reward accurate predictions, penalize bad ones
|
||||
if prediction_error < 0.005: # Less than 0.5% error
|
||||
reward += 0.5
|
||||
elif prediction_error > 0.02: # More than 2% error
|
||||
reward -= 0.5
|
||||
|
||||
return reward
|
||||
|
||||
class Agent:
|
||||
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4,
|
||||
device="cuda" if torch.cuda.is_available() else "cpu"):
|
||||
@ -889,10 +1353,18 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
||||
'episode_rewards': [],
|
||||
'episode_lengths': [],
|
||||
'balances': [],
|
||||
'win_rates': []
|
||||
'win_rates': [],
|
||||
'episode_pnls': [],
|
||||
'cumulative_pnl': [],
|
||||
'drawdowns': [],
|
||||
'prediction_losses': []
|
||||
}
|
||||
|
||||
best_reward = -float('inf')
|
||||
best_pnl = -float('inf')
|
||||
|
||||
# Initialize price predictor
|
||||
env.initialize_price_predictor(agent.device)
|
||||
|
||||
try:
|
||||
for episode in range(num_episodes):
|
||||
@ -900,6 +1372,10 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
||||
state = env.reset()
|
||||
episode_reward = 0
|
||||
|
||||
# Train price predictor
|
||||
prediction_loss = env.train_price_predictor()
|
||||
stats['prediction_losses'].append(prediction_loss)
|
||||
|
||||
for step in range(max_steps_per_episode):
|
||||
# Select action
|
||||
action = agent.select_action(state)
|
||||
@ -940,19 +1416,33 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
||||
stats['episode_lengths'].append(step + 1)
|
||||
stats['balances'].append(env.balance)
|
||||
stats['win_rates'].append(win_rate)
|
||||
stats['episode_pnls'].append(env.episode_pnl)
|
||||
stats['cumulative_pnl'].append(env.total_pnl)
|
||||
stats['drawdowns'].append(env.max_drawdown * 100)
|
||||
|
||||
# Log to TensorBoard
|
||||
agent.writer.add_scalar('Reward/train', episode_reward, episode)
|
||||
agent.writer.add_scalar('Balance/train', env.balance, episode)
|
||||
agent.writer.add_scalar('WinRate/train', win_rate, episode)
|
||||
agent.writer.add_scalar('PnL/episode', env.episode_pnl, episode)
|
||||
agent.writer.add_scalar('PnL/cumulative', env.total_pnl, episode)
|
||||
agent.writer.add_scalar('Drawdown/percent', env.max_drawdown * 100, episode)
|
||||
agent.writer.add_scalar('PredictionLoss', prediction_loss, episode)
|
||||
|
||||
logger.info(f"Episode {episode}: Reward={episode_reward:.2f}, Balance=${env.balance:.2f}, "
|
||||
f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}")
|
||||
f"Win Rate={win_rate:.1f}%, Trades={len(env.trades)}, "
|
||||
f"Episode PnL=${env.episode_pnl:.2f}, Total PnL=${env.total_pnl:.2f}, "
|
||||
f"Max Drawdown={env.max_drawdown*100:.1f}%")
|
||||
|
||||
# Save best model
|
||||
# Save best model by reward
|
||||
if episode_reward > best_reward:
|
||||
best_reward = episode_reward
|
||||
agent.save("models/trading_agent_best.pt")
|
||||
agent.save("models/trading_agent_best_reward.pt")
|
||||
|
||||
# Save best model by PnL
|
||||
if env.episode_pnl > best_pnl:
|
||||
best_pnl = env.episode_pnl
|
||||
agent.save("models/trading_agent_best_pnl.pt")
|
||||
|
||||
# Save checkpoint
|
||||
if episode % 10 == 0:
|
||||
@ -975,36 +1465,59 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
||||
return stats
|
||||
|
||||
def plot_training_results(stats):
|
||||
"""Plot training statistics"""
|
||||
plt.figure(figsize=(15, 10))
|
||||
"""Plot detailed training results"""
|
||||
plt.figure(figsize=(20, 15))
|
||||
|
||||
plt.subplot(2, 2, 1)
|
||||
# Plot rewards
|
||||
plt.subplot(3, 2, 1)
|
||||
plt.plot(stats['episode_rewards'])
|
||||
plt.title('Episode Rewards')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Reward')
|
||||
|
||||
plt.subplot(2, 2, 2)
|
||||
# Plot balance
|
||||
plt.subplot(3, 2, 2)
|
||||
plt.plot(stats['balances'])
|
||||
plt.title('Account Balance')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Balance ($)')
|
||||
|
||||
plt.subplot(2, 2, 3)
|
||||
plt.plot(stats['episode_lengths'])
|
||||
plt.title('Episode Length')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Steps')
|
||||
|
||||
plt.subplot(2, 2, 4)
|
||||
# Plot win rate
|
||||
plt.subplot(3, 2, 3)
|
||||
plt.plot(stats['win_rates'])
|
||||
plt.title('Win Rate')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Win Rate (%)')
|
||||
|
||||
# Plot episode PnL
|
||||
plt.subplot(3, 2, 4)
|
||||
plt.plot(stats['episode_pnls'])
|
||||
plt.title('Episode PnL')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('PnL ($)')
|
||||
|
||||
# Plot cumulative PnL
|
||||
plt.subplot(3, 2, 5)
|
||||
plt.plot(stats['cumulative_pnl'])
|
||||
plt.title('Cumulative PnL')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Cumulative PnL ($)')
|
||||
|
||||
# Plot drawdown
|
||||
plt.subplot(3, 2, 6)
|
||||
plt.plot(stats['drawdowns'])
|
||||
plt.title('Maximum Drawdown')
|
||||
plt.xlabel('Episode')
|
||||
plt.ylabel('Drawdown (%)')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('training_results.png')
|
||||
plt.close()
|
||||
|
||||
# Save statistics to CSV
|
||||
df = pd.DataFrame(stats)
|
||||
df.to_csv('training_stats.csv', index=False)
|
||||
|
||||
logger.info("Training statistics saved to training_stats.csv and training_results.png")
|
||||
|
||||
def evaluate_agent(agent, env, num_episodes=10):
|
||||
"""Evaluate the agent on test data"""
|
||||
@ -1235,4 +1748,26 @@ if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Program terminated by user")
|
||||
logger.info("Program terminated by user")
|
||||
|
||||
# Add these functions to identify tops and bottoms
|
||||
def find_local_extrema(prices, window=5):
|
||||
"""Find local minima (bottoms) and maxima (tops) in price data"""
|
||||
bottoms = []
|
||||
tops = []
|
||||
|
||||
if len(prices) < window * 2 + 1:
|
||||
return bottoms, tops
|
||||
|
||||
for i in range(window, len(prices) - window):
|
||||
# Check if this is a local minimum (bottom)
|
||||
if all(prices[i] <= prices[i-j] for j in range(1, window+1)) and \
|
||||
all(prices[i] <= prices[i+j] for j in range(1, window+1)):
|
||||
bottoms.append(i)
|
||||
|
||||
# Check if this is a local maximum (top)
|
||||
if all(prices[i] >= prices[i-j] for j in range(1, window+1)) and \
|
||||
all(prices[i] >= prices[i+j] for j in range(1, window+1)):
|
||||
tops.append(i)
|
||||
|
||||
return bottoms, tops
|
@ -5,4 +5,5 @@ torch>=1.9.0
|
||||
python-dotenv>=0.19.0
|
||||
ccxt>=2.0.0
|
||||
websockets>=10.0
|
||||
tensorboard>=2.6.0
|
||||
tensorboard>=2.6.0
|
||||
scikit-learn
|
Reference in New Issue
Block a user