enhancements
This commit is contained in:
296
NN/train_rl.py
296
NN/train_rl.py
@ -63,6 +63,9 @@ class RLTradingEnvironment(gym.Env):
|
||||
|
||||
# State variables
|
||||
self.reset()
|
||||
|
||||
# Callback for visualization or external monitoring
|
||||
self.action_callback = None
|
||||
|
||||
def reset(self):
|
||||
"""Reset the environment to initial state"""
|
||||
@ -145,6 +148,7 @@ class RLTradingEnvironment(gym.Env):
|
||||
# Default reward is slightly negative to discourage inaction
|
||||
reward = -0.0001
|
||||
done = False
|
||||
profit_pct = None # Initialize profit_pct variable
|
||||
|
||||
# Execute action
|
||||
if action == 0: # BUY
|
||||
@ -218,214 +222,188 @@ class RLTradingEnvironment(gym.Env):
|
||||
'total_value': total_value,
|
||||
'gain': gain,
|
||||
'trades': self.trades,
|
||||
'win_rate': self.win_rate
|
||||
'win_rate': self.win_rate,
|
||||
'profit_pct': profit_pct if action == 1 and self.position == 0 else None,
|
||||
'current_price': current_price,
|
||||
'next_price': next_price
|
||||
}
|
||||
|
||||
# Call the callback if it exists
|
||||
if self.action_callback:
|
||||
self.action_callback(action, current_price, reward, info)
|
||||
|
||||
return observation, reward, done, info
|
||||
|
||||
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent"):
|
||||
def set_action_callback(self, callback):
|
||||
"""
|
||||
Set a callback function to be called after each action
|
||||
|
||||
Args:
|
||||
callback: Function with signature (action, price, reward, info)
|
||||
"""
|
||||
self.action_callback = callback
|
||||
|
||||
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent",
|
||||
action_callback=None, episode_callback=None, symbol="BTC/USDT"):
|
||||
"""
|
||||
Train DQN agent for RL-based trading with extended training and monitoring
|
||||
|
||||
Args:
|
||||
env_class: Optional environment class to use, defaults to RLTradingEnvironment
|
||||
num_episodes: Number of episodes to train
|
||||
max_steps: Maximum steps per episode
|
||||
save_path: Path to save the model
|
||||
action_callback: Optional callback for each action (step, action, price, reward, info)
|
||||
episode_callback: Optional callback after each episode (episode, reward, info)
|
||||
symbol: Trading pair symbol (e.g., "BTC/USDT")
|
||||
|
||||
Returns:
|
||||
DQNAgent: The trained agent
|
||||
"""
|
||||
logger.info("Starting extended RL training for trading...")
|
||||
import pandas as pd
|
||||
from NN.utils.data_interface import DataInterface
|
||||
|
||||
# Environment setup
|
||||
window_size = 20
|
||||
timeframes = ["1m", "5m", "15m"]
|
||||
trading_fee = 0.001
|
||||
logger.info("Starting DQN training for RL trading")
|
||||
|
||||
# Ensure save directory exists
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
# Create data interface with specified symbol
|
||||
data_interface = DataInterface(symbol=symbol)
|
||||
|
||||
# Setup TensorBoard for monitoring
|
||||
writer = SummaryWriter(f'runs/rl_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}')
|
||||
|
||||
# Data loading
|
||||
data_interface = DataInterface(
|
||||
symbol="BTC/USDT",
|
||||
timeframes=timeframes
|
||||
)
|
||||
|
||||
# Get training data for each timeframe with more data
|
||||
logger.info("Loading training data...")
|
||||
features_1m = data_interface.get_training_data("1m", n_candles=5000)
|
||||
if features_1m is not None:
|
||||
logger.info(f"Loaded {len(features_1m)} 1m candles")
|
||||
else:
|
||||
logger.error("Failed to load 1m data")
|
||||
return None
|
||||
|
||||
features_5m = data_interface.get_training_data("5m", n_candles=2500)
|
||||
if features_5m is not None:
|
||||
logger.info(f"Loaded {len(features_5m)} 5m candles")
|
||||
else:
|
||||
logger.error("Failed to load 5m data")
|
||||
return None
|
||||
|
||||
features_15m = data_interface.get_training_data("15m", n_candles=2500)
|
||||
if features_15m is not None:
|
||||
logger.info(f"Loaded {len(features_15m)} 15m candles")
|
||||
else:
|
||||
logger.error("Failed to load 15m data")
|
||||
return None
|
||||
# Load and preprocess data
|
||||
logger.info(f"Loading data from multiple timeframes for {symbol}")
|
||||
features_1m = data_interface.get_training_data("1m", n_candles=2000)
|
||||
features_5m = data_interface.get_training_data("5m", n_candles=1000)
|
||||
features_15m = data_interface.get_training_data("15m", n_candles=500)
|
||||
|
||||
# Check if we have all the data
|
||||
if features_1m is None or features_5m is None or features_15m is None:
|
||||
logger.error("Failed to load training data")
|
||||
logger.error("Failed to load training data from one or more timeframes")
|
||||
return None
|
||||
|
||||
# Convert DataFrames to numpy arrays, excluding timestamp column
|
||||
features_1m = features_1m.drop('timestamp', axis=1, errors='ignore').values
|
||||
features_5m = features_5m.drop('timestamp', axis=1, errors='ignore').values
|
||||
features_15m = features_15m.drop('timestamp', axis=1, errors='ignore').values
|
||||
# If data is a DataFrame, convert to numpy array excluding the timestamp column
|
||||
if isinstance(features_1m, pd.DataFrame):
|
||||
features_1m = features_1m.drop('timestamp', axis=1, errors='ignore').values
|
||||
if isinstance(features_5m, pd.DataFrame):
|
||||
features_5m = features_5m.drop('timestamp', axis=1, errors='ignore').values
|
||||
if isinstance(features_15m, pd.DataFrame):
|
||||
features_15m = features_15m.drop('timestamp', axis=1, errors='ignore').values
|
||||
|
||||
# Calculate number of features per timeframe
|
||||
num_features = features_1m.shape[1] # Number of features after dropping timestamp
|
||||
# Initialize environment or use provided class
|
||||
if env_class is None:
|
||||
env = RLTradingEnvironment(features_1m, features_5m, features_15m)
|
||||
else:
|
||||
env = env_class(features_1m, features_5m, features_15m)
|
||||
|
||||
# Create environment
|
||||
env = RLTradingEnvironment(
|
||||
features_1m=features_1m,
|
||||
features_5m=features_5m,
|
||||
features_15m=features_15m,
|
||||
window_size=window_size,
|
||||
trading_fee=trading_fee
|
||||
)
|
||||
# Set action callback if provided
|
||||
if action_callback:
|
||||
def step_callback(action, price, reward, info):
|
||||
action_callback(env.current_step, action, price, reward, info)
|
||||
env.set_action_callback(step_callback)
|
||||
|
||||
# Initialize agent
|
||||
window_size = env.window_size
|
||||
num_features = env.num_features * env.num_timeframes
|
||||
action_size = env.action_space.n
|
||||
timeframes = ['1m', '5m', '15m'] # Match the timeframes from the environment
|
||||
|
||||
# Create agent with adjusted parameters for longer training
|
||||
state_size = window_size
|
||||
action_size = 3
|
||||
agent = DQNAgent(
|
||||
state_size=state_size,
|
||||
state_size=window_size * num_features,
|
||||
action_size=action_size,
|
||||
window_size=window_size,
|
||||
num_features=num_features,
|
||||
num_features=env.num_features,
|
||||
timeframes=timeframes,
|
||||
learning_rate=0.0005, # Reduced learning rate for stability
|
||||
gamma=0.99, # Increased discount factor
|
||||
memory_size=100000,
|
||||
batch_size=64,
|
||||
learning_rate=0.0001,
|
||||
gamma=0.99,
|
||||
epsilon=1.0,
|
||||
epsilon_min=0.01,
|
||||
epsilon_decay=0.999, # Slower epsilon decay
|
||||
memory_size=50000, # Increased memory size
|
||||
batch_size=128 # Increased batch size
|
||||
epsilon_decay=0.995
|
||||
)
|
||||
|
||||
# Variables to track best performance
|
||||
best_reward = float('-inf')
|
||||
best_episode = 0
|
||||
best_pnl = float('-inf')
|
||||
best_win_rate = 0.0
|
||||
|
||||
# Training metrics
|
||||
# Training variables
|
||||
best_reward = -float('inf')
|
||||
episode_rewards = []
|
||||
episode_pnls = []
|
||||
episode_win_rates = []
|
||||
episode_trades = []
|
||||
|
||||
# Check if previous best model exists and load it
|
||||
best_model_path = f"{save_path}_best"
|
||||
if os.path.exists(f"{best_model_path}_policy.pt"):
|
||||
try:
|
||||
logger.info(f"Loading previous best model from {best_model_path}")
|
||||
agent.load(best_model_path)
|
||||
metadata_path = f"{best_model_path}_metadata.json"
|
||||
if os.path.exists(metadata_path):
|
||||
with open(metadata_path, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
best_reward = metadata.get('best_reward', best_reward)
|
||||
best_episode = metadata.get('best_episode', best_episode)
|
||||
best_pnl = metadata.get('best_pnl', best_pnl)
|
||||
best_win_rate = metadata.get('best_win_rate', best_win_rate)
|
||||
logger.info(f"Loaded previous best metrics - Reward: {best_reward:.4f}, PnL: {best_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading previous best model: {e}")
|
||||
# TensorBoard writer for logging
|
||||
writer = SummaryWriter(log_dir=f'runs/rl_trading_{int(time.time())}')
|
||||
|
||||
# Main training loop
|
||||
logger.info(f"Starting training for {num_episodes} episodes...")
|
||||
logger.info(f"Starting training on device: {agent.device}")
|
||||
|
||||
# Training loop
|
||||
try:
|
||||
for episode in range(1, num_episodes + 1):
|
||||
for episode in range(num_episodes):
|
||||
state = env.reset()
|
||||
total_reward = 0
|
||||
done = False
|
||||
steps = 0
|
||||
|
||||
while not done and steps < max_steps:
|
||||
for step in range(max_steps):
|
||||
# Select action
|
||||
action = agent.act(state)
|
||||
|
||||
# Take action and observe next state and reward
|
||||
next_state, reward, done, info = env.step(action)
|
||||
|
||||
# Store the experience in memory
|
||||
agent.remember(state, action, reward, next_state, done)
|
||||
|
||||
# Learn from experience
|
||||
loss = agent.replay()
|
||||
|
||||
# Update state and reward
|
||||
state = next_state
|
||||
total_reward += reward
|
||||
steps += 1
|
||||
|
||||
# Train the agent by sampling from memory
|
||||
if len(agent.memory) >= agent.batch_size:
|
||||
loss = agent.replay()
|
||||
|
||||
if done or step == max_steps - 1:
|
||||
break
|
||||
|
||||
# Calculate episode metrics
|
||||
# Track rewards
|
||||
episode_rewards.append(total_reward)
|
||||
episode_pnls.append(info['gain'])
|
||||
episode_win_rates.append(info['win_rate'])
|
||||
episode_trades.append(info['trades'])
|
||||
|
||||
# Log progress
|
||||
avg_reward = np.mean(episode_rewards[-100:])
|
||||
logger.info(f"Episode {episode}/{num_episodes} - Reward: {total_reward:.4f}, " +
|
||||
f"Avg (100): {avg_reward:.4f}, Epsilon: {agent.epsilon:.4f}")
|
||||
|
||||
# Calculate trading metrics
|
||||
win_rate = env.win_rate if hasattr(env, 'win_rate') else 0
|
||||
trades = env.trades if hasattr(env, 'trades') else 0
|
||||
|
||||
# Log to TensorBoard
|
||||
writer.add_scalar('Reward/episode', total_reward, episode)
|
||||
writer.add_scalar('PnL/episode', info['gain'], episode)
|
||||
writer.add_scalar('WinRate/episode', info['win_rate'], episode)
|
||||
writer.add_scalar('Trades/episode', info['trades'], episode)
|
||||
writer.add_scalar('Epsilon/episode', agent.epsilon, episode)
|
||||
writer.add_scalar('Reward/Episode', total_reward, episode)
|
||||
writer.add_scalar('Reward/Average100', avg_reward, episode)
|
||||
writer.add_scalar('Trade/WinRate', win_rate, episode)
|
||||
writer.add_scalar('Trade/Count', trades, episode)
|
||||
|
||||
# Save the best model based on multiple metrics (only every 50 episodes)
|
||||
is_better = False
|
||||
if episode % 50 == 0: # Only check for saving every 50 episodes
|
||||
if (info['gain'] > best_pnl and info['win_rate'] > 0.5) or \
|
||||
(info['gain'] > best_pnl * 1.1) or \
|
||||
(info['win_rate'] > best_win_rate * 1.1):
|
||||
best_reward = total_reward
|
||||
best_episode = episode
|
||||
best_pnl = info['gain']
|
||||
best_win_rate = info['win_rate']
|
||||
agent.save(best_model_path)
|
||||
is_better = True
|
||||
|
||||
# Save metadata about the best model
|
||||
metadata = {
|
||||
'best_reward': best_reward,
|
||||
'best_episode': best_episode,
|
||||
'best_pnl': best_pnl,
|
||||
'best_win_rate': best_win_rate,
|
||||
'date': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
with open(f"{best_model_path}_metadata.json", 'w') as f:
|
||||
json.dump(metadata, f)
|
||||
# Save best model
|
||||
if avg_reward > best_reward and episode > 10:
|
||||
logger.info(f"New best average reward: {avg_reward:.4f}, saving model")
|
||||
agent.save(save_path)
|
||||
best_reward = avg_reward
|
||||
|
||||
# Log training progress
|
||||
if episode % 10 == 0:
|
||||
avg_reward = sum(episode_rewards[-10:]) / 10
|
||||
avg_pnl = sum(episode_pnls[-10:]) / 10
|
||||
avg_win_rate = sum(episode_win_rates[-10:]) / 10
|
||||
avg_trades = sum(episode_trades[-10:]) / 10
|
||||
# Periodic save every 100 episodes
|
||||
if episode % 100 == 0 and episode > 0:
|
||||
agent.save(f"{save_path}_episode_{episode}")
|
||||
|
||||
status = "NEW BEST!" if is_better else ""
|
||||
logger.info(f"Episode {episode}/{num_episodes} {status}")
|
||||
logger.info(f"Metrics (last 10 episodes):")
|
||||
logger.info(f" Reward: {avg_reward:.4f}")
|
||||
logger.info(f" PnL: {avg_pnl:.4f}")
|
||||
logger.info(f" Win Rate: {avg_win_rate:.4f}")
|
||||
logger.info(f" Trades: {avg_trades:.2f}")
|
||||
logger.info(f" Epsilon: {agent.epsilon:.4f}")
|
||||
logger.info(f"Best so far - PnL: {best_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user. Saving best model...")
|
||||
# Call episode callback if provided
|
||||
if episode_callback:
|
||||
# Add environment to info dict to use for extrema training
|
||||
info_with_env = info.copy()
|
||||
info_with_env['env'] = env
|
||||
episode_callback(episode, total_reward, info_with_env)
|
||||
|
||||
# Final save
|
||||
logger.info("Training completed, saving final model")
|
||||
agent.save(f"{save_path}_final")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# Close TensorBoard writer
|
||||
writer.close()
|
||||
|
||||
# Final logs
|
||||
logger.info(f"Training completed. Best model from episode {best_episode}")
|
||||
logger.info(f"Best metrics:")
|
||||
logger.info(f" Reward: {best_reward:.4f}")
|
||||
logger.info(f" PnL: {best_pnl:.4f}")
|
||||
logger.info(f" Win Rate: {best_win_rate:.4f}")
|
||||
|
||||
# Return the agent for potential further use
|
||||
return agent
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user