improvements
This commit is contained in:
parent
66a2c41338
commit
643bc154a2
@ -17,3 +17,8 @@ C:\Users\popov\miniforge3\Lib\site-packages\torch\amp\grad_scaler.py:132: UserWa
|
|||||||
2025-03-10 12:11:30,928 - ERROR - Training failed: 'TradingEnvironment' object has no attribute 'initialize_price_predictor'
|
2025-03-10 12:11:30,928 - ERROR - Training failed: 'TradingEnvironment' object has no attribute 'initialize_price_predictor'
|
||||||
2025-03-10 12:11:30,928 - INFO - Exchange connection closed
|
2025-03-10 12:11:30,928 - INFO - Exchange connection closed
|
||||||
Backend tkagg is interactive backend. Turning interactive mode on.
|
Backend tkagg is interactive backend. Turning interactive mode on.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2025-03-10 12:35:14,489 - INFO - Episode 34: Reward=232.41, Balance=$98.47, Win Rate=70.6%, Trades=17, Episode PnL=$-1.33, Total PnL=$-559.78, Max Drawdown=7.0%, Pred Accuracy=99.9%
|
4
crypto/gogo2/cuda.py
Normal file
4
crypto/gogo2/cuda.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
import torch
|
||||||
|
print(f"PyTorch version: {torch.__version__}")
|
||||||
|
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||||
|
print(f"CUDA version: {torch.version.cuda if torch.cuda.is_available() else 'Not available'}")
|
@ -22,6 +22,7 @@ from sklearn.preprocessing import MinMaxScaler
|
|||||||
import copy
|
import copy
|
||||||
import argparse
|
import argparse
|
||||||
import traceback
|
import traceback
|
||||||
|
import math
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -93,6 +94,9 @@ class DQN(nn.Module):
|
|||||||
def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4):
|
def __init__(self, state_size, action_size, hidden_size=384, lstm_layers=2, attention_heads=4):
|
||||||
super(DQN, self).__init__()
|
super(DQN, self).__init__()
|
||||||
|
|
||||||
|
# Ensure model parameters are float32
|
||||||
|
self.float()
|
||||||
|
|
||||||
self.state_size = state_size
|
self.state_size = state_size
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.lstm_layers = lstm_layers
|
self.lstm_layers = lstm_layers
|
||||||
@ -224,7 +228,12 @@ class PricePredictionModel(nn.Module):
|
|||||||
predictions = self.postprocess(scaled_predictions)
|
predictions = self.postprocess(scaled_predictions)
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
def train_on_new_data(self, price_history, optimizer, epochs=10):
|
def train_on_new_data(self, price_history, optimizer, epochs=5):
|
||||||
|
"""Train the model on new price data"""
|
||||||
|
# Convert to numpy array if it's not already
|
||||||
|
if isinstance(price_history, list):
|
||||||
|
price_history = np.array(price_history, dtype=np.float32) # Force float32
|
||||||
|
|
||||||
if len(price_history) < 35: # Need enough history for training
|
if len(price_history) < 35: # Need enough history for training
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
@ -320,6 +329,9 @@ class TradingEnvironment:
|
|||||||
self.optimal_tops = []
|
self.optimal_tops = []
|
||||||
self.optimal_signals = np.array([])
|
self.optimal_signals = np.array([])
|
||||||
|
|
||||||
|
# Add risk factor for curriculum learning
|
||||||
|
self.risk_factor = 1.0 # Default risk factor
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the environment to initial state"""
|
"""Reset the environment to initial state"""
|
||||||
self.balance = self.initial_balance
|
self.balance = self.initial_balance
|
||||||
@ -635,14 +647,14 @@ class TradingEnvironment:
|
|||||||
"""Create state representation for the agent"""
|
"""Create state representation for the agent"""
|
||||||
if len(self.data) < 30 or len(self.features['price']) == 0:
|
if len(self.data) < 30 or len(self.features['price']) == 0:
|
||||||
# Return zeros if not enough data
|
# Return zeros if not enough data
|
||||||
return np.zeros(STATE_SIZE)
|
return np.zeros(STATE_SIZE, dtype=np.float32) # Ensure float32
|
||||||
|
|
||||||
# Create a normalized state vector with recent price action and indicators
|
# Create a normalized state vector with recent price action and indicators
|
||||||
state_components = []
|
state_components = []
|
||||||
|
|
||||||
# Price features (normalize recent prices by the latest price)
|
# Price features (normalize recent prices by the latest price)
|
||||||
latest_price = self.features['price'][-1]
|
latest_price = self.features['price'][-1]
|
||||||
price_features = np.array(self.features['price'][-10:]) / latest_price - 1.0
|
price_features = np.array(self.features['price'][-10:], dtype=np.float32) / latest_price - 1.0
|
||||||
state_components.append(price_features)
|
state_components.append(price_features)
|
||||||
|
|
||||||
# Volume features (normalize by max volume)
|
# Volume features (normalize by max volume)
|
||||||
@ -732,43 +744,46 @@ class TradingEnvironment:
|
|||||||
state = state[:STATE_SIZE]
|
state = state[:STATE_SIZE]
|
||||||
elif len(state) < STATE_SIZE:
|
elif len(state) < STATE_SIZE:
|
||||||
# Pad with zeros if too short
|
# Pad with zeros if too short
|
||||||
padding = np.zeros(STATE_SIZE - len(state))
|
padding = np.zeros(STATE_SIZE - len(state), dtype=np.float32) # Ensure float32
|
||||||
state = np.concatenate([state, padding])
|
state = np.concatenate([state, padding])
|
||||||
|
|
||||||
return state
|
# Ensure float32 type
|
||||||
|
return state.astype(np.float32)
|
||||||
|
|
||||||
def calculate_reward(self, action):
|
def calculate_reward(self, action):
|
||||||
"""Calculate reward for the given action with improved penalties for losing trades"""
|
"""Calculate reward for the given action with improved penalties for losing trades"""
|
||||||
reward = 0
|
reward = 0
|
||||||
|
|
||||||
|
# Store previous balance for direct PnL calculation
|
||||||
|
prev_balance = self.balance
|
||||||
|
|
||||||
# Base reward for actions
|
# Base reward for actions
|
||||||
if action == 0: # HOLD
|
if action == 0: # HOLD
|
||||||
reward = -0.01 # Small penalty for doing nothing
|
# Small penalty for doing nothing to encourage action
|
||||||
|
# But make it context-dependent - holding during high volatility should be penalized more
|
||||||
|
volatility = self.get_recent_volatility()
|
||||||
|
reward = -0.01 * (1 + volatility)
|
||||||
|
|
||||||
elif action == 1: # BUY/LONG
|
elif action == 1: # BUY/LONG
|
||||||
if self.position == 'flat':
|
if self.position == 'flat':
|
||||||
# Opening a long position
|
# Opening a long position
|
||||||
self.position = 'long'
|
self.position = 'long'
|
||||||
self.entry_price = self.current_price
|
self.entry_price = self.current_price
|
||||||
|
self.entry_index = self.current_step
|
||||||
self.position_size = self.calculate_position_size()
|
self.position_size = self.calculate_position_size()
|
||||||
self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT/100)
|
|
||||||
self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT/100)
|
|
||||||
|
|
||||||
# Check if this is an optimal buy point (bottom)
|
# Calculate stop loss and take profit levels
|
||||||
current_idx = len(self.features['price']) - 1
|
self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT / 100)
|
||||||
if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms:
|
self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT / 100)
|
||||||
reward += 2.0 # Bonus for buying at a bottom
|
|
||||||
else:
|
# Check if this is a good entry point based on technical indicators
|
||||||
# Check if we're buying in a downtrend (bad)
|
entry_quality = self.evaluate_entry_quality('long')
|
||||||
if self.is_downtrend():
|
reward += entry_quality * 0.5 # Scale the reward based on entry quality
|
||||||
reward -= 0.5 # Penalty for buying in downtrend
|
|
||||||
else:
|
|
||||||
reward += 0.1 # Small reward for opening a position
|
|
||||||
|
|
||||||
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
||||||
|
|
||||||
elif self.position == 'short':
|
elif self.position == 'short':
|
||||||
# Close short and open long
|
# Closing a short position
|
||||||
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
||||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||||
|
|
||||||
@ -778,46 +793,39 @@ class TradingEnvironment:
|
|||||||
# Update balance
|
# Update balance
|
||||||
self.balance += pnl_dollar
|
self.balance += pnl_dollar
|
||||||
self.total_pnl += pnl_dollar
|
self.total_pnl += pnl_dollar
|
||||||
|
self.episode_pnl += pnl_dollar
|
||||||
|
|
||||||
# Record trade
|
# Record trade
|
||||||
trade_duration = len(self.features['price']) - self.entry_index
|
|
||||||
self.trades.append({
|
self.trades.append({
|
||||||
'type': 'short',
|
'type': 'short',
|
||||||
'entry': self.entry_price,
|
'entry': self.entry_price,
|
||||||
'exit': self.current_price,
|
'exit': self.current_price,
|
||||||
'pnl_percent': pnl_percent,
|
'pnl_percent': pnl_percent,
|
||||||
'pnl_dollar': pnl_dollar,
|
'pnl_dollar': pnl_dollar,
|
||||||
'duration': trade_duration,
|
'duration': self.current_step - self.entry_index,
|
||||||
'market_direction': self.get_market_direction()
|
'market_direction': self.get_market_direction(),
|
||||||
|
'reason': 'manual_close'
|
||||||
})
|
})
|
||||||
|
|
||||||
# Reward based on PnL with stronger penalties for losses
|
# Update win/loss count
|
||||||
if pnl_dollar > 0:
|
if pnl_dollar > 0:
|
||||||
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
|
||||||
self.win_count += 1
|
self.win_count += 1
|
||||||
|
reward += 1.0 + (pnl_percent / 2) # Bonus for winning trade
|
||||||
else:
|
else:
|
||||||
# Stronger penalty for losses, scaled by the size of the loss
|
|
||||||
loss_penalty = 1.0 + abs(pnl_dollar) / 5
|
|
||||||
reward -= loss_penalty
|
|
||||||
self.loss_count += 1
|
self.loss_count += 1
|
||||||
|
reward -= 1.0 + (abs(pnl_percent) / 2) # Penalty for losing trade
|
||||||
# Extra penalty for closing a losing trade too quickly
|
|
||||||
if trade_duration < 5:
|
|
||||||
reward -= 0.5 # Penalty for very short losing trades
|
|
||||||
|
|
||||||
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
logger.info(f"CLOSED short at {self.current_price} | PnL: {pnl_percent:.2f}% | ${pnl_dollar:.2f}")
|
||||||
|
|
||||||
# Now open long
|
# Reset position and open new long
|
||||||
self.position = 'long'
|
self.position = 'long'
|
||||||
self.entry_price = self.current_price
|
self.entry_price = self.current_price
|
||||||
self.entry_index = len(self.features['price']) - 1
|
self.entry_index = self.current_step
|
||||||
self.position_size = self.calculate_position_size()
|
self.position_size = self.calculate_position_size()
|
||||||
self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT/100)
|
|
||||||
self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT/100)
|
|
||||||
|
|
||||||
# Check if this is an optimal buy point
|
# Calculate stop loss and take profit levels
|
||||||
if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms:
|
self.stop_loss = self.entry_price * (1 - STOP_LOSS_PERCENT / 100)
|
||||||
reward += 2.0 # Bonus for buying at a bottom
|
self.take_profit = self.entry_price * (1 + TAKE_PROFIT_PERCENT / 100)
|
||||||
|
|
||||||
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
logger.info(f"OPENED LONG at {self.entry_price} | Stop loss: {self.stop_loss} | Take profit: {self.take_profit}")
|
||||||
|
|
||||||
@ -975,7 +983,14 @@ class TradingEnvironment:
|
|||||||
self.stop_loss = 0
|
self.stop_loss = 0
|
||||||
self.take_profit = 0
|
self.take_profit = 0
|
||||||
|
|
||||||
# Add prediction accuracy component to reward
|
# Add reward based on direct PnL change
|
||||||
|
balance_change = self.balance - prev_balance
|
||||||
|
if balance_change > 0:
|
||||||
|
reward += balance_change * 0.5 # Positive reward for making money
|
||||||
|
else:
|
||||||
|
reward += balance_change * 1.0 # Stronger negative reward for losing money
|
||||||
|
|
||||||
|
# Add reward for predicted price movement alignment
|
||||||
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
|
if hasattr(self, 'predicted_prices') and len(self.predicted_prices) > 0:
|
||||||
# Compare the first prediction with actual price
|
# Compare the first prediction with actual price
|
||||||
if len(self.data) > 1:
|
if len(self.data) > 1:
|
||||||
@ -985,12 +1000,106 @@ class TradingEnvironment:
|
|||||||
|
|
||||||
# Reward accurate predictions, penalize bad ones
|
# Reward accurate predictions, penalize bad ones
|
||||||
if prediction_error < 0.005: # Less than 0.5% error
|
if prediction_error < 0.005: # Less than 0.5% error
|
||||||
reward += 0.5
|
reward += 0.2
|
||||||
elif prediction_error > 0.02: # More than 2% error
|
elif prediction_error > 0.02: # More than 2% error
|
||||||
reward -= 0.5
|
reward -= 0.2
|
||||||
|
|
||||||
|
# Add reward/penalty based on market trend alignment
|
||||||
|
market_direction = self.get_market_direction()
|
||||||
|
if (self.position == 'long' and market_direction == 'uptrend') or \
|
||||||
|
(self.position == 'short' and market_direction == 'downtrend'):
|
||||||
|
reward += 0.2 # Reward for trading with the trend
|
||||||
|
elif (self.position == 'long' and market_direction == 'downtrend') or \
|
||||||
|
(self.position == 'short' and market_direction == 'uptrend'):
|
||||||
|
reward -= 0.3 # Stronger penalty for trading against the trend
|
||||||
|
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
|
def evaluate_entry_quality(self, position_type):
|
||||||
|
"""Evaluate the quality of an entry point based on technical indicators"""
|
||||||
|
score = 0
|
||||||
|
|
||||||
|
# Get current indicators
|
||||||
|
rsi = self.features['rsi'][-1] if len(self.features['rsi']) > 0 else 50
|
||||||
|
macd = self.features['macd'][-1] if len(self.features['macd']) > 0 else 0
|
||||||
|
macd_signal = self.features['macd_signal'][-1] if len(self.features['macd_signal']) > 0 else 0
|
||||||
|
stoch_k = self.features['stoch_k'][-1] if len(self.features['stoch_k']) > 0 else 50
|
||||||
|
stoch_d = self.features['stoch_d'][-1] if len(self.features['stoch_d']) > 0 else 50
|
||||||
|
|
||||||
|
if position_type == 'long':
|
||||||
|
# RSI oversold condition (good for long)
|
||||||
|
if rsi < 30:
|
||||||
|
score += 0.5
|
||||||
|
elif rsi < 40:
|
||||||
|
score += 0.2
|
||||||
|
elif rsi > 70:
|
||||||
|
score -= 0.5 # Overbought, bad for long
|
||||||
|
|
||||||
|
# MACD crossover (bullish)
|
||||||
|
if macd > macd_signal and macd > 0:
|
||||||
|
score += 0.3
|
||||||
|
elif macd < macd_signal and macd < 0:
|
||||||
|
score -= 0.3
|
||||||
|
|
||||||
|
# Stochastic oversold
|
||||||
|
if stoch_k < 20 and stoch_d < 20:
|
||||||
|
score += 0.3
|
||||||
|
elif stoch_k > 80 and stoch_d > 80:
|
||||||
|
score -= 0.3
|
||||||
|
|
||||||
|
elif position_type == 'short':
|
||||||
|
# RSI overbought condition (good for short)
|
||||||
|
if rsi > 70:
|
||||||
|
score += 0.5
|
||||||
|
elif rsi > 60:
|
||||||
|
score += 0.2
|
||||||
|
elif rsi < 30:
|
||||||
|
score -= 0.5 # Oversold, bad for short
|
||||||
|
|
||||||
|
# MACD crossover (bearish)
|
||||||
|
if macd < macd_signal and macd < 0:
|
||||||
|
score += 0.3
|
||||||
|
elif macd > macd_signal and macd > 0:
|
||||||
|
score -= 0.3
|
||||||
|
|
||||||
|
# Stochastic overbought
|
||||||
|
if stoch_k > 80 and stoch_d > 80:
|
||||||
|
score += 0.3
|
||||||
|
elif stoch_k < 20 and stoch_d < 20:
|
||||||
|
score -= 0.3
|
||||||
|
|
||||||
|
# Check price relative to moving averages
|
||||||
|
if len(self.features['ema_9']) > 0 and len(self.features['ema_21']) > 0:
|
||||||
|
ema_9 = self.features['ema_9'][-1]
|
||||||
|
ema_21 = self.features['ema_21'][-1]
|
||||||
|
|
||||||
|
if position_type == 'long':
|
||||||
|
if self.current_price > ema_9 > ema_21: # Strong uptrend
|
||||||
|
score += 0.4
|
||||||
|
elif self.current_price < ema_9 < ema_21: # Strong downtrend
|
||||||
|
score -= 0.4
|
||||||
|
elif position_type == 'short':
|
||||||
|
if self.current_price < ema_9 < ema_21: # Strong downtrend
|
||||||
|
score += 0.4
|
||||||
|
elif self.current_price > ema_9 > ema_21: # Strong uptrend
|
||||||
|
score -= 0.4
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
def get_recent_volatility(self):
|
||||||
|
"""Calculate recent price volatility"""
|
||||||
|
if len(self.features['price']) < 10:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Use ATR if available
|
||||||
|
if len(self.features['atr']) > 0:
|
||||||
|
return self.features['atr'][-1] / self.current_price
|
||||||
|
|
||||||
|
# Otherwise calculate simple volatility
|
||||||
|
recent_prices = self.features['price'][-10:]
|
||||||
|
returns = [recent_prices[i] / recent_prices[i-1] - 1 for i in range(1, len(recent_prices))]
|
||||||
|
return np.std(returns) * 100 # Volatility as percentage
|
||||||
|
|
||||||
def is_downtrend(self):
|
def is_downtrend(self):
|
||||||
"""Check if the market is in a downtrend"""
|
"""Check if the market is in a downtrend"""
|
||||||
if len(self.features['price']) < 20:
|
if len(self.features['price']) < 20:
|
||||||
@ -1016,13 +1125,49 @@ class TradingEnvironment:
|
|||||||
return short_ema > long_ema
|
return short_ema > long_ema
|
||||||
|
|
||||||
def get_market_direction(self):
|
def get_market_direction(self):
|
||||||
"""Get the current market direction"""
|
"""Determine the current market direction (uptrend, downtrend, or sideways)"""
|
||||||
if self.is_uptrend():
|
if len(self.features['price']) < 20:
|
||||||
return "uptrend"
|
return 'unknown'
|
||||||
elif self.is_downtrend():
|
|
||||||
return "downtrend"
|
# Use EMAs to determine trend
|
||||||
else:
|
if len(self.features['ema_9']) > 0 and len(self.features['ema_21']) > 0:
|
||||||
return "sideways"
|
ema_9 = self.features['ema_9'][-5:]
|
||||||
|
ema_21 = self.features['ema_21'][-5:]
|
||||||
|
price = self.features['price'][-5:]
|
||||||
|
|
||||||
|
# Check if price is above/below EMAs
|
||||||
|
price_above_ema9 = sum(p > e for p, e in zip(price, ema_9))
|
||||||
|
price_above_ema21 = sum(p > e for p, e in zip(price, ema_21))
|
||||||
|
ema9_above_ema21 = sum(e9 > e21 for e9, e21 in zip(ema_9, ema_21))
|
||||||
|
|
||||||
|
# Strong uptrend: price > EMA9 > EMA21
|
||||||
|
if price_above_ema9 >= 4 and price_above_ema21 >= 4 and ema9_above_ema21 >= 4:
|
||||||
|
return 'uptrend'
|
||||||
|
|
||||||
|
# Strong downtrend: price < EMA9 < EMA21
|
||||||
|
elif price_above_ema9 <= 1 and price_above_ema21 <= 1 and ema9_above_ema21 <= 1:
|
||||||
|
return 'downtrend'
|
||||||
|
|
||||||
|
# Check price action
|
||||||
|
price_data = self.features['price'][-20:]
|
||||||
|
price_change = (price_data[-1] / price_data[0] - 1) * 100
|
||||||
|
|
||||||
|
if price_change > 1.0:
|
||||||
|
return 'uptrend'
|
||||||
|
elif price_change < -1.0:
|
||||||
|
return 'downtrend'
|
||||||
|
|
||||||
|
# Check RSI for trend confirmation
|
||||||
|
if len(self.features['rsi']) > 0:
|
||||||
|
rsi = self.features['rsi'][-5:]
|
||||||
|
avg_rsi = sum(rsi) / len(rsi)
|
||||||
|
|
||||||
|
if avg_rsi > 60:
|
||||||
|
return 'uptrend'
|
||||||
|
elif avg_rsi < 40:
|
||||||
|
return 'downtrend'
|
||||||
|
|
||||||
|
return 'sideways'
|
||||||
|
|
||||||
def analyze_trades(self):
|
def analyze_trades(self):
|
||||||
"""Analyze completed trades to identify patterns"""
|
"""Analyze completed trades to identify patterns"""
|
||||||
@ -1119,12 +1264,17 @@ class TradingEnvironment:
|
|||||||
logger.info(f"Identified {len(bottoms)} optimal buy points and {len(tops)} optimal sell points")
|
logger.info(f"Identified {len(bottoms)} optimal buy points and {len(tops)} optimal sell points")
|
||||||
|
|
||||||
def calculate_position_size(self):
|
def calculate_position_size(self):
|
||||||
"""Calculate position size based on current balance and risk parameters"""
|
"""Calculate position size based on current balance, volatility and risk parameters"""
|
||||||
# Use a fixed percentage of balance for each trade
|
# Base risk percentage (adjust based on volatility)
|
||||||
risk_percent = 5.0 # Risk 5% of balance per trade
|
volatility = self.get_recent_volatility()
|
||||||
|
|
||||||
|
# Reduce risk during high volatility
|
||||||
|
base_risk = 5.0 # Base risk percentage
|
||||||
|
adjusted_risk = base_risk / (1 + volatility * 5) # Reduce risk as volatility increases
|
||||||
|
adjusted_risk = max(1.0, min(adjusted_risk, base_risk)) # Cap between 1% and base_risk
|
||||||
|
|
||||||
# Calculate position size with leverage
|
# Calculate position size with leverage
|
||||||
position_size = self.balance * (risk_percent / 100) * MAX_LEVERAGE
|
position_size = self.balance * (adjusted_risk / 100) * MAX_LEVERAGE
|
||||||
|
|
||||||
# Apply a safety factor to avoid liquidation
|
# Apply a safety factor to avoid liquidation
|
||||||
safety_factor = 0.8
|
safety_factor = 0.8
|
||||||
@ -1138,6 +1288,14 @@ class TradingEnvironment:
|
|||||||
max_position = self.balance * MAX_LEVERAGE
|
max_position = self.balance * MAX_LEVERAGE
|
||||||
position_size = min(position_size, max_position)
|
position_size = min(position_size, max_position)
|
||||||
|
|
||||||
|
# Adjust stop loss based on volatility
|
||||||
|
global STOP_LOSS_PERCENT, TAKE_PROFIT_PERCENT
|
||||||
|
STOP_LOSS_PERCENT = 0.5 * (1 + volatility) # Wider stop loss during high volatility
|
||||||
|
TAKE_PROFIT_PERCENT = 1.5 * (1 + volatility * 0.5) # Higher take profit during high volatility
|
||||||
|
|
||||||
|
# Apply risk factor from curriculum learning
|
||||||
|
position_size *= self.risk_factor
|
||||||
|
|
||||||
return position_size
|
return position_size
|
||||||
|
|
||||||
def calculate_fees(self, position_size):
|
def calculate_fees(self, position_size):
|
||||||
@ -1152,15 +1310,15 @@ class TradingEnvironment:
|
|||||||
|
|
||||||
# Ensure GPU usage if available
|
# Ensure GPU usage if available
|
||||||
def get_device():
|
def get_device():
|
||||||
"""Get the best available device (CUDA GPU or CPU)"""
|
"""Get the device to use (GPU or CPU)"""
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
# Set default tensor type to float32 for CUDA
|
||||||
|
torch.set_default_tensor_type(torch.FloatTensor)
|
||||||
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||||
# Set up for mixed precision training
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
logger.info("GPU not available, using CPU")
|
logger.info("Using CPU")
|
||||||
return device
|
return device
|
||||||
|
|
||||||
# Update Agent class to use GPU properly
|
# Update Agent class to use GPU properly
|
||||||
@ -1180,6 +1338,8 @@ class Agent:
|
|||||||
# Initialize policy and target networks
|
# Initialize policy and target networks
|
||||||
self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
||||||
self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
||||||
|
ensure_float32(self.policy_net)
|
||||||
|
ensure_float32(self.target_net)
|
||||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||||
self.target_net.eval()
|
self.target_net.eval()
|
||||||
|
|
||||||
@ -1195,6 +1355,16 @@ class Agent:
|
|||||||
# Create models directory if it doesn't exist
|
# Create models directory if it doesn't exist
|
||||||
os.makedirs("models", exist_ok=True)
|
os.makedirs("models", exist_ok=True)
|
||||||
|
|
||||||
|
# Use pinned memory for faster CPU-to-GPU transfers
|
||||||
|
if self.device.type == "cuda":
|
||||||
|
self.use_pinned_memory = True
|
||||||
|
else:
|
||||||
|
self.use_pinned_memory = False
|
||||||
|
|
||||||
|
# Ensure models are using float32
|
||||||
|
self.policy_net.float()
|
||||||
|
self.target_net.float()
|
||||||
|
|
||||||
def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8):
|
def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8):
|
||||||
"""Expand the model to handle more features or increase capacity"""
|
"""Expand the model to handle more features or increase capacity"""
|
||||||
logger.info(f"Expanding model: {self.state_size} → {new_state_size}, "
|
logger.info(f"Expanding model: {self.state_size} → {new_state_size}, "
|
||||||
@ -1245,20 +1415,34 @@ class Agent:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def select_action(self, state, training=True):
|
def select_action(self, state, training=True):
|
||||||
|
"""Select an action using epsilon-greedy policy"""
|
||||||
sample = random.random()
|
sample = random.random()
|
||||||
|
eps_threshold = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
|
||||||
|
math.exp(-1. * self.steps_done / self.epsilon_decay)
|
||||||
|
|
||||||
if training:
|
if training:
|
||||||
# Epsilon decay
|
self.epsilon = eps_threshold
|
||||||
self.epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \
|
else:
|
||||||
np.exp(-1. * self.steps_done / EPSILON_DECAY)
|
self.epsilon = 0.0 # No exploration during evaluation/live trading
|
||||||
|
|
||||||
self.steps_done += 1
|
self.steps_done += 1
|
||||||
|
|
||||||
if sample > self.epsilon or not training:
|
if sample > self.epsilon:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
# Convert state to tensor and ensure it's float32 (not double/float64)
|
||||||
state_tensor = torch.FloatTensor(state).to(self.device)
|
state_tensor = torch.FloatTensor(state).to(self.device)
|
||||||
action_values = self.policy_net(state_tensor)
|
|
||||||
return action_values.max(1)[1].item()
|
# Ensure state has correct shape
|
||||||
|
if state_tensor.dim() == 1:
|
||||||
|
state_tensor = state_tensor.unsqueeze(0)
|
||||||
|
|
||||||
|
# Get Q values
|
||||||
|
q_values = self.policy_net(state_tensor)
|
||||||
|
|
||||||
|
# Return action with highest Q value
|
||||||
|
return q_values.max(1)[1].item()
|
||||||
else:
|
else:
|
||||||
|
# Random action
|
||||||
return random.randrange(self.action_size)
|
return random.randrange(self.action_size)
|
||||||
|
|
||||||
def learn(self):
|
def learn(self):
|
||||||
@ -1270,12 +1454,27 @@ class Agent:
|
|||||||
# Sample a batch of experiences
|
# Sample a batch of experiences
|
||||||
experiences = self.memory.sample(BATCH_SIZE)
|
experiences = self.memory.sample(BATCH_SIZE)
|
||||||
|
|
||||||
# Convert experiences to tensors
|
# Convert experiences to tensors more efficiently
|
||||||
states = torch.FloatTensor([e.state for e in experiences]).to(self.device)
|
# First create numpy arrays, then convert to tensors
|
||||||
actions = torch.LongTensor([e.action for e in experiences]).to(self.device)
|
states_np = np.array([e.state for e in experiences], dtype=np.float32) # Ensure float32
|
||||||
rewards = torch.FloatTensor([e.reward for e in experiences]).to(self.device)
|
actions_np = np.array([e.action for e in experiences], dtype=np.int64) # Ensure int64
|
||||||
next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device)
|
rewards_np = np.array([e.reward for e in experiences], dtype=np.float32) # Ensure float32
|
||||||
dones = torch.FloatTensor([e.done for e in experiences]).to(self.device)
|
next_states_np = np.array([e.next_state for e in experiences], dtype=np.float32) # Ensure float32
|
||||||
|
dones_np = np.array([e.done for e in experiences], dtype=np.float32) # Ensure float32
|
||||||
|
|
||||||
|
# Convert numpy arrays to tensors with pinned memory if using GPU
|
||||||
|
if self.use_pinned_memory:
|
||||||
|
states = torch.from_numpy(states_np).pin_memory().to(self.device, non_blocking=True)
|
||||||
|
actions = torch.from_numpy(actions_np).long().pin_memory().to(self.device, non_blocking=True)
|
||||||
|
rewards = torch.from_numpy(rewards_np).pin_memory().to(self.device, non_blocking=True)
|
||||||
|
next_states = torch.from_numpy(next_states_np).pin_memory().to(self.device, non_blocking=True)
|
||||||
|
dones = torch.from_numpy(dones_np).pin_memory().to(self.device, non_blocking=True)
|
||||||
|
else:
|
||||||
|
states = torch.FloatTensor(states_np).to(self.device)
|
||||||
|
actions = torch.LongTensor(actions_np).to(self.device)
|
||||||
|
rewards = torch.FloatTensor(rewards_np).to(self.device)
|
||||||
|
next_states = torch.FloatTensor(next_states_np).to(self.device)
|
||||||
|
dones = torch.FloatTensor(dones_np).to(self.device)
|
||||||
|
|
||||||
# Use mixed precision for forward/backward passes
|
# Use mixed precision for forward/backward passes
|
||||||
if self.device.type == "cuda":
|
if self.device.type == "cuda":
|
||||||
@ -1346,29 +1545,60 @@ class Agent:
|
|||||||
def update_target_network(self):
|
def update_target_network(self):
|
||||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||||
|
|
||||||
def save(self, path="models/trading_agent.pt"):
|
def save(self, path):
|
||||||
|
"""Save model to path"""
|
||||||
|
try:
|
||||||
|
# Create directory if it doesn't exist
|
||||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
|
||||||
|
# Save model state
|
||||||
torch.save({
|
torch.save({
|
||||||
'policy_net': self.policy_net.state_dict(),
|
'policy_net': self.policy_net.state_dict(),
|
||||||
'target_net': self.target_net.state_dict(),
|
'target_net': self.target_net.state_dict(),
|
||||||
'optimizer': self.optimizer.state_dict(),
|
'optimizer': self.optimizer.state_dict(),
|
||||||
'epsilon': self.epsilon,
|
|
||||||
'steps_done': self.steps_done
|
'steps_done': self.steps_done
|
||||||
}, path)
|
}, path)
|
||||||
logger.info(f"Model saved to {path}")
|
|
||||||
|
|
||||||
def load(self, path="models/trading_agent.pt"):
|
logger.info(f"Model saved to {path}")
|
||||||
if os.path.isfile(path):
|
except Exception as e:
|
||||||
checkpoint = torch.load(path)
|
logger.error(f"Failed to save model: {e}")
|
||||||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
def load(self, path):
|
||||||
|
"""Load model from path with proper error handling for PyTorch 2.6+"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Loading model from {path}")
|
||||||
|
|
||||||
|
# First try with weights_only=True (safer)
|
||||||
|
try:
|
||||||
|
# Add numpy scalar to safe globals first
|
||||||
|
import torch.serialization
|
||||||
|
torch.serialization.add_safe_globals(['numpy._core.multiarray.scalar'])
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
checkpoint = torch.load(path, map_location=self.device)
|
||||||
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
||||||
self.target_net.load_state_dict(checkpoint['target_net'])
|
self.target_net.load_state_dict(checkpoint['target_net'])
|
||||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
self.epsilon = checkpoint['epsilon']
|
self.steps_done = checkpoint.get('steps_done', 0)
|
||||||
self.steps_done = checkpoint['steps_done']
|
logger.info(f"Model loaded successfully with weights_only=True")
|
||||||
logger.info(f"Model loaded from {path}")
|
|
||||||
return True
|
except Exception as e:
|
||||||
logger.warning(f"No model found at {path}")
|
logger.warning(f"Could not load with weights_only=True: {e}")
|
||||||
return False
|
logger.warning("Attempting to load with weights_only=False (less secure)")
|
||||||
|
|
||||||
|
# Fall back to weights_only=False (less secure but more compatible)
|
||||||
|
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
||||||
|
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
||||||
|
self.target_net.load_state_dict(checkpoint['target_net'])
|
||||||
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
self.steps_done = checkpoint.get('steps_done', 0)
|
||||||
|
logger.info(f"Model loaded successfully with weights_only=False")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model: {e}")
|
||||||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
|
raise
|
||||||
|
|
||||||
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
||||||
"""Get live price data using websockets"""
|
"""Get live price data using websockets"""
|
||||||
@ -1408,10 +1638,28 @@ async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
|||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
break
|
break
|
||||||
|
|
||||||
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000):
|
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, exchange=None):
|
||||||
"""Train the agent using historical and live data with GPU acceleration"""
|
"""Train the agent using historical and live data with GPU acceleration"""
|
||||||
logger.info(f"Starting training on device: {agent.device}")
|
logger.info(f"Starting training on device: {agent.device}")
|
||||||
|
|
||||||
|
# Add early stopping based on performance
|
||||||
|
patience = 50 # Episodes to wait for improvement
|
||||||
|
best_pnl = -float('inf')
|
||||||
|
episodes_without_improvement = 0
|
||||||
|
|
||||||
|
# Add adaptive learning rate
|
||||||
|
initial_lr = LEARNING_RATE
|
||||||
|
min_lr = LEARNING_RATE / 10
|
||||||
|
|
||||||
|
# Add curriculum learning
|
||||||
|
curriculum_stages = [
|
||||||
|
{"episodes": 100, "risk_factor": 0.5, "exploration": 0.3}, # Conservative trading
|
||||||
|
{"episodes": 200, "risk_factor": 0.75, "exploration": 0.2}, # Moderate risk
|
||||||
|
{"episodes": 300, "risk_factor": 1.0, "exploration": 0.1}, # Normal risk
|
||||||
|
{"episodes": 400, "risk_factor": 1.25, "exploration": 0.05} # Aggressive trading
|
||||||
|
]
|
||||||
|
current_stage = 0
|
||||||
|
|
||||||
stats = {
|
stats = {
|
||||||
'episode_rewards': [],
|
'episode_rewards': [],
|
||||||
'episode_lengths': [],
|
'episode_lengths': [],
|
||||||
@ -1420,19 +1668,45 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
|||||||
'episode_pnls': [],
|
'episode_pnls': [],
|
||||||
'cumulative_pnl': [],
|
'cumulative_pnl': [],
|
||||||
'drawdowns': [],
|
'drawdowns': [],
|
||||||
'prediction_accuracy': [],
|
'prediction_accuracy': []
|
||||||
'trade_analysis': []
|
|
||||||
}
|
}
|
||||||
|
|
||||||
best_reward = -float('inf')
|
|
||||||
best_pnl = -float('inf')
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize price predictor
|
# Initialize price predictor
|
||||||
env.initialize_price_predictor(agent.device)
|
env.initialize_price_predictor(agent.device)
|
||||||
|
|
||||||
for episode in range(num_episodes):
|
for episode in range(num_episodes):
|
||||||
try:
|
try:
|
||||||
|
# Update curriculum stage if needed
|
||||||
|
if current_stage < len(curriculum_stages) - 1 and episode >= curriculum_stages[current_stage]["episodes"]:
|
||||||
|
current_stage += 1
|
||||||
|
logger.info(f"Moving to curriculum stage {current_stage+1}: "
|
||||||
|
f"risk_factor={curriculum_stages[current_stage]['risk_factor']}, "
|
||||||
|
f"exploration={curriculum_stages[current_stage]['exploration']}")
|
||||||
|
|
||||||
|
# Apply curriculum settings
|
||||||
|
risk_factor = curriculum_stages[current_stage]["risk_factor"]
|
||||||
|
exploration = curriculum_stages[current_stage]["exploration"]
|
||||||
|
|
||||||
|
# Set exploration rate for this episode
|
||||||
|
agent.epsilon = exploration
|
||||||
|
|
||||||
|
# Set risk factor for this episode
|
||||||
|
env.risk_factor = risk_factor
|
||||||
|
|
||||||
|
# Refresh data with latest candles if exchange is provided
|
||||||
|
if exchange is not None:
|
||||||
|
try:
|
||||||
|
logger.info(f"Fetching latest data for episode {episode}")
|
||||||
|
latest_data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 100)
|
||||||
|
if latest_data:
|
||||||
|
# Add new data to environment
|
||||||
|
for candle in latest_data:
|
||||||
|
env.add_data(candle)
|
||||||
|
logger.info(f"Added {len(latest_data)} new candles for episode {episode}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error refreshing data: {e}")
|
||||||
|
|
||||||
# Reset environment
|
# Reset environment
|
||||||
state = env.reset()
|
state = env.reset()
|
||||||
episode_reward = 0
|
episode_reward = 0
|
||||||
@ -1457,49 +1731,34 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
|||||||
# Store experience
|
# Store experience
|
||||||
agent.memory.push(state, action, reward, next_state, done)
|
agent.memory.push(state, action, reward, next_state, done)
|
||||||
|
|
||||||
|
# Learn from experience
|
||||||
|
loss = agent.learn()
|
||||||
|
|
||||||
|
# Update state and reward
|
||||||
state = next_state
|
state = next_state
|
||||||
episode_reward += reward
|
episode_reward += reward
|
||||||
|
|
||||||
# Learn from experience with mixed precision
|
# Break if done
|
||||||
try:
|
|
||||||
loss = agent.learn()
|
|
||||||
if loss is not None:
|
|
||||||
agent.writer.add_scalar('Loss/train', loss, agent.steps_done)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Learning error in episode {episode}, step {step}: {e}")
|
|
||||||
|
|
||||||
# Update price predictions periodically
|
|
||||||
if step % 10 == 0:
|
|
||||||
env.update_price_predictions()
|
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Update target network
|
|
||||||
if episode % TARGET_UPDATE == 0:
|
|
||||||
agent.target_net.load_state_dict(agent.policy_net.state_dict())
|
|
||||||
|
|
||||||
# Calculate win rate
|
# Calculate win rate
|
||||||
if len(env.trades) > 0:
|
total_trades = env.win_count + env.loss_count
|
||||||
wins = sum(1 for trade in env.trades if trade.get('pnl_percent', 0) > 0)
|
win_rate = (env.win_count / total_trades * 100) if total_trades > 0 else 0
|
||||||
win_rate = wins / len(env.trades) * 100
|
|
||||||
else:
|
|
||||||
win_rate = 0
|
|
||||||
|
|
||||||
# Analyze trades
|
|
||||||
trade_analysis = env.analyze_trades()
|
|
||||||
stats['trade_analysis'].append(trade_analysis)
|
|
||||||
|
|
||||||
# Calculate prediction accuracy
|
# Calculate prediction accuracy
|
||||||
prediction_accuracy = 0.0
|
|
||||||
if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0:
|
if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0:
|
||||||
if len(env.data) > 5:
|
# Compare predictions with actual prices
|
||||||
actual_prices = [candle['close'] for candle in env.data[-5:]]
|
actual_prices = env.features['price'][-len(env.predicted_prices):]
|
||||||
predicted = env.predicted_prices[:min(5, len(actual_prices))]
|
prediction_errors = np.abs(env.predicted_prices - actual_prices) / actual_prices
|
||||||
errors = [abs(p - a) / a for p, a in zip(predicted, actual_prices[:len(predicted)])]
|
prediction_accuracy = 100 * (1 - np.mean(prediction_errors))
|
||||||
prediction_accuracy = 100 * (1 - sum(errors) / len(errors))
|
else:
|
||||||
|
prediction_accuracy = 0
|
||||||
|
|
||||||
# Log statistics
|
# Analyze trades
|
||||||
|
trade_analysis = env.analyze_trades() if hasattr(env, 'analyze_trades') else {}
|
||||||
|
|
||||||
|
# Update stats
|
||||||
stats['episode_rewards'].append(episode_reward)
|
stats['episode_rewards'].append(episode_reward)
|
||||||
stats['episode_lengths'].append(step + 1)
|
stats['episode_lengths'].append(step + 1)
|
||||||
stats['balances'].append(env.balance)
|
stats['balances'].append(env.balance)
|
||||||
@ -1534,18 +1793,60 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
|||||||
if episode_reward > best_reward:
|
if episode_reward > best_reward:
|
||||||
best_reward = episode_reward
|
best_reward = episode_reward
|
||||||
agent.save("models/trading_agent_best_reward.pt")
|
agent.save("models/trading_agent_best_reward.pt")
|
||||||
|
logger.info(f"New best reward model saved: {episode_reward:.2f}")
|
||||||
|
|
||||||
# Save best model by PnL
|
# Save best model by PnL
|
||||||
if env.episode_pnl > best_pnl:
|
if env.episode_pnl > best_pnl:
|
||||||
best_pnl = env.episode_pnl
|
best_pnl = env.episode_pnl
|
||||||
agent.save("models/trading_agent_best_pnl.pt")
|
agent.save("models/trading_agent_best_pnl.pt")
|
||||||
|
logger.info(f"New best PnL model saved: ${env.episode_pnl:.2f}")
|
||||||
|
|
||||||
# Save checkpoint
|
# Save best model by win rate (if enough trades)
|
||||||
|
if total_trades >= 10 and win_rate > best_win_rate:
|
||||||
|
best_win_rate = win_rate
|
||||||
|
agent.save("models/trading_agent_best_winrate.pt")
|
||||||
|
logger.info(f"New best win rate model saved: {win_rate:.1f}%")
|
||||||
|
|
||||||
|
# Save checkpoint every 10 episodes
|
||||||
if episode % 10 == 0:
|
if episode % 10 == 0:
|
||||||
agent.save(f"models/trading_agent_episode_{episode}.pt")
|
checkpoint_path = f"checkpoints/trading_agent_episode_{episode}.pt"
|
||||||
|
agent.save(checkpoint_path)
|
||||||
|
|
||||||
|
# Save best metrics to resume training if interrupted
|
||||||
|
best_metrics = {
|
||||||
|
'best_reward': float(best_reward),
|
||||||
|
'best_pnl': float(best_pnl),
|
||||||
|
'best_win_rate': float(best_win_rate),
|
||||||
|
'last_episode': episode,
|
||||||
|
'timestamp': datetime.datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
with open("checkpoints/best_metrics.json", 'w') as f:
|
||||||
|
json.dump(best_metrics, f)
|
||||||
|
|
||||||
|
logger.info(f"Checkpoint saved at episode {episode}")
|
||||||
|
|
||||||
|
# Check for early stopping
|
||||||
|
if env.episode_pnl > best_pnl:
|
||||||
|
best_pnl = env.episode_pnl
|
||||||
|
episodes_without_improvement = 0
|
||||||
|
else:
|
||||||
|
episodes_without_improvement += 1
|
||||||
|
|
||||||
|
# Adjust learning rate based on performance
|
||||||
|
if episodes_without_improvement > 20:
|
||||||
|
# Reduce learning rate
|
||||||
|
for param_group in agent.optimizer.param_groups:
|
||||||
|
param_group['lr'] = max(param_group['lr'] * 0.9, min_lr)
|
||||||
|
logger.info(f"Reducing learning rate to {agent.optimizer.param_groups[0]['lr']:.6f}")
|
||||||
|
|
||||||
|
# Early stopping check
|
||||||
|
if episodes_without_improvement >= patience:
|
||||||
|
logger.info(f"Early stopping triggered after {episode+1} episodes without improvement")
|
||||||
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in episode {episode}: {e}")
|
logger.error(f"Error in episode {episode}: {e}")
|
||||||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Save final model
|
# Save final model
|
||||||
@ -1556,7 +1857,14 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Training failed: {e}")
|
logger.error(f"Training failed: {e}")
|
||||||
raise
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
# Save emergency checkpoint
|
||||||
|
try:
|
||||||
|
agent.save("models/trading_agent_emergency.pt")
|
||||||
|
logger.info("Emergency model saved due to training failure")
|
||||||
|
except Exception as save_error:
|
||||||
|
logger.error(f"Failed to save emergency model: {save_error}")
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
@ -1901,6 +2209,7 @@ async def main():
|
|||||||
help='Mode to run the bot in')
|
help='Mode to run the bot in')
|
||||||
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train')
|
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train')
|
||||||
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
|
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
|
||||||
|
parser.add_argument('--refresh-data', action='store_true', help='Refresh data during training')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Get device (GPU or CPU)
|
# Get device (GPU or CPU)
|
||||||
@ -1924,6 +2233,12 @@ async def main():
|
|||||||
if args.mode == 'train':
|
if args.mode == 'train':
|
||||||
# Train the agent
|
# Train the agent
|
||||||
logger.info(f"Starting training for {args.episodes} episodes...")
|
logger.info(f"Starting training for {args.episodes} episodes...")
|
||||||
|
|
||||||
|
# Pass exchange to training function if refresh-data is enabled
|
||||||
|
if args.refresh_data:
|
||||||
|
logger.info("Data refresh enabled during training")
|
||||||
|
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange)
|
||||||
|
else:
|
||||||
stats = await train_agent(agent, env, num_episodes=args.episodes)
|
stats = await train_agent(agent, env, num_episodes=args.episodes)
|
||||||
|
|
||||||
elif args.mode == 'evaluate':
|
elif args.mode == 'evaluate':
|
||||||
@ -1955,6 +2270,11 @@ async def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not properly close exchange connection: {e}")
|
logger.warning(f"Could not properly close exchange connection: {e}")
|
||||||
|
|
||||||
|
def ensure_float32(model):
|
||||||
|
"""Ensure all model parameters are float32"""
|
||||||
|
for param in model.parameters():
|
||||||
|
param.data = param.data.float() # Convert to float32
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
@ -46,6 +46,12 @@ pip install -r requirements.txt
|
|||||||
```bash
|
```bash
|
||||||
MEXC_API_KEY=your_api_key
|
MEXC_API_KEY=your_api_key
|
||||||
MEXC_API_SECRET=your_api_secret
|
MEXC_API_SECRET=your_api_secret
|
||||||
|
|
||||||
|
|
||||||
|
cuda support
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||||
```
|
```
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user