999 lines
41 KiB
Python
999 lines
41 KiB
Python
import torch
|
|
import numpy as np
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
import logging
|
|
import time
|
|
from datetime import datetime
|
|
import os
|
|
import sys
|
|
import pandas as pd
|
|
import gym
|
|
import json
|
|
import random
|
|
import torch.nn as nn
|
|
import contextlib
|
|
|
|
# Add parent directory to path
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from NN.utils.data_interface import DataInterface
|
|
from NN.utils.trading_env import TradingEnvironment
|
|
from NN.models.dqn_agent import DQNAgent
|
|
from NN.utils.signal_interpreter import SignalInterpreter
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger(__name__)
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('rl_training.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
|
|
# Set up device for PyTorch (use GPU if available)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# Log GPU status
|
|
if torch.cuda.is_available():
|
|
gpu_count = torch.cuda.device_count()
|
|
gpu_names = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
|
|
logger.info(f"Using GPU: {gpu_names}")
|
|
|
|
# Enable TensorFloat32 for NVIDIA Ampere GPUs for faster training
|
|
if hasattr(torch.cuda, 'amp') and torch.cuda.is_bf16_supported():
|
|
logger.info("BFloat16 precision is supported - will use for faster training")
|
|
else:
|
|
logger.warning("GPU not available. Using CPU for training (slower).")
|
|
|
|
class RLTradingEnvironment(gym.Env):
|
|
"""
|
|
Reinforcement Learning environment for trading with technical indicators
|
|
from multiple timeframes
|
|
"""
|
|
def __init__(self, features_1m, features_1h=None, features_1d=None, window_size=20, trading_fee=0.0025, min_trade_interval=15):
|
|
super().__init__()
|
|
|
|
# Initialize attributes before parent class
|
|
self.window_size = window_size
|
|
self.num_features = features_1m.shape[1] - 1 # Exclude close price
|
|
|
|
# Count available timeframes
|
|
self.num_timeframes = 1 # Always have 1m
|
|
if features_1h is not None:
|
|
self.num_timeframes += 1
|
|
if features_1d is not None:
|
|
self.num_timeframes += 1
|
|
|
|
self.feature_dim = self.num_features * self.num_timeframes
|
|
|
|
# Store features from different timeframes
|
|
self.features_1m = features_1m
|
|
self.features_1h = features_1h
|
|
self.features_1d = features_1d
|
|
|
|
# Create synthetic 1s data from 1m (for demo purposes)
|
|
self.features_1s = self._create_synthetic_1s_data(features_1m)
|
|
|
|
# If higher timeframes are missing, create synthetic data
|
|
if self.features_1h is None:
|
|
self.features_1h = self._create_synthetic_hourly_data(features_1m)
|
|
|
|
if self.features_1d is None:
|
|
self.features_1d = self._create_synthetic_daily_data(features_1h)
|
|
|
|
# Trading parameters
|
|
self.initial_balance = 1.0
|
|
self.trading_fee = trading_fee # Increased from 0.001 to 0.0025 (0.25%)
|
|
self.min_trade_interval = min_trade_interval # Minimum steps between trades
|
|
|
|
# Define action and observation spaces
|
|
self.action_space = gym.spaces.Discrete(3) # 0: Buy, 1: Sell, 2: Hold
|
|
self.observation_space = gym.spaces.Box(
|
|
low=-np.inf,
|
|
high=np.inf,
|
|
shape=(self.window_size, self.feature_dim),
|
|
dtype=np.float32
|
|
)
|
|
|
|
# State variables
|
|
self.reset()
|
|
|
|
# Callback for visualization or external monitoring
|
|
self.action_callback = None
|
|
|
|
def _create_synthetic_1s_data(self, features_1m):
|
|
"""Create synthetic 1-second data from 1-minute data"""
|
|
# Simple approach: duplicate each 1m candle for 60 seconds with some noise
|
|
num_samples = features_1m.shape[0]
|
|
synthetic_1s = np.zeros((num_samples * 60, features_1m.shape[1]))
|
|
|
|
for i in range(num_samples):
|
|
for j in range(60):
|
|
idx = i * 60 + j
|
|
if idx < synthetic_1s.shape[0]:
|
|
# Copy the 1m data with small random noise
|
|
synthetic_1s[idx] = features_1m[i] * (1 + np.random.normal(0, 0.0001, features_1m.shape[1]))
|
|
|
|
return synthetic_1s
|
|
|
|
def _create_synthetic_hourly_data(self, features_1m):
|
|
"""Create synthetic hourly data from minute data"""
|
|
# Group by hour, taking every 60th candle
|
|
num_samples = features_1m.shape[0] // 60
|
|
synthetic_1h = np.zeros((num_samples, features_1m.shape[1]))
|
|
|
|
for i in range(num_samples):
|
|
if i * 60 < features_1m.shape[0]:
|
|
synthetic_1h[i] = features_1m[i * 60]
|
|
|
|
return synthetic_1h
|
|
|
|
def _create_synthetic_daily_data(self, features_1h):
|
|
"""Create synthetic daily data from hourly data"""
|
|
# Group by day, taking every 24th candle
|
|
num_samples = features_1h.shape[0] // 24
|
|
synthetic_1d = np.zeros((num_samples, features_1h.shape[1]))
|
|
|
|
for i in range(num_samples):
|
|
if i * 24 < features_1h.shape[0]:
|
|
synthetic_1d[i] = features_1h[i * 24]
|
|
|
|
return synthetic_1d
|
|
|
|
def reset(self):
|
|
"""Reset the environment to initial state"""
|
|
self.balance = self.initial_balance
|
|
self.position = 0.0 # Amount of asset held
|
|
self.current_step = self.window_size
|
|
self.trades = 0
|
|
self.wins = 0
|
|
self.losses = 0
|
|
self.trade_history = []
|
|
self.last_trade_step = -self.min_trade_interval # Initialize to allow immediate first trade
|
|
|
|
# Get initial observation
|
|
observation = self._get_observation()
|
|
return observation
|
|
|
|
def _get_observation(self):
|
|
"""
|
|
Get the current state observation.
|
|
Combine features from multiple timeframes, reshaped for the CNN.
|
|
"""
|
|
# Calculate indices for each timeframe
|
|
idx_1m = min(self.current_step, self.features_1m.shape[0] - 1)
|
|
idx_1h = idx_1m // 60 # 60 minutes in an hour
|
|
idx_1d = idx_1h // 24 # 24 hours in a day
|
|
|
|
# Cap indices to prevent out of bounds
|
|
idx_1h = min(idx_1h, self.features_1h.shape[0] - 1)
|
|
idx_1d = min(idx_1d, self.features_1d.shape[0] - 1)
|
|
|
|
# Extract feature windows from each timeframe
|
|
window_1m = self.features_1m[max(0, idx_1m - self.window_size):idx_1m]
|
|
|
|
# Handle hourly timeframe
|
|
start_1h = max(0, idx_1h - self.window_size)
|
|
window_1h = self.features_1h[start_1h:idx_1h]
|
|
|
|
# Handle daily timeframe
|
|
start_1d = max(0, idx_1d - self.window_size)
|
|
window_1d = self.features_1d[start_1d:idx_1d]
|
|
|
|
# Pad if needed (for higher timeframes)
|
|
if len(window_1m) < self.window_size:
|
|
padding = np.zeros((self.window_size - len(window_1m), window_1m.shape[1]))
|
|
window_1m = np.vstack([padding, window_1m])
|
|
|
|
if len(window_1h) < self.window_size:
|
|
padding = np.zeros((self.window_size - len(window_1h), window_1h.shape[1]))
|
|
window_1h = np.vstack([padding, window_1h])
|
|
|
|
if len(window_1d) < self.window_size:
|
|
padding = np.zeros((self.window_size - len(window_1d), window_1d.shape[1]))
|
|
window_1d = np.vstack([padding, window_1d])
|
|
|
|
# Combine features from all timeframes
|
|
combined_features = np.hstack([
|
|
window_1m.reshape(self.window_size, -1),
|
|
window_1h.reshape(self.window_size, -1),
|
|
window_1d.reshape(self.window_size, -1)
|
|
])
|
|
|
|
# Convert to float32 and handle any NaN values
|
|
combined_features = np.nan_to_num(combined_features, nan=0.0).astype(np.float32)
|
|
|
|
return combined_features
|
|
|
|
def step(self, action):
|
|
"""
|
|
Take an action in the environment and return the next state, reward, done flag, and info
|
|
|
|
Args:
|
|
action (int): 0 = Buy, 1 = Sell, 2 = Hold
|
|
|
|
Returns:
|
|
tuple: (observation, reward, done, info)
|
|
"""
|
|
# Get current and next price
|
|
current_price = self.features_1m[self.current_step, -1] # Close price is last column
|
|
|
|
# Check if we're at the end of the data
|
|
if self.current_step + 1 >= len(self.features_1m):
|
|
next_price = current_price # Use current price if at the end
|
|
done = True
|
|
else:
|
|
next_price = self.features_1m[self.current_step + 1, -1]
|
|
done = False
|
|
|
|
# Handle zero or negative prices
|
|
if current_price <= 0:
|
|
current_price = 1e-8 # Small positive number
|
|
if next_price <= 0:
|
|
next_price = current_price # Use current price if next price is invalid
|
|
|
|
price_change = (next_price - current_price) / current_price
|
|
|
|
# Default reward is slightly negative to discourage inaction
|
|
reward = -0.0001
|
|
profit_pct = None # Initialize profit_pct variable
|
|
|
|
# Check if enough time has passed since last trade
|
|
trade_interval = self.current_step - self.last_trade_step
|
|
trade_interval_penalty = 0
|
|
|
|
# Execute action
|
|
if action == 0: # BUY
|
|
if self.position == 0: # Only buy if not already in position
|
|
# Apply extra penalty for trading too frequently
|
|
if trade_interval < self.min_trade_interval:
|
|
trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval)
|
|
# Still allow the trade but with penalty
|
|
|
|
self.position = self.balance * (1 - self.trading_fee)
|
|
self.balance = 0
|
|
self.trades += 1
|
|
reward = -0.001 + trade_interval_penalty # Small cost for transaction + potential penalty
|
|
self.trade_entry_price = current_price
|
|
self.last_trade_step = self.current_step
|
|
|
|
elif action == 1: # SELL
|
|
if self.position > 0: # Only sell if in position
|
|
# Apply extra penalty for trading too frequently
|
|
if trade_interval < self.min_trade_interval:
|
|
trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval)
|
|
# Still allow the trade but with penalty
|
|
|
|
# Calculate position value at current price
|
|
position_value = self.position * (1 + price_change)
|
|
self.balance = position_value * (1 - self.trading_fee)
|
|
|
|
# Calculate profit/loss from trade
|
|
profit_pct = (next_price - self.trade_entry_price) / self.trade_entry_price
|
|
# Scale reward by profit percentage and apply trade interval penalty
|
|
reward = (profit_pct * 10) + trade_interval_penalty
|
|
|
|
# Update win/loss count
|
|
if profit_pct > 0:
|
|
self.wins += 1
|
|
else:
|
|
self.losses += 1
|
|
|
|
# Record trade
|
|
self.trade_history.append({
|
|
'entry_price': self.trade_entry_price,
|
|
'exit_price': next_price,
|
|
'profit_pct': profit_pct,
|
|
'trade_interval': trade_interval
|
|
})
|
|
|
|
# Reset position and update last trade step
|
|
self.position = 0
|
|
self.last_trade_step = self.current_step
|
|
|
|
# else: (action == 2 - HOLD) - no position change
|
|
|
|
# Move to next step
|
|
self.current_step += 1
|
|
|
|
# Check if done (reached end of data)
|
|
if self.current_step >= len(self.features_1m) - 1:
|
|
done = True
|
|
|
|
# Apply final evaluation
|
|
if self.position > 0:
|
|
# Force close position at the end
|
|
position_value = self.position * (1 + price_change)
|
|
self.balance = position_value * (1 - self.trading_fee)
|
|
profit_pct = (next_price - self.trade_entry_price) / self.trade_entry_price
|
|
reward += profit_pct * 10
|
|
|
|
# Update win/loss count
|
|
if profit_pct > 0:
|
|
self.wins += 1
|
|
else:
|
|
self.losses += 1
|
|
|
|
# Get the next observation
|
|
observation = self._get_observation()
|
|
|
|
# Calculate metrics for info
|
|
total_value = self.balance + self.position * next_price
|
|
gain = (total_value - self.initial_balance) / self.initial_balance
|
|
self.win_rate = self.wins / max(1, self.trades)
|
|
|
|
# Check if we have prediction data for future timeframes
|
|
future_price_1h = None
|
|
future_price_1d = None
|
|
|
|
# Get hourly index
|
|
idx_1h = self.current_step // 60
|
|
if idx_1h + 1 < len(self.features_1h):
|
|
hourly_close_idx = self.features_1h.shape[1] - 1 # Assuming close is last column
|
|
current_1h_price = self.features_1h[idx_1h, hourly_close_idx]
|
|
next_1h_price = self.features_1h[idx_1h + 1, hourly_close_idx]
|
|
future_price_1h = (next_1h_price - current_1h_price) / current_1h_price
|
|
|
|
# Get daily index
|
|
idx_1d = idx_1h // 24
|
|
if idx_1d + 1 < len(self.features_1d):
|
|
daily_close_idx = self.features_1d.shape[1] - 1 # Assuming close is last column
|
|
current_1d_price = self.features_1d[idx_1d, daily_close_idx]
|
|
next_1d_price = self.features_1d[idx_1d + 1, daily_close_idx]
|
|
future_price_1d = (next_1d_price - current_1d_price) / current_1d_price
|
|
|
|
info = {
|
|
'balance': self.balance,
|
|
'position': self.position,
|
|
'total_value': total_value,
|
|
'gain': gain,
|
|
'trades': self.trades,
|
|
'win_rate': self.win_rate,
|
|
'profit_pct': profit_pct if action == 1 and self.position == 0 else None,
|
|
'current_price': current_price,
|
|
'next_price': next_price,
|
|
'future_price_1h': future_price_1h, # Actual future hourly price change
|
|
'future_price_1d': future_price_1d # Actual future daily price change
|
|
}
|
|
|
|
# Call the callback if it exists
|
|
if self.action_callback:
|
|
self.action_callback(action, current_price, reward, info)
|
|
|
|
return observation, reward, done, info
|
|
|
|
def set_action_callback(self, callback):
|
|
"""
|
|
Set a callback function to be called after each action
|
|
|
|
Args:
|
|
callback: Function with signature (action, price, reward, info)
|
|
"""
|
|
self.action_callback = callback
|
|
|
|
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent",
|
|
action_callback=None, episode_callback=None, symbol="BTC/USDT",
|
|
pretrain_price_prediction_enabled=True, pretrain_epochs=10):
|
|
"""
|
|
Train a reinforcement learning agent for trading
|
|
|
|
Args:
|
|
env_class: Optional environment class override
|
|
num_episodes: Number of episodes to train for
|
|
max_steps: Maximum steps per episode
|
|
save_path: Path to save the trained model
|
|
action_callback: Callback function for monitoring actions
|
|
episode_callback: Callback function for monitoring episodes
|
|
symbol: Trading symbol to use
|
|
pretrain_price_prediction_enabled: Whether to pre-train price prediction
|
|
pretrain_epochs: Number of epochs for pre-training
|
|
|
|
Returns:
|
|
tuple: (trained agent, environment)
|
|
"""
|
|
# Load data for the selected symbol
|
|
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m'])
|
|
|
|
try:
|
|
# Try to load data for the requested symbol using get_historical_data method
|
|
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=5000)
|
|
data_5m = data_interface.get_historical_data(timeframe='5m', n_candles=5000)
|
|
data_15m = data_interface.get_historical_data(timeframe='15m', n_candles=5000)
|
|
|
|
if data_1m is None or data_5m is None or data_15m is None:
|
|
raise FileNotFoundError("Could not retrieve data for specified symbol")
|
|
except Exception as e:
|
|
logger.warning(f"Data for {symbol} not available: {str(e)}. Using default data.")
|
|
# Try to use cached data if available
|
|
symbol = "BTC/USDT"
|
|
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m'])
|
|
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=5000)
|
|
data_5m = data_interface.get_historical_data(timeframe='5m', n_candles=5000)
|
|
data_15m = data_interface.get_historical_data(timeframe='15m', n_candles=5000)
|
|
|
|
if data_1m is None or data_5m is None or data_15m is None:
|
|
logger.error("Failed to retrieve any data. Cannot continue training.")
|
|
raise ValueError("No data available for training")
|
|
|
|
# Create features from the data by adding technical indicators and converting to numpy format
|
|
if data_1m is not None:
|
|
data_1m = data_interface.add_technical_indicators(data_1m)
|
|
# Convert to numpy array with close price as the last column
|
|
features_1m = np.hstack([
|
|
data_1m.drop(['timestamp', 'close'], axis=1).values,
|
|
data_1m['close'].values.reshape(-1, 1)
|
|
])
|
|
else:
|
|
features_1m = None
|
|
|
|
if data_5m is not None:
|
|
data_5m = data_interface.add_technical_indicators(data_5m)
|
|
# Convert to numpy array with close price as the last column
|
|
features_5m = np.hstack([
|
|
data_5m.drop(['timestamp', 'close'], axis=1).values,
|
|
data_5m['close'].values.reshape(-1, 1)
|
|
])
|
|
else:
|
|
features_5m = None
|
|
|
|
if data_15m is not None:
|
|
data_15m = data_interface.add_technical_indicators(data_15m)
|
|
# Convert to numpy array with close price as the last column
|
|
features_15m = np.hstack([
|
|
data_15m.drop(['timestamp', 'close'], axis=1).values,
|
|
data_15m['close'].values.reshape(-1, 1)
|
|
])
|
|
else:
|
|
features_15m = None
|
|
|
|
# Check if we have all the required features
|
|
if features_1m is None or features_5m is None or features_15m is None:
|
|
logger.error("Failed to create features for all timeframes.")
|
|
raise ValueError("Could not create features for training")
|
|
|
|
# Create the environment
|
|
if env_class:
|
|
# Use provided environment class
|
|
env = env_class(features_1m, features_5m, features_15m)
|
|
else:
|
|
# Use the default environment
|
|
env = RLTradingEnvironment(features_1m, features_5m, features_15m)
|
|
|
|
# Set action callback if provided
|
|
if action_callback:
|
|
env.set_action_callback(action_callback)
|
|
|
|
# Get environment properties for agent creation
|
|
input_shape = env.observation_space.shape
|
|
n_actions = env.action_space.n
|
|
|
|
# Create the agent
|
|
agent = DQNAgent(
|
|
state_shape=input_shape,
|
|
n_actions=n_actions,
|
|
epsilon=1.0,
|
|
epsilon_decay=0.995,
|
|
epsilon_min=0.01,
|
|
learning_rate=0.0001,
|
|
gamma=0.99,
|
|
buffer_size=10000,
|
|
batch_size=64,
|
|
device=device # Pass device to agent for GPU usage
|
|
)
|
|
|
|
# Check if model file exists and load it
|
|
model_file = f"{save_path}_model.pth"
|
|
if os.path.exists(model_file):
|
|
try:
|
|
agent.load(model_file)
|
|
logger.info(f"Loaded existing model from {model_file}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading model: {e}")
|
|
else:
|
|
logger.info("No existing model found. Starting with a new model.")
|
|
|
|
# Pre-train price prediction if enabled and we have a new model
|
|
if pretrain_price_prediction_enabled:
|
|
if not os.path.exists(model_file) or input("Pre-train price prediction? (y/n): ").lower() == 'y':
|
|
logger.info("Pre-training price prediction capability...")
|
|
# Attempt to load hourly and daily data for pre-training
|
|
try:
|
|
data_interface.add_timeframe('1h')
|
|
data_interface.add_timeframe('1d')
|
|
|
|
# Run pre-training
|
|
agent = pretrain_price_prediction(
|
|
agent=agent,
|
|
data_interface=data_interface,
|
|
n_epochs=pretrain_epochs,
|
|
batch_size=128
|
|
)
|
|
|
|
# Save the pre-trained model
|
|
agent.save(f"{save_path}_pretrained")
|
|
logger.info("Pre-trained model saved.")
|
|
except Exception as e:
|
|
logger.error(f"Error during pre-training: {e}")
|
|
logger.warning("Continuing with RL training without pre-training.")
|
|
|
|
# Create TensorBoard writer
|
|
writer = SummaryWriter(log_dir=f'runs/dqn_{int(time.time())}')
|
|
|
|
# Log GPU status to TensorBoard
|
|
writer.add_text("hardware/device", str(device), 0)
|
|
if torch.cuda.is_available():
|
|
for i in range(torch.cuda.device_count()):
|
|
writer.add_text(f"hardware/gpu_{i}", torch.cuda.get_device_name(i), 0)
|
|
|
|
# Training loop
|
|
total_rewards = []
|
|
trade_win_rates = []
|
|
best_reward = -np.inf
|
|
|
|
# Move models to the appropriate device if not already there
|
|
agent.move_models_to_device(device)
|
|
|
|
# Enable mixed precision if GPU and feature is available
|
|
use_mixed_precision = False
|
|
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp'):
|
|
logger.info("Enabling mixed precision training")
|
|
use_mixed_precision = True
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
|
# Define step callback for tensorboard logging and model tracking
|
|
def step_callback(action, price, reward, info):
|
|
# Pass to external callback if provided
|
|
if action_callback:
|
|
action_callback(env.current_step, action, price, reward, info)
|
|
|
|
# Main training loop
|
|
logger.info(f"Starting training for {num_episodes} episodes...")
|
|
logger.info(f"Starting training on device: {agent.device}")
|
|
|
|
try:
|
|
for episode in range(num_episodes):
|
|
state = env.reset()
|
|
total_reward = 0
|
|
|
|
for step in range(max_steps):
|
|
# Select action
|
|
action = agent.act(state)
|
|
|
|
# Take action and observe next state and reward
|
|
next_state, reward, done, info = env.step(action)
|
|
|
|
# Store the experience in memory
|
|
agent.remember(state, action, reward, next_state, done)
|
|
|
|
# Update state and reward
|
|
state = next_state
|
|
total_reward += reward
|
|
|
|
# Train the agent by sampling from memory
|
|
if len(agent.memory) >= agent.batch_size:
|
|
loss = agent.replay()
|
|
|
|
if done or step == max_steps - 1:
|
|
break
|
|
|
|
# Track rewards
|
|
total_rewards.append(total_reward)
|
|
|
|
# Calculate trading metrics
|
|
win_rate = env.win_rate if hasattr(env, 'win_rate') else 0
|
|
trades = env.trades if hasattr(env, 'trades') else 0
|
|
|
|
# Log to TensorBoard
|
|
writer.add_scalar('Reward/Episode', total_reward, episode)
|
|
writer.add_scalar('Trade/WinRate', win_rate, episode)
|
|
writer.add_scalar('Trade/Count', trades, episode)
|
|
|
|
# Save best model
|
|
if total_reward > best_reward and episode > 10:
|
|
logger.info(f"New best average reward: {total_reward:.4f}, saving model")
|
|
agent.save(save_path)
|
|
best_reward = total_reward
|
|
|
|
# Periodic save every 100 episodes
|
|
if episode % 100 == 0 and episode > 0:
|
|
agent.save(f"{save_path}_episode_{episode}")
|
|
|
|
# Call episode callback if provided
|
|
if episode_callback:
|
|
# Add environment to info dict to use for extrema training
|
|
info_with_env = info.copy()
|
|
info_with_env['env'] = env
|
|
episode_callback(episode, total_reward, info_with_env)
|
|
|
|
# Final save
|
|
logger.info("Training completed, saving final model")
|
|
agent.save(f"{save_path}_final")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Training failed: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
|
|
# Close TensorBoard writer
|
|
writer.close()
|
|
|
|
return agent, env
|
|
|
|
def generate_price_prediction_training_data(data_1m, data_1h, data_1d, window_size=20):
|
|
"""
|
|
Generate labeled training data for price prediction at different timeframes
|
|
|
|
Args:
|
|
data_1m: DataFrame with 1-minute data
|
|
data_1h: DataFrame with 1-hour data
|
|
data_1d: DataFrame with 1-day data
|
|
window_size: Size of the input window
|
|
|
|
Returns:
|
|
tuple: (X, y_immediate, y_midterm, y_longterm, y_values)
|
|
- X: input features (window sequences)
|
|
- y_immediate: immediate direction labels (0=down, 1=sideways, 2=up)
|
|
- y_midterm: mid-term direction labels
|
|
- y_longterm: long-term direction labels
|
|
- y_values: actual percentage changes for each timeframe
|
|
"""
|
|
logger.info("Generating price prediction training data from historical prices")
|
|
|
|
# Prepare data structures
|
|
X = []
|
|
y_immediate = [] # 1m
|
|
y_midterm = [] # 1h
|
|
y_longterm = [] # 1d
|
|
y_values = [] # Actual percentage changes
|
|
|
|
# Calculate future returns for labeling
|
|
data_1m['future_return_1m'] = data_1m['close'].pct_change(1).shift(-1) # Next candle
|
|
data_1m['future_return_10m'] = data_1m['close'].pct_change(10).shift(-10) # Next 10 candles
|
|
|
|
# Add indices to align data
|
|
data_1m['index'] = range(len(data_1m))
|
|
data_1h['index'] = range(len(data_1h))
|
|
data_1d['index'] = range(len(data_1d))
|
|
|
|
# Define thresholds for direction labels
|
|
immediate_threshold = 0.0005
|
|
midterm_threshold = 0.001
|
|
longterm_threshold = 0.002
|
|
|
|
# Loop through 1m data to create training samples
|
|
max_idx = len(data_1m) - window_size - 10 # Ensure we have future data for labels
|
|
sample_indices = random.sample(range(window_size, max_idx), min(10000, max_idx - window_size))
|
|
|
|
for idx in sample_indices:
|
|
# Get window of 1m data
|
|
window_1m = data_1m.iloc[idx-window_size:idx].drop(['timestamp', 'future_return_1m', 'future_return_10m', 'index'], axis=1, errors='ignore')
|
|
|
|
# Skip if window contains NaN
|
|
if window_1m.isnull().values.any():
|
|
continue
|
|
|
|
# Get future returns for labeling
|
|
future_return_1m = data_1m.iloc[idx]['future_return_1m']
|
|
future_return_10m = data_1m.iloc[idx]['future_return_10m']
|
|
|
|
# Find corresponding row in 1h data (closest timestamp)
|
|
current_timestamp = data_1m.iloc[idx]['timestamp']
|
|
|
|
# Find 1h candle for mid-term prediction
|
|
if 'timestamp' in data_1h.columns:
|
|
# Find closest 1h candle
|
|
closest_1h_idx = data_1h['timestamp'].searchsorted(current_timestamp)
|
|
if closest_1h_idx >= len(data_1h):
|
|
closest_1h_idx = len(data_1h) - 1
|
|
|
|
# Get future 1h return (next candle)
|
|
if closest_1h_idx < len(data_1h) - 1:
|
|
future_return_1h = (data_1h.iloc[closest_1h_idx + 1]['close'] - data_1h.iloc[closest_1h_idx]['close']) / data_1h.iloc[closest_1h_idx]['close']
|
|
else:
|
|
future_return_1h = 0
|
|
else:
|
|
future_return_1h = future_return_10m # Fallback
|
|
|
|
# Find 1d candle for long-term prediction
|
|
if 'timestamp' in data_1d.columns:
|
|
# Find closest 1d candle
|
|
closest_1d_idx = data_1d['timestamp'].searchsorted(current_timestamp)
|
|
if closest_1d_idx >= len(data_1d):
|
|
closest_1d_idx = len(data_1d) - 1
|
|
|
|
# Get future 1d return (next candle)
|
|
if closest_1d_idx < len(data_1d) - 1:
|
|
future_return_1d = (data_1d.iloc[closest_1d_idx + 1]['close'] - data_1d.iloc[closest_1d_idx]['close']) / data_1d.iloc[closest_1d_idx]['close']
|
|
else:
|
|
future_return_1d = 0
|
|
else:
|
|
future_return_1d = future_return_1h * 2 # Fallback
|
|
|
|
# Create direction labels
|
|
# 0=down, 1=sideways, 2=up
|
|
|
|
# Immediate (1m)
|
|
if future_return_1m > immediate_threshold:
|
|
immediate_label = 2 # UP
|
|
elif future_return_1m < -immediate_threshold:
|
|
immediate_label = 0 # DOWN
|
|
else:
|
|
immediate_label = 1 # SIDEWAYS
|
|
|
|
# Mid-term (1h)
|
|
if future_return_1h > midterm_threshold:
|
|
midterm_label = 2 # UP
|
|
elif future_return_1h < -midterm_threshold:
|
|
midterm_label = 0 # DOWN
|
|
else:
|
|
midterm_label = 1 # SIDEWAYS
|
|
|
|
# Long-term (1d)
|
|
if future_return_1d > longterm_threshold:
|
|
longterm_label = 2 # UP
|
|
elif future_return_1d < -longterm_threshold:
|
|
longterm_label = 0 # DOWN
|
|
else:
|
|
longterm_label = 1 # SIDEWAYS
|
|
|
|
# Store data
|
|
X.append(window_1m.values)
|
|
y_immediate.append(immediate_label)
|
|
y_midterm.append(midterm_label)
|
|
y_longterm.append(longterm_label)
|
|
y_values.append([future_return_1m, future_return_1h, future_return_1d, future_return_1d * 1.5]) # Add weekly estimate
|
|
|
|
# Convert to numpy arrays
|
|
X = np.array(X)
|
|
y_immediate = np.array(y_immediate)
|
|
y_midterm = np.array(y_midterm)
|
|
y_longterm = np.array(y_longterm)
|
|
y_values = np.array(y_values)
|
|
|
|
logger.info(f"Generated {len(X)} price prediction training samples")
|
|
|
|
# Log class distribution
|
|
for name, y in [("Immediate", y_immediate), ("Mid-term", y_midterm), ("Long-term", y_longterm)]:
|
|
down = (y == 0).sum()
|
|
sideways = (y == 1).sum()
|
|
up = (y == 2).sum()
|
|
logger.info(f"{name} direction distribution: DOWN={down} ({down/len(y)*100:.1f}%), "
|
|
f"SIDEWAYS={sideways} ({sideways/len(y)*100:.1f}%), "
|
|
f"UP={up} ({up/len(y)*100:.1f}%)")
|
|
|
|
return X, y_immediate, y_midterm, y_longterm, y_values
|
|
|
|
def pretrain_price_prediction(agent, data_interface, n_epochs=10, batch_size=128):
|
|
"""
|
|
Pre-train the agent's price prediction capability on historical data
|
|
|
|
Args:
|
|
agent: DQNAgent instance to train
|
|
data_interface: DataInterface instance for accessing data
|
|
n_epochs: Number of epochs for training
|
|
batch_size: Batch size for training
|
|
|
|
Returns:
|
|
The agent with pre-trained price prediction capabilities
|
|
"""
|
|
logger.info("Starting supervised pre-training of price prediction")
|
|
|
|
try:
|
|
# Load data for all required timeframes
|
|
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=10000)
|
|
data_1h = data_interface.get_historical_data(timeframe='1h', n_candles=1000)
|
|
data_1d = data_interface.get_historical_data(timeframe='1d', n_candles=500)
|
|
|
|
# Check if data is available
|
|
if data_1m is None:
|
|
logger.warning("1m data not available for pre-training")
|
|
return agent
|
|
|
|
if data_1h is None:
|
|
logger.warning("1h data not available, using synthesized data")
|
|
# Create synthetic 1h data from 1m data
|
|
data_1h = data_1m.iloc[::60].reset_index(drop=True).copy() # Take every 60th record
|
|
|
|
if data_1d is None:
|
|
logger.warning("1d data not available, using synthesized data")
|
|
# Create synthetic 1d data from 1h data
|
|
data_1d = data_1h.iloc[::24].reset_index(drop=True).copy() # Take every 24th record
|
|
|
|
# Add technical indicators to all data
|
|
data_1m = data_interface.add_technical_indicators(data_1m)
|
|
data_1h = data_interface.add_technical_indicators(data_1h)
|
|
data_1d = data_interface.add_technical_indicators(data_1d)
|
|
|
|
# Generate labeled training data
|
|
X, y_immediate, y_midterm, y_longterm, y_values = generate_price_prediction_training_data(
|
|
data_1m, data_1h, data_1d, window_size=20
|
|
)
|
|
|
|
# Split data into training and validation sets
|
|
from sklearn.model_selection import train_test_split
|
|
X_train, X_val, y_imm_train, y_imm_val, y_mid_train, y_mid_val, y_long_train, y_long_val, y_val_train, y_val_val = train_test_split(
|
|
X, y_immediate, y_midterm, y_longterm, y_values, test_size=0.2, random_state=42
|
|
)
|
|
|
|
# Convert to torch tensors
|
|
X_train_tensor = torch.FloatTensor(X_train).to(agent.device)
|
|
y_imm_train_tensor = torch.LongTensor(y_imm_train).to(agent.device)
|
|
y_mid_train_tensor = torch.LongTensor(y_mid_train).to(agent.device)
|
|
y_long_train_tensor = torch.LongTensor(y_long_train).to(agent.device)
|
|
y_val_train_tensor = torch.FloatTensor(y_val_train).to(agent.device)
|
|
|
|
X_val_tensor = torch.FloatTensor(X_val).to(agent.device)
|
|
y_imm_val_tensor = torch.LongTensor(y_imm_val).to(agent.device)
|
|
y_mid_val_tensor = torch.LongTensor(y_mid_val).to(agent.device)
|
|
y_long_val_tensor = torch.LongTensor(y_long_val).to(agent.device)
|
|
y_val_val_tensor = torch.FloatTensor(y_val_val).to(agent.device)
|
|
|
|
# Calculate class weights for imbalanced data
|
|
from torch.nn.functional import one_hot
|
|
|
|
# Function to calculate class weights
|
|
def get_class_weights(labels):
|
|
counts = np.bincount(labels)
|
|
if len(counts) < 3: # Ensure we have 3 classes
|
|
counts = np.append(counts, [0] * (3 - len(counts)))
|
|
weights = 1.0 / np.array(counts)
|
|
weights = weights / np.sum(weights) # Normalize
|
|
return weights
|
|
|
|
imm_weights = torch.FloatTensor(get_class_weights(y_imm_train)).to(agent.device)
|
|
mid_weights = torch.FloatTensor(get_class_weights(y_mid_train)).to(agent.device)
|
|
long_weights = torch.FloatTensor(get_class_weights(y_long_train)).to(agent.device)
|
|
|
|
# Create DataLoader for batch training
|
|
from torch.utils.data import TensorDataset, DataLoader
|
|
|
|
train_dataset = TensorDataset(
|
|
X_train_tensor, y_imm_train_tensor, y_mid_train_tensor,
|
|
y_long_train_tensor, y_val_train_tensor
|
|
)
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
# Set up loss functions with class weights
|
|
imm_criterion = nn.CrossEntropyLoss(weight=imm_weights)
|
|
mid_criterion = nn.CrossEntropyLoss(weight=mid_weights)
|
|
long_criterion = nn.CrossEntropyLoss(weight=long_weights)
|
|
value_criterion = nn.MSELoss()
|
|
|
|
# Set up optimizer (separate from agent's optimizer)
|
|
pretrain_optimizer = torch.optim.Adam(agent.policy_net.parameters(), lr=0.0002)
|
|
pretrain_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
pretrain_optimizer, mode='min', factor=0.5, patience=3, verbose=True
|
|
)
|
|
|
|
# Set model to training mode
|
|
agent.policy_net.train()
|
|
|
|
# Training loop
|
|
best_val_loss = float('inf')
|
|
patience = 5
|
|
patience_counter = 0
|
|
|
|
for epoch in range(n_epochs):
|
|
# Training phase
|
|
train_loss = 0.0
|
|
imm_correct, mid_correct, long_correct = 0, 0, 0
|
|
total = 0
|
|
|
|
for X_batch, y_imm_batch, y_mid_batch, y_long_batch, y_val_batch in train_loader:
|
|
# Zero gradients
|
|
pretrain_optimizer.zero_grad()
|
|
|
|
# Forward pass - we only need the price predictions
|
|
with torch.cuda.amp.autocast() if agent.use_mixed_precision else contextlib.nullcontext():
|
|
_, _, price_preds = agent.policy_net(X_batch)
|
|
|
|
# Calculate losses for each prediction head
|
|
imm_loss = imm_criterion(price_preds['immediate'], y_imm_batch)
|
|
mid_loss = mid_criterion(price_preds['midterm'], y_mid_batch)
|
|
long_loss = long_criterion(price_preds['longterm'], y_long_batch)
|
|
value_loss = value_criterion(price_preds['values'], y_val_batch)
|
|
|
|
# Combined loss (weighted by importance)
|
|
total_loss = imm_loss + 0.7 * mid_loss + 0.5 * long_loss + 0.3 * value_loss
|
|
|
|
# Backward pass and optimize
|
|
if agent.use_mixed_precision:
|
|
agent.scaler.scale(total_loss).backward()
|
|
agent.scaler.unscale_(pretrain_optimizer)
|
|
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
|
|
agent.scaler.step(pretrain_optimizer)
|
|
agent.scaler.update()
|
|
else:
|
|
total_loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
|
|
pretrain_optimizer.step()
|
|
|
|
# Accumulate metrics
|
|
train_loss += total_loss.item()
|
|
total += X_batch.size(0)
|
|
|
|
# Calculate accuracy
|
|
_, imm_pred = torch.max(price_preds['immediate'], 1)
|
|
_, mid_pred = torch.max(price_preds['midterm'], 1)
|
|
_, long_pred = torch.max(price_preds['longterm'], 1)
|
|
|
|
imm_correct += (imm_pred == y_imm_batch).sum().item()
|
|
mid_correct += (mid_pred == y_mid_batch).sum().item()
|
|
long_correct += (long_pred == y_long_batch).sum().item()
|
|
|
|
# Calculate epoch metrics
|
|
train_loss /= len(train_loader)
|
|
imm_acc = imm_correct / total
|
|
mid_acc = mid_correct / total
|
|
long_acc = long_correct / total
|
|
|
|
# Validation phase
|
|
agent.policy_net.eval()
|
|
val_loss = 0.0
|
|
imm_val_correct, mid_val_correct, long_val_correct = 0, 0, 0
|
|
|
|
with torch.no_grad():
|
|
# Forward pass on validation data
|
|
_, _, val_price_preds = agent.policy_net(X_val_tensor)
|
|
|
|
# Calculate validation losses
|
|
val_imm_loss = imm_criterion(val_price_preds['immediate'], y_imm_val_tensor)
|
|
val_mid_loss = mid_criterion(val_price_preds['midterm'], y_mid_val_tensor)
|
|
val_long_loss = long_criterion(val_price_preds['longterm'], y_long_val_tensor)
|
|
val_value_loss = value_criterion(val_price_preds['values'], y_val_val_tensor)
|
|
|
|
val_total_loss = val_imm_loss + 0.7 * val_mid_loss + 0.5 * val_long_loss + 0.3 * val_value_loss
|
|
val_loss = val_total_loss.item()
|
|
|
|
# Calculate validation accuracy
|
|
_, imm_val_pred = torch.max(val_price_preds['immediate'], 1)
|
|
_, mid_val_pred = torch.max(val_price_preds['midterm'], 1)
|
|
_, long_val_pred = torch.max(val_price_preds['longterm'], 1)
|
|
|
|
imm_val_correct = (imm_val_pred == y_imm_val_tensor).sum().item()
|
|
mid_val_correct = (mid_val_pred == y_mid_val_tensor).sum().item()
|
|
long_val_correct = (long_val_pred == y_long_val_tensor).sum().item()
|
|
|
|
imm_val_acc = imm_val_correct / len(X_val_tensor)
|
|
mid_val_acc = mid_val_correct / len(X_val_tensor)
|
|
long_val_acc = long_val_correct / len(X_val_tensor)
|
|
|
|
# Learning rate scheduling
|
|
pretrain_scheduler.step(val_loss)
|
|
|
|
# Early stopping check
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
patience_counter = 0
|
|
# Copy policy_net weights to target_net
|
|
agent.target_net.load_state_dict(agent.policy_net.state_dict())
|
|
logger.info(f"Saved best model with validation loss: {val_loss:.4f}")
|
|
else:
|
|
patience_counter += 1
|
|
if patience_counter >= patience:
|
|
logger.info(f"Early stopping triggered after {epoch+1} epochs")
|
|
break
|
|
|
|
# Log progress
|
|
logger.info(f"Epoch {epoch+1}/{n_epochs}: "
|
|
f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
|
|
f"Imm Acc: {imm_acc:.4f}/{imm_val_acc:.4f}, "
|
|
f"Mid Acc: {mid_acc:.4f}/{mid_val_acc:.4f}, "
|
|
f"Long Acc: {long_acc:.4f}/{long_val_acc:.4f}")
|
|
|
|
# Set model back to training mode for next epoch
|
|
agent.policy_net.train()
|
|
|
|
logger.info("Price prediction pre-training complete")
|
|
return agent
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during price prediction pre-training: {str(e)}")
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
return agent
|
|
|
|
if __name__ == "__main__":
|
|
train_rl() |