gogo2/NN/train_rl.py
2025-04-02 14:20:39 +03:00

999 lines
41 KiB
Python

import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import logging
import time
from datetime import datetime
import os
import sys
import pandas as pd
import gym
import json
import random
import torch.nn as nn
import contextlib
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from NN.utils.data_interface import DataInterface
from NN.utils.trading_env import TradingEnvironment
from NN.models.dqn_agent import DQNAgent
from NN.utils.signal_interpreter import SignalInterpreter
# Configure logging
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('rl_training.log'),
logging.StreamHandler()
]
)
# Set up device for PyTorch (use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Log GPU status
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
gpu_names = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
logger.info(f"Using GPU: {gpu_names}")
# Enable TensorFloat32 for NVIDIA Ampere GPUs for faster training
if hasattr(torch.cuda, 'amp') and torch.cuda.is_bf16_supported():
logger.info("BFloat16 precision is supported - will use for faster training")
else:
logger.warning("GPU not available. Using CPU for training (slower).")
class RLTradingEnvironment(gym.Env):
"""
Reinforcement Learning environment for trading with technical indicators
from multiple timeframes
"""
def __init__(self, features_1m, features_1h=None, features_1d=None, window_size=20, trading_fee=0.0025, min_trade_interval=15):
super().__init__()
# Initialize attributes before parent class
self.window_size = window_size
self.num_features = features_1m.shape[1] - 1 # Exclude close price
# Count available timeframes
self.num_timeframes = 1 # Always have 1m
if features_1h is not None:
self.num_timeframes += 1
if features_1d is not None:
self.num_timeframes += 1
self.feature_dim = self.num_features * self.num_timeframes
# Store features from different timeframes
self.features_1m = features_1m
self.features_1h = features_1h
self.features_1d = features_1d
# Create synthetic 1s data from 1m (for demo purposes)
self.features_1s = self._create_synthetic_1s_data(features_1m)
# If higher timeframes are missing, create synthetic data
if self.features_1h is None:
self.features_1h = self._create_synthetic_hourly_data(features_1m)
if self.features_1d is None:
self.features_1d = self._create_synthetic_daily_data(features_1h)
# Trading parameters
self.initial_balance = 1.0
self.trading_fee = trading_fee # Increased from 0.001 to 0.0025 (0.25%)
self.min_trade_interval = min_trade_interval # Minimum steps between trades
# Define action and observation spaces
self.action_space = gym.spaces.Discrete(3) # 0: Buy, 1: Sell, 2: Hold
self.observation_space = gym.spaces.Box(
low=-np.inf,
high=np.inf,
shape=(self.window_size, self.feature_dim),
dtype=np.float32
)
# State variables
self.reset()
# Callback for visualization or external monitoring
self.action_callback = None
def _create_synthetic_1s_data(self, features_1m):
"""Create synthetic 1-second data from 1-minute data"""
# Simple approach: duplicate each 1m candle for 60 seconds with some noise
num_samples = features_1m.shape[0]
synthetic_1s = np.zeros((num_samples * 60, features_1m.shape[1]))
for i in range(num_samples):
for j in range(60):
idx = i * 60 + j
if idx < synthetic_1s.shape[0]:
# Copy the 1m data with small random noise
synthetic_1s[idx] = features_1m[i] * (1 + np.random.normal(0, 0.0001, features_1m.shape[1]))
return synthetic_1s
def _create_synthetic_hourly_data(self, features_1m):
"""Create synthetic hourly data from minute data"""
# Group by hour, taking every 60th candle
num_samples = features_1m.shape[0] // 60
synthetic_1h = np.zeros((num_samples, features_1m.shape[1]))
for i in range(num_samples):
if i * 60 < features_1m.shape[0]:
synthetic_1h[i] = features_1m[i * 60]
return synthetic_1h
def _create_synthetic_daily_data(self, features_1h):
"""Create synthetic daily data from hourly data"""
# Group by day, taking every 24th candle
num_samples = features_1h.shape[0] // 24
synthetic_1d = np.zeros((num_samples, features_1h.shape[1]))
for i in range(num_samples):
if i * 24 < features_1h.shape[0]:
synthetic_1d[i] = features_1h[i * 24]
return synthetic_1d
def reset(self):
"""Reset the environment to initial state"""
self.balance = self.initial_balance
self.position = 0.0 # Amount of asset held
self.current_step = self.window_size
self.trades = 0
self.wins = 0
self.losses = 0
self.trade_history = []
self.last_trade_step = -self.min_trade_interval # Initialize to allow immediate first trade
# Get initial observation
observation = self._get_observation()
return observation
def _get_observation(self):
"""
Get the current state observation.
Combine features from multiple timeframes, reshaped for the CNN.
"""
# Calculate indices for each timeframe
idx_1m = min(self.current_step, self.features_1m.shape[0] - 1)
idx_1h = idx_1m // 60 # 60 minutes in an hour
idx_1d = idx_1h // 24 # 24 hours in a day
# Cap indices to prevent out of bounds
idx_1h = min(idx_1h, self.features_1h.shape[0] - 1)
idx_1d = min(idx_1d, self.features_1d.shape[0] - 1)
# Extract feature windows from each timeframe
window_1m = self.features_1m[max(0, idx_1m - self.window_size):idx_1m]
# Handle hourly timeframe
start_1h = max(0, idx_1h - self.window_size)
window_1h = self.features_1h[start_1h:idx_1h]
# Handle daily timeframe
start_1d = max(0, idx_1d - self.window_size)
window_1d = self.features_1d[start_1d:idx_1d]
# Pad if needed (for higher timeframes)
if len(window_1m) < self.window_size:
padding = np.zeros((self.window_size - len(window_1m), window_1m.shape[1]))
window_1m = np.vstack([padding, window_1m])
if len(window_1h) < self.window_size:
padding = np.zeros((self.window_size - len(window_1h), window_1h.shape[1]))
window_1h = np.vstack([padding, window_1h])
if len(window_1d) < self.window_size:
padding = np.zeros((self.window_size - len(window_1d), window_1d.shape[1]))
window_1d = np.vstack([padding, window_1d])
# Combine features from all timeframes
combined_features = np.hstack([
window_1m.reshape(self.window_size, -1),
window_1h.reshape(self.window_size, -1),
window_1d.reshape(self.window_size, -1)
])
# Convert to float32 and handle any NaN values
combined_features = np.nan_to_num(combined_features, nan=0.0).astype(np.float32)
return combined_features
def step(self, action):
"""
Take an action in the environment and return the next state, reward, done flag, and info
Args:
action (int): 0 = Buy, 1 = Sell, 2 = Hold
Returns:
tuple: (observation, reward, done, info)
"""
# Get current and next price
current_price = self.features_1m[self.current_step, -1] # Close price is last column
# Check if we're at the end of the data
if self.current_step + 1 >= len(self.features_1m):
next_price = current_price # Use current price if at the end
done = True
else:
next_price = self.features_1m[self.current_step + 1, -1]
done = False
# Handle zero or negative prices
if current_price <= 0:
current_price = 1e-8 # Small positive number
if next_price <= 0:
next_price = current_price # Use current price if next price is invalid
price_change = (next_price - current_price) / current_price
# Default reward is slightly negative to discourage inaction
reward = -0.0001
profit_pct = None # Initialize profit_pct variable
# Check if enough time has passed since last trade
trade_interval = self.current_step - self.last_trade_step
trade_interval_penalty = 0
# Execute action
if action == 0: # BUY
if self.position == 0: # Only buy if not already in position
# Apply extra penalty for trading too frequently
if trade_interval < self.min_trade_interval:
trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval)
# Still allow the trade but with penalty
self.position = self.balance * (1 - self.trading_fee)
self.balance = 0
self.trades += 1
reward = -0.001 + trade_interval_penalty # Small cost for transaction + potential penalty
self.trade_entry_price = current_price
self.last_trade_step = self.current_step
elif action == 1: # SELL
if self.position > 0: # Only sell if in position
# Apply extra penalty for trading too frequently
if trade_interval < self.min_trade_interval:
trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval)
# Still allow the trade but with penalty
# Calculate position value at current price
position_value = self.position * (1 + price_change)
self.balance = position_value * (1 - self.trading_fee)
# Calculate profit/loss from trade
profit_pct = (next_price - self.trade_entry_price) / self.trade_entry_price
# Scale reward by profit percentage and apply trade interval penalty
reward = (profit_pct * 10) + trade_interval_penalty
# Update win/loss count
if profit_pct > 0:
self.wins += 1
else:
self.losses += 1
# Record trade
self.trade_history.append({
'entry_price': self.trade_entry_price,
'exit_price': next_price,
'profit_pct': profit_pct,
'trade_interval': trade_interval
})
# Reset position and update last trade step
self.position = 0
self.last_trade_step = self.current_step
# else: (action == 2 - HOLD) - no position change
# Move to next step
self.current_step += 1
# Check if done (reached end of data)
if self.current_step >= len(self.features_1m) - 1:
done = True
# Apply final evaluation
if self.position > 0:
# Force close position at the end
position_value = self.position * (1 + price_change)
self.balance = position_value * (1 - self.trading_fee)
profit_pct = (next_price - self.trade_entry_price) / self.trade_entry_price
reward += profit_pct * 10
# Update win/loss count
if profit_pct > 0:
self.wins += 1
else:
self.losses += 1
# Get the next observation
observation = self._get_observation()
# Calculate metrics for info
total_value = self.balance + self.position * next_price
gain = (total_value - self.initial_balance) / self.initial_balance
self.win_rate = self.wins / max(1, self.trades)
# Check if we have prediction data for future timeframes
future_price_1h = None
future_price_1d = None
# Get hourly index
idx_1h = self.current_step // 60
if idx_1h + 1 < len(self.features_1h):
hourly_close_idx = self.features_1h.shape[1] - 1 # Assuming close is last column
current_1h_price = self.features_1h[idx_1h, hourly_close_idx]
next_1h_price = self.features_1h[idx_1h + 1, hourly_close_idx]
future_price_1h = (next_1h_price - current_1h_price) / current_1h_price
# Get daily index
idx_1d = idx_1h // 24
if idx_1d + 1 < len(self.features_1d):
daily_close_idx = self.features_1d.shape[1] - 1 # Assuming close is last column
current_1d_price = self.features_1d[idx_1d, daily_close_idx]
next_1d_price = self.features_1d[idx_1d + 1, daily_close_idx]
future_price_1d = (next_1d_price - current_1d_price) / current_1d_price
info = {
'balance': self.balance,
'position': self.position,
'total_value': total_value,
'gain': gain,
'trades': self.trades,
'win_rate': self.win_rate,
'profit_pct': profit_pct if action == 1 and self.position == 0 else None,
'current_price': current_price,
'next_price': next_price,
'future_price_1h': future_price_1h, # Actual future hourly price change
'future_price_1d': future_price_1d # Actual future daily price change
}
# Call the callback if it exists
if self.action_callback:
self.action_callback(action, current_price, reward, info)
return observation, reward, done, info
def set_action_callback(self, callback):
"""
Set a callback function to be called after each action
Args:
callback: Function with signature (action, price, reward, info)
"""
self.action_callback = callback
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent",
action_callback=None, episode_callback=None, symbol="BTC/USDT",
pretrain_price_prediction_enabled=True, pretrain_epochs=10):
"""
Train a reinforcement learning agent for trading
Args:
env_class: Optional environment class override
num_episodes: Number of episodes to train for
max_steps: Maximum steps per episode
save_path: Path to save the trained model
action_callback: Callback function for monitoring actions
episode_callback: Callback function for monitoring episodes
symbol: Trading symbol to use
pretrain_price_prediction_enabled: Whether to pre-train price prediction
pretrain_epochs: Number of epochs for pre-training
Returns:
tuple: (trained agent, environment)
"""
# Load data for the selected symbol
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m'])
try:
# Try to load data for the requested symbol using get_historical_data method
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=5000)
data_5m = data_interface.get_historical_data(timeframe='5m', n_candles=5000)
data_15m = data_interface.get_historical_data(timeframe='15m', n_candles=5000)
if data_1m is None or data_5m is None or data_15m is None:
raise FileNotFoundError("Could not retrieve data for specified symbol")
except Exception as e:
logger.warning(f"Data for {symbol} not available: {str(e)}. Using default data.")
# Try to use cached data if available
symbol = "BTC/USDT"
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m'])
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=5000)
data_5m = data_interface.get_historical_data(timeframe='5m', n_candles=5000)
data_15m = data_interface.get_historical_data(timeframe='15m', n_candles=5000)
if data_1m is None or data_5m is None or data_15m is None:
logger.error("Failed to retrieve any data. Cannot continue training.")
raise ValueError("No data available for training")
# Create features from the data by adding technical indicators and converting to numpy format
if data_1m is not None:
data_1m = data_interface.add_technical_indicators(data_1m)
# Convert to numpy array with close price as the last column
features_1m = np.hstack([
data_1m.drop(['timestamp', 'close'], axis=1).values,
data_1m['close'].values.reshape(-1, 1)
])
else:
features_1m = None
if data_5m is not None:
data_5m = data_interface.add_technical_indicators(data_5m)
# Convert to numpy array with close price as the last column
features_5m = np.hstack([
data_5m.drop(['timestamp', 'close'], axis=1).values,
data_5m['close'].values.reshape(-1, 1)
])
else:
features_5m = None
if data_15m is not None:
data_15m = data_interface.add_technical_indicators(data_15m)
# Convert to numpy array with close price as the last column
features_15m = np.hstack([
data_15m.drop(['timestamp', 'close'], axis=1).values,
data_15m['close'].values.reshape(-1, 1)
])
else:
features_15m = None
# Check if we have all the required features
if features_1m is None or features_5m is None or features_15m is None:
logger.error("Failed to create features for all timeframes.")
raise ValueError("Could not create features for training")
# Create the environment
if env_class:
# Use provided environment class
env = env_class(features_1m, features_5m, features_15m)
else:
# Use the default environment
env = RLTradingEnvironment(features_1m, features_5m, features_15m)
# Set action callback if provided
if action_callback:
env.set_action_callback(action_callback)
# Get environment properties for agent creation
input_shape = env.observation_space.shape
n_actions = env.action_space.n
# Create the agent
agent = DQNAgent(
state_shape=input_shape,
n_actions=n_actions,
epsilon=1.0,
epsilon_decay=0.995,
epsilon_min=0.01,
learning_rate=0.0001,
gamma=0.99,
buffer_size=10000,
batch_size=64,
device=device # Pass device to agent for GPU usage
)
# Check if model file exists and load it
model_file = f"{save_path}_model.pth"
if os.path.exists(model_file):
try:
agent.load(model_file)
logger.info(f"Loaded existing model from {model_file}")
except Exception as e:
logger.error(f"Error loading model: {e}")
else:
logger.info("No existing model found. Starting with a new model.")
# Pre-train price prediction if enabled and we have a new model
if pretrain_price_prediction_enabled:
if not os.path.exists(model_file) or input("Pre-train price prediction? (y/n): ").lower() == 'y':
logger.info("Pre-training price prediction capability...")
# Attempt to load hourly and daily data for pre-training
try:
data_interface.add_timeframe('1h')
data_interface.add_timeframe('1d')
# Run pre-training
agent = pretrain_price_prediction(
agent=agent,
data_interface=data_interface,
n_epochs=pretrain_epochs,
batch_size=128
)
# Save the pre-trained model
agent.save(f"{save_path}_pretrained")
logger.info("Pre-trained model saved.")
except Exception as e:
logger.error(f"Error during pre-training: {e}")
logger.warning("Continuing with RL training without pre-training.")
# Create TensorBoard writer
writer = SummaryWriter(log_dir=f'runs/dqn_{int(time.time())}')
# Log GPU status to TensorBoard
writer.add_text("hardware/device", str(device), 0)
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
writer.add_text(f"hardware/gpu_{i}", torch.cuda.get_device_name(i), 0)
# Training loop
total_rewards = []
trade_win_rates = []
best_reward = -np.inf
# Move models to the appropriate device if not already there
agent.move_models_to_device(device)
# Enable mixed precision if GPU and feature is available
use_mixed_precision = False
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp'):
logger.info("Enabling mixed precision training")
use_mixed_precision = True
scaler = torch.cuda.amp.GradScaler()
# Define step callback for tensorboard logging and model tracking
def step_callback(action, price, reward, info):
# Pass to external callback if provided
if action_callback:
action_callback(env.current_step, action, price, reward, info)
# Main training loop
logger.info(f"Starting training for {num_episodes} episodes...")
logger.info(f"Starting training on device: {agent.device}")
try:
for episode in range(num_episodes):
state = env.reset()
total_reward = 0
for step in range(max_steps):
# Select action
action = agent.act(state)
# Take action and observe next state and reward
next_state, reward, done, info = env.step(action)
# Store the experience in memory
agent.remember(state, action, reward, next_state, done)
# Update state and reward
state = next_state
total_reward += reward
# Train the agent by sampling from memory
if len(agent.memory) >= agent.batch_size:
loss = agent.replay()
if done or step == max_steps - 1:
break
# Track rewards
total_rewards.append(total_reward)
# Calculate trading metrics
win_rate = env.win_rate if hasattr(env, 'win_rate') else 0
trades = env.trades if hasattr(env, 'trades') else 0
# Log to TensorBoard
writer.add_scalar('Reward/Episode', total_reward, episode)
writer.add_scalar('Trade/WinRate', win_rate, episode)
writer.add_scalar('Trade/Count', trades, episode)
# Save best model
if total_reward > best_reward and episode > 10:
logger.info(f"New best average reward: {total_reward:.4f}, saving model")
agent.save(save_path)
best_reward = total_reward
# Periodic save every 100 episodes
if episode % 100 == 0 and episode > 0:
agent.save(f"{save_path}_episode_{episode}")
# Call episode callback if provided
if episode_callback:
# Add environment to info dict to use for extrema training
info_with_env = info.copy()
info_with_env['env'] = env
episode_callback(episode, total_reward, info_with_env)
# Final save
logger.info("Training completed, saving final model")
agent.save(f"{save_path}_final")
except Exception as e:
logger.error(f"Training failed: {str(e)}")
import traceback
logger.error(traceback.format_exc())
# Close TensorBoard writer
writer.close()
return agent, env
def generate_price_prediction_training_data(data_1m, data_1h, data_1d, window_size=20):
"""
Generate labeled training data for price prediction at different timeframes
Args:
data_1m: DataFrame with 1-minute data
data_1h: DataFrame with 1-hour data
data_1d: DataFrame with 1-day data
window_size: Size of the input window
Returns:
tuple: (X, y_immediate, y_midterm, y_longterm, y_values)
- X: input features (window sequences)
- y_immediate: immediate direction labels (0=down, 1=sideways, 2=up)
- y_midterm: mid-term direction labels
- y_longterm: long-term direction labels
- y_values: actual percentage changes for each timeframe
"""
logger.info("Generating price prediction training data from historical prices")
# Prepare data structures
X = []
y_immediate = [] # 1m
y_midterm = [] # 1h
y_longterm = [] # 1d
y_values = [] # Actual percentage changes
# Calculate future returns for labeling
data_1m['future_return_1m'] = data_1m['close'].pct_change(1).shift(-1) # Next candle
data_1m['future_return_10m'] = data_1m['close'].pct_change(10).shift(-10) # Next 10 candles
# Add indices to align data
data_1m['index'] = range(len(data_1m))
data_1h['index'] = range(len(data_1h))
data_1d['index'] = range(len(data_1d))
# Define thresholds for direction labels
immediate_threshold = 0.0005
midterm_threshold = 0.001
longterm_threshold = 0.002
# Loop through 1m data to create training samples
max_idx = len(data_1m) - window_size - 10 # Ensure we have future data for labels
sample_indices = random.sample(range(window_size, max_idx), min(10000, max_idx - window_size))
for idx in sample_indices:
# Get window of 1m data
window_1m = data_1m.iloc[idx-window_size:idx].drop(['timestamp', 'future_return_1m', 'future_return_10m', 'index'], axis=1, errors='ignore')
# Skip if window contains NaN
if window_1m.isnull().values.any():
continue
# Get future returns for labeling
future_return_1m = data_1m.iloc[idx]['future_return_1m']
future_return_10m = data_1m.iloc[idx]['future_return_10m']
# Find corresponding row in 1h data (closest timestamp)
current_timestamp = data_1m.iloc[idx]['timestamp']
# Find 1h candle for mid-term prediction
if 'timestamp' in data_1h.columns:
# Find closest 1h candle
closest_1h_idx = data_1h['timestamp'].searchsorted(current_timestamp)
if closest_1h_idx >= len(data_1h):
closest_1h_idx = len(data_1h) - 1
# Get future 1h return (next candle)
if closest_1h_idx < len(data_1h) - 1:
future_return_1h = (data_1h.iloc[closest_1h_idx + 1]['close'] - data_1h.iloc[closest_1h_idx]['close']) / data_1h.iloc[closest_1h_idx]['close']
else:
future_return_1h = 0
else:
future_return_1h = future_return_10m # Fallback
# Find 1d candle for long-term prediction
if 'timestamp' in data_1d.columns:
# Find closest 1d candle
closest_1d_idx = data_1d['timestamp'].searchsorted(current_timestamp)
if closest_1d_idx >= len(data_1d):
closest_1d_idx = len(data_1d) - 1
# Get future 1d return (next candle)
if closest_1d_idx < len(data_1d) - 1:
future_return_1d = (data_1d.iloc[closest_1d_idx + 1]['close'] - data_1d.iloc[closest_1d_idx]['close']) / data_1d.iloc[closest_1d_idx]['close']
else:
future_return_1d = 0
else:
future_return_1d = future_return_1h * 2 # Fallback
# Create direction labels
# 0=down, 1=sideways, 2=up
# Immediate (1m)
if future_return_1m > immediate_threshold:
immediate_label = 2 # UP
elif future_return_1m < -immediate_threshold:
immediate_label = 0 # DOWN
else:
immediate_label = 1 # SIDEWAYS
# Mid-term (1h)
if future_return_1h > midterm_threshold:
midterm_label = 2 # UP
elif future_return_1h < -midterm_threshold:
midterm_label = 0 # DOWN
else:
midterm_label = 1 # SIDEWAYS
# Long-term (1d)
if future_return_1d > longterm_threshold:
longterm_label = 2 # UP
elif future_return_1d < -longterm_threshold:
longterm_label = 0 # DOWN
else:
longterm_label = 1 # SIDEWAYS
# Store data
X.append(window_1m.values)
y_immediate.append(immediate_label)
y_midterm.append(midterm_label)
y_longterm.append(longterm_label)
y_values.append([future_return_1m, future_return_1h, future_return_1d, future_return_1d * 1.5]) # Add weekly estimate
# Convert to numpy arrays
X = np.array(X)
y_immediate = np.array(y_immediate)
y_midterm = np.array(y_midterm)
y_longterm = np.array(y_longterm)
y_values = np.array(y_values)
logger.info(f"Generated {len(X)} price prediction training samples")
# Log class distribution
for name, y in [("Immediate", y_immediate), ("Mid-term", y_midterm), ("Long-term", y_longterm)]:
down = (y == 0).sum()
sideways = (y == 1).sum()
up = (y == 2).sum()
logger.info(f"{name} direction distribution: DOWN={down} ({down/len(y)*100:.1f}%), "
f"SIDEWAYS={sideways} ({sideways/len(y)*100:.1f}%), "
f"UP={up} ({up/len(y)*100:.1f}%)")
return X, y_immediate, y_midterm, y_longterm, y_values
def pretrain_price_prediction(agent, data_interface, n_epochs=10, batch_size=128):
"""
Pre-train the agent's price prediction capability on historical data
Args:
agent: DQNAgent instance to train
data_interface: DataInterface instance for accessing data
n_epochs: Number of epochs for training
batch_size: Batch size for training
Returns:
The agent with pre-trained price prediction capabilities
"""
logger.info("Starting supervised pre-training of price prediction")
try:
# Load data for all required timeframes
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=10000)
data_1h = data_interface.get_historical_data(timeframe='1h', n_candles=1000)
data_1d = data_interface.get_historical_data(timeframe='1d', n_candles=500)
# Check if data is available
if data_1m is None:
logger.warning("1m data not available for pre-training")
return agent
if data_1h is None:
logger.warning("1h data not available, using synthesized data")
# Create synthetic 1h data from 1m data
data_1h = data_1m.iloc[::60].reset_index(drop=True).copy() # Take every 60th record
if data_1d is None:
logger.warning("1d data not available, using synthesized data")
# Create synthetic 1d data from 1h data
data_1d = data_1h.iloc[::24].reset_index(drop=True).copy() # Take every 24th record
# Add technical indicators to all data
data_1m = data_interface.add_technical_indicators(data_1m)
data_1h = data_interface.add_technical_indicators(data_1h)
data_1d = data_interface.add_technical_indicators(data_1d)
# Generate labeled training data
X, y_immediate, y_midterm, y_longterm, y_values = generate_price_prediction_training_data(
data_1m, data_1h, data_1d, window_size=20
)
# Split data into training and validation sets
from sklearn.model_selection import train_test_split
X_train, X_val, y_imm_train, y_imm_val, y_mid_train, y_mid_val, y_long_train, y_long_val, y_val_train, y_val_val = train_test_split(
X, y_immediate, y_midterm, y_longterm, y_values, test_size=0.2, random_state=42
)
# Convert to torch tensors
X_train_tensor = torch.FloatTensor(X_train).to(agent.device)
y_imm_train_tensor = torch.LongTensor(y_imm_train).to(agent.device)
y_mid_train_tensor = torch.LongTensor(y_mid_train).to(agent.device)
y_long_train_tensor = torch.LongTensor(y_long_train).to(agent.device)
y_val_train_tensor = torch.FloatTensor(y_val_train).to(agent.device)
X_val_tensor = torch.FloatTensor(X_val).to(agent.device)
y_imm_val_tensor = torch.LongTensor(y_imm_val).to(agent.device)
y_mid_val_tensor = torch.LongTensor(y_mid_val).to(agent.device)
y_long_val_tensor = torch.LongTensor(y_long_val).to(agent.device)
y_val_val_tensor = torch.FloatTensor(y_val_val).to(agent.device)
# Calculate class weights for imbalanced data
from torch.nn.functional import one_hot
# Function to calculate class weights
def get_class_weights(labels):
counts = np.bincount(labels)
if len(counts) < 3: # Ensure we have 3 classes
counts = np.append(counts, [0] * (3 - len(counts)))
weights = 1.0 / np.array(counts)
weights = weights / np.sum(weights) # Normalize
return weights
imm_weights = torch.FloatTensor(get_class_weights(y_imm_train)).to(agent.device)
mid_weights = torch.FloatTensor(get_class_weights(y_mid_train)).to(agent.device)
long_weights = torch.FloatTensor(get_class_weights(y_long_train)).to(agent.device)
# Create DataLoader for batch training
from torch.utils.data import TensorDataset, DataLoader
train_dataset = TensorDataset(
X_train_tensor, y_imm_train_tensor, y_mid_train_tensor,
y_long_train_tensor, y_val_train_tensor
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Set up loss functions with class weights
imm_criterion = nn.CrossEntropyLoss(weight=imm_weights)
mid_criterion = nn.CrossEntropyLoss(weight=mid_weights)
long_criterion = nn.CrossEntropyLoss(weight=long_weights)
value_criterion = nn.MSELoss()
# Set up optimizer (separate from agent's optimizer)
pretrain_optimizer = torch.optim.Adam(agent.policy_net.parameters(), lr=0.0002)
pretrain_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
pretrain_optimizer, mode='min', factor=0.5, patience=3, verbose=True
)
# Set model to training mode
agent.policy_net.train()
# Training loop
best_val_loss = float('inf')
patience = 5
patience_counter = 0
for epoch in range(n_epochs):
# Training phase
train_loss = 0.0
imm_correct, mid_correct, long_correct = 0, 0, 0
total = 0
for X_batch, y_imm_batch, y_mid_batch, y_long_batch, y_val_batch in train_loader:
# Zero gradients
pretrain_optimizer.zero_grad()
# Forward pass - we only need the price predictions
with torch.cuda.amp.autocast() if agent.use_mixed_precision else contextlib.nullcontext():
_, _, price_preds = agent.policy_net(X_batch)
# Calculate losses for each prediction head
imm_loss = imm_criterion(price_preds['immediate'], y_imm_batch)
mid_loss = mid_criterion(price_preds['midterm'], y_mid_batch)
long_loss = long_criterion(price_preds['longterm'], y_long_batch)
value_loss = value_criterion(price_preds['values'], y_val_batch)
# Combined loss (weighted by importance)
total_loss = imm_loss + 0.7 * mid_loss + 0.5 * long_loss + 0.3 * value_loss
# Backward pass and optimize
if agent.use_mixed_precision:
agent.scaler.scale(total_loss).backward()
agent.scaler.unscale_(pretrain_optimizer)
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
agent.scaler.step(pretrain_optimizer)
agent.scaler.update()
else:
total_loss.backward()
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
pretrain_optimizer.step()
# Accumulate metrics
train_loss += total_loss.item()
total += X_batch.size(0)
# Calculate accuracy
_, imm_pred = torch.max(price_preds['immediate'], 1)
_, mid_pred = torch.max(price_preds['midterm'], 1)
_, long_pred = torch.max(price_preds['longterm'], 1)
imm_correct += (imm_pred == y_imm_batch).sum().item()
mid_correct += (mid_pred == y_mid_batch).sum().item()
long_correct += (long_pred == y_long_batch).sum().item()
# Calculate epoch metrics
train_loss /= len(train_loader)
imm_acc = imm_correct / total
mid_acc = mid_correct / total
long_acc = long_correct / total
# Validation phase
agent.policy_net.eval()
val_loss = 0.0
imm_val_correct, mid_val_correct, long_val_correct = 0, 0, 0
with torch.no_grad():
# Forward pass on validation data
_, _, val_price_preds = agent.policy_net(X_val_tensor)
# Calculate validation losses
val_imm_loss = imm_criterion(val_price_preds['immediate'], y_imm_val_tensor)
val_mid_loss = mid_criterion(val_price_preds['midterm'], y_mid_val_tensor)
val_long_loss = long_criterion(val_price_preds['longterm'], y_long_val_tensor)
val_value_loss = value_criterion(val_price_preds['values'], y_val_val_tensor)
val_total_loss = val_imm_loss + 0.7 * val_mid_loss + 0.5 * val_long_loss + 0.3 * val_value_loss
val_loss = val_total_loss.item()
# Calculate validation accuracy
_, imm_val_pred = torch.max(val_price_preds['immediate'], 1)
_, mid_val_pred = torch.max(val_price_preds['midterm'], 1)
_, long_val_pred = torch.max(val_price_preds['longterm'], 1)
imm_val_correct = (imm_val_pred == y_imm_val_tensor).sum().item()
mid_val_correct = (mid_val_pred == y_mid_val_tensor).sum().item()
long_val_correct = (long_val_pred == y_long_val_tensor).sum().item()
imm_val_acc = imm_val_correct / len(X_val_tensor)
mid_val_acc = mid_val_correct / len(X_val_tensor)
long_val_acc = long_val_correct / len(X_val_tensor)
# Learning rate scheduling
pretrain_scheduler.step(val_loss)
# Early stopping check
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
# Copy policy_net weights to target_net
agent.target_net.load_state_dict(agent.policy_net.state_dict())
logger.info(f"Saved best model with validation loss: {val_loss:.4f}")
else:
patience_counter += 1
if patience_counter >= patience:
logger.info(f"Early stopping triggered after {epoch+1} epochs")
break
# Log progress
logger.info(f"Epoch {epoch+1}/{n_epochs}: "
f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
f"Imm Acc: {imm_acc:.4f}/{imm_val_acc:.4f}, "
f"Mid Acc: {mid_acc:.4f}/{mid_val_acc:.4f}, "
f"Long Acc: {long_acc:.4f}/{long_val_acc:.4f}")
# Set model back to training mode for next epoch
agent.policy_net.train()
logger.info("Price prediction pre-training complete")
return agent
except Exception as e:
logger.error(f"Error during price prediction pre-training: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return agent
if __name__ == "__main__":
train_rl()