657 lines
28 KiB
Python
657 lines
28 KiB
Python
import torch
|
|
import numpy as np
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import logging
|
|
import time
|
|
from datetime import datetime
|
|
import os
|
|
import sys
|
|
import pandas as pd
|
|
import gym
|
|
import json
|
|
import random
|
|
import torch.nn as nn
|
|
import contextlib
|
|
|
|
# Add parent directory to path
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from NN.utils.data_interface import DataInterface
|
|
from NN.utils.trading_env import TradingEnvironment
|
|
from NN.models.dqn_agent import DQNAgent
|
|
from NN.utils.signal_interpreter import SignalInterpreter
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger(__name__)
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('rl_training.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
|
|
# Set up device for PyTorch (use GPU if available)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# Log GPU status
|
|
if torch.cuda.is_available():
|
|
gpu_count = torch.cuda.device_count()
|
|
gpu_names = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
|
|
logger.info(f"Using GPU: {gpu_names}")
|
|
|
|
# Enable TensorFloat32 for NVIDIA Ampere GPUs for faster training
|
|
if hasattr(torch.cuda, 'amp') and torch.cuda.is_bf16_supported():
|
|
logger.info("BFloat16 precision is supported - will use for faster training")
|
|
else:
|
|
logger.warning("GPU not available. Using CPU for training (slower).")
|
|
|
|
class RLTradingEnvironment(gym.Env):
|
|
"""
|
|
Reinforcement Learning environment for trading with technical indicators
|
|
from multiple timeframes
|
|
"""
|
|
def __init__(self, features_1m, features_1h, features_1d, window_size=20, trading_fee=0.0025, min_trade_interval=15):
|
|
super().__init__()
|
|
|
|
# Initialize attributes before parent class
|
|
self.window_size = window_size
|
|
self.num_features = features_1m.shape[1] - 1 # Exclude close price
|
|
|
|
# Count available timeframes
|
|
self.num_timeframes = 3 # We require all timeframes now
|
|
self.feature_dim = self.num_features * self.num_timeframes
|
|
|
|
# Store features from different timeframes
|
|
self.features_1m = features_1m
|
|
self.features_1h = features_1h
|
|
self.features_1d = features_1d
|
|
|
|
# Trading parameters
|
|
self.initial_balance = 1.0
|
|
self.trading_fee = trading_fee # Increased from 0.001 to 0.0025 (0.25%)
|
|
self.min_trade_interval = min_trade_interval # Minimum steps between trades
|
|
|
|
# Define action and observation spaces
|
|
self.action_space = gym.spaces.Discrete(3) # 0: Buy, 1: Sell, 2: Hold
|
|
self.observation_space = gym.spaces.Box(
|
|
low=-np.inf,
|
|
high=np.inf,
|
|
shape=(self.window_size, self.feature_dim),
|
|
dtype=np.float32
|
|
)
|
|
|
|
# State variables
|
|
self.reset()
|
|
|
|
# Callback for visualization or external monitoring
|
|
self.action_callback = None
|
|
|
|
def reset(self):
|
|
"""Reset the environment to initial state"""
|
|
self.balance = self.initial_balance
|
|
self.position = 0.0 # Amount of asset held
|
|
self.current_step = self.window_size
|
|
self.trades = 0
|
|
self.wins = 0
|
|
self.losses = 0
|
|
self.trade_history = []
|
|
self.last_trade_step = -self.min_trade_interval # Initialize to allow immediate first trade
|
|
|
|
# Get initial observation
|
|
observation = self._get_observation()
|
|
return observation
|
|
|
|
def _get_observation(self):
|
|
"""
|
|
Get the current state observation.
|
|
Combine features from multiple timeframes, reshaped for the CNN.
|
|
"""
|
|
# Calculate indices for each timeframe
|
|
idx_1m = min(self.current_step, self.features_1m.shape[0] - 1)
|
|
idx_1h = idx_1m // 60 # 60 minutes in an hour
|
|
idx_1d = idx_1h // 24 # 24 hours in a day
|
|
|
|
# Cap indices to prevent out of bounds
|
|
idx_1h = min(idx_1h, self.features_1h.shape[0] - 1)
|
|
idx_1d = min(idx_1d, self.features_1d.shape[0] - 1)
|
|
|
|
# Extract feature windows from each timeframe
|
|
window_1m = self.features_1m[max(0, idx_1m - self.window_size):idx_1m]
|
|
|
|
# Handle hourly timeframe
|
|
start_1h = max(0, idx_1h - self.window_size)
|
|
window_1h = self.features_1h[start_1h:idx_1h]
|
|
|
|
# Handle daily timeframe
|
|
start_1d = max(0, idx_1d - self.window_size)
|
|
window_1d = self.features_1d[start_1d:idx_1d]
|
|
|
|
# Pad if needed (for higher timeframes)
|
|
if len(window_1m) < self.window_size:
|
|
padding = np.zeros((self.window_size - len(window_1m), window_1m.shape[1]))
|
|
window_1m = np.vstack([padding, window_1m])
|
|
|
|
if len(window_1h) < self.window_size:
|
|
padding = np.zeros((self.window_size - len(window_1h), window_1h.shape[1]))
|
|
window_1h = np.vstack([padding, window_1h])
|
|
|
|
if len(window_1d) < self.window_size:
|
|
padding = np.zeros((self.window_size - len(window_1d), window_1d.shape[1]))
|
|
window_1d = np.vstack([padding, window_1d])
|
|
|
|
# Combine features from all timeframes
|
|
combined_features = np.hstack([
|
|
window_1m.reshape(self.window_size, -1),
|
|
window_1h.reshape(self.window_size, -1),
|
|
window_1d.reshape(self.window_size, -1)
|
|
])
|
|
|
|
# Convert to float32 and handle any NaN values
|
|
combined_features = np.nan_to_num(combined_features, nan=0.0).astype(np.float32)
|
|
|
|
return combined_features
|
|
|
|
def step(self, action):
|
|
"""Take an action and return the next state, reward, done flag, and info"""
|
|
# Initialize info dictionary for additional data
|
|
info = {
|
|
'trade_executed': False,
|
|
'price_change': 0.0,
|
|
'position_change': 0,
|
|
'current_price': 0.0,
|
|
'next_price': 0.0,
|
|
'balance_change': 0.0,
|
|
'reward_components': {},
|
|
'future_prices': {}
|
|
}
|
|
|
|
# Get the current and next price
|
|
current_price = self.features_1m[self.current_step, -1]
|
|
|
|
# Handle edge case at the end of the data
|
|
if self.current_step >= len(self.features_1m) - 1:
|
|
next_price = current_price # Use current price as next price
|
|
done = True
|
|
else:
|
|
next_price = self.features_1m[self.current_step + 1, -1]
|
|
done = False
|
|
|
|
# Handle zero or negative price (data error)
|
|
if current_price <= 0:
|
|
current_price = 0.01 # Set to a small positive number
|
|
logger.warning(f"Zero or negative price detected at step {self.current_step}. Setting to 0.01.")
|
|
|
|
if next_price <= 0:
|
|
next_price = current_price # Use current price instead
|
|
logger.warning(f"Zero or negative next price detected at step {self.current_step + 1}. Using current price.")
|
|
|
|
# Calculate price change as percentage
|
|
price_change_pct = ((next_price - current_price) / current_price) * 100
|
|
|
|
# Store prices in info
|
|
info['current_price'] = current_price
|
|
info['next_price'] = next_price
|
|
info['price_change'] = price_change_pct
|
|
|
|
# Initialize reward components dictionary
|
|
reward_components = {
|
|
'holding_reward': 0.0,
|
|
'action_reward': 0.0,
|
|
'profit_reward': 0.0,
|
|
'trade_freq_penalty': 0.0
|
|
}
|
|
|
|
# Default small negative reward to discourage inaction
|
|
reward = -0.01
|
|
reward_components['holding_reward'] = -0.01
|
|
|
|
# Track previous balance for changes
|
|
previous_balance = self.balance
|
|
|
|
# Execute action (0: Buy, 1: Sell, 2: Hold)
|
|
if action == 0: # Buy
|
|
if self.position == 0: # Only buy if we don't already have a position
|
|
# Calculate how much of the asset we can buy with 100% of balance
|
|
self.position = self.balance / current_price
|
|
self.balance = 0 # All balance used
|
|
|
|
# If price goes up after buying, that's good
|
|
expected_profit = price_change_pct
|
|
# Scale reward based on expected profit
|
|
if expected_profit > 0:
|
|
# Positive reward for profitable buy decision
|
|
action_reward = 0.1 + (expected_profit * 0.05) # Base reward + profit-based bonus
|
|
reward_components['action_reward'] = action_reward
|
|
reward += action_reward
|
|
else:
|
|
# Small negative reward for unprofitable buy
|
|
action_reward = -0.1 + (expected_profit * 0.03) # Smaller penalty for small losses
|
|
reward_components['action_reward'] = action_reward
|
|
reward += action_reward
|
|
|
|
# Check if we've traded too frequently
|
|
if len(self.trade_history) > 0:
|
|
last_trade_step = self.trade_history[-1]['step']
|
|
if self.current_step - last_trade_step < 5: # If less than 5 steps since last trade
|
|
freq_penalty = -0.2 # Penalty for trading too frequently
|
|
reward += freq_penalty
|
|
reward_components['trade_freq_penalty'] = freq_penalty
|
|
|
|
# Record the trade
|
|
self.trade_history.append({
|
|
'step': self.current_step,
|
|
'action': 'buy',
|
|
'price': current_price,
|
|
'position': self.position,
|
|
'balance': self.balance
|
|
})
|
|
|
|
info['trade_executed'] = True
|
|
logger.info(f"Buy at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
|
|
|
|
elif action == 1: # Sell
|
|
if self.position > 0: # Only sell if we have a position
|
|
# Calculate sale proceeds
|
|
sale_value = self.position * current_price
|
|
|
|
# Calculate profit or loss percentage from last buy
|
|
last_buy_price = None
|
|
for trade in reversed(self.trade_history):
|
|
if trade['action'] == 'buy':
|
|
last_buy_price = trade['price']
|
|
break
|
|
|
|
# If we found the last buy price, calculate profit
|
|
if last_buy_price is not None:
|
|
profit_pct = ((current_price - last_buy_price) / last_buy_price) * 100
|
|
|
|
# Highly reward profitable trades
|
|
if profit_pct > 0:
|
|
# Progressive reward based on profit percentage
|
|
profit_reward = min(5.0, profit_pct * 0.2) # Cap at 5.0 to prevent exploitation
|
|
reward_components['profit_reward'] = profit_reward
|
|
reward += profit_reward
|
|
logger.info(f"Profitable trade! {profit_pct:.2f}% profit, reward: {profit_reward:.4f}")
|
|
else:
|
|
# Penalize losses more heavily based on size of loss
|
|
loss_penalty = max(-3.0, profit_pct * 0.15) # Cap at -3.0 to prevent excessive punishment
|
|
reward_components['profit_reward'] = loss_penalty
|
|
reward += loss_penalty
|
|
logger.info(f"Loss trade! {profit_pct:.2f}% loss, penalty: {loss_penalty:.4f}")
|
|
|
|
# If price goes down after selling, that's good
|
|
if price_change_pct < 0:
|
|
# Reward for good timing on sell (avoiding future loss)
|
|
timing_reward = min(1.0, abs(price_change_pct) * 0.05)
|
|
reward_components['action_reward'] = timing_reward
|
|
reward += timing_reward
|
|
|
|
# Check for trading too frequently
|
|
if len(self.trade_history) > 0:
|
|
last_trade_step = self.trade_history[-1]['step']
|
|
if self.current_step - last_trade_step < 5: # If less than 5 steps since last trade
|
|
freq_penalty = -0.2 # Penalty for trading too frequently
|
|
reward += freq_penalty
|
|
reward_components['trade_freq_penalty'] = freq_penalty
|
|
|
|
# Update balance and position
|
|
self.balance = sale_value
|
|
position_change = self.position
|
|
self.position = 0
|
|
|
|
# Record the trade
|
|
self.trade_history.append({
|
|
'step': self.current_step,
|
|
'action': 'sell',
|
|
'price': current_price,
|
|
'position': self.position,
|
|
'balance': self.balance
|
|
})
|
|
|
|
info['trade_executed'] = True
|
|
info['position_change'] = position_change
|
|
logger.info(f"Sell at step {self.current_step}, price: {current_price:.4f}, new balance: {self.balance:.4f}")
|
|
|
|
elif action == 2: # Hold
|
|
# Small reward if holding was a good decision
|
|
if self.position > 0 and price_change_pct > 0: # Holding long position during price increase
|
|
hold_reward = price_change_pct * 0.01 # Small reward proportional to price increase
|
|
reward += hold_reward
|
|
reward_components['holding_reward'] = hold_reward
|
|
elif self.position == 0 and price_change_pct < 0: # Holding cash during price decrease
|
|
hold_reward = abs(price_change_pct) * 0.01 # Small reward for avoiding loss
|
|
reward += hold_reward
|
|
reward_components['holding_reward'] = hold_reward
|
|
|
|
# Move to the next step
|
|
self.current_step += 1
|
|
|
|
# Update current portfolio value
|
|
if self.position > 0:
|
|
self.current_value = self.balance + (self.position * next_price)
|
|
else:
|
|
self.current_value = self.balance
|
|
|
|
# Calculate balance change
|
|
balance_change = self.current_value - previous_balance
|
|
info['balance_change'] = balance_change
|
|
|
|
# Check if we've reached the end of the data
|
|
if self.current_step >= len(self.features_1m) - 1:
|
|
done = True
|
|
|
|
# Final evaluation if we have a position
|
|
if self.position > 0:
|
|
# Sell remaining position at the final price
|
|
final_balance = self.balance + (self.position * next_price)
|
|
|
|
# Calculate final portfolio value and return
|
|
final_return_pct = ((final_balance - self.initial_balance) / self.initial_balance) * 100
|
|
|
|
# Add big reward/penalty based on overall performance
|
|
performance_reward = final_return_pct * 0.1
|
|
reward += performance_reward
|
|
reward_components['final_performance'] = performance_reward
|
|
|
|
logger.info(f"Episode ended. Final balance: {final_balance:.4f}, Return: {final_return_pct:.2f}%")
|
|
|
|
# Get future prices for evaluation (1-hour and 1-day ahead)
|
|
info['future_prices'] = {}
|
|
|
|
# 1-hour future price if hourly data is available
|
|
if hasattr(self, 'features_1h') and self.features_1h is not None:
|
|
# Find the closest hourly data point
|
|
if self.current_step < len(self.features_1m):
|
|
current_time = self.current_step # Use as index for simplicity
|
|
hourly_idx = min(current_time // 60, len(self.features_1h) - 1) # Assuming 60 minutes per hour
|
|
if hourly_idx < len(self.features_1h) - 1:
|
|
future_1h_price = self.features_1h[hourly_idx + 1, -1]
|
|
info['future_prices']['1h'] = future_1h_price
|
|
|
|
# 1-day future price if daily data is available
|
|
if hasattr(self, 'features_1d') and self.features_1d is not None:
|
|
# Find the closest daily data point
|
|
if self.current_step < len(self.features_1m):
|
|
current_time = self.current_step # Use as index for simplicity
|
|
daily_idx = min(current_time // 1440, len(self.features_1d) - 1) # Assuming 1440 minutes per day
|
|
if daily_idx < len(self.features_1d) - 1:
|
|
future_1d_price = self.features_1d[daily_idx + 1, -1]
|
|
info['future_prices']['1d'] = future_1d_price
|
|
|
|
# Get next observation
|
|
next_state = self._get_observation()
|
|
|
|
# Store reward components in info
|
|
info['reward_components'] = reward_components
|
|
|
|
# Clip reward to prevent extreme values
|
|
reward = np.clip(reward, -10.0, 10.0)
|
|
|
|
return next_state, reward, done, info
|
|
|
|
def set_action_callback(self, callback):
|
|
"""
|
|
Set a callback function to be called after each action
|
|
|
|
Args:
|
|
callback: Function with signature (action, price, reward, info)
|
|
"""
|
|
self.action_callback = callback
|
|
|
|
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent",
|
|
action_callback=None, episode_callback=None, symbol="BTC/USDT",
|
|
pretrain_price_prediction_enabled=False, pretrain_epochs=10):
|
|
"""
|
|
Train a reinforcement learning agent for trading using ONLY real market data
|
|
|
|
Args:
|
|
env_class: Optional environment class override
|
|
num_episodes: Number of episodes to train for
|
|
max_steps: Maximum steps per episode
|
|
save_path: Path to save the trained model
|
|
action_callback: Callback function for monitoring actions
|
|
episode_callback: Callback function for monitoring episodes
|
|
symbol: Trading symbol to use
|
|
pretrain_price_prediction_enabled: DEPRECATED - No longer supported (synthetic data not used)
|
|
pretrain_epochs: DEPRECATED - No longer supported (synthetic data not used)
|
|
|
|
Returns:
|
|
tuple: (trained agent, environment)
|
|
"""
|
|
# Load data for the selected symbol
|
|
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m', '1h', '1d'])
|
|
|
|
try:
|
|
# Try to load data for the requested symbol using get_historical_data method
|
|
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=5000)
|
|
data_5m = data_interface.get_historical_data(timeframe='5m', n_candles=5000)
|
|
data_15m = data_interface.get_historical_data(timeframe='15m', n_candles=5000)
|
|
data_1h = data_interface.get_historical_data(timeframe='1h', n_candles=1000)
|
|
data_1d = data_interface.get_historical_data(timeframe='1d', n_candles=500)
|
|
|
|
if data_1m is None or data_5m is None or data_15m is None or data_1h is None or data_1d is None:
|
|
raise FileNotFoundError("Could not retrieve all required timeframes data for specified symbol")
|
|
except Exception as e:
|
|
logger.warning(f"Data for {symbol} not available: {str(e)}. Using default cached data.")
|
|
# Try to use cached data if available
|
|
symbol = "BTC/USDT"
|
|
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m', '1h', '1d'])
|
|
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=5000)
|
|
data_5m = data_interface.get_historical_data(timeframe='5m', n_candles=5000)
|
|
data_15m = data_interface.get_historical_data(timeframe='15m', n_candles=5000)
|
|
data_1h = data_interface.get_historical_data(timeframe='1h', n_candles=1000)
|
|
data_1d = data_interface.get_historical_data(timeframe='1d', n_candles=500)
|
|
|
|
if data_1m is None or data_5m is None or data_15m is None or data_1h is None or data_1d is None:
|
|
logger.error("Failed to retrieve all required timeframes data. Cannot continue training.")
|
|
raise ValueError("No data available for training")
|
|
|
|
# Create features from the data by adding technical indicators and converting to numpy format
|
|
if data_1m is not None:
|
|
data_1m = data_interface.add_technical_indicators(data_1m)
|
|
# Convert to numpy array with close price as the last column
|
|
features_1m = np.hstack([
|
|
data_1m.drop(['timestamp', 'close'], axis=1).values,
|
|
data_1m['close'].values.reshape(-1, 1)
|
|
])
|
|
else:
|
|
features_1m = None
|
|
|
|
if data_5m is not None:
|
|
data_5m = data_interface.add_technical_indicators(data_5m)
|
|
# Convert to numpy array with close price as the last column
|
|
features_5m = np.hstack([
|
|
data_5m.drop(['timestamp', 'close'], axis=1).values,
|
|
data_5m['close'].values.reshape(-1, 1)
|
|
])
|
|
else:
|
|
features_5m = None
|
|
|
|
if data_15m is not None:
|
|
data_15m = data_interface.add_technical_indicators(data_15m)
|
|
# Convert to numpy array with close price as the last column
|
|
features_15m = np.hstack([
|
|
data_15m.drop(['timestamp', 'close'], axis=1).values,
|
|
data_15m['close'].values.reshape(-1, 1)
|
|
])
|
|
else:
|
|
features_15m = None
|
|
|
|
if data_1h is not None:
|
|
data_1h = data_interface.add_technical_indicators(data_1h)
|
|
# Convert to numpy array with close price as the last column
|
|
features_1h = np.hstack([
|
|
data_1h.drop(['timestamp', 'close'], axis=1).values,
|
|
data_1h['close'].values.reshape(-1, 1)
|
|
])
|
|
else:
|
|
features_1h = None
|
|
|
|
if data_1d is not None:
|
|
data_1d = data_interface.add_technical_indicators(data_1d)
|
|
# Convert to numpy array with close price as the last column
|
|
features_1d = np.hstack([
|
|
data_1d.drop(['timestamp', 'close'], axis=1).values,
|
|
data_1d['close'].values.reshape(-1, 1)
|
|
])
|
|
else:
|
|
features_1d = None
|
|
|
|
# Check if we have all the required features
|
|
if features_1m is None or features_5m is None or features_15m is None or features_1h is None or features_1d is None:
|
|
logger.error("Failed to create features for all timeframes.")
|
|
raise ValueError("Could not create features for training")
|
|
|
|
# Create the environment
|
|
if env_class:
|
|
# Use provided environment class
|
|
env = env_class(features_1m, features_1h, features_1d)
|
|
else:
|
|
# Use the default environment
|
|
env = RLTradingEnvironment(features_1m, features_1h, features_1d)
|
|
|
|
# Set action callback if provided
|
|
if action_callback:
|
|
env.set_action_callback(action_callback)
|
|
|
|
# Get environment properties for agent creation
|
|
input_shape = env.observation_space.shape
|
|
n_actions = env.action_space.n
|
|
|
|
# Create the agent
|
|
agent = DQNAgent(
|
|
state_shape=input_shape,
|
|
n_actions=n_actions,
|
|
epsilon=1.0,
|
|
epsilon_decay=0.995,
|
|
epsilon_min=0.01,
|
|
learning_rate=0.0001,
|
|
gamma=0.99,
|
|
buffer_size=10000,
|
|
batch_size=64,
|
|
device=device # Pass device to agent for GPU usage
|
|
)
|
|
|
|
# Check if model file exists and load it
|
|
model_file = f"{save_path}_model.pth"
|
|
if os.path.exists(model_file):
|
|
try:
|
|
agent.load(model_file)
|
|
logger.info(f"Loaded existing model from {model_file}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading model: {e}")
|
|
else:
|
|
logger.info("No existing model found. Starting with a new model.")
|
|
|
|
# Remove pre-training code since it used synthetic data
|
|
# Pre-training with real data would require a separate implementation
|
|
if pretrain_price_prediction_enabled:
|
|
logger.warning("Pre-training with synthetic data is no longer supported. Continuing with RL training only.")
|
|
|
|
# Create TensorBoard writer
|
|
writer = SummaryWriter(log_dir=f'runs/dqn_{int(time.time())}')
|
|
|
|
# Log GPU status to TensorBoard
|
|
writer.add_text("hardware/device", str(device), 0)
|
|
if torch.cuda.is_available():
|
|
for i in range(torch.cuda.device_count()):
|
|
writer.add_text(f"hardware/gpu_{i}", torch.cuda.get_device_name(i), 0)
|
|
|
|
# Training loop
|
|
total_rewards = []
|
|
trade_win_rates = []
|
|
best_reward = -np.inf
|
|
|
|
# Move models to the appropriate device if not already there
|
|
agent.move_models_to_device(device)
|
|
|
|
# Enable mixed precision if GPU and feature is available
|
|
use_mixed_precision = False
|
|
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp'):
|
|
logger.info("Enabling mixed precision training")
|
|
use_mixed_precision = True
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
|
# Define step callback for tensorboard logging and model tracking
|
|
def step_callback(action, price, reward, info):
|
|
# Pass to external callback if provided
|
|
if action_callback:
|
|
action_callback(env.current_step, action, price, reward, info)
|
|
|
|
# Main training loop
|
|
logger.info(f"Starting training for {num_episodes} episodes...")
|
|
logger.info(f"Starting training on device: {agent.device}")
|
|
|
|
try:
|
|
for episode in range(num_episodes):
|
|
state = env.reset()
|
|
total_reward = 0
|
|
|
|
for step in range(max_steps):
|
|
# Select action
|
|
action = agent.act(state)
|
|
|
|
# Take action and observe next state and reward
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
# Store the experience in memory
|
|
agent.remember(state, action, reward, next_state, done)
|
|
|
|
# Update state and reward
|
|
state = next_state
|
|
total_reward += reward
|
|
|
|
# Train the agent by sampling from memory
|
|
if len(agent.memory) >= agent.batch_size:
|
|
loss = agent.replay()
|
|
|
|
if done or step == max_steps - 1:
|
|
break
|
|
|
|
# Track rewards
|
|
total_rewards.append(total_reward)
|
|
|
|
# Calculate trading metrics
|
|
win_rate = env.wins / max(1, env.trades)
|
|
trades = env.trades
|
|
|
|
# Log to TensorBoard
|
|
writer.add_scalar('Reward/Episode', total_reward, episode)
|
|
writer.add_scalar('Trade/WinRate', win_rate, episode)
|
|
writer.add_scalar('Trade/Count', trades, episode)
|
|
|
|
# Save best model
|
|
if total_reward > best_reward and episode > 10:
|
|
logger.info(f"New best average reward: {total_reward:.4f}, saving model")
|
|
agent.save(save_path)
|
|
best_reward = total_reward
|
|
|
|
# Periodic save every 100 episodes
|
|
if episode % 100 == 0 and episode > 0:
|
|
agent.save(f"{save_path}_episode_{episode}")
|
|
|
|
# Call episode callback if provided
|
|
if episode_callback:
|
|
# Add environment to info dict to use for extrema training
|
|
info_with_env = info.copy()
|
|
info_with_env['env'] = env
|
|
episode_callback(episode, total_reward, info_with_env)
|
|
|
|
# Final save
|
|
logger.info("Training completed, saving final model")
|
|
agent.save(f"{save_path}_final")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Training failed: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
|
|
# Close TensorBoard writer
|
|
writer.close()
|
|
|
|
return agent, env
|
|
|
|
if __name__ == "__main__":
|
|
train_rl() |