models overhaul
This commit is contained in:
@ -23,11 +23,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class DQNNetwork(nn.Module):
|
||||
"""
|
||||
Massive Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
|
||||
Configurable Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
|
||||
Handles 7850 input features from multi-timeframe, multi-asset data
|
||||
TARGET: 50M parameters for enhanced learning capacity
|
||||
Architecture is configurable via config.yaml
|
||||
"""
|
||||
def __init__(self, input_dim: int, n_actions: int):
|
||||
def __init__(self, input_dim: int, n_actions: int, config: dict = None):
|
||||
super(DQNNetwork, self).__init__()
|
||||
|
||||
# Handle different input dimension formats
|
||||
@ -41,59 +41,65 @@ class DQNNetwork(nn.Module):
|
||||
|
||||
self.n_actions = n_actions
|
||||
|
||||
# MASSIVE network architecture optimized for trading features
|
||||
# Target: ~50M parameters
|
||||
self.feature_extractor = nn.Sequential(
|
||||
# Initial feature extraction with massive width
|
||||
nn.Linear(self.input_size, 8192), # 7850 -> 8192 = ~64M weights
|
||||
nn.LayerNorm(8192),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
|
||||
# Deep feature processing layers
|
||||
nn.Linear(8192, 6144), # 8192 -> 6144 = ~50M weights
|
||||
nn.LayerNorm(6144),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
|
||||
nn.Linear(6144, 4096), # 6144 -> 4096 = ~25M weights
|
||||
nn.LayerNorm(4096),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
|
||||
nn.Linear(4096, 3072), # 4096 -> 3072 = ~12M weights
|
||||
nn.LayerNorm(3072),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
|
||||
nn.Linear(3072, 2048), # 3072 -> 2048 = ~6M weights
|
||||
nn.LayerNorm(2048),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
)
|
||||
# Get network architecture from config or use defaults
|
||||
if config and 'network_architecture' in config:
|
||||
arch_config = config['network_architecture']
|
||||
feature_layers = arch_config.get('feature_layers', [4096, 3072, 2048, 1536, 1024])
|
||||
regime_head = arch_config.get('regime_head', [512, 256])
|
||||
price_direction_head = arch_config.get('price_direction_head', [512, 256])
|
||||
volatility_head = arch_config.get('volatility_head', [512, 128])
|
||||
value_head = arch_config.get('value_head', [512, 256])
|
||||
advantage_head = arch_config.get('advantage_head', [512, 256])
|
||||
dropout_rate = arch_config.get('dropout_rate', 0.1)
|
||||
use_layer_norm = arch_config.get('use_layer_norm', True)
|
||||
else:
|
||||
# Default reduced architecture (half the original size)
|
||||
feature_layers = [4096, 3072, 2048, 1536, 1024]
|
||||
regime_head = [512, 256]
|
||||
price_direction_head = [512, 256]
|
||||
volatility_head = [512, 128]
|
||||
value_head = [512, 256]
|
||||
advantage_head = [512, 256]
|
||||
dropout_rate = 0.1
|
||||
use_layer_norm = True
|
||||
|
||||
# Build configurable feature extractor
|
||||
feature_layers_list = []
|
||||
prev_size = self.input_size
|
||||
|
||||
for layer_size in feature_layers:
|
||||
feature_layers_list.append(nn.Linear(prev_size, layer_size))
|
||||
if use_layer_norm:
|
||||
feature_layers_list.append(nn.LayerNorm(layer_size))
|
||||
feature_layers_list.append(nn.ReLU(inplace=True))
|
||||
feature_layers_list.append(nn.Dropout(dropout_rate))
|
||||
prev_size = layer_size
|
||||
|
||||
self.feature_extractor = nn.Sequential(*feature_layers_list)
|
||||
self.feature_size = feature_layers[-1] # Final feature size
|
||||
|
||||
# Build configurable network heads
|
||||
def build_head_layers(input_size, layer_sizes, output_size):
|
||||
layers = []
|
||||
prev_size = input_size
|
||||
for layer_size in layer_sizes:
|
||||
layers.append(nn.Linear(prev_size, layer_size))
|
||||
if use_layer_norm:
|
||||
layers.append(nn.LayerNorm(layer_size))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
layers.append(nn.Dropout(dropout_rate))
|
||||
prev_size = layer_size
|
||||
layers.append(nn.Linear(prev_size, output_size))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
# Market regime detection head
|
||||
self.regime_head = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(512, 4) # trending, ranging, volatile, mixed
|
||||
self.regime_head = build_head_layers(
|
||||
self.feature_size, regime_head, 4 # trending, ranging, volatile, mixed
|
||||
)
|
||||
|
||||
# Price direction prediction head - outputs direction and confidence
|
||||
self.price_direction_head = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(512, 2) # [direction, confidence]
|
||||
self.price_direction_head = build_head_layers(
|
||||
self.feature_size, price_direction_head, 2 # [direction, confidence]
|
||||
)
|
||||
|
||||
# Direction activation (tanh for -1 to 1)
|
||||
@ -102,38 +108,18 @@ class DQNNetwork(nn.Module):
|
||||
self.confidence_activation = nn.Sigmoid()
|
||||
|
||||
# Volatility prediction head
|
||||
self.volatility_head = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 256),
|
||||
nn.LayerNorm(256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(256, 4) # predicted volatility for 4 timeframes
|
||||
self.volatility_head = build_head_layers(
|
||||
self.feature_size, volatility_head, 4 # predicted volatility for 4 timeframes
|
||||
)
|
||||
|
||||
# Main Q-value head (dueling architecture)
|
||||
self.value_head = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(512, 1) # State value
|
||||
self.value_head = build_head_layers(
|
||||
self.feature_size, value_head, 1 # Single value for dueling architecture
|
||||
)
|
||||
|
||||
self.advantage_head = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(512, n_actions) # Action advantages
|
||||
# Advantage head (dueling architecture)
|
||||
self.advantage_head = build_head_layers(
|
||||
self.feature_size, advantage_head, n_actions # Action advantages
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
@ -248,7 +234,8 @@ class DQNAgent:
|
||||
priority_memory: bool = True,
|
||||
device=None,
|
||||
model_name: str = "dqn_agent",
|
||||
enable_checkpoints: bool = True):
|
||||
enable_checkpoints: bool = True,
|
||||
config: dict = None):
|
||||
|
||||
# Checkpoint management
|
||||
self.model_name = model_name
|
||||
@ -292,8 +279,8 @@ class DQNAgent:
|
||||
logger.info(f"DQN Agent using device: {self.device}")
|
||||
|
||||
# Initialize models with RL-specific network architecture
|
||||
self.policy_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device)
|
||||
self.target_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device)
|
||||
self.policy_net = DQNNetwork(self.state_dim, self.n_actions, config).to(self.device)
|
||||
self.target_net = DQNNetwork(self.state_dim, self.n_actions, config).to(self.device)
|
||||
|
||||
# Ensure models are on the correct device
|
||||
self.policy_net = self.policy_net.to(self.device)
|
||||
|
Reference in New Issue
Block a user