misc
This commit is contained in:
Binary file not shown.
6
NN/environments/__init__.py
Normal file
6
NN/environments/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
# Trading environments for reinforcement learning
|
||||
# This module contains environments for training trading agents
|
||||
|
||||
from NN.environments.trading_env import TradingEnvironment
|
||||
|
||||
__all__ = ['TradingEnvironment']
|
484
NN/environments/trading_env.py
Normal file
484
NN/environments/trading_env.py
Normal file
@ -0,0 +1,484 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Dict, Tuple, List, Any, Optional
|
||||
import logging
|
||||
import gym
|
||||
from gym import spaces
|
||||
import random
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TradingEnvironment(gym.Env):
|
||||
"""
|
||||
Trading environment implementing gym interface for reinforcement learning
|
||||
|
||||
Actions:
|
||||
- 0: Buy
|
||||
- 1: Sell
|
||||
- 2: Hold
|
||||
|
||||
State:
|
||||
- OHLCV data from multiple timeframes
|
||||
- Technical indicators
|
||||
- Position data
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_interface,
|
||||
initial_balance: float = 10000.0,
|
||||
transaction_fee: float = 0.0002,
|
||||
window_size: int = 20,
|
||||
max_position: float = 1.0,
|
||||
reward_scaling: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Initialize the trading environment.
|
||||
|
||||
Args:
|
||||
data_interface: DataInterface instance to get market data
|
||||
initial_balance: Initial balance in the base currency
|
||||
transaction_fee: Fee for each transaction as a fraction of trade value
|
||||
window_size: Number of candles in the observation window
|
||||
max_position: Maximum position size as a fraction of balance
|
||||
reward_scaling: Scale factor for rewards
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.data_interface = data_interface
|
||||
self.initial_balance = initial_balance
|
||||
self.transaction_fee = transaction_fee
|
||||
self.window_size = window_size
|
||||
self.max_position = max_position
|
||||
self.reward_scaling = reward_scaling
|
||||
|
||||
# Load data for primary timeframe (assuming the first one is primary)
|
||||
self.timeframe = self.data_interface.timeframes[0]
|
||||
self.reset_data()
|
||||
|
||||
# Define action and observation spaces
|
||||
self.action_space = spaces.Discrete(3) # Buy, Sell, Hold
|
||||
|
||||
# For observation space, we consider multiple timeframes with OHLCV data
|
||||
# and additional features like technical indicators, position info, etc.
|
||||
n_timeframes = len(self.data_interface.timeframes)
|
||||
n_features = 5 # OHLCV data by default
|
||||
|
||||
# Add additional features for position, balance, etc.
|
||||
additional_features = 3 # position, balance, unrealized_pnl
|
||||
|
||||
# Calculate total feature dimension
|
||||
total_features = (n_timeframes * n_features * self.window_size) + additional_features
|
||||
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(total_features,), dtype=np.float32
|
||||
)
|
||||
|
||||
# Use tuple for state_shape that EnhancedCNN expects
|
||||
self.state_shape = (total_features,)
|
||||
|
||||
# Initialize state
|
||||
self.reset()
|
||||
|
||||
def reset_data(self):
|
||||
"""Reset data and generate a new set of price data for training"""
|
||||
# Get data for each timeframe
|
||||
self.data = {}
|
||||
for tf in self.data_interface.timeframes:
|
||||
df = self.data_interface.dataframes[tf]
|
||||
if df is not None and not df.empty:
|
||||
self.data[tf] = df
|
||||
|
||||
if not self.data:
|
||||
raise ValueError("No data available for training")
|
||||
|
||||
# Use the primary timeframe for step count
|
||||
self.prices = self.data[self.timeframe]['close'].values
|
||||
self.timestamps = self.data[self.timeframe].index.values
|
||||
self.max_steps = len(self.prices) - self.window_size - 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset the environment to initial state"""
|
||||
# Reset trading variables
|
||||
self.balance = self.initial_balance
|
||||
self.position = 0.0 # No position initially
|
||||
self.entry_price = 0.0
|
||||
self.total_pnl = 0.0
|
||||
self.trades = []
|
||||
self.rewards = []
|
||||
|
||||
# Reset step counter
|
||||
self.current_step = self.window_size
|
||||
|
||||
# Get initial observation
|
||||
observation = self._get_observation()
|
||||
|
||||
return observation
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Take a step in the environment.
|
||||
|
||||
Args:
|
||||
action: Action to take (0: Buy, 1: Sell, 2: Hold)
|
||||
|
||||
Returns:
|
||||
tuple: (observation, reward, done, info)
|
||||
"""
|
||||
# Get current state before taking action
|
||||
prev_balance = self.balance
|
||||
prev_position = self.position
|
||||
prev_price = self.prices[self.current_step]
|
||||
|
||||
# Take action
|
||||
info = {}
|
||||
reward = 0
|
||||
last_position_info = None
|
||||
|
||||
# Get current price
|
||||
current_price = self.prices[self.current_step]
|
||||
next_price = self.prices[self.current_step + 1] if self.current_step + 1 < len(self.prices) else current_price
|
||||
|
||||
# Process the action
|
||||
if action == 0: # Buy
|
||||
if self.position <= 0: # Only buy if not already long
|
||||
# Close any existing short position
|
||||
if self.position < 0:
|
||||
close_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += close_pnl * self.reward_scaling
|
||||
|
||||
# Open new long position
|
||||
self._open_position(1.0 * self.max_position, current_price)
|
||||
logger.info(f"Buy at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
|
||||
|
||||
elif action == 1: # Sell
|
||||
if self.position >= 0: # Only sell if not already short
|
||||
# Close any existing long position
|
||||
if self.position > 0:
|
||||
close_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += close_pnl * self.reward_scaling
|
||||
|
||||
# Open new short position
|
||||
self._open_position(-1.0 * self.max_position, current_price)
|
||||
logger.info(f"Sell at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
|
||||
|
||||
elif action == 2: # Hold
|
||||
# No action, but still calculate unrealized PnL for reward
|
||||
pass
|
||||
|
||||
# Calculate unrealized PnL and add to reward
|
||||
if self.position != 0:
|
||||
unrealized_pnl = self._calculate_unrealized_pnl(next_price)
|
||||
reward += unrealized_pnl * self.reward_scaling * 0.1 # Scale down unrealized PnL
|
||||
|
||||
# Apply penalties for holding a position
|
||||
if self.position != 0:
|
||||
# Small holding fee/interest
|
||||
holding_penalty = abs(self.position) * 0.0001 # 0.01% per step
|
||||
reward -= holding_penalty * self.reward_scaling
|
||||
|
||||
# Move to next step
|
||||
self.current_step += 1
|
||||
|
||||
# Get new observation
|
||||
observation = self._get_observation()
|
||||
|
||||
# Check if episode is done
|
||||
done = self.current_step >= len(self.prices) - 1
|
||||
|
||||
# If done, close any remaining positions
|
||||
if done and self.position != 0:
|
||||
final_pnl, last_position_info = self._close_position(current_price)
|
||||
reward += final_pnl * self.reward_scaling
|
||||
info['final_pnl'] = final_pnl
|
||||
info['final_balance'] = self.balance
|
||||
logger.info(f"Episode ended. Final balance: {self.balance:.4f}, Return: {(self.balance/self.initial_balance-1)*100:.2f}%")
|
||||
|
||||
# Track trade result if position changed or position was closed
|
||||
if prev_position != self.position or last_position_info is not None:
|
||||
# Calculate realized PnL if position was closed
|
||||
realized_pnl = 0
|
||||
position_info = {}
|
||||
|
||||
if last_position_info is not None:
|
||||
# Use the position information from closing
|
||||
realized_pnl = last_position_info['pnl']
|
||||
position_info = last_position_info
|
||||
else:
|
||||
# Calculate manually based on balance change
|
||||
realized_pnl = self.balance - prev_balance if prev_position != 0 else 0
|
||||
|
||||
# Record detailed trade information
|
||||
trade_result = {
|
||||
'step': self.current_step,
|
||||
'timestamp': self.timestamps[self.current_step],
|
||||
'action': action,
|
||||
'action_name': ['BUY', 'SELL', 'HOLD'][action],
|
||||
'price': current_price,
|
||||
'position_changed': prev_position != self.position,
|
||||
'prev_position': prev_position,
|
||||
'new_position': self.position,
|
||||
'position_size': abs(self.position) if self.position != 0 else abs(prev_position),
|
||||
'entry_price': position_info.get('entry_price', self.entry_price),
|
||||
'exit_price': position_info.get('exit_price', current_price),
|
||||
'realized_pnl': realized_pnl,
|
||||
'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0,
|
||||
'pnl': realized_pnl, # Total PnL (realized for this step)
|
||||
'balance_before': prev_balance,
|
||||
'balance_after': self.balance,
|
||||
'trade_fee': position_info.get('fee', abs(self.position - prev_position) * current_price * self.transaction_fee)
|
||||
}
|
||||
info['trade_result'] = trade_result
|
||||
self.trades.append(trade_result)
|
||||
|
||||
# Log trade details
|
||||
logger.info(f"Trade executed - Action: {['BUY', 'SELL', 'HOLD'][action]}, "
|
||||
f"Price: {current_price:.4f}, PnL: {realized_pnl:.4f}, "
|
||||
f"Balance: {self.balance:.4f}")
|
||||
|
||||
# Store reward
|
||||
self.rewards.append(reward)
|
||||
|
||||
# Update info dict with current state
|
||||
info.update({
|
||||
'step': self.current_step,
|
||||
'price': current_price,
|
||||
'prev_price': prev_price,
|
||||
'price_change': (current_price - prev_price) / prev_price if prev_price != 0 else 0,
|
||||
'balance': self.balance,
|
||||
'position': self.position,
|
||||
'entry_price': self.entry_price,
|
||||
'unrealized_pnl': self._calculate_unrealized_pnl(current_price) if self.position != 0 else 0.0,
|
||||
'total_trades': len(self.trades),
|
||||
'total_pnl': self.total_pnl,
|
||||
'return_pct': (self.balance/self.initial_balance-1)*100
|
||||
})
|
||||
|
||||
return observation, reward, done, info
|
||||
|
||||
def _calculate_unrealized_pnl(self, current_price):
|
||||
"""Calculate unrealized PnL for current position"""
|
||||
if self.position == 0 or self.entry_price == 0:
|
||||
return 0.0
|
||||
|
||||
if self.position > 0: # Long position
|
||||
return self.position * (current_price / self.entry_price - 1.0)
|
||||
else: # Short position
|
||||
return -self.position * (1.0 - current_price / self.entry_price)
|
||||
|
||||
def _open_position(self, position_size, price):
|
||||
"""Open a new position"""
|
||||
self.position = position_size
|
||||
self.entry_price = price
|
||||
|
||||
def _close_position(self, price):
|
||||
"""Close the current position and return PnL"""
|
||||
pnl = self._calculate_unrealized_pnl(price)
|
||||
|
||||
# Apply transaction fee
|
||||
fee = abs(self.position) * price * self.transaction_fee
|
||||
pnl -= fee
|
||||
|
||||
# Update balance
|
||||
self.balance += pnl
|
||||
self.total_pnl += pnl
|
||||
|
||||
# Store position details before resetting
|
||||
last_position = {
|
||||
'position_size': self.position,
|
||||
'entry_price': self.entry_price,
|
||||
'exit_price': price,
|
||||
'pnl': pnl,
|
||||
'fee': fee
|
||||
}
|
||||
|
||||
# Reset position
|
||||
self.position = 0.0
|
||||
self.entry_price = 0.0
|
||||
|
||||
# Log position closure
|
||||
logger.info(f"Closed position - Size: {last_position['position_size']:.4f}, "
|
||||
f"Entry: {last_position['entry_price']:.4f}, Exit: {last_position['exit_price']:.4f}, "
|
||||
f"PnL: {last_position['pnl']:.4f}, Fee: {last_position['fee']:.4f}")
|
||||
|
||||
return pnl, last_position
|
||||
|
||||
def _get_observation(self):
|
||||
"""
|
||||
Get the current observation.
|
||||
|
||||
Returns:
|
||||
np.array: The observation vector
|
||||
"""
|
||||
observations = []
|
||||
|
||||
# Get data from each timeframe
|
||||
for tf in self.data_interface.timeframes:
|
||||
if tf in self.data:
|
||||
# Get the window of data for this timeframe
|
||||
df = self.data[tf]
|
||||
start_idx = self._align_timeframe_index(tf)
|
||||
|
||||
if start_idx is not None and start_idx >= 0 and start_idx + self.window_size <= len(df):
|
||||
window = df.iloc[start_idx:start_idx + self.window_size]
|
||||
|
||||
# Extract OHLCV data
|
||||
ohlcv = window[['open', 'high', 'low', 'close', 'volume']].values
|
||||
|
||||
# Normalize OHLCV data
|
||||
last_close = ohlcv[-1, 3] # Last close price
|
||||
ohlcv_normalized = np.zeros_like(ohlcv)
|
||||
ohlcv_normalized[:, 0] = ohlcv[:, 0] / last_close - 1.0 # open
|
||||
ohlcv_normalized[:, 1] = ohlcv[:, 1] / last_close - 1.0 # high
|
||||
ohlcv_normalized[:, 2] = ohlcv[:, 2] / last_close - 1.0 # low
|
||||
ohlcv_normalized[:, 3] = ohlcv[:, 3] / last_close - 1.0 # close
|
||||
|
||||
# Normalize volume (relative to moving average of volume)
|
||||
if 'volume' in window.columns:
|
||||
volume_ma = ohlcv[:, 4].mean()
|
||||
if volume_ma > 0:
|
||||
ohlcv_normalized[:, 4] = ohlcv[:, 4] / volume_ma - 1.0
|
||||
else:
|
||||
ohlcv_normalized[:, 4] = 0.0
|
||||
else:
|
||||
ohlcv_normalized[:, 4] = 0.0
|
||||
|
||||
# Flatten and add to observations
|
||||
observations.append(ohlcv_normalized.flatten())
|
||||
else:
|
||||
# Fill with zeros if not enough data
|
||||
observations.append(np.zeros(self.window_size * 5))
|
||||
|
||||
# Add position and balance information
|
||||
current_price = self.prices[self.current_step]
|
||||
position_info = np.array([
|
||||
self.position / self.max_position, # Normalized position (-1 to 1)
|
||||
self.balance / self.initial_balance - 1.0, # Normalized balance change
|
||||
self._calculate_unrealized_pnl(current_price) # Unrealized PnL
|
||||
])
|
||||
|
||||
observations.append(position_info)
|
||||
|
||||
# Concatenate all observations
|
||||
observation = np.concatenate(observations)
|
||||
return observation
|
||||
|
||||
def _align_timeframe_index(self, timeframe):
|
||||
"""
|
||||
Align the index of a higher timeframe with the current step in the primary timeframe.
|
||||
|
||||
Args:
|
||||
timeframe: The timeframe to align
|
||||
|
||||
Returns:
|
||||
int: The starting index in the higher timeframe
|
||||
"""
|
||||
if timeframe == self.timeframe:
|
||||
return self.current_step - self.window_size
|
||||
|
||||
# Get timestamps for current primary timeframe step
|
||||
primary_ts = self.timestamps[self.current_step]
|
||||
|
||||
# Find closest index in the higher timeframe
|
||||
higher_ts = self.data[timeframe].index.values
|
||||
idx = np.searchsorted(higher_ts, primary_ts)
|
||||
|
||||
# Adjust to get the starting index
|
||||
start_idx = max(0, idx - self.window_size)
|
||||
return start_idx
|
||||
|
||||
def get_last_positions(self, n=5):
|
||||
"""
|
||||
Get detailed information about the last n positions.
|
||||
|
||||
Args:
|
||||
n: Number of last positions to return
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing position details
|
||||
"""
|
||||
if not self.trades:
|
||||
return []
|
||||
|
||||
# Filter trades to only include those that closed positions
|
||||
position_trades = [t for t in self.trades if t.get('realized_pnl', 0) != 0 or (t.get('prev_position', 0) != 0 and t.get('new_position', 0) == 0)]
|
||||
|
||||
positions = []
|
||||
last_n_trades = position_trades[-n:] if len(position_trades) >= n else position_trades
|
||||
|
||||
for trade in last_n_trades:
|
||||
position_info = {
|
||||
'timestamp': trade.get('timestamp', self.timestamps[trade['step']]),
|
||||
'action': trade.get('action_name', ['BUY', 'SELL', 'HOLD'][trade['action']]),
|
||||
'entry_price': trade.get('entry_price', 0.0),
|
||||
'exit_price': trade.get('exit_price', trade['price']),
|
||||
'position_size': trade.get('position_size', self.max_position),
|
||||
'realized_pnl': trade.get('realized_pnl', 0.0),
|
||||
'fee': trade.get('trade_fee', 0.0),
|
||||
'pnl': trade.get('pnl', 0.0),
|
||||
'pnl_percentage': (trade.get('pnl', 0.0) / self.initial_balance) * 100,
|
||||
'balance_before': trade.get('balance_before', 0.0),
|
||||
'balance_after': trade.get('balance_after', 0.0),
|
||||
'duration': trade.get('duration', 'N/A')
|
||||
}
|
||||
positions.append(position_info)
|
||||
|
||||
return positions
|
||||
|
||||
def render(self, mode='human'):
|
||||
"""Render the environment"""
|
||||
current_step = self.current_step
|
||||
current_price = self.prices[current_step]
|
||||
|
||||
# Display basic information
|
||||
print(f"\nTrading Environment Status:")
|
||||
print(f"============================")
|
||||
print(f"Step: {current_step}/{len(self.prices)-1}")
|
||||
print(f"Current Price: {current_price:.4f}")
|
||||
print(f"Current Balance: {self.balance:.4f}")
|
||||
print(f"Current Position: {self.position:.4f}")
|
||||
|
||||
if self.position != 0:
|
||||
unrealized_pnl = self._calculate_unrealized_pnl(current_price)
|
||||
print(f"Entry Price: {self.entry_price:.4f}")
|
||||
print(f"Unrealized PnL: {unrealized_pnl:.4f} ({unrealized_pnl/self.balance*100:.2f}%)")
|
||||
|
||||
print(f"Total PnL: {self.total_pnl:.4f} ({self.total_pnl/self.initial_balance*100:.2f}%)")
|
||||
print(f"Total Trades: {len(self.trades)}")
|
||||
|
||||
if len(self.trades) > 0:
|
||||
win_trades = [t for t in self.trades if t.get('realized_pnl', 0) > 0]
|
||||
win_count = len(win_trades)
|
||||
# Count trades that closed positions (not just changed them)
|
||||
closed_positions = [t for t in self.trades if t.get('realized_pnl', 0) != 0]
|
||||
closed_count = len(closed_positions)
|
||||
win_rate = win_count / closed_count if closed_count > 0 else 0
|
||||
print(f"Positions Closed: {closed_count}")
|
||||
print(f"Winning Positions: {win_count}")
|
||||
print(f"Win Rate: {win_rate:.2f}")
|
||||
|
||||
# Display last 5 positions
|
||||
print("\nLast 5 Positions:")
|
||||
print("================")
|
||||
last_positions = self.get_last_positions(5)
|
||||
|
||||
if not last_positions:
|
||||
print("No closed positions yet.")
|
||||
|
||||
for pos in last_positions:
|
||||
print(f"Time: {pos['timestamp']}")
|
||||
print(f"Action: {pos['action']}")
|
||||
print(f"Entry: {pos['entry_price']:.4f}, Exit: {pos['exit_price']:.4f}")
|
||||
print(f"Size: {pos['position_size']:.4f}")
|
||||
print(f"PnL: {pos['realized_pnl']:.4f} ({pos['pnl_percentage']:.2f}%)")
|
||||
print(f"Fee: {pos['fee']:.4f}")
|
||||
print(f"Balance: {pos['balance_before']:.4f} -> {pos['balance_after']:.4f}")
|
||||
print("----------------")
|
||||
|
||||
return
|
||||
|
||||
def close(self):
|
||||
"""Close the environment"""
|
||||
pass
|
@ -78,17 +78,25 @@ class CNNPyTorch(nn.Module):
|
||||
window_size, num_features = input_shape
|
||||
self.window_size = window_size
|
||||
|
||||
# Simpler architecture with fewer layers and dropout
|
||||
# Increased complexity
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv1d(num_features, 32, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(32),
|
||||
nn.Conv1d(num_features, 64, kernel_size=3, padding=1), # Increased filters
|
||||
nn.BatchNorm1d(64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2)
|
||||
)
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv1d(32, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.Conv1d(64, 128, kernel_size=3, padding=1), # Increased filters
|
||||
nn.BatchNorm1d(128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2)
|
||||
)
|
||||
|
||||
# Added third conv layer
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.Conv1d(128, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2)
|
||||
)
|
||||
@ -96,12 +104,12 @@ class CNNPyTorch(nn.Module):
|
||||
# Global average pooling to handle variable length sequences
|
||||
self.global_pool = nn.AdaptiveAvgPool1d(1)
|
||||
|
||||
# Fully connected layers
|
||||
# Fully connected layers (updated input size and hidden size)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(64, 32),
|
||||
nn.Linear(128, 64), # Updated input size from conv3, increased hidden size
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(32, output_size)
|
||||
nn.Linear(64, output_size)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -120,10 +128,11 @@ class CNNPyTorch(nn.Module):
|
||||
# Convolutional layers
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x) # Added conv3 pass
|
||||
|
||||
# Global pooling
|
||||
x = self.global_pool(x)
|
||||
x = x.squeeze(-1)
|
||||
x = x.squeeze(-1) # Shape becomes [batch, 128]
|
||||
|
||||
# Fully connected layers
|
||||
action_logits = self.fc(x)
|
||||
@ -216,6 +225,8 @@ class CNNModelPyTorch:
|
||||
self.last_actions = [[] for _ in range(num_pairs)] # Track recent actions per pair
|
||||
|
||||
def train_epoch(self, X_train, y_train, future_prices, batch_size):
|
||||
# Add a call to predict_extrema here
|
||||
self.predict_extrema(X_train)
|
||||
"""Train the model for one epoch with focus on short-term pattern recognition"""
|
||||
self.model.train()
|
||||
total_loss = 0
|
||||
@ -321,7 +332,8 @@ class CNNModelPyTorch:
|
||||
|
||||
return avg_loss, 0, accuracy # Return 0 for price_loss as we're not using it
|
||||
|
||||
def predict(self, X):
|
||||
def predict_extrema(self, X):
|
||||
# Predict local extrema (lows and highs) based on input data
|
||||
"""Make predictions optimized for short-term high-leverage trading signals"""
|
||||
self.model.eval()
|
||||
|
||||
|
@ -54,6 +54,7 @@ class DQNAgent:
|
||||
self.epsilon = epsilon
|
||||
self.epsilon_min = epsilon_min
|
||||
self.epsilon_decay = epsilon_decay
|
||||
self.epsilon_start = epsilon # Store initial epsilon value for resets/bumps
|
||||
self.buffer_size = buffer_size
|
||||
self.batch_size = batch_size
|
||||
self.target_update = target_update
|
||||
@ -127,6 +128,28 @@ class DQNAgent:
|
||||
self.best_reward = -float('inf')
|
||||
self.no_improvement_count = 0
|
||||
|
||||
# Confidence tracking
|
||||
self.confidence_history = []
|
||||
self.avg_confidence = 0.0
|
||||
self.max_confidence = 0.0
|
||||
self.min_confidence = 1.0
|
||||
|
||||
# Trade action fee and confidence thresholds
|
||||
self.trade_action_fee = 0.0005 # Small fee to discourage unnecessary trading
|
||||
self.minimum_action_confidence = 0.5 # Minimum confidence to consider trading
|
||||
self.recent_actions = [] # Track recent actions to avoid oscillations
|
||||
|
||||
# Violent move detection
|
||||
self.price_history = []
|
||||
self.volatility_window = 20 # Window size for volatility calculation
|
||||
self.volatility_threshold = 0.0015 # Threshold for considering a move "violent"
|
||||
self.post_violent_move = False # Flag for recent violent move
|
||||
self.violent_move_cooldown = 0 # Cooldown after violent move
|
||||
|
||||
# Feature integration
|
||||
self.last_hidden_features = None # Store last extracted features
|
||||
self.feature_history = [] # Store history of features for analysis
|
||||
|
||||
# Check if mixed precision training should be used
|
||||
self.use_mixed_precision = False
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
||||
@ -146,6 +169,7 @@ class DQNAgent:
|
||||
self.timeframes = ["1m", "5m", "15m"][:self.state_dim[0]] # Default timeframes
|
||||
|
||||
logger.info(f"DQN Agent using device: {self.device}")
|
||||
logger.info(f"Trade action fee set to {self.trade_action_fee}, minimum confidence: {self.minimum_action_confidence}")
|
||||
|
||||
def move_models_to_device(self, device=None):
|
||||
"""Move models to the specified device (GPU/CPU)"""
|
||||
@ -189,8 +213,20 @@ class DQNAgent:
|
||||
current_price = state[-1] # Last feature
|
||||
next_price = next_state[-1]
|
||||
|
||||
# Calculate price change
|
||||
price_change = (next_price - current_price) / current_price
|
||||
# Calculate price change - avoid division by zero
|
||||
if np.isscalar(current_price) and current_price != 0:
|
||||
price_change = (next_price - current_price) / current_price
|
||||
elif isinstance(current_price, np.ndarray):
|
||||
# Handle array case - protect against division by zero
|
||||
with np.errstate(divide='ignore', invalid='ignore'):
|
||||
price_change = (next_price - current_price) / current_price
|
||||
# Replace infinities and NaNs with zeros
|
||||
if isinstance(price_change, np.ndarray):
|
||||
price_change = np.nan_to_num(price_change, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
else:
|
||||
price_change = 0.0 if np.isnan(price_change) or np.isinf(price_change) else price_change
|
||||
else:
|
||||
price_change = 0.0
|
||||
|
||||
# Check if this is a significant price movement
|
||||
if abs(price_change) > 0.002: # Significant price change
|
||||
@ -264,9 +300,17 @@ class DQNAgent:
|
||||
|
||||
# Get predictions using the policy network
|
||||
self.policy_net.eval() # Set to evaluation mode for inference
|
||||
action_probs, extrema_pred, price_predictions = self.policy_net(state_tensor)
|
||||
action_probs, extrema_pred, price_predictions, hidden_features = self.policy_net(state_tensor)
|
||||
self.policy_net.train() # Back to training mode
|
||||
|
||||
# Store hidden features for integration
|
||||
self.last_hidden_features = hidden_features.cpu().numpy()
|
||||
|
||||
# Track feature history (limited size)
|
||||
self.feature_history.append(hidden_features.cpu().numpy())
|
||||
if len(self.feature_history) > 100:
|
||||
self.feature_history = self.feature_history[-100:]
|
||||
|
||||
# Get the predicted extrema class (0=bottom, 1=top, 2=neither)
|
||||
extrema_class = extrema_pred.argmax(dim=1).item()
|
||||
extrema_confidence = torch.softmax(extrema_pred, dim=1)[0, extrema_class].item()
|
||||
@ -336,17 +380,120 @@ class DQNAgent:
|
||||
# Get the action with highest Q-value
|
||||
action = action_probs.argmax().item()
|
||||
|
||||
# Calculate overall confidence in the action
|
||||
q_values_softmax = F.softmax(action_probs, dim=1)[0]
|
||||
action_confidence = q_values_softmax[action].item()
|
||||
|
||||
# Track confidence metrics
|
||||
self.confidence_history.append(action_confidence)
|
||||
if len(self.confidence_history) > 100:
|
||||
self.confidence_history = self.confidence_history[-100:]
|
||||
|
||||
# Update confidence metrics
|
||||
self.avg_confidence = sum(self.confidence_history) / len(self.confidence_history)
|
||||
self.max_confidence = max(self.max_confidence, action_confidence)
|
||||
self.min_confidence = min(self.min_confidence, action_confidence)
|
||||
|
||||
# Log average confidence occasionally
|
||||
if random.random() < 0.01: # 1% of the time
|
||||
logger.info(f"Confidence metrics - Current: {action_confidence:.4f}, Avg: {self.avg_confidence:.4f}, " +
|
||||
f"Min: {self.min_confidence:.4f}, Max: {self.max_confidence:.4f}")
|
||||
|
||||
# Track price for violent move detection
|
||||
try:
|
||||
# Extract current price from state (assuming it's in the last position)
|
||||
if len(state.shape) > 1: # For 2D state
|
||||
current_price = state[-1, -1]
|
||||
else: # For 1D state
|
||||
current_price = state[-1]
|
||||
|
||||
self.price_history.append(current_price)
|
||||
if len(self.price_history) > self.volatility_window:
|
||||
self.price_history = self.price_history[-self.volatility_window:]
|
||||
|
||||
# Detect violent price moves if we have enough price history
|
||||
if len(self.price_history) >= 5:
|
||||
# Calculate short-term volatility
|
||||
recent_prices = self.price_history[-5:]
|
||||
|
||||
# Make sure we're working with scalar values, not arrays
|
||||
if isinstance(recent_prices[0], np.ndarray):
|
||||
# If prices are arrays, extract the last value (current price)
|
||||
recent_prices = [p[-1] if isinstance(p, np.ndarray) and p.size > 0 else p for p in recent_prices]
|
||||
|
||||
# Calculate price changes with protection against division by zero
|
||||
price_changes = []
|
||||
for i in range(1, len(recent_prices)):
|
||||
if recent_prices[i-1] != 0 and not np.isnan(recent_prices[i-1]) and not np.isnan(recent_prices[i]):
|
||||
change = (recent_prices[i] - recent_prices[i-1]) / recent_prices[i-1]
|
||||
price_changes.append(change)
|
||||
else:
|
||||
price_changes.append(0.0)
|
||||
|
||||
# Calculate volatility as sum of absolute price changes
|
||||
volatility = sum([abs(change) for change in price_changes])
|
||||
|
||||
# Check if we've had a violent move
|
||||
if volatility > self.volatility_threshold:
|
||||
logger.info(f"Violent price move detected! Volatility: {volatility:.6f}")
|
||||
self.post_violent_move = True
|
||||
self.violent_move_cooldown = 10 # Set cooldown period
|
||||
|
||||
# Handle post-violent move period
|
||||
if self.post_violent_move:
|
||||
if self.violent_move_cooldown > 0:
|
||||
self.violent_move_cooldown -= 1
|
||||
# Increase confidence threshold temporarily after violent moves
|
||||
effective_threshold = self.minimum_action_confidence * 1.1
|
||||
logger.info(f"Post-violent move period: {self.violent_move_cooldown} steps remaining. " +
|
||||
f"Using higher confidence threshold: {effective_threshold:.4f}")
|
||||
else:
|
||||
self.post_violent_move = False
|
||||
logger.info("Post-violent move period ended")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in violent move detection: {str(e)}")
|
||||
|
||||
# Apply trade action fee to buy/sell actions but not to hold
|
||||
# This creates a threshold that must be exceeded to justify a trade
|
||||
action_values = action_probs.clone()
|
||||
|
||||
# If BUY or SELL, apply fee by reducing the Q-value
|
||||
if action == 0 or action == 1: # BUY or SELL
|
||||
# Check if confidence is above minimum threshold
|
||||
effective_threshold = self.minimum_action_confidence
|
||||
if self.post_violent_move:
|
||||
effective_threshold *= 1.1 # Higher threshold after violent moves
|
||||
|
||||
if action_confidence < effective_threshold:
|
||||
# If confidence is below threshold, force HOLD action
|
||||
logger.info(f"Action {action} confidence {action_confidence:.4f} below threshold {effective_threshold}, forcing HOLD")
|
||||
action = 2 # HOLD
|
||||
else:
|
||||
# Apply trade action fee to ensure we only trade when there's clear benefit
|
||||
fee_adjusted_action_values = action_values.clone()
|
||||
fee_adjusted_action_values[0, 0] -= self.trade_action_fee # Reduce BUY value
|
||||
fee_adjusted_action_values[0, 1] -= self.trade_action_fee # Reduce SELL value
|
||||
# Hold value remains unchanged
|
||||
|
||||
# Re-determine the action based on fee-adjusted values
|
||||
fee_adjusted_action = fee_adjusted_action_values.argmax().item()
|
||||
|
||||
# If the fee changes our decision, log this
|
||||
if fee_adjusted_action != action:
|
||||
logger.info(f"Trade action fee changed decision from {action} to {fee_adjusted_action}")
|
||||
action = fee_adjusted_action
|
||||
|
||||
# Adjust action based on extrema and price predictions
|
||||
# Prioritize short-term movement for trading decisions
|
||||
if immediate_conf > 0.8: # Only adjust for strong signals
|
||||
if immediate_direction == 2: # UP prediction
|
||||
# Bias toward BUY for strong up predictions
|
||||
if action != 0 and random.random() < 0.3 * immediate_conf:
|
||||
if action != 0 and action != 2 and random.random() < 0.3 * immediate_conf:
|
||||
logger.info(f"Adjusting action to BUY based on immediate UP prediction")
|
||||
action = 0 # BUY
|
||||
elif immediate_direction == 0: # DOWN prediction
|
||||
# Bias toward SELL for strong down predictions
|
||||
if action != 1 and random.random() < 0.3 * immediate_conf:
|
||||
if action != 1 and action != 2 and random.random() < 0.3 * immediate_conf:
|
||||
logger.info(f"Adjusting action to SELL based on immediate DOWN prediction")
|
||||
action = 1 # SELL
|
||||
|
||||
@ -354,333 +501,217 @@ class DQNAgent:
|
||||
if extrema_confidence > 0.8: # Only adjust for strong signals
|
||||
if extrema_class == 0: # Bottom detected
|
||||
# Bias toward BUY at bottoms
|
||||
if action != 0 and random.random() < 0.3 * extrema_confidence:
|
||||
if action != 0 and action != 2 and random.random() < 0.3 * extrema_confidence:
|
||||
logger.info(f"Adjusting action to BUY based on bottom detection")
|
||||
action = 0 # BUY
|
||||
elif extrema_class == 1: # Top detected
|
||||
# Bias toward SELL at tops
|
||||
if action != 1 and random.random() < 0.3 * extrema_confidence:
|
||||
if action != 1 and action != 2 and random.random() < 0.3 * extrema_confidence:
|
||||
logger.info(f"Adjusting action to SELL based on top detection")
|
||||
action = 1 # SELL
|
||||
|
||||
# Finally, avoid action oscillation by checking recent history
|
||||
if len(self.recent_actions) >= 2:
|
||||
last_action = self.recent_actions[-1]
|
||||
if action != last_action and action != 2 and last_action != 2:
|
||||
# We're switching between BUY and SELL too quickly
|
||||
# Only allow this if we have very high confidence
|
||||
if action_confidence < 0.85:
|
||||
logger.info(f"Preventing oscillation from {last_action} to {action}, forcing HOLD")
|
||||
action = 2 # HOLD
|
||||
|
||||
# Update recent actions list
|
||||
self.recent_actions.append(action)
|
||||
if len(self.recent_actions) > 5:
|
||||
self.recent_actions = self.recent_actions[-5:]
|
||||
|
||||
return action
|
||||
|
||||
def replay(self, use_prioritized=True) -> float:
|
||||
"""Experience replay - learn from stored experiences
|
||||
|
||||
Args:
|
||||
use_prioritized: Whether to use prioritized experience replay
|
||||
|
||||
Returns:
|
||||
float: Training loss
|
||||
"""
|
||||
# Check if we have enough samples
|
||||
if len(self.memory) < self.batch_size:
|
||||
def replay(self, experiences=None):
|
||||
"""Train the model using experiences from memory"""
|
||||
|
||||
# Don't train if not in training mode
|
||||
if not self.training:
|
||||
return 0.0
|
||||
|
||||
# Check if mixed precision should be disabled
|
||||
if 'DISABLE_MIXED_PRECISION' in os.environ:
|
||||
self.use_mixed_precision = False
|
||||
# If no experiences provided, sample from memory
|
||||
if experiences is None:
|
||||
# Skip if memory is too small
|
||||
if len(self.memory) < self.batch_size:
|
||||
return 0.0
|
||||
|
||||
# Sample from memory with or without prioritization
|
||||
if use_prioritized and len(self.positive_memory) > self.batch_size // 4:
|
||||
# Use prioritized sampling: mix normal samples with positive reward samples
|
||||
positive_batch_size = min(self.batch_size // 4, len(self.positive_memory))
|
||||
regular_batch_size = self.batch_size - positive_batch_size
|
||||
|
||||
# Get positive examples
|
||||
positive_batch = random.sample(self.positive_memory, positive_batch_size)
|
||||
|
||||
# Get regular examples
|
||||
regular_batch = random.sample(self.memory, regular_batch_size)
|
||||
|
||||
# Combine batches
|
||||
minibatch = positive_batch + regular_batch
|
||||
else:
|
||||
# Use regular uniform sampling
|
||||
minibatch = random.sample(self.memory, self.batch_size)
|
||||
# Sample random mini-batch from memory
|
||||
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
|
||||
experiences = [self.memory[i] for i in indices]
|
||||
|
||||
# Extract batches with proper tensor conversion
|
||||
states = np.vstack([self._normalize_state(x[0]) for x in minibatch])
|
||||
actions = np.array([x[1] for x in minibatch])
|
||||
rewards = np.array([x[2] for x in minibatch])
|
||||
next_states = np.vstack([self._normalize_state(x[3]) for x in minibatch])
|
||||
dones = np.array([x[4] for x in minibatch], dtype=np.float32)
|
||||
|
||||
# Convert to torch tensors and move to device
|
||||
states_tensor = torch.FloatTensor(states).to(self.device)
|
||||
actions_tensor = torch.LongTensor(actions).to(self.device)
|
||||
rewards_tensor = torch.FloatTensor(rewards).to(self.device)
|
||||
next_states_tensor = torch.FloatTensor(next_states).to(self.device)
|
||||
dones_tensor = torch.FloatTensor(dones).to(self.device)
|
||||
|
||||
# First training step with mixed precision if available
|
||||
# Choose appropriate replay method
|
||||
if self.use_mixed_precision:
|
||||
loss = self._replay_mixed_precision(
|
||||
states_tensor, actions_tensor, rewards_tensor,
|
||||
next_states_tensor, dones_tensor
|
||||
)
|
||||
# Convert experiences to tensors for mixed precision
|
||||
states = torch.FloatTensor(np.array([e[0] for e in experiences])).to(self.device)
|
||||
actions = torch.LongTensor(np.array([e[1] for e in experiences])).to(self.device)
|
||||
rewards = torch.FloatTensor(np.array([e[2] for e in experiences])).to(self.device)
|
||||
next_states = torch.FloatTensor(np.array([e[3] for e in experiences])).to(self.device)
|
||||
dones = torch.FloatTensor(np.array([e[4] for e in experiences])).to(self.device)
|
||||
|
||||
# Use mixed precision replay
|
||||
loss = self._replay_mixed_precision(states, actions, rewards, next_states, dones)
|
||||
else:
|
||||
loss = self._replay_standard(
|
||||
states_tensor, actions_tensor, rewards_tensor,
|
||||
next_states_tensor, dones_tensor
|
||||
)
|
||||
# Pass experiences directly to standard replay method
|
||||
loss = self._replay_standard(experiences)
|
||||
|
||||
# Training focus selector - randomly focus on one of the specialized training types
|
||||
training_focus = random.random()
|
||||
|
||||
# Occasionally train specifically on extrema points
|
||||
if training_focus < 0.3 and hasattr(self, 'extrema_memory') and len(self.extrema_memory) >= self.batch_size // 2:
|
||||
# Sample from extrema memory
|
||||
extrema_batch_size = min(self.batch_size // 2, len(self.extrema_memory))
|
||||
extrema_batch = random.sample(self.extrema_memory, extrema_batch_size)
|
||||
|
||||
# Extract batches with proper tensor conversion
|
||||
extrema_states = np.vstack([self._normalize_state(x[0]) for x in extrema_batch])
|
||||
extrema_actions = np.array([x[1] for x in extrema_batch])
|
||||
extrema_rewards = np.array([x[2] for x in extrema_batch])
|
||||
extrema_next_states = np.vstack([self._normalize_state(x[3]) for x in extrema_batch])
|
||||
extrema_dones = np.array([x[4] for x in extrema_batch], dtype=np.float32)
|
||||
|
||||
# Convert to torch tensors and move to device
|
||||
extrema_states_tensor = torch.FloatTensor(extrema_states).to(self.device)
|
||||
extrema_actions_tensor = torch.LongTensor(extrema_actions).to(self.device)
|
||||
extrema_rewards_tensor = torch.FloatTensor(extrema_rewards).to(self.device)
|
||||
extrema_next_states_tensor = torch.FloatTensor(extrema_next_states).to(self.device)
|
||||
extrema_dones_tensor = torch.FloatTensor(extrema_dones).to(self.device)
|
||||
|
||||
# Additional training step focused on extrema points (with smaller learning rate)
|
||||
original_lr = self.optimizer.param_groups[0]['lr']
|
||||
# Temporarily reduce learning rate for fine-tuning on extrema
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr * 0.5
|
||||
|
||||
# Train on extrema
|
||||
if self.use_mixed_precision:
|
||||
extrema_loss = self._replay_mixed_precision(
|
||||
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
|
||||
extrema_next_states_tensor, extrema_dones_tensor
|
||||
)
|
||||
else:
|
||||
extrema_loss = self._replay_standard(
|
||||
extrema_states_tensor, extrema_actions_tensor, extrema_rewards_tensor,
|
||||
extrema_next_states_tensor, extrema_dones_tensor
|
||||
)
|
||||
|
||||
# Restore original learning rate
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr
|
||||
|
||||
logger.info(f"Extra training on extrema points: loss={extrema_loss:.4f}")
|
||||
|
||||
# Average the loss
|
||||
loss = (loss + extrema_loss) / 2
|
||||
|
||||
# Occasionally train specifically on price movement data
|
||||
elif training_focus >= 0.3 and training_focus < 0.6 and hasattr(self, 'price_movement_memory') and len(self.price_movement_memory) >= self.batch_size // 2:
|
||||
# Sample from price movement memory
|
||||
price_batch_size = min(self.batch_size // 2, len(self.price_movement_memory))
|
||||
price_batch = random.sample(self.price_movement_memory, price_batch_size)
|
||||
|
||||
# Extract batches with proper tensor conversion
|
||||
price_states = np.vstack([self._normalize_state(x[0]) for x in price_batch])
|
||||
price_actions = np.array([x[1] for x in price_batch])
|
||||
price_rewards = np.array([x[2] for x in price_batch])
|
||||
price_next_states = np.vstack([self._normalize_state(x[3]) for x in price_batch])
|
||||
price_dones = np.array([x[4] for x in price_batch], dtype=np.float32)
|
||||
|
||||
# Convert to torch tensors and move to device
|
||||
price_states_tensor = torch.FloatTensor(price_states).to(self.device)
|
||||
price_actions_tensor = torch.LongTensor(price_actions).to(self.device)
|
||||
price_rewards_tensor = torch.FloatTensor(price_rewards).to(self.device)
|
||||
price_next_states_tensor = torch.FloatTensor(price_next_states).to(self.device)
|
||||
price_dones_tensor = torch.FloatTensor(price_dones).to(self.device)
|
||||
|
||||
# Additional training step focused on price movements (with smaller learning rate)
|
||||
original_lr = self.optimizer.param_groups[0]['lr']
|
||||
# Temporarily reduce learning rate
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr * 0.5
|
||||
|
||||
# Train on price movement data
|
||||
if self.use_mixed_precision:
|
||||
price_loss = self._replay_mixed_precision(
|
||||
price_states_tensor, price_actions_tensor, price_rewards_tensor,
|
||||
price_next_states_tensor, price_dones_tensor
|
||||
)
|
||||
else:
|
||||
price_loss = self._replay_standard(
|
||||
price_states_tensor, price_actions_tensor, price_rewards_tensor,
|
||||
price_next_states_tensor, price_dones_tensor
|
||||
)
|
||||
|
||||
# Restore original learning rate
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = original_lr
|
||||
|
||||
logger.info(f"Extra training on price movement data: loss={price_loss:.4f}")
|
||||
|
||||
# Average the loss
|
||||
loss = (loss + price_loss) / 2
|
||||
|
||||
# Store and return loss
|
||||
# Store loss for monitoring
|
||||
self.losses.append(loss)
|
||||
return loss
|
||||
|
||||
def _replay_standard(self, states, actions, rewards, next_states, dones):
|
||||
"""Standard precision training step"""
|
||||
# Zero gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Get current Q values and extrema predictions
|
||||
current_q_values, current_extrema_pred, current_price_pred = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get next Q values from target network
|
||||
with torch.no_grad():
|
||||
next_q_values, next_extrema_pred, next_price_pred = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Check for dimension mismatch and fix it
|
||||
if rewards.shape[0] != next_q_values.shape[0]:
|
||||
# Log the shape mismatch for debugging
|
||||
logger.warning(f"Shape mismatch detected in standard replay: rewards {rewards.shape}, next_q_values {next_q_values.shape}")
|
||||
# Use the smaller size to prevent index errors
|
||||
min_size = min(rewards.shape[0], next_q_values.shape[0])
|
||||
rewards = rewards[:min_size]
|
||||
dones = dones[:min_size]
|
||||
next_q_values = next_q_values[:min_size]
|
||||
current_q_values = current_q_values[:min_size]
|
||||
|
||||
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
||||
|
||||
# Compute Q-value loss (primary task)
|
||||
q_loss = nn.MSELoss()(current_q_values, target_q_values)
|
||||
|
||||
# Initialize combined loss with Q-value loss
|
||||
loss = q_loss
|
||||
|
||||
# Try to extract price from current and next states
|
||||
try:
|
||||
# Extract price feature from sequence data (if available)
|
||||
if len(states.shape) == 3: # [batch, seq, features]
|
||||
current_prices = states[:, -1, -1] # Last timestep, last feature
|
||||
next_prices = next_states[:, -1, -1]
|
||||
else: # [batch, features]
|
||||
current_prices = states[:, -1] # Last feature
|
||||
next_prices = next_states[:, -1]
|
||||
|
||||
# Compute price changes for different timeframes
|
||||
immediate_changes = (next_prices - current_prices) / current_prices
|
||||
|
||||
# Create price direction labels - simplified for training
|
||||
# 0 = down, 1 = sideways, 2 = up
|
||||
immediate_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1 # Default: sideways
|
||||
midterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
|
||||
longterm_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 1
|
||||
|
||||
# Immediate term direction (1s, 1m)
|
||||
immediate_up = (immediate_changes > 0.0005)
|
||||
immediate_down = (immediate_changes < -0.0005)
|
||||
immediate_labels[immediate_up] = 2 # Up
|
||||
immediate_labels[immediate_down] = 0 # Down
|
||||
|
||||
# For mid and long term, we can only approximate during training
|
||||
# In a real system, we'd need historical data to validate these
|
||||
# Here we'll use the immediate term with increasing thresholds as approximation
|
||||
|
||||
# Mid-term (1h) - use slightly higher threshold
|
||||
midterm_up = (immediate_changes > 0.001)
|
||||
midterm_down = (immediate_changes < -0.001)
|
||||
midterm_labels[midterm_up] = 2 # Up
|
||||
midterm_labels[midterm_down] = 0 # Down
|
||||
|
||||
# Long-term (1d) - use even higher threshold
|
||||
longterm_up = (immediate_changes > 0.002)
|
||||
longterm_down = (immediate_changes < -0.002)
|
||||
longterm_labels[longterm_up] = 2 # Up
|
||||
longterm_labels[longterm_down] = 0 # Down
|
||||
|
||||
# Generate target values for price change regression
|
||||
# For simplicity, we'll use the immediate change and scaled versions for longer timeframes
|
||||
price_value_targets = torch.zeros((min_size, 4), device=self.device)
|
||||
price_value_targets[:, 0] = immediate_changes
|
||||
price_value_targets[:, 1] = immediate_changes * 2.0 # Approximate 1h change
|
||||
price_value_targets[:, 2] = immediate_changes * 4.0 # Approximate 1d change
|
||||
price_value_targets[:, 3] = immediate_changes * 6.0 # Approximate 1w change
|
||||
|
||||
# Calculate loss for price direction prediction (classification)
|
||||
if len(current_price_pred['immediate'].shape) > 1 and current_price_pred['immediate'].shape[0] >= min_size:
|
||||
# Slice predictions to match the adjusted batch size
|
||||
immediate_pred = current_price_pred['immediate'][:min_size]
|
||||
midterm_pred = current_price_pred['midterm'][:min_size]
|
||||
longterm_pred = current_price_pred['longterm'][:min_size]
|
||||
price_values_pred = current_price_pred['values'][:min_size]
|
||||
|
||||
# Compute losses for each task
|
||||
immediate_loss = nn.CrossEntropyLoss()(immediate_pred, immediate_labels)
|
||||
midterm_loss = nn.CrossEntropyLoss()(midterm_pred, midterm_labels)
|
||||
longterm_loss = nn.CrossEntropyLoss()(longterm_pred, longterm_labels)
|
||||
|
||||
# MSE loss for price value regression
|
||||
price_value_loss = nn.MSELoss()(price_values_pred, price_value_targets)
|
||||
|
||||
# Combine all price prediction losses
|
||||
price_loss = immediate_loss + 0.7 * midterm_loss + 0.5 * longterm_loss + 0.3 * price_value_loss
|
||||
|
||||
# Create extrema labels (same as before)
|
||||
extrema_labels = torch.ones(min_size, dtype=torch.long, device=self.device) * 2 # Default: neither
|
||||
|
||||
# Identify potential bottoms (significant negative change)
|
||||
bottoms = (immediate_changes < -0.003)
|
||||
extrema_labels[bottoms] = 0
|
||||
|
||||
# Identify potential tops (significant positive change)
|
||||
tops = (immediate_changes > 0.003)
|
||||
extrema_labels[tops] = 1
|
||||
|
||||
# Calculate extrema prediction loss
|
||||
if len(current_extrema_pred.shape) > 1 and current_extrema_pred.shape[0] >= min_size:
|
||||
current_extrema_pred = current_extrema_pred[:min_size]
|
||||
extrema_loss = nn.CrossEntropyLoss()(current_extrema_pred, extrema_labels)
|
||||
|
||||
# Combined loss with all components
|
||||
# Primary task: Q-value learning (RL objective)
|
||||
# Secondary tasks: extrema detection and price prediction (supervised objectives)
|
||||
loss = q_loss + 0.3 * extrema_loss + 0.3 * price_loss
|
||||
|
||||
# Log loss components occasionally
|
||||
if random.random() < 0.01: # Log 1% of the time
|
||||
logger.info(
|
||||
f"Training losses: Q-loss={q_loss.item():.4f}, "
|
||||
f"Extrema-loss={extrema_loss.item():.4f}, "
|
||||
f"Price-loss={price_loss.item():.4f}, "
|
||||
f"Imm-loss={immediate_loss.item():.4f}, "
|
||||
f"Mid-loss={midterm_loss.item():.4f}, "
|
||||
f"Long-loss={longterm_loss.item():.4f}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Fallback if price extraction fails
|
||||
logger.warning(f"Failed to calculate price prediction loss: {str(e)}. Using only Q-value loss.")
|
||||
# Just use Q-value loss
|
||||
loss = q_loss
|
||||
|
||||
# Backward pass and optimize
|
||||
loss.backward()
|
||||
|
||||
# Gradient clipping to prevent exploding gradients
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
||||
self.optimizer.step()
|
||||
|
||||
# Update target network if needed
|
||||
self.update_count += 1
|
||||
if self.update_count % self.target_update == 0:
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
# Track and decay epsilon
|
||||
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
|
||||
|
||||
return loss.item()
|
||||
# Randomly decide if we should train on extrema points from special memory
|
||||
if random.random() < 0.3 and len(self.extrema_memory) >= self.batch_size:
|
||||
# Train specifically on extrema memory examples
|
||||
extrema_indices = np.random.choice(len(self.extrema_memory), size=min(self.batch_size, len(self.extrema_memory)), replace=False)
|
||||
extrema_batch = [self.extrema_memory[i] for i in extrema_indices]
|
||||
|
||||
# Extract tensors from extrema batch
|
||||
extrema_states = torch.FloatTensor(np.array([e[0] for e in extrema_batch])).to(self.device)
|
||||
extrema_actions = torch.LongTensor(np.array([e[1] for e in extrema_batch])).to(self.device)
|
||||
extrema_rewards = torch.FloatTensor(np.array([e[2] for e in extrema_batch])).to(self.device)
|
||||
extrema_next_states = torch.FloatTensor(np.array([e[3] for e in extrema_batch])).to(self.device)
|
||||
extrema_dones = torch.FloatTensor(np.array([e[4] for e in extrema_batch])).to(self.device)
|
||||
|
||||
# Use a slightly reduced learning rate for extrema training
|
||||
old_lr = self.optimizer.param_groups[0]['lr']
|
||||
self.optimizer.param_groups[0]['lr'] = old_lr * 0.8
|
||||
|
||||
# Train on extrema memory
|
||||
if self.use_mixed_precision:
|
||||
extrema_loss = self._replay_mixed_precision(extrema_states, extrema_actions, extrema_rewards, extrema_next_states, extrema_dones)
|
||||
else:
|
||||
extrema_loss = self._replay_standard(extrema_batch)
|
||||
|
||||
# Reset learning rate
|
||||
self.optimizer.param_groups[0]['lr'] = old_lr
|
||||
|
||||
# Log extrema loss
|
||||
logger.info(f"Extra training on extrema points, loss: {extrema_loss:.4f}")
|
||||
|
||||
# Randomly train on price movement examples (similar to extrema)
|
||||
if random.random() < 0.3 and len(self.price_movement_memory) >= self.batch_size:
|
||||
# Train specifically on price movement memory examples
|
||||
price_indices = np.random.choice(len(self.price_movement_memory), size=min(self.batch_size, len(self.price_movement_memory)), replace=False)
|
||||
price_batch = [self.price_movement_memory[i] for i in price_indices]
|
||||
|
||||
# Extract tensors from price movement batch
|
||||
price_states = torch.FloatTensor(np.array([e[0] for e in price_batch])).to(self.device)
|
||||
price_actions = torch.LongTensor(np.array([e[1] for e in price_batch])).to(self.device)
|
||||
price_rewards = torch.FloatTensor(np.array([e[2] for e in price_batch])).to(self.device)
|
||||
price_next_states = torch.FloatTensor(np.array([e[3] for e in price_batch])).to(self.device)
|
||||
price_dones = torch.FloatTensor(np.array([e[4] for e in price_batch])).to(self.device)
|
||||
|
||||
# Use a slightly reduced learning rate for price movement training
|
||||
old_lr = self.optimizer.param_groups[0]['lr']
|
||||
self.optimizer.param_groups[0]['lr'] = old_lr * 0.75
|
||||
|
||||
# Train on price movement memory
|
||||
if self.use_mixed_precision:
|
||||
price_loss = self._replay_mixed_precision(price_states, price_actions, price_rewards, price_next_states, price_dones)
|
||||
else:
|
||||
price_loss = self._replay_standard(price_batch)
|
||||
|
||||
# Reset learning rate
|
||||
self.optimizer.param_groups[0]['lr'] = old_lr
|
||||
|
||||
# Log price movement loss
|
||||
logger.info(f"Extra training on price movement examples, loss: {price_loss:.4f}")
|
||||
|
||||
return loss
|
||||
|
||||
def _replay_standard(self, experiences=None):
|
||||
"""Standard training step without mixed precision"""
|
||||
try:
|
||||
# Use experiences if provided, otherwise sample from memory
|
||||
if experiences is None:
|
||||
# If memory is too small, skip training
|
||||
if len(self.memory) < self.batch_size:
|
||||
return 0.0
|
||||
|
||||
# Sample random mini-batch from memory
|
||||
indices = np.random.choice(len(self.memory), size=min(self.batch_size, len(self.memory)), replace=False)
|
||||
batch = [self.memory[i] for i in indices]
|
||||
experiences = batch
|
||||
|
||||
# Unpack experiences
|
||||
states, actions, rewards, next_states, dones = zip(*experiences)
|
||||
|
||||
# Convert to PyTorch tensors
|
||||
states = torch.FloatTensor(np.array(states)).to(self.device)
|
||||
actions = torch.LongTensor(np.array(actions)).to(self.device)
|
||||
rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
|
||||
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
||||
dones = torch.FloatTensor(np.array(dones)).to(self.device)
|
||||
|
||||
# Get current Q values
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get next Q values with target network
|
||||
with torch.no_grad():
|
||||
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Check for dimension mismatch between rewards and next_q_values
|
||||
if rewards.shape[0] != next_q_values.shape[0]:
|
||||
logger.warning(f"Shape mismatch detected in standard replay: rewards {rewards.shape}, next_q_values {next_q_values.shape}")
|
||||
# Use the smaller size to prevent index error
|
||||
min_size = min(rewards.shape[0], next_q_values.shape[0])
|
||||
rewards = rewards[:min_size]
|
||||
dones = dones[:min_size]
|
||||
next_q_values = next_q_values[:min_size]
|
||||
current_q_values = current_q_values[:min_size]
|
||||
|
||||
# Calculate target Q values
|
||||
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
||||
|
||||
# Compute loss for Q value
|
||||
q_loss = self.criterion(current_q_values, target_q_values)
|
||||
|
||||
# Try to compute extrema loss if possible
|
||||
try:
|
||||
# Get the target classes from extrema predictions
|
||||
extrema_targets = torch.argmax(current_extrema_pred, dim=1).long()
|
||||
|
||||
# Compute extrema loss using cross-entropy - this is an auxiliary task
|
||||
extrema_loss = F.cross_entropy(current_extrema_pred, extrema_targets)
|
||||
|
||||
# Combined loss with emphasis on Q-learning
|
||||
total_loss = q_loss + 0.1 * extrema_loss
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to calculate extrema loss: {str(e)}. Using only Q-value loss.")
|
||||
total_loss = q_loss
|
||||
|
||||
# Reset gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Backward pass
|
||||
total_loss.backward()
|
||||
|
||||
# Clip gradients to avoid exploding gradients
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
||||
|
||||
# Update weights
|
||||
self.optimizer.step()
|
||||
|
||||
# Update target network if needed
|
||||
self.update_count += 1
|
||||
if self.update_count % self.target_update == 0:
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
# Return loss
|
||||
return total_loss.item()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in replay standard: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return 0.0
|
||||
|
||||
def _replay_mixed_precision(self, states, actions, rewards, next_states, dones):
|
||||
"""Mixed precision training step for better GPU performance"""
|
||||
@ -696,12 +727,12 @@ class DQNAgent:
|
||||
# Forward pass with amp autocasting
|
||||
with torch.cuda.amp.autocast():
|
||||
# Get current Q values and extrema predictions
|
||||
current_q_values, current_extrema_pred, current_price_pred = self.policy_net(states)
|
||||
current_q_values, current_extrema_pred, current_price_pred, hidden_features = self.policy_net(states)
|
||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Get next Q values from target network
|
||||
with torch.no_grad():
|
||||
next_q_values, next_extrema_pred, next_price_pred = self.target_net(next_states)
|
||||
next_q_values, next_extrema_pred, next_price_pred, next_hidden_features = self.target_net(next_states)
|
||||
next_q_values = next_q_values.max(1)[0]
|
||||
|
||||
# Check for dimension mismatch and fix it
|
||||
@ -733,7 +764,7 @@ class DQNAgent:
|
||||
current_prices = states[:, -1] # Last feature
|
||||
next_prices = next_states[:, -1]
|
||||
|
||||
# Compute price changes for different timeframes
|
||||
# Calculate price change for different timeframes
|
||||
immediate_changes = (next_prices - current_prices) / current_prices
|
||||
|
||||
# Create price direction labels - simplified for training
|
||||
|
329
NN/models/dqn_agent_enhanced.py
Normal file
329
NN/models/dqn_agent_enhanced.py
Normal file
@ -0,0 +1,329 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
import random
|
||||
from typing import Tuple, List
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
# Import the EnhancedCNN model
|
||||
from NN.models.enhanced_cnn import EnhancedCNN, ExampleSiftingDataset
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EnhancedDQNAgent:
|
||||
"""
|
||||
Enhanced Deep Q-Network agent for trading
|
||||
Uses the improved EnhancedCNN model with residual connections and attention mechanisms
|
||||
"""
|
||||
def __init__(self,
|
||||
state_shape: Tuple[int, ...],
|
||||
n_actions: int,
|
||||
learning_rate: float = 0.0003, # Slightly reduced learning rate for stability
|
||||
gamma: float = 0.95, # Discount factor
|
||||
epsilon: float = 1.0,
|
||||
epsilon_min: float = 0.05,
|
||||
epsilon_decay: float = 0.995, # Slower decay for more exploration
|
||||
buffer_size: int = 50000, # Larger memory buffer
|
||||
batch_size: int = 128, # Larger batch size
|
||||
target_update: int = 10, # More frequent target updates
|
||||
confidence_threshold: float = 0.4, # Lower confidence threshold
|
||||
device=None):
|
||||
|
||||
# Extract state dimensions
|
||||
if isinstance(state_shape, tuple) and len(state_shape) > 1:
|
||||
# Multi-dimensional state (like image or sequence)
|
||||
self.state_dim = state_shape
|
||||
else:
|
||||
# 1D state
|
||||
if isinstance(state_shape, tuple):
|
||||
self.state_dim = state_shape[0]
|
||||
else:
|
||||
self.state_dim = state_shape
|
||||
|
||||
# Store parameters
|
||||
self.n_actions = n_actions
|
||||
self.learning_rate = learning_rate
|
||||
self.gamma = gamma
|
||||
self.epsilon = epsilon
|
||||
self.epsilon_min = epsilon_min
|
||||
self.epsilon_decay = epsilon_decay
|
||||
self.buffer_size = buffer_size
|
||||
self.batch_size = batch_size
|
||||
self.target_update = target_update
|
||||
self.confidence_threshold = confidence_threshold
|
||||
|
||||
# Set device for computation
|
||||
if device is None:
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
# Initialize models with the enhanced CNN
|
||||
self.policy_net = EnhancedCNN(self.state_dim, self.n_actions, self.confidence_threshold)
|
||||
self.target_net = EnhancedCNN(self.state_dim, self.n_actions, self.confidence_threshold)
|
||||
|
||||
# Initialize the target network with the same weights as the policy network
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
|
||||
# Set models to eval mode (important for batch norm, dropout)
|
||||
self.target_net.eval()
|
||||
|
||||
# Optimization components
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
|
||||
self.criterion = nn.MSELoss()
|
||||
|
||||
# Experience replay memory with example sifting
|
||||
self.memory = ExampleSiftingDataset(max_examples=buffer_size)
|
||||
self.update_count = 0
|
||||
|
||||
# Confidence tracking
|
||||
self.confidence_history = []
|
||||
self.avg_confidence = 0.0
|
||||
self.max_confidence = 0.0
|
||||
self.min_confidence = 1.0
|
||||
|
||||
# Performance tracking
|
||||
self.losses = []
|
||||
self.rewards = []
|
||||
self.avg_reward = 0.0
|
||||
|
||||
# Check if mixed precision training should be used
|
||||
self.use_mixed_precision = False
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and 'DISABLE_MIXED_PRECISION' not in os.environ:
|
||||
self.use_mixed_precision = True
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
logger.info("Mixed precision training enabled")
|
||||
else:
|
||||
logger.info("Mixed precision training disabled")
|
||||
|
||||
# For compatibility with old code
|
||||
self.action_size = n_actions
|
||||
|
||||
logger.info(f"Enhanced DQN Agent using device: {self.device}")
|
||||
logger.info(f"Confidence threshold set to {self.confidence_threshold}")
|
||||
|
||||
def move_models_to_device(self, device=None):
|
||||
"""Move models to the specified device (GPU/CPU)"""
|
||||
if device is not None:
|
||||
self.device = device
|
||||
|
||||
try:
|
||||
self.policy_net = self.policy_net.to(self.device)
|
||||
self.target_net = self.target_net.to(self.device)
|
||||
logger.info(f"Moved models to {self.device}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to move models to {self.device}: {str(e)}")
|
||||
return False
|
||||
|
||||
def _normalize_state(self, state):
|
||||
"""Normalize state for better training stability"""
|
||||
try:
|
||||
# Convert to numpy array if needed
|
||||
if isinstance(state, list):
|
||||
state = np.array(state, dtype=np.float32)
|
||||
|
||||
# Apply normalization based on state shape
|
||||
if len(state.shape) > 1:
|
||||
# Multi-dimensional state - normalize each feature dimension separately
|
||||
for i in range(state.shape[0]):
|
||||
# Skip if all zeros (to avoid division by zero)
|
||||
if np.sum(np.abs(state[i])) > 0:
|
||||
# Standardize each feature dimension
|
||||
mean = np.mean(state[i])
|
||||
std = np.std(state[i])
|
||||
if std > 0:
|
||||
state[i] = (state[i] - mean) / std
|
||||
else:
|
||||
# 1D state vector
|
||||
# Skip if all zeros
|
||||
if np.sum(np.abs(state)) > 0:
|
||||
mean = np.mean(state)
|
||||
std = np.std(state)
|
||||
if std > 0:
|
||||
state = (state - mean) / std
|
||||
|
||||
return state
|
||||
except Exception as e:
|
||||
logger.warning(f"Error normalizing state: {str(e)}")
|
||||
return state
|
||||
|
||||
def remember(self, state, action, reward, next_state, done):
|
||||
"""Store experience in memory with example sifting"""
|
||||
self.memory.add_example(state, action, reward, next_state, done)
|
||||
|
||||
# Also track rewards for monitoring
|
||||
self.rewards.append(reward)
|
||||
if len(self.rewards) > 100:
|
||||
self.rewards = self.rewards[-100:]
|
||||
self.avg_reward = np.mean(self.rewards)
|
||||
|
||||
def act(self, state, explore=True):
|
||||
"""Choose action using epsilon-greedy policy with built-in confidence thresholding"""
|
||||
if explore and random.random() < self.epsilon:
|
||||
return random.randrange(self.n_actions), 0.0 # Return action and zero confidence
|
||||
|
||||
# Normalize state before inference
|
||||
normalized_state = self._normalize_state(state)
|
||||
|
||||
# Use the EnhancedCNN's act method which includes confidence thresholding
|
||||
action, confidence = self.policy_net.act(normalized_state, explore=explore)
|
||||
|
||||
# Track confidence metrics
|
||||
self.confidence_history.append(confidence)
|
||||
if len(self.confidence_history) > 100:
|
||||
self.confidence_history = self.confidence_history[-100:]
|
||||
|
||||
# Update confidence metrics
|
||||
self.avg_confidence = sum(self.confidence_history) / len(self.confidence_history)
|
||||
self.max_confidence = max(self.max_confidence, confidence)
|
||||
self.min_confidence = min(self.min_confidence, confidence)
|
||||
|
||||
# Log average confidence occasionally
|
||||
if random.random() < 0.01: # 1% of the time
|
||||
logger.info(f"Confidence metrics - Current: {confidence:.4f}, Avg: {self.avg_confidence:.4f}, " +
|
||||
f"Min: {self.min_confidence:.4f}, Max: {self.max_confidence:.4f}")
|
||||
|
||||
return action, confidence
|
||||
|
||||
def replay(self):
|
||||
"""Train the model using experience replay with high-quality examples"""
|
||||
# Check if enough samples in memory
|
||||
if len(self.memory) < self.batch_size:
|
||||
return 0.0
|
||||
|
||||
# Get batch of experiences
|
||||
batch = self.memory.get_batch(self.batch_size)
|
||||
if batch is None:
|
||||
return 0.0
|
||||
|
||||
states = torch.FloatTensor(batch['states']).to(self.device)
|
||||
actions = torch.LongTensor(batch['actions']).to(self.device)
|
||||
rewards = torch.FloatTensor(batch['rewards']).to(self.device)
|
||||
next_states = torch.FloatTensor(batch['next_states']).to(self.device)
|
||||
dones = torch.FloatTensor(batch['dones']).to(self.device)
|
||||
|
||||
# Compute Q values
|
||||
self.policy_net.train() # Set to training mode
|
||||
|
||||
# Get current Q values
|
||||
if self.use_mixed_precision:
|
||||
with torch.cuda.amp.autocast():
|
||||
# Get current Q values
|
||||
q_values, _, _, _ = self.policy_net(states)
|
||||
current_q = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Compute target Q values
|
||||
with torch.no_grad():
|
||||
self.target_net.eval()
|
||||
next_q_values, _, _, _ = self.target_net(next_states)
|
||||
next_q = next_q_values.max(1)[0]
|
||||
target_q = rewards + (1 - dones) * self.gamma * next_q
|
||||
|
||||
# Compute loss
|
||||
loss = self.criterion(current_q, target_q)
|
||||
|
||||
# Perform backpropagation with mixed precision
|
||||
self.optimizer.zero_grad()
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
# Standard precision training
|
||||
# Get current Q values
|
||||
q_values, _, _, _ = self.policy_net(states)
|
||||
current_q = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Compute target Q values
|
||||
with torch.no_grad():
|
||||
self.target_net.eval()
|
||||
next_q_values, _, _, _ = self.target_net(next_states)
|
||||
next_q = next_q_values.max(1)[0]
|
||||
target_q = rewards + (1 - dones) * self.gamma * next_q
|
||||
|
||||
# Compute loss
|
||||
loss = self.criterion(current_q, target_q)
|
||||
|
||||
# Perform backpropagation
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
|
||||
self.optimizer.step()
|
||||
|
||||
# Track loss
|
||||
loss_value = loss.item()
|
||||
self.losses.append(loss_value)
|
||||
if len(self.losses) > 100:
|
||||
self.losses = self.losses[-100:]
|
||||
|
||||
# Update target network
|
||||
self.update_count += 1
|
||||
if self.update_count % self.target_update == 0:
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
logger.info(f"Updated target network (step {self.update_count})")
|
||||
|
||||
# Decay epsilon
|
||||
if self.epsilon > self.epsilon_min:
|
||||
self.epsilon *= self.epsilon_decay
|
||||
|
||||
return loss_value
|
||||
|
||||
def save(self, path):
|
||||
"""Save agent state and models"""
|
||||
self.policy_net.save(f"{path}_policy")
|
||||
self.target_net.save(f"{path}_target")
|
||||
|
||||
# Save agent state
|
||||
torch.save({
|
||||
'epsilon': self.epsilon,
|
||||
'confidence_threshold': self.confidence_threshold,
|
||||
'losses': self.losses,
|
||||
'rewards': self.rewards,
|
||||
'avg_reward': self.avg_reward,
|
||||
'confidence_history': self.confidence_history,
|
||||
'avg_confidence': self.avg_confidence,
|
||||
'max_confidence': self.max_confidence,
|
||||
'min_confidence': self.min_confidence,
|
||||
'update_count': self.update_count
|
||||
}, f"{path}_agent_state.pt")
|
||||
|
||||
logger.info(f"Agent state saved to {path}_agent_state.pt")
|
||||
|
||||
def load(self, path):
|
||||
"""Load agent state and models"""
|
||||
policy_loaded = self.policy_net.load(f"{path}_policy")
|
||||
target_loaded = self.target_net.load(f"{path}_target")
|
||||
|
||||
# Load agent state if available
|
||||
agent_state_path = f"{path}_agent_state.pt"
|
||||
if os.path.exists(agent_state_path):
|
||||
try:
|
||||
state = torch.load(agent_state_path)
|
||||
self.epsilon = state.get('epsilon', self.epsilon)
|
||||
self.confidence_threshold = state.get('confidence_threshold', self.confidence_threshold)
|
||||
self.policy_net.confidence_threshold = self.confidence_threshold
|
||||
self.target_net.confidence_threshold = self.confidence_threshold
|
||||
self.losses = state.get('losses', [])
|
||||
self.rewards = state.get('rewards', [])
|
||||
self.avg_reward = state.get('avg_reward', 0.0)
|
||||
self.confidence_history = state.get('confidence_history', [])
|
||||
self.avg_confidence = state.get('avg_confidence', 0.0)
|
||||
self.max_confidence = state.get('max_confidence', 0.0)
|
||||
self.min_confidence = state.get('min_confidence', 1.0)
|
||||
self.update_count = state.get('update_count', 0)
|
||||
logger.info(f"Agent state loaded from {agent_state_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading agent state: {str(e)}")
|
||||
|
||||
return policy_loaded and target_loaded
|
413
NN/models/enhanced_cnn.py
Normal file
413
NN/models/enhanced_cnn.py
Normal file
@ -0,0 +1,413 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
import os
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Tuple, Dict, Any, Optional, Union
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
"""
|
||||
Residual block with pre-activation (BatchNorm -> ReLU -> Conv)
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.bn1 = nn.BatchNorm1d(in_channels)
|
||||
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm1d(out_channels)
|
||||
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
|
||||
# Shortcut connection to match dimensions
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_channels != out_channels:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(x))
|
||||
shortcut = self.shortcut(out)
|
||||
out = self.conv1(out)
|
||||
out = self.conv2(F.relu(self.bn2(out)))
|
||||
out += shortcut
|
||||
return out
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""
|
||||
Self-attention mechanism for sequential data
|
||||
"""
|
||||
def __init__(self, dim):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.query = nn.Linear(dim, dim)
|
||||
self.key = nn.Linear(dim, dim)
|
||||
self.value = nn.Linear(dim, dim)
|
||||
self.scale = torch.sqrt(torch.tensor(dim, dtype=torch.float32))
|
||||
|
||||
def forward(self, x):
|
||||
# x shape: [batch_size, seq_len, dim]
|
||||
batch_size, seq_len, dim = x.size()
|
||||
|
||||
q = self.query(x) # [batch_size, seq_len, dim]
|
||||
k = self.key(x) # [batch_size, seq_len, dim]
|
||||
v = self.value(x) # [batch_size, seq_len, dim]
|
||||
|
||||
# Calculate attention scores
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale # [batch_size, seq_len, seq_len]
|
||||
|
||||
# Apply softmax to get attention weights
|
||||
attention = F.softmax(scores, dim=-1) # [batch_size, seq_len, seq_len]
|
||||
|
||||
# Apply attention to values
|
||||
out = torch.matmul(attention, v) # [batch_size, seq_len, dim]
|
||||
|
||||
return out, attention
|
||||
|
||||
class EnhancedCNN(nn.Module):
|
||||
"""
|
||||
Enhanced CNN model with residual connections and attention mechanisms
|
||||
for improved trading decision making
|
||||
"""
|
||||
def __init__(self, input_shape, n_actions, confidence_threshold=0.5):
|
||||
super(EnhancedCNN, self).__init__()
|
||||
|
||||
# Store dimensions
|
||||
self.input_shape = input_shape
|
||||
self.n_actions = n_actions
|
||||
self.confidence_threshold = confidence_threshold
|
||||
|
||||
# Calculate input dimensions
|
||||
if isinstance(input_shape, (list, tuple)):
|
||||
if len(input_shape) == 3: # [channels, height, width]
|
||||
self.channels, self.height, self.width = input_shape
|
||||
self.feature_dim = self.height * self.width
|
||||
elif len(input_shape) == 2: # [timeframes, features]
|
||||
self.channels = input_shape[0]
|
||||
self.features = input_shape[1]
|
||||
self.feature_dim = self.features * self.channels
|
||||
elif len(input_shape) == 1: # [features]
|
||||
self.channels = 1
|
||||
self.features = input_shape[0]
|
||||
self.feature_dim = self.features
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape: {input_shape}")
|
||||
else: # single integer
|
||||
self.channels = 1
|
||||
self.features = input_shape
|
||||
self.feature_dim = input_shape
|
||||
|
||||
# Build network
|
||||
self._build_network()
|
||||
|
||||
# Initialize device
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.to(self.device)
|
||||
|
||||
logger.info(f"EnhancedCNN initialized with input shape: {input_shape}, actions: {n_actions}")
|
||||
|
||||
def _build_network(self):
|
||||
"""Build the enhanced neural network with current feature dimensions"""
|
||||
|
||||
# 1D CNN for sequential data
|
||||
if self.channels > 1:
|
||||
# Reshape expected: [batch, timeframes, features]
|
||||
self.conv_layers = nn.Sequential(
|
||||
nn.Conv1d(self.channels, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
ResidualBlock(64, 128),
|
||||
nn.MaxPool1d(kernel_size=2, stride=2),
|
||||
nn.Dropout(0.3),
|
||||
|
||||
ResidualBlock(128, 256),
|
||||
nn.MaxPool1d(kernel_size=2, stride=2),
|
||||
nn.Dropout(0.4),
|
||||
|
||||
ResidualBlock(256, 512),
|
||||
nn.AdaptiveAvgPool1d(1) # Global average pooling
|
||||
)
|
||||
# Feature dimension after conv layers
|
||||
self.conv_features = 512
|
||||
else:
|
||||
# For 1D vectors, skip the convolutional part
|
||||
self.conv_layers = None
|
||||
self.conv_features = 0
|
||||
|
||||
# Fully connected layers for all cases
|
||||
# We'll use deeper layers with skip connections
|
||||
if self.conv_layers is None:
|
||||
# For 1D inputs without conv preprocessing
|
||||
self.fc1 = nn.Linear(self.feature_dim, 512)
|
||||
self.features_dim = 512
|
||||
else:
|
||||
# For data processed by conv layers
|
||||
self.fc1 = nn.Linear(self.conv_features, 512)
|
||||
self.features_dim = 512
|
||||
|
||||
# Common feature extraction layers
|
||||
self.fc_layers = nn.Sequential(
|
||||
self.fc1,
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.4),
|
||||
nn.Linear(512, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.4),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
# Dueling architecture
|
||||
self.advantage_stream = nn.Sequential(
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, self.n_actions)
|
||||
)
|
||||
|
||||
self.value_stream = nn.Sequential(
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 1)
|
||||
)
|
||||
|
||||
# Extrema detection head with increased capacity
|
||||
self.extrema_head = nn.Sequential(
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(128, 3) # 0=bottom, 1=top, 2=neither
|
||||
)
|
||||
|
||||
# Price prediction heads with increased capacity
|
||||
self.price_pred_immediate = nn.Sequential(
|
||||
nn.Linear(256, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 3) # Up, Down, Sideways
|
||||
)
|
||||
|
||||
self.price_pred_midterm = nn.Sequential(
|
||||
nn.Linear(256, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 3) # Up, Down, Sideways
|
||||
)
|
||||
|
||||
self.price_pred_longterm = nn.Sequential(
|
||||
nn.Linear(256, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 3) # Up, Down, Sideways
|
||||
)
|
||||
|
||||
# Value prediction with increased capacity
|
||||
self.price_pred_value = nn.Sequential(
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(128, 4) # % change for different timeframes
|
||||
)
|
||||
|
||||
# Additional attention layer for feature refinement
|
||||
self.attention = SelfAttention(256)
|
||||
|
||||
def _check_rebuild_network(self, features):
|
||||
"""Check if network needs to be rebuilt for different feature dimensions"""
|
||||
if features != self.feature_dim:
|
||||
logger.info(f"Rebuilding network for new feature dimension: {features} (was {self.feature_dim})")
|
||||
self.feature_dim = features
|
||||
self._build_network()
|
||||
# Move to device after rebuilding
|
||||
self.to(self.device)
|
||||
return True
|
||||
return False
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass through the network"""
|
||||
batch_size = x.size(0)
|
||||
|
||||
# Process different input shapes
|
||||
if len(x.shape) > 2:
|
||||
# Handle 3D input [batch, timeframes, features]
|
||||
if self.conv_layers is not None:
|
||||
# Reshape for 1D convolution:
|
||||
# [batch, timeframes, features] -> [batch, timeframes, features*1]
|
||||
if len(x.shape) == 3:
|
||||
x = x.permute(0, 1, 2) # Ensure shape is [batch, timeframes, features]
|
||||
x_reshaped = x.permute(0, 1, 2) # [batch, timeframes, features]
|
||||
|
||||
# Check if the feature dimension has changed and rebuild if necessary
|
||||
if x_reshaped.size(1) * x_reshaped.size(2) != self.feature_dim:
|
||||
total_features = x_reshaped.size(1) * x_reshaped.size(2)
|
||||
self._check_rebuild_network(total_features)
|
||||
|
||||
# Apply convolutions
|
||||
x_conv = self.conv_layers(x_reshaped)
|
||||
# Flatten: [batch, channels, 1] -> [batch, channels]
|
||||
x_flat = x_conv.view(batch_size, -1)
|
||||
else:
|
||||
# If no conv layers, just flatten
|
||||
x_flat = x.view(batch_size, -1)
|
||||
else:
|
||||
# For 2D input [batch, features]
|
||||
x_flat = x
|
||||
|
||||
# Check if dimensions have changed
|
||||
if x_flat.size(1) != self.feature_dim:
|
||||
self._check_rebuild_network(x_flat.size(1))
|
||||
|
||||
# Apply FC layers
|
||||
features = self.fc_layers(x_flat)
|
||||
|
||||
# Add attention for feature refinement
|
||||
features_3d = features.unsqueeze(1) # [batch, 1, features]
|
||||
features_attended, _ = self.attention(features_3d)
|
||||
features_refined = features_attended.squeeze(1) # [batch, features]
|
||||
|
||||
# Calculate advantage and value
|
||||
advantage = self.advantage_stream(features_refined)
|
||||
value = self.value_stream(features_refined)
|
||||
|
||||
# Combine for Q-values (Dueling architecture)
|
||||
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||
|
||||
# Get extrema predictions
|
||||
extrema_pred = self.extrema_head(features_refined)
|
||||
|
||||
# Price movement predictions
|
||||
price_immediate = self.price_pred_immediate(features_refined)
|
||||
price_midterm = self.price_pred_midterm(features_refined)
|
||||
price_longterm = self.price_pred_longterm(features_refined)
|
||||
price_values = self.price_pred_value(features_refined)
|
||||
|
||||
# Package price predictions
|
||||
price_predictions = {
|
||||
'immediate': price_immediate,
|
||||
'midterm': price_midterm,
|
||||
'longterm': price_longterm,
|
||||
'values': price_values
|
||||
}
|
||||
|
||||
return q_values, extrema_pred, price_predictions, features_refined
|
||||
|
||||
def act(self, state, explore=True):
|
||||
"""
|
||||
Choose action based on state with confidence thresholding
|
||||
"""
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
q_values, _, _, _ = self(state_tensor)
|
||||
|
||||
# Apply softmax to get action probabilities
|
||||
action_probs = F.softmax(q_values, dim=1)
|
||||
|
||||
# Get action with highest probability
|
||||
action = action_probs.argmax(dim=1).item()
|
||||
action_confidence = action_probs[0, action].item()
|
||||
|
||||
# Check if confidence exceeds threshold
|
||||
if action_confidence < self.confidence_threshold:
|
||||
# Force HOLD action (typically action 2)
|
||||
action = 2 # Assume 2 is HOLD
|
||||
logger.info(f"Action {action} confidence {action_confidence:.4f} below threshold {self.confidence_threshold}, forcing HOLD")
|
||||
|
||||
return action, action_confidence
|
||||
|
||||
def save(self, path):
|
||||
"""Save model weights and architecture"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
torch.save({
|
||||
'state_dict': self.state_dict(),
|
||||
'input_shape': self.input_shape,
|
||||
'n_actions': self.n_actions,
|
||||
'feature_dim': self.feature_dim,
|
||||
'confidence_threshold': self.confidence_threshold
|
||||
}, f"{path}.pt")
|
||||
logger.info(f"Enhanced CNN model saved to {path}.pt")
|
||||
|
||||
def load(self, path):
|
||||
"""Load model weights and architecture"""
|
||||
try:
|
||||
checkpoint = torch.load(f"{path}.pt", map_location=self.device)
|
||||
self.input_shape = checkpoint['input_shape']
|
||||
self.n_actions = checkpoint['n_actions']
|
||||
self.feature_dim = checkpoint['feature_dim']
|
||||
if 'confidence_threshold' in checkpoint:
|
||||
self.confidence_threshold = checkpoint['confidence_threshold']
|
||||
self._build_network()
|
||||
self.load_state_dict(checkpoint['state_dict'])
|
||||
self.to(self.device)
|
||||
logger.info(f"Enhanced CNN model loaded from {path}.pt")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {str(e)}")
|
||||
return False
|
||||
|
||||
# Additional utility for example sifting
|
||||
class ExampleSiftingDataset:
|
||||
"""
|
||||
Dataset that selectively keeps high-quality examples for training
|
||||
to improve model performance
|
||||
"""
|
||||
def __init__(self, max_examples=50000):
|
||||
self.examples = []
|
||||
self.labels = []
|
||||
self.rewards = []
|
||||
self.max_examples = max_examples
|
||||
self.min_reward_threshold = -0.05 # Minimum reward to keep an example
|
||||
|
||||
def add_example(self, state, action, reward, next_state, done):
|
||||
"""Add a new training example with reward-based filtering"""
|
||||
# Only keep examples with rewards above the threshold
|
||||
if reward > self.min_reward_threshold:
|
||||
self.examples.append((state, action, reward, next_state, done))
|
||||
self.rewards.append(reward)
|
||||
|
||||
# Sort by reward and keep only the top examples
|
||||
if len(self.examples) > self.max_examples:
|
||||
# Sort by reward (highest first)
|
||||
sorted_indices = np.argsort(self.rewards)[::-1]
|
||||
# Keep top examples
|
||||
self.examples = [self.examples[i] for i in sorted_indices[:self.max_examples]]
|
||||
self.rewards = [self.rewards[i] for i in sorted_indices[:self.max_examples]]
|
||||
|
||||
# Update the minimum reward threshold to be the minimum in our kept examples
|
||||
self.min_reward_threshold = min(self.rewards)
|
||||
|
||||
def get_batch(self, batch_size):
|
||||
"""Get a batch of examples, prioritizing better examples"""
|
||||
if not self.examples:
|
||||
return None
|
||||
|
||||
# Calculate selection probabilities based on rewards
|
||||
rewards = np.array(self.rewards)
|
||||
# Shift rewards to be positive for probability calculation
|
||||
min_reward = min(rewards)
|
||||
shifted_rewards = rewards - min_reward + 0.1 # Add small constant
|
||||
probs = shifted_rewards / shifted_rewards.sum()
|
||||
|
||||
# Sample batch indices with reward-based probabilities
|
||||
indices = np.random.choice(
|
||||
len(self.examples),
|
||||
size=min(batch_size, len(self.examples)),
|
||||
p=probs,
|
||||
replace=False
|
||||
)
|
||||
|
||||
# Create batch
|
||||
batch = [self.examples[i] for i in indices]
|
||||
states, actions, rewards, next_states, dones = zip(*batch)
|
||||
|
||||
return {
|
||||
'states': np.array(states),
|
||||
'actions': np.array(actions),
|
||||
'rewards': np.array(rewards),
|
||||
'next_states': np.array(next_states),
|
||||
'dones': np.array(dones)
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
1
NN/models/saved/dqn_agent_best_metadata.json
Normal file
1
NN/models/saved/dqn_agent_best_metadata.json
Normal file
@ -0,0 +1 @@
|
||||
{"best_reward": 4791516.572471984, "best_episode": 3250, "best_pnl": 826842167451289.1, "best_win_rate": 0.47368421052631576, "date": "2025-04-01 10:19:16"}
|
20
NN/models/saved/hybrid_stats_latest.json
Normal file
20
NN/models/saved/hybrid_stats_latest.json
Normal file
@ -0,0 +1,20 @@
|
||||
{
|
||||
"supervised": {
|
||||
"epochs_completed": 22650,
|
||||
"best_val_pnl": 0.0,
|
||||
"best_epoch": 50,
|
||||
"best_win_rate": 0
|
||||
},
|
||||
"reinforcement": {
|
||||
"episodes_completed": 0,
|
||||
"best_reward": -Infinity,
|
||||
"best_episode": 0,
|
||||
"best_win_rate": 0
|
||||
},
|
||||
"hybrid": {
|
||||
"iterations_completed": 453,
|
||||
"best_combined_score": 0.0,
|
||||
"training_started": "2025-04-09T10:30:42.510856",
|
||||
"last_update": "2025-04-09T10:40:02.217840"
|
||||
}
|
||||
}
|
326
NN/models/saved/realtime_ticks_training_stats.json
Normal file
326
NN/models/saved/realtime_ticks_training_stats.json
Normal file
@ -0,0 +1,326 @@
|
||||
{
|
||||
"epochs_completed": 8,
|
||||
"best_val_pnl": 0.0,
|
||||
"best_epoch": 1,
|
||||
"best_win_rate": 0.0,
|
||||
"training_started": "2025-04-02T10:43:58.946682",
|
||||
"last_update": "2025-04-02T10:44:10.940892",
|
||||
"epochs": [
|
||||
{
|
||||
"epoch": 1,
|
||||
"train_loss": 1.0950355529785156,
|
||||
"val_loss": 1.1657923062642415,
|
||||
"train_acc": 0.3255208333333333,
|
||||
"val_acc": 0.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:01.840889",
|
||||
"data_age": 2,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 2,
|
||||
"train_loss": 1.0831659038861592,
|
||||
"val_loss": 1.1212460199991863,
|
||||
"train_acc": 0.390625,
|
||||
"val_acc": 0.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:03.134833",
|
||||
"data_age": 4,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 3,
|
||||
"train_loss": 1.0740693012873332,
|
||||
"val_loss": 1.0992945830027263,
|
||||
"train_acc": 0.4739583333333333,
|
||||
"val_acc": 0.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:04.425272",
|
||||
"data_age": 5,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 4,
|
||||
"train_loss": 1.0747728943824768,
|
||||
"val_loss": 1.0821794271469116,
|
||||
"train_acc": 0.4609375,
|
||||
"val_acc": 0.3229166666666667,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:05.716421",
|
||||
"data_age": 6,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 5,
|
||||
"train_loss": 1.0489931503931682,
|
||||
"val_loss": 1.0669521888097127,
|
||||
"train_acc": 0.5833333333333334,
|
||||
"val_acc": 1.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:07.007935",
|
||||
"data_age": 8,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 6,
|
||||
"train_loss": 1.0533669590950012,
|
||||
"val_loss": 1.0505590836207073,
|
||||
"train_acc": 0.5104166666666666,
|
||||
"val_acc": 1.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:08.296061",
|
||||
"data_age": 9,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 7,
|
||||
"train_loss": 1.0456886688868205,
|
||||
"val_loss": 1.0351698795954387,
|
||||
"train_acc": 0.5651041666666666,
|
||||
"val_acc": 1.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:09.607584",
|
||||
"data_age": 10,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 8,
|
||||
"train_loss": 1.040040671825409,
|
||||
"val_loss": 1.0227736632029216,
|
||||
"train_acc": 0.6119791666666666,
|
||||
"val_acc": 1.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:10.940892",
|
||||
"data_age": 11,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
}
|
||||
],
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"total_wins": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
}
|
||||
}
|
192
NN/models/saved/realtime_training_stats.json
Normal file
192
NN/models/saved/realtime_training_stats.json
Normal file
@ -0,0 +1,192 @@
|
||||
{
|
||||
"epochs_completed": 7,
|
||||
"best_val_pnl": 0.002028853100759435,
|
||||
"best_epoch": 6,
|
||||
"best_win_rate": 0.5157894736842106,
|
||||
"training_started": "2025-03-31T02:50:10.418670",
|
||||
"last_update": "2025-03-31T02:50:15.227593",
|
||||
"epochs": [
|
||||
{
|
||||
"epoch": 1,
|
||||
"train_loss": 1.1206786036491394,
|
||||
"val_loss": 1.0542699098587036,
|
||||
"train_acc": 0.11197916666666667,
|
||||
"val_acc": 0.25,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:12.881423",
|
||||
"data_age": 2
|
||||
},
|
||||
{
|
||||
"epoch": 2,
|
||||
"train_loss": 1.1266120672225952,
|
||||
"val_loss": 1.072133183479309,
|
||||
"train_acc": 0.1171875,
|
||||
"val_acc": 0.25,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:13.186840",
|
||||
"data_age": 2
|
||||
},
|
||||
{
|
||||
"epoch": 3,
|
||||
"train_loss": 1.1415620843569438,
|
||||
"val_loss": 1.1701548099517822,
|
||||
"train_acc": 0.1015625,
|
||||
"val_acc": 0.5208333333333334,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:13.442018",
|
||||
"data_age": 3
|
||||
},
|
||||
{
|
||||
"epoch": 4,
|
||||
"train_loss": 1.1331567962964375,
|
||||
"val_loss": 1.070081114768982,
|
||||
"train_acc": 0.09375,
|
||||
"val_acc": 0.22916666666666666,
|
||||
"train_pnl": 0.010650217327384765,
|
||||
"val_pnl": -0.0007049481907895126,
|
||||
"train_win_rate": 0.49279538904899134,
|
||||
"val_win_rate": 0.40625,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.9036458333333334,
|
||||
"HOLD": 0.09635416666666667
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.3333333333333333,
|
||||
"HOLD": 0.6666666666666666
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:13.739899",
|
||||
"data_age": 3
|
||||
},
|
||||
{
|
||||
"epoch": 5,
|
||||
"train_loss": 1.10965762535731,
|
||||
"val_loss": 1.0485950708389282,
|
||||
"train_acc": 0.12239583333333333,
|
||||
"val_acc": 0.17708333333333334,
|
||||
"train_pnl": 0.011924086862580204,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.5070422535211268,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.7395833333333334,
|
||||
"HOLD": 0.2604166666666667
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:14.073439",
|
||||
"data_age": 3
|
||||
},
|
||||
{
|
||||
"epoch": 6,
|
||||
"train_loss": 1.1272419293721516,
|
||||
"val_loss": 1.084235429763794,
|
||||
"train_acc": 0.1015625,
|
||||
"val_acc": 0.22916666666666666,
|
||||
"train_pnl": 0.014825159601390072,
|
||||
"val_pnl": 0.00405770620151887,
|
||||
"train_win_rate": 0.4908616187989556,
|
||||
"val_win_rate": 0.5157894736842106,
|
||||
"best_position_size": 2.0,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 1.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 1.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:14.658295",
|
||||
"data_age": 4
|
||||
},
|
||||
{
|
||||
"epoch": 7,
|
||||
"train_loss": 1.1171108484268188,
|
||||
"val_loss": 1.0741244554519653,
|
||||
"train_acc": 0.1171875,
|
||||
"val_acc": 0.22916666666666666,
|
||||
"train_pnl": 0.0059474696523706605,
|
||||
"val_pnl": 0.00405770620151887,
|
||||
"train_win_rate": 0.4838709677419355,
|
||||
"val_win_rate": 0.5157894736842106,
|
||||
"best_position_size": 2.0,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.7291666666666666,
|
||||
"HOLD": 0.2708333333333333
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 1.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:15.227593",
|
||||
"data_age": 4
|
||||
}
|
||||
]
|
||||
}
|
@ -112,27 +112,33 @@ class SimpleCNN(nn.Module):
|
||||
def _build_network(self):
|
||||
"""Build the neural network with current feature dimensions"""
|
||||
# Create a flexible architecture that adapts to input dimensions
|
||||
# Increased complexity
|
||||
self.fc_layers = nn.Sequential(
|
||||
nn.Linear(self.feature_dim, 256),
|
||||
nn.Linear(self.feature_dim, 512), # Increased size
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 256),
|
||||
nn.ReLU()
|
||||
nn.Dropout(0.2), # Added dropout
|
||||
nn.Linear(512, 512), # Increased size
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2), # Added dropout
|
||||
nn.Linear(512, 512), # Added layer
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2) # Added dropout
|
||||
)
|
||||
|
||||
# Output heads (Dueling DQN architecture)
|
||||
self.advantage_head = nn.Linear(256, self.n_actions)
|
||||
self.value_head = nn.Linear(256, 1)
|
||||
self.advantage_head = nn.Linear(512, self.n_actions) # Updated input size
|
||||
self.value_head = nn.Linear(512, 1) # Updated input size
|
||||
|
||||
# Extrema detection head
|
||||
self.extrema_head = nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither
|
||||
self.extrema_head = nn.Linear(512, 3) # 0=bottom, 1=top, 2=neither, Updated input size
|
||||
|
||||
# Price prediction heads for different timeframes
|
||||
self.price_pred_immediate = nn.Linear(256, 3) # Up, Down, Sideways for immediate term (1s, 1m)
|
||||
self.price_pred_midterm = nn.Linear(256, 3) # Up, Down, Sideways for mid-term (1h)
|
||||
self.price_pred_longterm = nn.Linear(256, 3) # Up, Down, Sideways for long-term (1d)
|
||||
self.price_pred_immediate = nn.Linear(512, 3) # Updated input size
|
||||
self.price_pred_midterm = nn.Linear(512, 3) # Updated input size
|
||||
self.price_pred_longterm = nn.Linear(512, 3) # Updated input size
|
||||
|
||||
# Regression heads for exact price prediction
|
||||
self.price_pred_value = nn.Linear(256, 4) # Predicts % change for each timeframe (1s, 1m, 1h, 1d)
|
||||
self.price_pred_value = nn.Linear(512, 4) # Updated input size
|
||||
|
||||
def _check_rebuild_network(self, features):
|
||||
"""Check if network needs to be rebuilt for different feature dimensions"""
|
||||
@ -146,58 +152,70 @@ class SimpleCNN(nn.Module):
|
||||
return False
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the network
|
||||
Returns action values, extrema predictions, and price movement predictions for multiple timeframes
|
||||
"""
|
||||
# Handle different input shapes
|
||||
if len(x.shape) == 2: # [batch_size, features]
|
||||
# Simple feature vector
|
||||
batch_size, features = x.shape
|
||||
# Check if we need to rebuild the network for new dimensions
|
||||
self._check_rebuild_network(features)
|
||||
|
||||
elif len(x.shape) == 3: # [batch_size, timeframes/channels, features]
|
||||
# Reshape to flatten timeframes/channels with features
|
||||
batch_size, timeframes, features = x.shape
|
||||
total_features = timeframes * features
|
||||
|
||||
# Check if we need to rebuild the network for new dimensions
|
||||
self._check_rebuild_network(total_features)
|
||||
|
||||
# Reshape tensor to [batch_size, total_features]
|
||||
x = x.reshape(batch_size, total_features)
|
||||
|
||||
# Apply fully connected layers
|
||||
fc_out = self.fc_layers(x)
|
||||
"""Forward pass through the network"""
|
||||
# Flatten input if needed to ensure it matches the expected feature dimension
|
||||
batch_size = x.size(0)
|
||||
|
||||
# Dueling architecture
|
||||
advantage = self.advantage_head(fc_out)
|
||||
value = self.value_head(fc_out)
|
||||
# Reshape input if needed
|
||||
if len(x.shape) > 2: # Handle multi-dimensional input
|
||||
# For 3D input: [batch, seq_len, features] or [batch, channels, features]
|
||||
x = x.reshape(batch_size, -1) # Flatten to [batch, seq_len*features]
|
||||
|
||||
# Q-values = value + (advantage - mean(advantage))
|
||||
action_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||
# Check if the feature dimension matches and rebuild if necessary
|
||||
if x.size(1) != self.feature_dim:
|
||||
self._check_rebuild_network(x.size(1))
|
||||
|
||||
# Extrema predictions
|
||||
extrema_pred = self.extrema_head(fc_out)
|
||||
# Apply fully connected layers with ReLU activation
|
||||
x = self.fc_layers(x)
|
||||
|
||||
# Price movement predictions for different timeframes
|
||||
price_immediate = self.price_pred_immediate(fc_out) # 1s, 1m
|
||||
price_midterm = self.price_pred_midterm(fc_out) # 1h
|
||||
price_longterm = self.price_pred_longterm(fc_out) # 1d
|
||||
# Branch 1: Action values (Q-values)
|
||||
action_values = self.advantage_head(x)
|
||||
|
||||
# Regression values for exact price predictions (percentage changes)
|
||||
price_values = self.price_pred_value(fc_out)
|
||||
# Branch 2: Extrema detection (market top/bottom classification)
|
||||
extrema_pred = self.extrema_head(x)
|
||||
|
||||
# Return all predictions in a structured dictionary
|
||||
# Branch 3: Price movement prediction over different timeframes
|
||||
# Split into three timeframes: immediate, midterm, longterm
|
||||
price_immediate = self.price_pred_immediate(x)
|
||||
price_midterm = self.price_pred_midterm(x)
|
||||
price_longterm = self.price_pred_longterm(x)
|
||||
|
||||
# Branch 4: Value prediction (regression for expected price changes)
|
||||
price_values = self.price_pred_value(x)
|
||||
|
||||
# Package price predictions
|
||||
price_predictions = {
|
||||
'immediate': price_immediate,
|
||||
'midterm': price_midterm,
|
||||
'longterm': price_longterm,
|
||||
'values': price_values
|
||||
'immediate': price_immediate, # Classification (up/down/sideways)
|
||||
'midterm': price_midterm, # Classification (up/down/sideways)
|
||||
'longterm': price_longterm, # Classification (up/down/sideways)
|
||||
'values': price_values # Regression (expected % change)
|
||||
}
|
||||
|
||||
return action_values, extrema_pred, price_predictions
|
||||
# Return all outputs and the hidden feature representation
|
||||
return action_values, extrema_pred, price_predictions, x
|
||||
|
||||
def extract_features(self, x):
|
||||
"""Extract hidden features from the input and return both action values and features"""
|
||||
# Flatten input if needed to ensure it matches the expected feature dimension
|
||||
batch_size = x.size(0)
|
||||
|
||||
# Reshape input if needed
|
||||
if len(x.shape) > 2: # Handle multi-dimensional input
|
||||
# For 3D input: [batch, seq_len, features] or [batch, channels, features]
|
||||
x = x.reshape(batch_size, -1) # Flatten to [batch, seq_len*features]
|
||||
|
||||
# Check if the feature dimension matches and rebuild if necessary
|
||||
if x.size(1) != self.feature_dim:
|
||||
self._check_rebuild_network(x.size(1))
|
||||
|
||||
# Apply fully connected layers with ReLU activation
|
||||
x_features = self.fc_layers(x)
|
||||
|
||||
# Branch 1: Action values (Q-values)
|
||||
action_values = self.advantage_head(x_features)
|
||||
|
||||
# Return action values and the hidden feature representation
|
||||
return action_values, x_features
|
||||
|
||||
def save(self, path):
|
||||
"""Save model weights and architecture"""
|
||||
@ -241,8 +259,10 @@ class CNNModelPyTorch(nn.Module):
|
||||
self.output_size = output_size
|
||||
self.timeframes = timeframes
|
||||
|
||||
# Calculate total input features across all timeframes
|
||||
self.total_features = num_features * len(timeframes)
|
||||
# num_features should already be the total features across all timeframes
|
||||
self.total_features = num_features
|
||||
logger.info(f"CNNModelPyTorch initialized with window_size={window_size}, num_features={num_features}, "
|
||||
f"total_features={self.total_features}, output_size={output_size}, timeframes={timeframes}")
|
||||
|
||||
# Device configuration
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@ -317,6 +337,10 @@ class CNNModelPyTorch(nn.Module):
|
||||
# Ensure input is on the correct device
|
||||
x = x.to(self.device)
|
||||
|
||||
# Log input tensor shape for debugging
|
||||
input_shape = x.size()
|
||||
logger.debug(f"Input tensor shape: {input_shape}")
|
||||
|
||||
# Check input dimensions and reshape as needed
|
||||
if len(x.size()) == 2:
|
||||
# If input is [batch_size, features], reshape to [batch_size, features, 1]
|
||||
@ -324,8 +348,17 @@ class CNNModelPyTorch(nn.Module):
|
||||
|
||||
# Check and handle if input features don't match model expectations
|
||||
if feature_dim != self.total_features:
|
||||
logger.warning(f"Input features ({feature_dim}) don't match model features ({self.total_features}), rebuilding layers")
|
||||
self.rebuild_conv_layers(feature_dim)
|
||||
logger.warning(f"Input features ({feature_dim}) don't match model features ({self.total_features})")
|
||||
if not hasattr(self, 'rebuild_warning_shown'):
|
||||
logger.error(f"Dimension mismatch: Expected {self.total_features} features but got {feature_dim}")
|
||||
self.rebuild_warning_shown = True
|
||||
# Don't rebuild - instead adapt the input
|
||||
# If features are fewer, pad with zeros. If more, truncate
|
||||
if feature_dim < self.total_features:
|
||||
padding = torch.zeros(batch_size, self.total_features - feature_dim, device=self.device)
|
||||
x = torch.cat([x, padding], dim=1)
|
||||
else:
|
||||
x = x[:, :self.total_features]
|
||||
|
||||
# For 1D input, use a sequence length of 1
|
||||
seq_len = 1
|
||||
@ -336,14 +369,26 @@ class CNNModelPyTorch(nn.Module):
|
||||
|
||||
# Check and handle if input dimensions don't match model expectations
|
||||
if feature_dim != self.total_features:
|
||||
logger.warning(f"Input features ({feature_dim}) don't match model features ({self.total_features}), rebuilding layers")
|
||||
self.rebuild_conv_layers(feature_dim)
|
||||
logger.warning(f"Input features ({feature_dim}) don't match model features ({self.total_features})")
|
||||
if not hasattr(self, 'rebuild_warning_shown'):
|
||||
logger.error(f"Dimension mismatch: Expected {self.total_features} features but got {feature_dim}")
|
||||
self.rebuild_warning_shown = True
|
||||
# Don't rebuild - instead adapt the input
|
||||
# If features are fewer, pad with zeros. If more, truncate
|
||||
if feature_dim < self.total_features:
|
||||
padding = torch.zeros(batch_size, seq_len, self.total_features - feature_dim, device=self.device)
|
||||
x = torch.cat([x, padding], dim=2)
|
||||
else:
|
||||
x = x[:, :, :self.total_features]
|
||||
|
||||
# Reshape input: [batch, window_size, features] -> [batch, features, window_size]
|
||||
x = x.permute(0, 2, 1)
|
||||
else:
|
||||
raise ValueError(f"Unexpected input shape: {x.size()}, expected 2D or 3D tensor")
|
||||
|
||||
# Log reshaped tensor for debugging
|
||||
logger.debug(f"Reshaped tensor for convolution: {x.size()}")
|
||||
|
||||
# Convolutional layers with dropout - safely handle small spatial dimensions
|
||||
try:
|
||||
x = self.dropout1(F.relu(self.norm1(self.conv1(x))))
|
||||
|
@ -375,7 +375,7 @@ def realtime(data_interface, model, args, chart=None, symbol=None):
|
||||
logger.info(f"Starting real-time inference mode for {symbol}...")
|
||||
|
||||
try:
|
||||
from NN.utils.realtime_analyzer import RealtimeAnalyzer
|
||||
from NN.utils.realtime_analyzer import RealtimeAnalyzer
|
||||
|
||||
# Load the latest model
|
||||
model_dir = os.path.join('models')
|
||||
|
585
NN/train_enhanced.py
Normal file
585
NN/train_enhanced.py
Normal file
@ -0,0 +1,585 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
import contextlib
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# Import our enhanced agent
|
||||
from NN.models.dqn_agent_enhanced import EnhancedDQNAgent
|
||||
from NN.utils.data_interface import DataInterface
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler('logs/enhanced_training.log')
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(description='Train enhanced RL trading agent')
|
||||
parser.add_argument('--episodes', type=int, default=100, help='Number of episodes to train')
|
||||
parser.add_argument('--max-steps', type=int, default=2000, help='Maximum steps per episode')
|
||||
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading symbol')
|
||||
parser.add_argument('--no-gpu', action='store_true', help='Disable GPU usage')
|
||||
parser.add_argument('--confidence', type=float, default=0.4, help='Confidence threshold')
|
||||
parser.add_argument('--load-model', type=str, default='', help='Load existing model')
|
||||
parser.add_argument('--batch-size', type=int, default=128, help='Training batch size')
|
||||
parser.add_argument('--learning-rate', type=float, default=0.0003, help='Learning rate')
|
||||
parser.add_argument('--no-pretrain', action='store_true', help='Skip pre-training')
|
||||
parser.add_argument('--pretrain-epochs', type=int, default=20, help='Number of pre-training epochs')
|
||||
return parser.parse_args()
|
||||
|
||||
def generate_price_prediction_training_data(data_1m, data_1h, data_1d, window_size=20):
|
||||
"""
|
||||
Generate labeled training data for price prediction pre-training
|
||||
|
||||
Args:
|
||||
data_1m: 1-minute candle data
|
||||
data_1h: 1-hour candle data
|
||||
data_1d: 1-day candle data
|
||||
window_size: Size of the observation window
|
||||
|
||||
Returns:
|
||||
X, y_immediate, y_midterm, y_longterm, y_values
|
||||
"""
|
||||
logger.info("Generating price prediction training data")
|
||||
|
||||
# Features to use
|
||||
ohlcv_columns = ['open', 'high', 'low', 'close', 'volume']
|
||||
|
||||
# Create feature sets
|
||||
X = []
|
||||
y_immediate = [] # 1m prediction (next 5min)
|
||||
y_midterm = [] # 1h prediction (next few hours)
|
||||
y_longterm = [] # 1d prediction (next day)
|
||||
y_values = [] # % change for each timeframe
|
||||
|
||||
# Need enough data for all timeframes
|
||||
if len(data_1m) < window_size + 5 or len(data_1h) < 2 or len(data_1d) < 2:
|
||||
logger.error("Not enough data for all timeframes")
|
||||
return np.array([]), np.array([]), np.array([]), np.array([]), np.array([])
|
||||
|
||||
# Generate examples
|
||||
for i in range(window_size, len(data_1m) - 5):
|
||||
# Skip if we can't align with higher timeframes
|
||||
if i % 60 != 0: # Only use minutes that align with hour boundaries
|
||||
continue
|
||||
|
||||
try:
|
||||
# Get window of 1m data as input
|
||||
window_1m = data_1m[i-window_size:i][ohlcv_columns].values
|
||||
|
||||
# Find corresponding indices in higher timeframes
|
||||
curr_timestamp = data_1m.index[i]
|
||||
h_idx = data_1h.index.get_indexer([curr_timestamp], method='nearest')[0]
|
||||
d_idx = data_1d.index.get_indexer([curr_timestamp], method='nearest')[0]
|
||||
|
||||
# Skip if indices are out of bounds
|
||||
if h_idx < 0 or h_idx >= len(data_1h) - 1 or d_idx < 0 or d_idx >= len(data_1d) - 1:
|
||||
continue
|
||||
|
||||
# Get future prices for label generation
|
||||
future_5m = data_1m[i+5]['close']
|
||||
future_1h = data_1h[h_idx+1]['close']
|
||||
future_1d = data_1d[d_idx+1]['close']
|
||||
|
||||
current_price = data_1m[i]['close']
|
||||
|
||||
# Calculate % change for each timeframe
|
||||
change_5m = (future_5m - current_price) / current_price * 100
|
||||
change_1h = (future_1h - current_price) / current_price * 100
|
||||
change_1d = (future_1d - current_price) / current_price * 100
|
||||
|
||||
# Determine price direction (0=down, 1=sideways, 2=up)
|
||||
def get_direction(change):
|
||||
if change < -0.5: # Down if less than -0.5%
|
||||
return 0
|
||||
elif change > 0.5: # Up if more than 0.5%
|
||||
return 2
|
||||
else: # Sideways if between -0.5% and 0.5%
|
||||
return 1
|
||||
|
||||
direction_5m = get_direction(change_5m)
|
||||
direction_1h = get_direction(change_1h)
|
||||
direction_1d = get_direction(change_1d)
|
||||
|
||||
# Add to dataset
|
||||
X.append(window_1m.flatten())
|
||||
y_immediate.append(direction_5m)
|
||||
y_midterm.append(direction_1h)
|
||||
y_longterm.append(direction_1d)
|
||||
y_values.append([change_5m, change_1h, change_1d, 0]) # Last value reserved
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error generating training example at index {i}: {str(e)}")
|
||||
|
||||
# Convert to numpy arrays
|
||||
X = np.array(X)
|
||||
y_immediate = np.array(y_immediate)
|
||||
y_midterm = np.array(y_midterm)
|
||||
y_longterm = np.array(y_longterm)
|
||||
y_values = np.array(y_values)
|
||||
|
||||
logger.info(f"Generated {len(X)} training examples")
|
||||
logger.info(f"Class distribution - Immediate: {np.bincount(y_immediate)}, "
|
||||
f"Midterm: {np.bincount(y_midterm)}, Long-term: {np.bincount(y_longterm)}")
|
||||
|
||||
return X, y_immediate, y_midterm, y_longterm, y_values
|
||||
|
||||
def pretrain_price_prediction(agent, data_interface, n_epochs=20, batch_size=128, device=None):
|
||||
"""
|
||||
Pre-train the price prediction capabilities of the agent
|
||||
|
||||
Args:
|
||||
agent: EnhancedDQNAgent instance
|
||||
data_interface: DataInterface instance
|
||||
n_epochs: Number of pre-training epochs
|
||||
batch_size: Batch size for pre-training
|
||||
device: Device to use for pre-training
|
||||
|
||||
Returns:
|
||||
The pre-trained agent
|
||||
"""
|
||||
logger.info("Starting price prediction pre-training")
|
||||
|
||||
try:
|
||||
# Ensure we have the necessary timeframes
|
||||
timeframes_needed = ['1m', '1h', '1d']
|
||||
for tf in timeframes_needed:
|
||||
if tf not in data_interface.timeframes:
|
||||
logger.info(f"Adding timeframe {tf} for pre-training")
|
||||
# Add timeframe to the list if not present
|
||||
if tf not in data_interface.timeframes:
|
||||
data_interface.timeframes.append(tf)
|
||||
data_interface.dataframes[tf] = None
|
||||
|
||||
# Get data for each timeframe
|
||||
data_1m = data_interface.get_historical_data(timeframe='1m')
|
||||
data_1h = data_interface.get_historical_data(timeframe='1h')
|
||||
data_1d = data_interface.get_historical_data(timeframe='1d')
|
||||
|
||||
# Generate labeled training data
|
||||
X, y_immediate, y_midterm, y_longterm, y_values = generate_price_prediction_training_data(
|
||||
data_1m, data_1h, data_1d, window_size=20
|
||||
)
|
||||
|
||||
if len(X) == 0:
|
||||
logger.error("No training examples generated. Skipping pre-training.")
|
||||
return agent
|
||||
|
||||
# Split data into training and validation sets
|
||||
X_train, X_val, y_imm_train, y_imm_val, y_mid_train, y_mid_val, y_long_train, y_long_val, y_val_train, y_val_val = train_test_split(
|
||||
X, y_immediate, y_midterm, y_longterm, y_values, test_size=0.2, random_state=42
|
||||
)
|
||||
|
||||
# Convert to torch tensors
|
||||
X_train_tensor = torch.FloatTensor(X_train).to(device)
|
||||
y_imm_train_tensor = torch.LongTensor(y_imm_train).to(device)
|
||||
y_mid_train_tensor = torch.LongTensor(y_mid_train).to(device)
|
||||
y_long_train_tensor = torch.LongTensor(y_long_train).to(device)
|
||||
y_val_train_tensor = torch.FloatTensor(y_val_train).to(device)
|
||||
|
||||
X_val_tensor = torch.FloatTensor(X_val).to(device)
|
||||
y_imm_val_tensor = torch.LongTensor(y_imm_val).to(device)
|
||||
y_mid_val_tensor = torch.LongTensor(y_mid_val).to(device)
|
||||
y_long_val_tensor = torch.LongTensor(y_long_val).to(device)
|
||||
y_val_val_tensor = torch.FloatTensor(y_val_val).to(device)
|
||||
|
||||
# Calculate class weights for imbalanced data
|
||||
def get_class_weights(labels):
|
||||
counts = np.bincount(labels)
|
||||
if len(counts) < 3: # Ensure we have 3 classes
|
||||
counts = np.append(counts, [0] * (3 - len(counts)))
|
||||
weights = 1.0 / np.array(counts)
|
||||
weights = weights / np.sum(weights) # Normalize
|
||||
return weights
|
||||
|
||||
imm_weights = torch.FloatTensor(get_class_weights(y_imm_train)).to(device)
|
||||
mid_weights = torch.FloatTensor(get_class_weights(y_mid_train)).to(device)
|
||||
long_weights = torch.FloatTensor(get_class_weights(y_long_train)).to(device)
|
||||
|
||||
# Create DataLoader for batch training
|
||||
train_dataset = TensorDataset(
|
||||
X_train_tensor, y_imm_train_tensor, y_mid_train_tensor,
|
||||
y_long_train_tensor, y_val_train_tensor
|
||||
)
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
# Set up loss functions with class weights
|
||||
imm_criterion = nn.CrossEntropyLoss(weight=imm_weights)
|
||||
mid_criterion = nn.CrossEntropyLoss(weight=mid_weights)
|
||||
long_criterion = nn.CrossEntropyLoss(weight=long_weights)
|
||||
value_criterion = nn.MSELoss()
|
||||
|
||||
# Set up optimizer (separate from agent's optimizer)
|
||||
pretrain_optimizer = torch.optim.Adam(agent.policy_net.parameters(), lr=0.0002)
|
||||
pretrain_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
pretrain_optimizer, mode='min', factor=0.5, patience=3, verbose=True
|
||||
)
|
||||
|
||||
# Set model to training mode
|
||||
agent.policy_net.train()
|
||||
|
||||
# Training loop
|
||||
best_val_loss = float('inf')
|
||||
patience = 5
|
||||
patience_counter = 0
|
||||
|
||||
# Create TensorBoard writer for pre-training
|
||||
writer = SummaryWriter(log_dir=f'runs/pretrain_{int(time.time())}')
|
||||
|
||||
for epoch in range(n_epochs):
|
||||
# Training phase
|
||||
train_loss = 0.0
|
||||
imm_correct, mid_correct, long_correct = 0, 0, 0
|
||||
total = 0
|
||||
|
||||
for X_batch, y_imm_batch, y_mid_batch, y_long_batch, y_val_batch in train_loader:
|
||||
# Zero gradients
|
||||
pretrain_optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
with torch.cuda.amp.autocast() if agent.use_mixed_precision else contextlib.nullcontext():
|
||||
q_values, _, price_preds, _ = agent.policy_net(X_batch)
|
||||
|
||||
# Calculate losses for each prediction head
|
||||
imm_loss = imm_criterion(price_preds['immediate'], y_imm_batch)
|
||||
mid_loss = mid_criterion(price_preds['midterm'], y_mid_batch)
|
||||
long_loss = long_criterion(price_preds['longterm'], y_long_batch)
|
||||
value_loss = value_criterion(price_preds['values'], y_val_batch)
|
||||
|
||||
# Combined loss (weighted by importance)
|
||||
total_loss = imm_loss + 0.7 * mid_loss + 0.5 * long_loss + 0.3 * value_loss
|
||||
|
||||
# Backward pass and optimize
|
||||
if agent.use_mixed_precision:
|
||||
agent.scaler.scale(total_loss).backward()
|
||||
agent.scaler.unscale_(pretrain_optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
|
||||
agent.scaler.step(pretrain_optimizer)
|
||||
agent.scaler.update()
|
||||
else:
|
||||
total_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
|
||||
pretrain_optimizer.step()
|
||||
|
||||
# Accumulate metrics
|
||||
train_loss += total_loss.item()
|
||||
total += X_batch.size(0)
|
||||
|
||||
# Calculate accuracy
|
||||
_, imm_pred = torch.max(price_preds['immediate'], 1)
|
||||
_, mid_pred = torch.max(price_preds['midterm'], 1)
|
||||
_, long_pred = torch.max(price_preds['longterm'], 1)
|
||||
|
||||
imm_correct += (imm_pred == y_imm_batch).sum().item()
|
||||
mid_correct += (mid_pred == y_mid_batch).sum().item()
|
||||
long_correct += (long_pred == y_long_batch).sum().item()
|
||||
|
||||
# Calculate epoch metrics
|
||||
train_loss /= len(train_loader)
|
||||
imm_acc = imm_correct / total
|
||||
mid_acc = mid_correct / total
|
||||
long_acc = long_correct / total
|
||||
|
||||
# Validation phase
|
||||
agent.policy_net.eval()
|
||||
val_loss = 0.0
|
||||
imm_val_correct, mid_val_correct, long_val_correct = 0, 0, 0
|
||||
|
||||
with torch.no_grad():
|
||||
# Forward pass on validation data
|
||||
q_values, _, val_price_preds, _ = agent.policy_net(X_val_tensor)
|
||||
|
||||
# Calculate validation losses
|
||||
val_imm_loss = imm_criterion(val_price_preds['immediate'], y_imm_val_tensor)
|
||||
val_mid_loss = mid_criterion(val_price_preds['midterm'], y_mid_val_tensor)
|
||||
val_long_loss = long_criterion(val_price_preds['longterm'], y_long_val_tensor)
|
||||
val_value_loss = value_criterion(val_price_preds['values'], y_val_val_tensor)
|
||||
|
||||
val_total_loss = val_imm_loss + 0.7 * val_mid_loss + 0.5 * val_long_loss + 0.3 * val_value_loss
|
||||
val_loss = val_total_loss.item()
|
||||
|
||||
# Calculate validation accuracy
|
||||
_, imm_val_pred = torch.max(val_price_preds['immediate'], 1)
|
||||
_, mid_val_pred = torch.max(val_price_preds['midterm'], 1)
|
||||
_, long_val_pred = torch.max(val_price_preds['longterm'], 1)
|
||||
|
||||
imm_val_correct = (imm_val_pred == y_imm_val_tensor).sum().item()
|
||||
mid_val_correct = (mid_val_pred == y_mid_val_tensor).sum().item()
|
||||
long_val_correct = (long_val_pred == y_long_val_tensor).sum().item()
|
||||
|
||||
imm_val_acc = imm_val_correct / len(X_val_tensor)
|
||||
mid_val_acc = mid_val_correct / len(X_val_tensor)
|
||||
long_val_acc = long_val_correct / len(X_val_tensor)
|
||||
|
||||
# Log to TensorBoard
|
||||
writer.add_scalar('pretrain/train_loss', train_loss, epoch)
|
||||
writer.add_scalar('pretrain/val_loss', val_loss, epoch)
|
||||
writer.add_scalar('pretrain/imm_acc', imm_acc, epoch)
|
||||
writer.add_scalar('pretrain/mid_acc', mid_acc, epoch)
|
||||
writer.add_scalar('pretrain/long_acc', long_acc, epoch)
|
||||
writer.add_scalar('pretrain/imm_val_acc', imm_val_acc, epoch)
|
||||
writer.add_scalar('pretrain/mid_val_acc', mid_val_acc, epoch)
|
||||
writer.add_scalar('pretrain/long_val_acc', long_val_acc, epoch)
|
||||
|
||||
# Learning rate scheduling
|
||||
pretrain_scheduler.step(val_loss)
|
||||
|
||||
# Early stopping check
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
patience_counter = 0
|
||||
# Copy policy_net weights to target_net
|
||||
agent.target_net.load_state_dict(agent.policy_net.state_dict())
|
||||
logger.info(f"Saved best model with validation loss: {val_loss:.4f}")
|
||||
# Save pre-trained model
|
||||
agent.save("NN/models/saved/enhanced_dqn_pretrained")
|
||||
else:
|
||||
patience_counter += 1
|
||||
if patience_counter >= patience:
|
||||
logger.info(f"Early stopping triggered after {epoch+1} epochs")
|
||||
break
|
||||
|
||||
# Log progress
|
||||
logger.info(f"Epoch {epoch+1}/{n_epochs}: "
|
||||
f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
|
||||
f"Imm Acc: {imm_acc:.4f}/{imm_val_acc:.4f}, "
|
||||
f"Mid Acc: {mid_acc:.4f}/{mid_val_acc:.4f}, "
|
||||
f"Long Acc: {long_acc:.4f}/{long_val_acc:.4f}")
|
||||
|
||||
# Set model back to training mode for next epoch
|
||||
agent.policy_net.train()
|
||||
|
||||
writer.close()
|
||||
logger.info("Price prediction pre-training complete")
|
||||
return agent
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during price prediction pre-training: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return agent
|
||||
|
||||
def train_enhanced_rl(args):
|
||||
"""
|
||||
Train the enhanced RL agent for trading
|
||||
|
||||
Args:
|
||||
args: Command line arguments
|
||||
"""
|
||||
# Setup device
|
||||
if args.no_gpu:
|
||||
device = torch.device('cpu')
|
||||
else:
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Set up data interface
|
||||
data_interface = DataInterface(symbol=args.symbol, timeframes=['1m', '5m', '15m'])
|
||||
|
||||
# Fetch historical data for each timeframe
|
||||
for timeframe in data_interface.timeframes:
|
||||
df = data_interface.get_historical_data(timeframe=timeframe)
|
||||
logger.info(f"Using data for {args.symbol} {timeframe} ({len(data_interface.dataframes[timeframe])} candles)")
|
||||
|
||||
# Create environment for training
|
||||
from NN.environments.trading_env import TradingEnvironment
|
||||
window_size = 20
|
||||
train_env = TradingEnvironment(
|
||||
data_interface=data_interface,
|
||||
initial_balance=10000.0,
|
||||
transaction_fee=0.0002,
|
||||
window_size=window_size,
|
||||
max_position=1.0,
|
||||
reward_scaling=100.0
|
||||
)
|
||||
|
||||
# Create agent with improved parameters
|
||||
state_shape = train_env.observation_space.shape
|
||||
n_actions = train_env.action_space.n
|
||||
|
||||
agent = EnhancedDQNAgent(
|
||||
state_shape=state_shape,
|
||||
n_actions=n_actions,
|
||||
learning_rate=args.learning_rate,
|
||||
gamma=0.95,
|
||||
epsilon=1.0,
|
||||
epsilon_min=0.05,
|
||||
epsilon_decay=0.995,
|
||||
buffer_size=50000,
|
||||
batch_size=args.batch_size,
|
||||
target_update=10,
|
||||
confidence_threshold=args.confidence,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Load existing model if specified
|
||||
if args.load_model:
|
||||
model_path = args.load_model
|
||||
if agent.load(model_path):
|
||||
logger.info(f"Loaded existing model from {model_path}")
|
||||
else:
|
||||
logger.error(f"Error loading model from {model_path}")
|
||||
|
||||
# Pre-training for price prediction
|
||||
if not args.no_pretrain and not args.load_model:
|
||||
logger.info("Starting pre-training phase")
|
||||
agent = pretrain_price_prediction(
|
||||
agent=agent,
|
||||
data_interface=data_interface,
|
||||
n_epochs=args.pretrain_epochs,
|
||||
batch_size=args.batch_size,
|
||||
device=device
|
||||
)
|
||||
logger.info("Pre-training completed")
|
||||
|
||||
# Setup TensorBoard
|
||||
writer = SummaryWriter(log_dir=f'runs/enhanced_rl_{int(time.time())}')
|
||||
|
||||
# Log hardware info
|
||||
writer.add_text("hardware/device", str(device), 0)
|
||||
if torch.cuda.is_available():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
writer.add_text(f"hardware/gpu_{i}", torch.cuda.get_device_name(i), 0)
|
||||
|
||||
# Move agent to device
|
||||
agent.move_models_to_device(device)
|
||||
|
||||
# Training loop
|
||||
logger.info(f"Starting enhanced training for {args.episodes} episodes")
|
||||
|
||||
total_rewards = []
|
||||
episode_losses = []
|
||||
trade_win_rates = []
|
||||
best_reward = -np.inf
|
||||
|
||||
try:
|
||||
for episode in range(args.episodes):
|
||||
# Reset environment for new episode
|
||||
state = train_env.reset()
|
||||
total_reward = 0.0
|
||||
done = False
|
||||
step = 0
|
||||
episode_start_time = time.time()
|
||||
|
||||
# Track trade statistics
|
||||
trades = []
|
||||
wins = 0
|
||||
losses = 0
|
||||
|
||||
# Run episode
|
||||
while not done and step < args.max_steps:
|
||||
# Choose action
|
||||
action, confidence = agent.act(state)
|
||||
|
||||
# Take action in environment
|
||||
next_state, reward, done, info = train_env.step(action)
|
||||
|
||||
# Remember experience
|
||||
agent.remember(state, action, reward, next_state, done)
|
||||
|
||||
# Track trade results
|
||||
if 'trade_result' in info and info['trade_result'] is not None:
|
||||
trade_result = info['trade_result']
|
||||
trade_pnl = trade_result['pnl']
|
||||
trades.append(trade_pnl)
|
||||
|
||||
if trade_pnl > 0:
|
||||
wins += 1
|
||||
logger.info(f"Profitable trade! {trade_pnl:.2f}% profit, reward: {reward:.4f}")
|
||||
else:
|
||||
losses += 1
|
||||
logger.info(f"Loss trade! {trade_pnl:.2f}% loss, penalty: {reward:.4f}")
|
||||
|
||||
# Update state and counters
|
||||
state = next_state
|
||||
total_reward += reward
|
||||
step += 1
|
||||
|
||||
# Train agent
|
||||
loss = agent.replay()
|
||||
if loss > 0:
|
||||
episode_losses.append(loss)
|
||||
|
||||
# Log training metrics for each episode
|
||||
episode_time = time.time() - episode_start_time
|
||||
total_rewards.append(total_reward)
|
||||
|
||||
# Calculate win rate
|
||||
win_rate = wins / max(1, (wins + losses))
|
||||
trade_win_rates.append(win_rate)
|
||||
|
||||
# Log to console and TensorBoard
|
||||
logger.info(f"Episode {episode}/{args.episodes} - Reward: {total_reward:.4f}, Win Rate: {win_rate:.2f}, "
|
||||
f"Trades: {len(trades)}, Balance: ${train_env.balance:.2f}, Epsilon: {agent.epsilon:.4f}, "
|
||||
f"Time: {episode_time:.2f}s")
|
||||
|
||||
writer.add_scalar('metrics/reward', total_reward, episode)
|
||||
writer.add_scalar('metrics/balance', train_env.balance, episode)
|
||||
writer.add_scalar('metrics/win_rate', win_rate, episode)
|
||||
writer.add_scalar('metrics/trades', len(trades), episode)
|
||||
writer.add_scalar('metrics/epsilon', agent.epsilon, episode)
|
||||
|
||||
if episode_losses:
|
||||
avg_loss = sum(episode_losses) / len(episode_losses)
|
||||
writer.add_scalar('metrics/loss', avg_loss, episode)
|
||||
|
||||
# Check if this is the best model so far
|
||||
if total_reward > best_reward:
|
||||
best_reward = total_reward
|
||||
# Save best model
|
||||
agent.save(f"NN/models/saved/enhanced_dqn_best")
|
||||
logger.info(f"New best model saved with reward: {best_reward:.4f}")
|
||||
|
||||
# Save checkpoint every 10 episodes
|
||||
if episode % 10 == 0 and episode > 0:
|
||||
agent.save(f"NN/models/saved/enhanced_dqn_checkpoint")
|
||||
logger.info(f"Checkpoint saved at episode {episode}")
|
||||
|
||||
# Reset episode losses
|
||||
episode_losses = []
|
||||
|
||||
# Final save
|
||||
agent.save(f"NN/models/saved/enhanced_dqn_final")
|
||||
logger.info("Enhanced training completed, final model saved")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
# Close TensorBoard writer
|
||||
writer.close()
|
||||
|
||||
return agent, train_env
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs("logs", exist_ok=True)
|
||||
os.makedirs("NN/models/saved", exist_ok=True)
|
||||
|
||||
# Parse arguments
|
||||
args = parse_args()
|
||||
|
||||
# Start training
|
||||
train_enhanced_rl(args)
|
836
NN/train_rl.py
836
NN/train_rl.py
@ -52,7 +52,7 @@ class RLTradingEnvironment(gym.Env):
|
||||
Reinforcement Learning environment for trading with technical indicators
|
||||
from multiple timeframes
|
||||
"""
|
||||
def __init__(self, features_1m, features_1h=None, features_1d=None, window_size=20, trading_fee=0.0025, min_trade_interval=15):
|
||||
def __init__(self, features_1m, features_1h, features_1d, window_size=20, trading_fee=0.0025, min_trade_interval=15):
|
||||
super().__init__()
|
||||
|
||||
# Initialize attributes before parent class
|
||||
@ -60,12 +60,7 @@ class RLTradingEnvironment(gym.Env):
|
||||
self.num_features = features_1m.shape[1] - 1 # Exclude close price
|
||||
|
||||
# Count available timeframes
|
||||
self.num_timeframes = 1 # Always have 1m
|
||||
if features_1h is not None:
|
||||
self.num_timeframes += 1
|
||||
if features_1d is not None:
|
||||
self.num_timeframes += 1
|
||||
|
||||
self.num_timeframes = 3 # We require all timeframes now
|
||||
self.feature_dim = self.num_features * self.num_timeframes
|
||||
|
||||
# Store features from different timeframes
|
||||
@ -73,16 +68,6 @@ class RLTradingEnvironment(gym.Env):
|
||||
self.features_1h = features_1h
|
||||
self.features_1d = features_1d
|
||||
|
||||
# Create synthetic 1s data from 1m (for demo purposes)
|
||||
self.features_1s = self._create_synthetic_1s_data(features_1m)
|
||||
|
||||
# If higher timeframes are missing, create synthetic data
|
||||
if self.features_1h is None:
|
||||
self.features_1h = self._create_synthetic_hourly_data(features_1m)
|
||||
|
||||
if self.features_1d is None:
|
||||
self.features_1d = self._create_synthetic_daily_data(features_1h)
|
||||
|
||||
# Trading parameters
|
||||
self.initial_balance = 1.0
|
||||
self.trading_fee = trading_fee # Increased from 0.001 to 0.0025 (0.25%)
|
||||
@ -103,45 +88,6 @@ class RLTradingEnvironment(gym.Env):
|
||||
# Callback for visualization or external monitoring
|
||||
self.action_callback = None
|
||||
|
||||
def _create_synthetic_1s_data(self, features_1m):
|
||||
"""Create synthetic 1-second data from 1-minute data"""
|
||||
# Simple approach: duplicate each 1m candle for 60 seconds with some noise
|
||||
num_samples = features_1m.shape[0]
|
||||
synthetic_1s = np.zeros((num_samples * 60, features_1m.shape[1]))
|
||||
|
||||
for i in range(num_samples):
|
||||
for j in range(60):
|
||||
idx = i * 60 + j
|
||||
if idx < synthetic_1s.shape[0]:
|
||||
# Copy the 1m data with small random noise
|
||||
synthetic_1s[idx] = features_1m[i] * (1 + np.random.normal(0, 0.0001, features_1m.shape[1]))
|
||||
|
||||
return synthetic_1s
|
||||
|
||||
def _create_synthetic_hourly_data(self, features_1m):
|
||||
"""Create synthetic hourly data from minute data"""
|
||||
# Group by hour, taking every 60th candle
|
||||
num_samples = features_1m.shape[0] // 60
|
||||
synthetic_1h = np.zeros((num_samples, features_1m.shape[1]))
|
||||
|
||||
for i in range(num_samples):
|
||||
if i * 60 < features_1m.shape[0]:
|
||||
synthetic_1h[i] = features_1m[i * 60]
|
||||
|
||||
return synthetic_1h
|
||||
|
||||
def _create_synthetic_daily_data(self, features_1h):
|
||||
"""Create synthetic daily data from hourly data"""
|
||||
# Group by day, taking every 24th candle
|
||||
num_samples = features_1h.shape[0] // 24
|
||||
synthetic_1d = np.zeros((num_samples, features_1h.shape[1]))
|
||||
|
||||
for i in range(num_samples):
|
||||
if i * 24 < features_1h.shape[0]:
|
||||
synthetic_1d[i] = features_1h[i * 24]
|
||||
|
||||
return synthetic_1d
|
||||
|
||||
def reset(self):
|
||||
"""Reset the environment to initial state"""
|
||||
self.balance = self.initial_balance
|
||||
@ -208,161 +154,242 @@ class RLTradingEnvironment(gym.Env):
|
||||
return combined_features
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Take an action in the environment and return the next state, reward, done flag, and info
|
||||
"""Take an action and return the next state, reward, done flag, and info"""
|
||||
# Initialize info dictionary for additional data
|
||||
info = {
|
||||
'trade_executed': False,
|
||||
'price_change': 0.0,
|
||||
'position_change': 0,
|
||||
'current_price': 0.0,
|
||||
'next_price': 0.0,
|
||||
'balance_change': 0.0,
|
||||
'reward_components': {},
|
||||
'future_prices': {}
|
||||
}
|
||||
|
||||
Args:
|
||||
action (int): 0 = Buy, 1 = Sell, 2 = Hold
|
||||
|
||||
Returns:
|
||||
tuple: (observation, reward, done, info)
|
||||
"""
|
||||
# Get current and next price
|
||||
current_price = self.features_1m[self.current_step, -1] # Close price is last column
|
||||
# Get the current and next price
|
||||
current_price = self.features_1m[self.current_step, -1]
|
||||
|
||||
# Check if we're at the end of the data
|
||||
if self.current_step + 1 >= len(self.features_1m):
|
||||
next_price = current_price # Use current price if at the end
|
||||
# Handle edge case at the end of the data
|
||||
if self.current_step >= len(self.features_1m) - 1:
|
||||
next_price = current_price # Use current price as next price
|
||||
done = True
|
||||
else:
|
||||
next_price = self.features_1m[self.current_step + 1, -1]
|
||||
done = False
|
||||
|
||||
# Handle zero or negative prices
|
||||
# Handle zero or negative price (data error)
|
||||
if current_price <= 0:
|
||||
current_price = 1e-8 # Small positive number
|
||||
current_price = 0.01 # Set to a small positive number
|
||||
logger.warning(f"Zero or negative price detected at step {self.current_step}. Setting to 0.01.")
|
||||
|
||||
if next_price <= 0:
|
||||
next_price = current_price # Use current price if next price is invalid
|
||||
|
||||
price_change = (next_price - current_price) / current_price
|
||||
next_price = current_price # Use current price instead
|
||||
logger.warning(f"Zero or negative next price detected at step {self.current_step + 1}. Using current price.")
|
||||
|
||||
# Default reward is slightly negative to discourage inaction
|
||||
reward = -0.0001
|
||||
profit_pct = None # Initialize profit_pct variable
|
||||
# Calculate price change as percentage
|
||||
price_change_pct = ((next_price - current_price) / current_price) * 100
|
||||
|
||||
# Check if enough time has passed since last trade
|
||||
trade_interval = self.current_step - self.last_trade_step
|
||||
trade_interval_penalty = 0
|
||||
# Store prices in info
|
||||
info['current_price'] = current_price
|
||||
info['next_price'] = next_price
|
||||
info['price_change'] = price_change_pct
|
||||
|
||||
# Execute action
|
||||
if action == 0: # BUY
|
||||
if self.position == 0: # Only buy if not already in position
|
||||
# Apply extra penalty for trading too frequently
|
||||
if trade_interval < self.min_trade_interval:
|
||||
trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval)
|
||||
# Still allow the trade but with penalty
|
||||
# Initialize reward components dictionary
|
||||
reward_components = {
|
||||
'holding_reward': 0.0,
|
||||
'action_reward': 0.0,
|
||||
'profit_reward': 0.0,
|
||||
'trade_freq_penalty': 0.0
|
||||
}
|
||||
|
||||
# Default small negative reward to discourage inaction
|
||||
reward = -0.01
|
||||
reward_components['holding_reward'] = -0.01
|
||||
|
||||
# Track previous balance for changes
|
||||
previous_balance = self.balance
|
||||
|
||||
# Execute action (0: Buy, 1: Sell, 2: Hold)
|
||||
if action == 0: # Buy
|
||||
if self.position == 0: # Only buy if we don't already have a position
|
||||
# Calculate how much of the asset we can buy with 100% of balance
|
||||
self.position = self.balance / current_price
|
||||
self.balance = 0 # All balance used
|
||||
|
||||
self.position = self.balance * (1 - self.trading_fee)
|
||||
self.balance = 0
|
||||
self.trades += 1
|
||||
reward = -0.001 + trade_interval_penalty # Small cost for transaction + potential penalty
|
||||
self.trade_entry_price = current_price
|
||||
self.last_trade_step = self.current_step
|
||||
|
||||
elif action == 1: # SELL
|
||||
if self.position > 0: # Only sell if in position
|
||||
# Apply extra penalty for trading too frequently
|
||||
if trade_interval < self.min_trade_interval:
|
||||
trade_interval_penalty = -0.002 * (self.min_trade_interval - trade_interval)
|
||||
# Still allow the trade but with penalty
|
||||
|
||||
# Calculate position value at current price
|
||||
position_value = self.position * (1 + price_change)
|
||||
self.balance = position_value * (1 - self.trading_fee)
|
||||
|
||||
# Calculate profit/loss from trade
|
||||
profit_pct = (next_price - self.trade_entry_price) / self.trade_entry_price
|
||||
# Scale reward by profit percentage and apply trade interval penalty
|
||||
reward = (profit_pct * 10) + trade_interval_penalty
|
||||
|
||||
# Update win/loss count
|
||||
if profit_pct > 0:
|
||||
self.wins += 1
|
||||
# If price goes up after buying, that's good
|
||||
expected_profit = price_change_pct
|
||||
# Scale reward based on expected profit
|
||||
if expected_profit > 0:
|
||||
# Positive reward for profitable buy decision
|
||||
action_reward = 0.1 + (expected_profit * 0.05) # Base reward + profit-based bonus
|
||||
reward_components['action_reward'] = action_reward
|
||||
reward += action_reward
|
||||
else:
|
||||
self.losses += 1
|
||||
# Small negative reward for unprofitable buy
|
||||
action_reward = -0.1 + (expected_profit * 0.03) # Smaller penalty for small losses
|
||||
reward_components['action_reward'] = action_reward
|
||||
reward += action_reward
|
||||
|
||||
# Record trade
|
||||
# Check if we've traded too frequently
|
||||
if len(self.trade_history) > 0:
|
||||
last_trade_step = self.trade_history[-1]['step']
|
||||
if self.current_step - last_trade_step < 5: # If less than 5 steps since last trade
|
||||
freq_penalty = -0.2 # Penalty for trading too frequently
|
||||
reward += freq_penalty
|
||||
reward_components['trade_freq_penalty'] = freq_penalty
|
||||
|
||||
# Record the trade
|
||||
self.trade_history.append({
|
||||
'entry_price': self.trade_entry_price,
|
||||
'exit_price': next_price,
|
||||
'profit_pct': profit_pct,
|
||||
'trade_interval': trade_interval
|
||||
'step': self.current_step,
|
||||
'action': 'buy',
|
||||
'price': current_price,
|
||||
'position': self.position,
|
||||
'balance': self.balance
|
||||
})
|
||||
|
||||
# Reset position and update last trade step
|
||||
info['trade_executed'] = True
|
||||
logger.info(f"Buy at step {self.current_step}, price: {current_price:.4f}, position: {self.position:.6f}")
|
||||
|
||||
elif action == 1: # Sell
|
||||
if self.position > 0: # Only sell if we have a position
|
||||
# Calculate sale proceeds
|
||||
sale_value = self.position * current_price
|
||||
|
||||
# Calculate profit or loss percentage from last buy
|
||||
last_buy_price = None
|
||||
for trade in reversed(self.trade_history):
|
||||
if trade['action'] == 'buy':
|
||||
last_buy_price = trade['price']
|
||||
break
|
||||
|
||||
# If we found the last buy price, calculate profit
|
||||
if last_buy_price is not None:
|
||||
profit_pct = ((current_price - last_buy_price) / last_buy_price) * 100
|
||||
|
||||
# Highly reward profitable trades
|
||||
if profit_pct > 0:
|
||||
# Progressive reward based on profit percentage
|
||||
profit_reward = min(5.0, profit_pct * 0.2) # Cap at 5.0 to prevent exploitation
|
||||
reward_components['profit_reward'] = profit_reward
|
||||
reward += profit_reward
|
||||
logger.info(f"Profitable trade! {profit_pct:.2f}% profit, reward: {profit_reward:.4f}")
|
||||
else:
|
||||
# Penalize losses more heavily based on size of loss
|
||||
loss_penalty = max(-3.0, profit_pct * 0.15) # Cap at -3.0 to prevent excessive punishment
|
||||
reward_components['profit_reward'] = loss_penalty
|
||||
reward += loss_penalty
|
||||
logger.info(f"Loss trade! {profit_pct:.2f}% loss, penalty: {loss_penalty:.4f}")
|
||||
|
||||
# If price goes down after selling, that's good
|
||||
if price_change_pct < 0:
|
||||
# Reward for good timing on sell (avoiding future loss)
|
||||
timing_reward = min(1.0, abs(price_change_pct) * 0.05)
|
||||
reward_components['action_reward'] = timing_reward
|
||||
reward += timing_reward
|
||||
|
||||
# Check for trading too frequently
|
||||
if len(self.trade_history) > 0:
|
||||
last_trade_step = self.trade_history[-1]['step']
|
||||
if self.current_step - last_trade_step < 5: # If less than 5 steps since last trade
|
||||
freq_penalty = -0.2 # Penalty for trading too frequently
|
||||
reward += freq_penalty
|
||||
reward_components['trade_freq_penalty'] = freq_penalty
|
||||
|
||||
# Update balance and position
|
||||
self.balance = sale_value
|
||||
position_change = self.position
|
||||
self.position = 0
|
||||
self.last_trade_step = self.current_step
|
||||
|
||||
# Record the trade
|
||||
self.trade_history.append({
|
||||
'step': self.current_step,
|
||||
'action': 'sell',
|
||||
'price': current_price,
|
||||
'position': self.position,
|
||||
'balance': self.balance
|
||||
})
|
||||
|
||||
info['trade_executed'] = True
|
||||
info['position_change'] = position_change
|
||||
logger.info(f"Sell at step {self.current_step}, price: {current_price:.4f}, new balance: {self.balance:.4f}")
|
||||
|
||||
elif action == 2: # Hold
|
||||
# Small reward if holding was a good decision
|
||||
if self.position > 0 and price_change_pct > 0: # Holding long position during price increase
|
||||
hold_reward = price_change_pct * 0.01 # Small reward proportional to price increase
|
||||
reward += hold_reward
|
||||
reward_components['holding_reward'] = hold_reward
|
||||
elif self.position == 0 and price_change_pct < 0: # Holding cash during price decrease
|
||||
hold_reward = abs(price_change_pct) * 0.01 # Small reward for avoiding loss
|
||||
reward += hold_reward
|
||||
reward_components['holding_reward'] = hold_reward
|
||||
|
||||
# else: (action == 2 - HOLD) - no position change
|
||||
|
||||
# Move to next step
|
||||
# Move to the next step
|
||||
self.current_step += 1
|
||||
|
||||
# Check if done (reached end of data)
|
||||
# Update current portfolio value
|
||||
if self.position > 0:
|
||||
self.current_value = self.balance + (self.position * next_price)
|
||||
else:
|
||||
self.current_value = self.balance
|
||||
|
||||
# Calculate balance change
|
||||
balance_change = self.current_value - previous_balance
|
||||
info['balance_change'] = balance_change
|
||||
|
||||
# Check if we've reached the end of the data
|
||||
if self.current_step >= len(self.features_1m) - 1:
|
||||
done = True
|
||||
|
||||
# Apply final evaluation
|
||||
# Final evaluation if we have a position
|
||||
if self.position > 0:
|
||||
# Force close position at the end
|
||||
position_value = self.position * (1 + price_change)
|
||||
self.balance = position_value * (1 - self.trading_fee)
|
||||
profit_pct = (next_price - self.trade_entry_price) / self.trade_entry_price
|
||||
reward += profit_pct * 10
|
||||
# Sell remaining position at the final price
|
||||
final_balance = self.balance + (self.position * next_price)
|
||||
|
||||
# Update win/loss count
|
||||
if profit_pct > 0:
|
||||
self.wins += 1
|
||||
else:
|
||||
self.losses += 1
|
||||
# Calculate final portfolio value and return
|
||||
final_return_pct = ((final_balance - self.initial_balance) / self.initial_balance) * 100
|
||||
|
||||
# Add big reward/penalty based on overall performance
|
||||
performance_reward = final_return_pct * 0.1
|
||||
reward += performance_reward
|
||||
reward_components['final_performance'] = performance_reward
|
||||
|
||||
logger.info(f"Episode ended. Final balance: {final_balance:.4f}, Return: {final_return_pct:.2f}%")
|
||||
|
||||
# Get the next observation
|
||||
observation = self._get_observation()
|
||||
# Get future prices for evaluation (1-hour and 1-day ahead)
|
||||
info['future_prices'] = {}
|
||||
|
||||
# Calculate metrics for info
|
||||
total_value = self.balance + self.position * next_price
|
||||
gain = (total_value - self.initial_balance) / self.initial_balance
|
||||
self.win_rate = self.wins / max(1, self.trades)
|
||||
# 1-hour future price if hourly data is available
|
||||
if hasattr(self, 'features_1h') and self.features_1h is not None:
|
||||
# Find the closest hourly data point
|
||||
if self.current_step < len(self.features_1m):
|
||||
current_time = self.current_step # Use as index for simplicity
|
||||
hourly_idx = min(current_time // 60, len(self.features_1h) - 1) # Assuming 60 minutes per hour
|
||||
if hourly_idx < len(self.features_1h) - 1:
|
||||
future_1h_price = self.features_1h[hourly_idx + 1, -1]
|
||||
info['future_prices']['1h'] = future_1h_price
|
||||
|
||||
# Check if we have prediction data for future timeframes
|
||||
future_price_1h = None
|
||||
future_price_1d = None
|
||||
# 1-day future price if daily data is available
|
||||
if hasattr(self, 'features_1d') and self.features_1d is not None:
|
||||
# Find the closest daily data point
|
||||
if self.current_step < len(self.features_1m):
|
||||
current_time = self.current_step # Use as index for simplicity
|
||||
daily_idx = min(current_time // 1440, len(self.features_1d) - 1) # Assuming 1440 minutes per day
|
||||
if daily_idx < len(self.features_1d) - 1:
|
||||
future_1d_price = self.features_1d[daily_idx + 1, -1]
|
||||
info['future_prices']['1d'] = future_1d_price
|
||||
|
||||
# Get hourly index
|
||||
idx_1h = self.current_step // 60
|
||||
if idx_1h + 1 < len(self.features_1h):
|
||||
hourly_close_idx = self.features_1h.shape[1] - 1 # Assuming close is last column
|
||||
current_1h_price = self.features_1h[idx_1h, hourly_close_idx]
|
||||
next_1h_price = self.features_1h[idx_1h + 1, hourly_close_idx]
|
||||
future_price_1h = (next_1h_price - current_1h_price) / current_1h_price
|
||||
# Get next observation
|
||||
next_state = self._get_observation()
|
||||
|
||||
# Get daily index
|
||||
idx_1d = idx_1h // 24
|
||||
if idx_1d + 1 < len(self.features_1d):
|
||||
daily_close_idx = self.features_1d.shape[1] - 1 # Assuming close is last column
|
||||
current_1d_price = self.features_1d[idx_1d, daily_close_idx]
|
||||
next_1d_price = self.features_1d[idx_1d + 1, daily_close_idx]
|
||||
future_price_1d = (next_1d_price - current_1d_price) / current_1d_price
|
||||
# Store reward components in info
|
||||
info['reward_components'] = reward_components
|
||||
|
||||
info = {
|
||||
'balance': self.balance,
|
||||
'position': self.position,
|
||||
'total_value': total_value,
|
||||
'gain': gain,
|
||||
'trades': self.trades,
|
||||
'win_rate': self.win_rate,
|
||||
'profit_pct': profit_pct if action == 1 and self.position == 0 else None,
|
||||
'current_price': current_price,
|
||||
'next_price': next_price,
|
||||
'future_price_1h': future_price_1h, # Actual future hourly price change
|
||||
'future_price_1d': future_price_1d # Actual future daily price change
|
||||
}
|
||||
# Clip reward to prevent extreme values
|
||||
reward = np.clip(reward, -10.0, 10.0)
|
||||
|
||||
# Call the callback if it exists
|
||||
if self.action_callback:
|
||||
self.action_callback(action, current_price, reward, info)
|
||||
|
||||
return observation, reward, done, info
|
||||
return next_state, reward, done, info
|
||||
|
||||
def set_action_callback(self, callback):
|
||||
"""
|
||||
@ -375,9 +402,9 @@ class RLTradingEnvironment(gym.Env):
|
||||
|
||||
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent",
|
||||
action_callback=None, episode_callback=None, symbol="BTC/USDT",
|
||||
pretrain_price_prediction_enabled=True, pretrain_epochs=10):
|
||||
pretrain_price_prediction_enabled=False, pretrain_epochs=10):
|
||||
"""
|
||||
Train a reinforcement learning agent for trading
|
||||
Train a reinforcement learning agent for trading using ONLY real market data
|
||||
|
||||
Args:
|
||||
env_class: Optional environment class override
|
||||
@ -387,34 +414,38 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo
|
||||
action_callback: Callback function for monitoring actions
|
||||
episode_callback: Callback function for monitoring episodes
|
||||
symbol: Trading symbol to use
|
||||
pretrain_price_prediction_enabled: Whether to pre-train price prediction
|
||||
pretrain_epochs: Number of epochs for pre-training
|
||||
pretrain_price_prediction_enabled: DEPRECATED - No longer supported (synthetic data not used)
|
||||
pretrain_epochs: DEPRECATED - No longer supported (synthetic data not used)
|
||||
|
||||
Returns:
|
||||
tuple: (trained agent, environment)
|
||||
"""
|
||||
# Load data for the selected symbol
|
||||
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m'])
|
||||
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m', '1h', '1d'])
|
||||
|
||||
try:
|
||||
# Try to load data for the requested symbol using get_historical_data method
|
||||
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=5000)
|
||||
data_5m = data_interface.get_historical_data(timeframe='5m', n_candles=5000)
|
||||
data_15m = data_interface.get_historical_data(timeframe='15m', n_candles=5000)
|
||||
data_1h = data_interface.get_historical_data(timeframe='1h', n_candles=1000)
|
||||
data_1d = data_interface.get_historical_data(timeframe='1d', n_candles=500)
|
||||
|
||||
if data_1m is None or data_5m is None or data_15m is None:
|
||||
raise FileNotFoundError("Could not retrieve data for specified symbol")
|
||||
if data_1m is None or data_5m is None or data_15m is None or data_1h is None or data_1d is None:
|
||||
raise FileNotFoundError("Could not retrieve all required timeframes data for specified symbol")
|
||||
except Exception as e:
|
||||
logger.warning(f"Data for {symbol} not available: {str(e)}. Using default data.")
|
||||
logger.warning(f"Data for {symbol} not available: {str(e)}. Using default cached data.")
|
||||
# Try to use cached data if available
|
||||
symbol = "BTC/USDT"
|
||||
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m'])
|
||||
data_interface = DataInterface(symbol=symbol, timeframes=['1m', '5m', '15m', '1h', '1d'])
|
||||
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=5000)
|
||||
data_5m = data_interface.get_historical_data(timeframe='5m', n_candles=5000)
|
||||
data_15m = data_interface.get_historical_data(timeframe='15m', n_candles=5000)
|
||||
data_1h = data_interface.get_historical_data(timeframe='1h', n_candles=1000)
|
||||
data_1d = data_interface.get_historical_data(timeframe='1d', n_candles=500)
|
||||
|
||||
if data_1m is None or data_5m is None or data_15m is None:
|
||||
logger.error("Failed to retrieve any data. Cannot continue training.")
|
||||
if data_1m is None or data_5m is None or data_15m is None or data_1h is None or data_1d is None:
|
||||
logger.error("Failed to retrieve all required timeframes data. Cannot continue training.")
|
||||
raise ValueError("No data available for training")
|
||||
|
||||
# Create features from the data by adding technical indicators and converting to numpy format
|
||||
@ -447,19 +478,39 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo
|
||||
])
|
||||
else:
|
||||
features_15m = None
|
||||
|
||||
if data_1h is not None:
|
||||
data_1h = data_interface.add_technical_indicators(data_1h)
|
||||
# Convert to numpy array with close price as the last column
|
||||
features_1h = np.hstack([
|
||||
data_1h.drop(['timestamp', 'close'], axis=1).values,
|
||||
data_1h['close'].values.reshape(-1, 1)
|
||||
])
|
||||
else:
|
||||
features_1h = None
|
||||
|
||||
if data_1d is not None:
|
||||
data_1d = data_interface.add_technical_indicators(data_1d)
|
||||
# Convert to numpy array with close price as the last column
|
||||
features_1d = np.hstack([
|
||||
data_1d.drop(['timestamp', 'close'], axis=1).values,
|
||||
data_1d['close'].values.reshape(-1, 1)
|
||||
])
|
||||
else:
|
||||
features_1d = None
|
||||
|
||||
# Check if we have all the required features
|
||||
if features_1m is None or features_5m is None or features_15m is None:
|
||||
if features_1m is None or features_5m is None or features_15m is None or features_1h is None or features_1d is None:
|
||||
logger.error("Failed to create features for all timeframes.")
|
||||
raise ValueError("Could not create features for training")
|
||||
|
||||
# Create the environment
|
||||
if env_class:
|
||||
# Use provided environment class
|
||||
env = env_class(features_1m, features_5m, features_15m)
|
||||
env = env_class(features_1m, features_1h, features_1d)
|
||||
else:
|
||||
# Use the default environment
|
||||
env = RLTradingEnvironment(features_1m, features_5m, features_15m)
|
||||
env = RLTradingEnvironment(features_1m, features_1h, features_1d)
|
||||
|
||||
# Set action callback if provided
|
||||
if action_callback:
|
||||
@ -494,29 +545,10 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo
|
||||
else:
|
||||
logger.info("No existing model found. Starting with a new model.")
|
||||
|
||||
# Pre-train price prediction if enabled and we have a new model
|
||||
# Remove pre-training code since it used synthetic data
|
||||
# Pre-training with real data would require a separate implementation
|
||||
if pretrain_price_prediction_enabled:
|
||||
if not os.path.exists(model_file) or input("Pre-train price prediction? (y/n): ").lower() == 'y':
|
||||
logger.info("Pre-training price prediction capability...")
|
||||
# Attempt to load hourly and daily data for pre-training
|
||||
try:
|
||||
data_interface.add_timeframe('1h')
|
||||
data_interface.add_timeframe('1d')
|
||||
|
||||
# Run pre-training
|
||||
agent = pretrain_price_prediction(
|
||||
agent=agent,
|
||||
data_interface=data_interface,
|
||||
n_epochs=pretrain_epochs,
|
||||
batch_size=128
|
||||
)
|
||||
|
||||
# Save the pre-trained model
|
||||
agent.save(f"{save_path}_pretrained")
|
||||
logger.info("Pre-trained model saved.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during pre-training: {e}")
|
||||
logger.warning("Continuing with RL training without pre-training.")
|
||||
logger.warning("Pre-training with synthetic data is no longer supported. Continuing with RL training only.")
|
||||
|
||||
# Create TensorBoard writer
|
||||
writer = SummaryWriter(log_dir=f'runs/dqn_{int(time.time())}')
|
||||
@ -582,8 +614,8 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo
|
||||
total_rewards.append(total_reward)
|
||||
|
||||
# Calculate trading metrics
|
||||
win_rate = env.win_rate if hasattr(env, 'win_rate') else 0
|
||||
trades = env.trades if hasattr(env, 'trades') else 0
|
||||
win_rate = env.wins / max(1, env.trades)
|
||||
trades = env.trades
|
||||
|
||||
# Log to TensorBoard
|
||||
writer.add_scalar('Reward/Episode', total_reward, episode)
|
||||
@ -621,379 +653,5 @@ def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/mo
|
||||
|
||||
return agent, env
|
||||
|
||||
def generate_price_prediction_training_data(data_1m, data_1h, data_1d, window_size=20):
|
||||
"""
|
||||
Generate labeled training data for price prediction at different timeframes
|
||||
|
||||
Args:
|
||||
data_1m: DataFrame with 1-minute data
|
||||
data_1h: DataFrame with 1-hour data
|
||||
data_1d: DataFrame with 1-day data
|
||||
window_size: Size of the input window
|
||||
|
||||
Returns:
|
||||
tuple: (X, y_immediate, y_midterm, y_longterm, y_values)
|
||||
- X: input features (window sequences)
|
||||
- y_immediate: immediate direction labels (0=down, 1=sideways, 2=up)
|
||||
- y_midterm: mid-term direction labels
|
||||
- y_longterm: long-term direction labels
|
||||
- y_values: actual percentage changes for each timeframe
|
||||
"""
|
||||
logger.info("Generating price prediction training data from historical prices")
|
||||
|
||||
# Prepare data structures
|
||||
X = []
|
||||
y_immediate = [] # 1m
|
||||
y_midterm = [] # 1h
|
||||
y_longterm = [] # 1d
|
||||
y_values = [] # Actual percentage changes
|
||||
|
||||
# Calculate future returns for labeling
|
||||
data_1m['future_return_1m'] = data_1m['close'].pct_change(1).shift(-1) # Next candle
|
||||
data_1m['future_return_10m'] = data_1m['close'].pct_change(10).shift(-10) # Next 10 candles
|
||||
|
||||
# Add indices to align data
|
||||
data_1m['index'] = range(len(data_1m))
|
||||
data_1h['index'] = range(len(data_1h))
|
||||
data_1d['index'] = range(len(data_1d))
|
||||
|
||||
# Define thresholds for direction labels
|
||||
immediate_threshold = 0.0005
|
||||
midterm_threshold = 0.001
|
||||
longterm_threshold = 0.002
|
||||
|
||||
# Loop through 1m data to create training samples
|
||||
max_idx = len(data_1m) - window_size - 10 # Ensure we have future data for labels
|
||||
sample_indices = random.sample(range(window_size, max_idx), min(10000, max_idx - window_size))
|
||||
|
||||
for idx in sample_indices:
|
||||
# Get window of 1m data
|
||||
window_1m = data_1m.iloc[idx-window_size:idx].drop(['timestamp', 'future_return_1m', 'future_return_10m', 'index'], axis=1, errors='ignore')
|
||||
|
||||
# Skip if window contains NaN
|
||||
if window_1m.isnull().values.any():
|
||||
continue
|
||||
|
||||
# Get future returns for labeling
|
||||
future_return_1m = data_1m.iloc[idx]['future_return_1m']
|
||||
future_return_10m = data_1m.iloc[idx]['future_return_10m']
|
||||
|
||||
# Find corresponding row in 1h data (closest timestamp)
|
||||
current_timestamp = data_1m.iloc[idx]['timestamp']
|
||||
|
||||
# Find 1h candle for mid-term prediction
|
||||
if 'timestamp' in data_1h.columns:
|
||||
# Find closest 1h candle
|
||||
closest_1h_idx = data_1h['timestamp'].searchsorted(current_timestamp)
|
||||
if closest_1h_idx >= len(data_1h):
|
||||
closest_1h_idx = len(data_1h) - 1
|
||||
|
||||
# Get future 1h return (next candle)
|
||||
if closest_1h_idx < len(data_1h) - 1:
|
||||
future_return_1h = (data_1h.iloc[closest_1h_idx + 1]['close'] - data_1h.iloc[closest_1h_idx]['close']) / data_1h.iloc[closest_1h_idx]['close']
|
||||
else:
|
||||
future_return_1h = 0
|
||||
else:
|
||||
future_return_1h = future_return_10m # Fallback
|
||||
|
||||
# Find 1d candle for long-term prediction
|
||||
if 'timestamp' in data_1d.columns:
|
||||
# Find closest 1d candle
|
||||
closest_1d_idx = data_1d['timestamp'].searchsorted(current_timestamp)
|
||||
if closest_1d_idx >= len(data_1d):
|
||||
closest_1d_idx = len(data_1d) - 1
|
||||
|
||||
# Get future 1d return (next candle)
|
||||
if closest_1d_idx < len(data_1d) - 1:
|
||||
future_return_1d = (data_1d.iloc[closest_1d_idx + 1]['close'] - data_1d.iloc[closest_1d_idx]['close']) / data_1d.iloc[closest_1d_idx]['close']
|
||||
else:
|
||||
future_return_1d = 0
|
||||
else:
|
||||
future_return_1d = future_return_1h * 2 # Fallback
|
||||
|
||||
# Create direction labels
|
||||
# 0=down, 1=sideways, 2=up
|
||||
|
||||
# Immediate (1m)
|
||||
if future_return_1m > immediate_threshold:
|
||||
immediate_label = 2 # UP
|
||||
elif future_return_1m < -immediate_threshold:
|
||||
immediate_label = 0 # DOWN
|
||||
else:
|
||||
immediate_label = 1 # SIDEWAYS
|
||||
|
||||
# Mid-term (1h)
|
||||
if future_return_1h > midterm_threshold:
|
||||
midterm_label = 2 # UP
|
||||
elif future_return_1h < -midterm_threshold:
|
||||
midterm_label = 0 # DOWN
|
||||
else:
|
||||
midterm_label = 1 # SIDEWAYS
|
||||
|
||||
# Long-term (1d)
|
||||
if future_return_1d > longterm_threshold:
|
||||
longterm_label = 2 # UP
|
||||
elif future_return_1d < -longterm_threshold:
|
||||
longterm_label = 0 # DOWN
|
||||
else:
|
||||
longterm_label = 1 # SIDEWAYS
|
||||
|
||||
# Store data
|
||||
X.append(window_1m.values)
|
||||
y_immediate.append(immediate_label)
|
||||
y_midterm.append(midterm_label)
|
||||
y_longterm.append(longterm_label)
|
||||
y_values.append([future_return_1m, future_return_1h, future_return_1d, future_return_1d * 1.5]) # Add weekly estimate
|
||||
|
||||
# Convert to numpy arrays
|
||||
X = np.array(X)
|
||||
y_immediate = np.array(y_immediate)
|
||||
y_midterm = np.array(y_midterm)
|
||||
y_longterm = np.array(y_longterm)
|
||||
y_values = np.array(y_values)
|
||||
|
||||
logger.info(f"Generated {len(X)} price prediction training samples")
|
||||
|
||||
# Log class distribution
|
||||
for name, y in [("Immediate", y_immediate), ("Mid-term", y_midterm), ("Long-term", y_longterm)]:
|
||||
down = (y == 0).sum()
|
||||
sideways = (y == 1).sum()
|
||||
up = (y == 2).sum()
|
||||
logger.info(f"{name} direction distribution: DOWN={down} ({down/len(y)*100:.1f}%), "
|
||||
f"SIDEWAYS={sideways} ({sideways/len(y)*100:.1f}%), "
|
||||
f"UP={up} ({up/len(y)*100:.1f}%)")
|
||||
|
||||
return X, y_immediate, y_midterm, y_longterm, y_values
|
||||
|
||||
def pretrain_price_prediction(agent, data_interface, n_epochs=10, batch_size=128):
|
||||
"""
|
||||
Pre-train the agent's price prediction capability on historical data
|
||||
|
||||
Args:
|
||||
agent: DQNAgent instance to train
|
||||
data_interface: DataInterface instance for accessing data
|
||||
n_epochs: Number of epochs for training
|
||||
batch_size: Batch size for training
|
||||
|
||||
Returns:
|
||||
The agent with pre-trained price prediction capabilities
|
||||
"""
|
||||
logger.info("Starting supervised pre-training of price prediction")
|
||||
|
||||
try:
|
||||
# Load data for all required timeframes
|
||||
data_1m = data_interface.get_historical_data(timeframe='1m', n_candles=10000)
|
||||
data_1h = data_interface.get_historical_data(timeframe='1h', n_candles=1000)
|
||||
data_1d = data_interface.get_historical_data(timeframe='1d', n_candles=500)
|
||||
|
||||
# Check if data is available
|
||||
if data_1m is None:
|
||||
logger.warning("1m data not available for pre-training")
|
||||
return agent
|
||||
|
||||
if data_1h is None:
|
||||
logger.warning("1h data not available, using synthesized data")
|
||||
# Create synthetic 1h data from 1m data
|
||||
data_1h = data_1m.iloc[::60].reset_index(drop=True).copy() # Take every 60th record
|
||||
|
||||
if data_1d is None:
|
||||
logger.warning("1d data not available, using synthesized data")
|
||||
# Create synthetic 1d data from 1h data
|
||||
data_1d = data_1h.iloc[::24].reset_index(drop=True).copy() # Take every 24th record
|
||||
|
||||
# Add technical indicators to all data
|
||||
data_1m = data_interface.add_technical_indicators(data_1m)
|
||||
data_1h = data_interface.add_technical_indicators(data_1h)
|
||||
data_1d = data_interface.add_technical_indicators(data_1d)
|
||||
|
||||
# Generate labeled training data
|
||||
X, y_immediate, y_midterm, y_longterm, y_values = generate_price_prediction_training_data(
|
||||
data_1m, data_1h, data_1d, window_size=20
|
||||
)
|
||||
|
||||
# Split data into training and validation sets
|
||||
from sklearn.model_selection import train_test_split
|
||||
X_train, X_val, y_imm_train, y_imm_val, y_mid_train, y_mid_val, y_long_train, y_long_val, y_val_train, y_val_val = train_test_split(
|
||||
X, y_immediate, y_midterm, y_longterm, y_values, test_size=0.2, random_state=42
|
||||
)
|
||||
|
||||
# Convert to torch tensors
|
||||
X_train_tensor = torch.FloatTensor(X_train).to(agent.device)
|
||||
y_imm_train_tensor = torch.LongTensor(y_imm_train).to(agent.device)
|
||||
y_mid_train_tensor = torch.LongTensor(y_mid_train).to(agent.device)
|
||||
y_long_train_tensor = torch.LongTensor(y_long_train).to(agent.device)
|
||||
y_val_train_tensor = torch.FloatTensor(y_val_train).to(agent.device)
|
||||
|
||||
X_val_tensor = torch.FloatTensor(X_val).to(agent.device)
|
||||
y_imm_val_tensor = torch.LongTensor(y_imm_val).to(agent.device)
|
||||
y_mid_val_tensor = torch.LongTensor(y_mid_val).to(agent.device)
|
||||
y_long_val_tensor = torch.LongTensor(y_long_val).to(agent.device)
|
||||
y_val_val_tensor = torch.FloatTensor(y_val_val).to(agent.device)
|
||||
|
||||
# Calculate class weights for imbalanced data
|
||||
from torch.nn.functional import one_hot
|
||||
|
||||
# Function to calculate class weights
|
||||
def get_class_weights(labels):
|
||||
counts = np.bincount(labels)
|
||||
if len(counts) < 3: # Ensure we have 3 classes
|
||||
counts = np.append(counts, [0] * (3 - len(counts)))
|
||||
weights = 1.0 / np.array(counts)
|
||||
weights = weights / np.sum(weights) # Normalize
|
||||
return weights
|
||||
|
||||
imm_weights = torch.FloatTensor(get_class_weights(y_imm_train)).to(agent.device)
|
||||
mid_weights = torch.FloatTensor(get_class_weights(y_mid_train)).to(agent.device)
|
||||
long_weights = torch.FloatTensor(get_class_weights(y_long_train)).to(agent.device)
|
||||
|
||||
# Create DataLoader for batch training
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
|
||||
train_dataset = TensorDataset(
|
||||
X_train_tensor, y_imm_train_tensor, y_mid_train_tensor,
|
||||
y_long_train_tensor, y_val_train_tensor
|
||||
)
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
# Set up loss functions with class weights
|
||||
imm_criterion = nn.CrossEntropyLoss(weight=imm_weights)
|
||||
mid_criterion = nn.CrossEntropyLoss(weight=mid_weights)
|
||||
long_criterion = nn.CrossEntropyLoss(weight=long_weights)
|
||||
value_criterion = nn.MSELoss()
|
||||
|
||||
# Set up optimizer (separate from agent's optimizer)
|
||||
pretrain_optimizer = torch.optim.Adam(agent.policy_net.parameters(), lr=0.0002)
|
||||
pretrain_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
pretrain_optimizer, mode='min', factor=0.5, patience=3, verbose=True
|
||||
)
|
||||
|
||||
# Set model to training mode
|
||||
agent.policy_net.train()
|
||||
|
||||
# Training loop
|
||||
best_val_loss = float('inf')
|
||||
patience = 5
|
||||
patience_counter = 0
|
||||
|
||||
for epoch in range(n_epochs):
|
||||
# Training phase
|
||||
train_loss = 0.0
|
||||
imm_correct, mid_correct, long_correct = 0, 0, 0
|
||||
total = 0
|
||||
|
||||
for X_batch, y_imm_batch, y_mid_batch, y_long_batch, y_val_batch in train_loader:
|
||||
# Zero gradients
|
||||
pretrain_optimizer.zero_grad()
|
||||
|
||||
# Forward pass - we only need the price predictions
|
||||
with torch.cuda.amp.autocast() if agent.use_mixed_precision else contextlib.nullcontext():
|
||||
_, _, price_preds = agent.policy_net(X_batch)
|
||||
|
||||
# Calculate losses for each prediction head
|
||||
imm_loss = imm_criterion(price_preds['immediate'], y_imm_batch)
|
||||
mid_loss = mid_criterion(price_preds['midterm'], y_mid_batch)
|
||||
long_loss = long_criterion(price_preds['longterm'], y_long_batch)
|
||||
value_loss = value_criterion(price_preds['values'], y_val_batch)
|
||||
|
||||
# Combined loss (weighted by importance)
|
||||
total_loss = imm_loss + 0.7 * mid_loss + 0.5 * long_loss + 0.3 * value_loss
|
||||
|
||||
# Backward pass and optimize
|
||||
if agent.use_mixed_precision:
|
||||
agent.scaler.scale(total_loss).backward()
|
||||
agent.scaler.unscale_(pretrain_optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
|
||||
agent.scaler.step(pretrain_optimizer)
|
||||
agent.scaler.update()
|
||||
else:
|
||||
total_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(agent.policy_net.parameters(), 1.0)
|
||||
pretrain_optimizer.step()
|
||||
|
||||
# Accumulate metrics
|
||||
train_loss += total_loss.item()
|
||||
total += X_batch.size(0)
|
||||
|
||||
# Calculate accuracy
|
||||
_, imm_pred = torch.max(price_preds['immediate'], 1)
|
||||
_, mid_pred = torch.max(price_preds['midterm'], 1)
|
||||
_, long_pred = torch.max(price_preds['longterm'], 1)
|
||||
|
||||
imm_correct += (imm_pred == y_imm_batch).sum().item()
|
||||
mid_correct += (mid_pred == y_mid_batch).sum().item()
|
||||
long_correct += (long_pred == y_long_batch).sum().item()
|
||||
|
||||
# Calculate epoch metrics
|
||||
train_loss /= len(train_loader)
|
||||
imm_acc = imm_correct / total
|
||||
mid_acc = mid_correct / total
|
||||
long_acc = long_correct / total
|
||||
|
||||
# Validation phase
|
||||
agent.policy_net.eval()
|
||||
val_loss = 0.0
|
||||
imm_val_correct, mid_val_correct, long_val_correct = 0, 0, 0
|
||||
|
||||
with torch.no_grad():
|
||||
# Forward pass on validation data
|
||||
_, _, val_price_preds = agent.policy_net(X_val_tensor)
|
||||
|
||||
# Calculate validation losses
|
||||
val_imm_loss = imm_criterion(val_price_preds['immediate'], y_imm_val_tensor)
|
||||
val_mid_loss = mid_criterion(val_price_preds['midterm'], y_mid_val_tensor)
|
||||
val_long_loss = long_criterion(val_price_preds['longterm'], y_long_val_tensor)
|
||||
val_value_loss = value_criterion(val_price_preds['values'], y_val_val_tensor)
|
||||
|
||||
val_total_loss = val_imm_loss + 0.7 * val_mid_loss + 0.5 * val_long_loss + 0.3 * val_value_loss
|
||||
val_loss = val_total_loss.item()
|
||||
|
||||
# Calculate validation accuracy
|
||||
_, imm_val_pred = torch.max(val_price_preds['immediate'], 1)
|
||||
_, mid_val_pred = torch.max(val_price_preds['midterm'], 1)
|
||||
_, long_val_pred = torch.max(val_price_preds['longterm'], 1)
|
||||
|
||||
imm_val_correct = (imm_val_pred == y_imm_val_tensor).sum().item()
|
||||
mid_val_correct = (mid_val_pred == y_mid_val_tensor).sum().item()
|
||||
long_val_correct = (long_val_pred == y_long_val_tensor).sum().item()
|
||||
|
||||
imm_val_acc = imm_val_correct / len(X_val_tensor)
|
||||
mid_val_acc = mid_val_correct / len(X_val_tensor)
|
||||
long_val_acc = long_val_correct / len(X_val_tensor)
|
||||
|
||||
# Learning rate scheduling
|
||||
pretrain_scheduler.step(val_loss)
|
||||
|
||||
# Early stopping check
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
patience_counter = 0
|
||||
# Copy policy_net weights to target_net
|
||||
agent.target_net.load_state_dict(agent.policy_net.state_dict())
|
||||
logger.info(f"Saved best model with validation loss: {val_loss:.4f}")
|
||||
else:
|
||||
patience_counter += 1
|
||||
if patience_counter >= patience:
|
||||
logger.info(f"Early stopping triggered after {epoch+1} epochs")
|
||||
break
|
||||
|
||||
# Log progress
|
||||
logger.info(f"Epoch {epoch+1}/{n_epochs}: "
|
||||
f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
|
||||
f"Imm Acc: {imm_acc:.4f}/{imm_val_acc:.4f}, "
|
||||
f"Mid Acc: {mid_acc:.4f}/{mid_val_acc:.4f}, "
|
||||
f"Long Acc: {long_acc:.4f}/{long_val_acc:.4f}")
|
||||
|
||||
# Set model back to training mode for next epoch
|
||||
agent.policy_net.train()
|
||||
|
||||
logger.info("Price prediction pre-training complete")
|
||||
return agent
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during price prediction pre-training: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return agent
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_rl()
|
Binary file not shown.
Binary file not shown.
@ -373,7 +373,7 @@ class DataInterface:
|
||||
|
||||
return df_copy
|
||||
|
||||
def calculate_pnl(self, predictions, actual_prices, position_size=1.0):
|
||||
def calculate_pnl(self, predictions, actual_prices, position_size=1.0, fee_rate=0.0002):
|
||||
"""
|
||||
Robust PnL calculator that handles:
|
||||
- Action predictions (0=SELL, 1=HOLD, 2=BUY)
|
||||
@ -384,6 +384,7 @@ class DataInterface:
|
||||
predictions: Array of predicted actions or probabilities
|
||||
actual_prices: Array of actual prices (can be 1D or 2D OHLC format)
|
||||
position_size: Position size multiplier
|
||||
fee_rate: Trading fee rate (default: 0.0002 for 0.02% per trade)
|
||||
|
||||
Returns:
|
||||
tuple: (total_pnl, win_rate, trades)
|
||||
@ -443,13 +444,33 @@ class DataInterface:
|
||||
price_change = (next_price - current_price) / current_price
|
||||
|
||||
if action == 2: # BUY
|
||||
trade_pnl = price_change * position_size
|
||||
# Calculate raw PnL
|
||||
raw_pnl = price_change * position_size
|
||||
|
||||
# Calculate fees (entry and exit)
|
||||
entry_fee = position_size * fee_rate
|
||||
exit_fee = position_size * (1 + price_change) * fee_rate
|
||||
total_fees = entry_fee + exit_fee
|
||||
|
||||
# Net PnL after fees
|
||||
trade_pnl = raw_pnl - total_fees
|
||||
|
||||
trade_type = 'BUY'
|
||||
is_win = price_change > 0
|
||||
is_win = trade_pnl > 0
|
||||
elif action == 0: # SELL
|
||||
trade_pnl = -price_change * position_size
|
||||
# Calculate raw PnL
|
||||
raw_pnl = -price_change * position_size
|
||||
|
||||
# Calculate fees (entry and exit)
|
||||
entry_fee = position_size * fee_rate
|
||||
exit_fee = position_size * (1 - price_change) * fee_rate
|
||||
total_fees = entry_fee + exit_fee
|
||||
|
||||
# Net PnL after fees
|
||||
trade_pnl = raw_pnl - total_fees
|
||||
|
||||
trade_type = 'SELL'
|
||||
is_win = price_change < 0
|
||||
is_win = trade_pnl > 0
|
||||
else:
|
||||
continue # Invalid action
|
||||
|
||||
@ -462,6 +483,8 @@ class DataInterface:
|
||||
'entry': current_price,
|
||||
'exit': next_price,
|
||||
'pnl': trade_pnl,
|
||||
'raw_pnl': price_change * position_size if trade_type == 'BUY' else -price_change * position_size,
|
||||
'fees': total_fees,
|
||||
'win': is_win,
|
||||
'duration': 1 # In number of candles
|
||||
})
|
||||
|
Reference in New Issue
Block a user