models overhaul
This commit is contained in:
@ -1,201 +1,201 @@
|
||||
"""
|
||||
Legacy CNN Model Compatibility Layer
|
||||
# """
|
||||
# Legacy CNN Model Compatibility Layer
|
||||
|
||||
This module provides compatibility redirects to the unified StandardizedCNN model.
|
||||
All legacy models (EnhancedCNNModel, CNNModelTrainer, CNNModel) have been retired
|
||||
in favor of the StandardizedCNN architecture.
|
||||
"""
|
||||
# This module provides compatibility redirects to the unified StandardizedCNN model.
|
||||
# All legacy models (EnhancedCNNModel, CNNModelTrainer, CNNModel) have been retired
|
||||
# in favor of the StandardizedCNN architecture.
|
||||
# """
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Tuple, Dict, Any, Optional
|
||||
import torch
|
||||
import numpy as np
|
||||
# import logging
|
||||
# import warnings
|
||||
# from typing import Tuple, Dict, Any, Optional
|
||||
# import torch
|
||||
# import numpy as np
|
||||
|
||||
# Import the standardized CNN model
|
||||
from .standardized_cnn import StandardizedCNN
|
||||
# # Import the standardized CNN model
|
||||
# from .standardized_cnn import StandardizedCNN
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
# Compatibility aliases and wrappers
|
||||
class EnhancedCNNModel:
|
||||
"""Legacy compatibility wrapper - redirects to StandardizedCNN"""
|
||||
# # Compatibility aliases and wrappers
|
||||
# class EnhancedCNNModel:
|
||||
# """Legacy compatibility wrapper - redirects to StandardizedCNN"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"EnhancedCNNModel is deprecated. Use StandardizedCNN instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
# Create StandardizedCNN with default parameters
|
||||
self.standardized_cnn = StandardizedCNN()
|
||||
logger.warning("EnhancedCNNModel compatibility wrapper created - please migrate to StandardizedCNN")
|
||||
# def __init__(self, *args, **kwargs):
|
||||
# warnings.warn(
|
||||
# "EnhancedCNNModel is deprecated. Use StandardizedCNN instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# # Create StandardizedCNN with default parameters
|
||||
# self.standardized_cnn = StandardizedCNN()
|
||||
# logger.warning("EnhancedCNNModel compatibility wrapper created - please migrate to StandardizedCNN")
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Delegate all method calls to StandardizedCNN"""
|
||||
return getattr(self.standardized_cnn, name)
|
||||
# def __getattr__(self, name):
|
||||
# """Delegate all method calls to StandardizedCNN"""
|
||||
# return getattr(self.standardized_cnn, name)
|
||||
|
||||
|
||||
class CNNModelTrainer:
|
||||
"""Legacy compatibility wrapper for CNN training"""
|
||||
# class CNNModelTrainer:
|
||||
# """Legacy compatibility wrapper for CNN training"""
|
||||
|
||||
def __init__(self, model=None, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"CNNModelTrainer is deprecated. Use StandardizedCNN.train_step() instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
if isinstance(model, EnhancedCNNModel):
|
||||
self.model = model.standardized_cnn
|
||||
else:
|
||||
self.model = StandardizedCNN()
|
||||
logger.warning("CNNModelTrainer compatibility wrapper created - please use StandardizedCNN.train_step()")
|
||||
# def __init__(self, model=None, *args, **kwargs):
|
||||
# warnings.warn(
|
||||
# "CNNModelTrainer is deprecated. Use StandardizedCNN.train_step() instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# if isinstance(model, EnhancedCNNModel):
|
||||
# self.model = model.standardized_cnn
|
||||
# else:
|
||||
# self.model = StandardizedCNN()
|
||||
# logger.warning("CNNModelTrainer compatibility wrapper created - please use StandardizedCNN.train_step()")
|
||||
|
||||
def train_step(self, x, y, *args, **kwargs):
|
||||
"""Legacy train step wrapper"""
|
||||
try:
|
||||
# Convert to BaseDataInput format if needed
|
||||
if hasattr(x, 'get_feature_vector'):
|
||||
# Already BaseDataInput
|
||||
base_input = x
|
||||
else:
|
||||
# Create mock BaseDataInput for legacy compatibility
|
||||
from core.data_models import BaseDataInput
|
||||
base_input = BaseDataInput()
|
||||
# Set mock feature vector
|
||||
if isinstance(x, torch.Tensor):
|
||||
feature_vector = x.flatten().cpu().numpy()
|
||||
else:
|
||||
feature_vector = np.array(x).flatten()
|
||||
# def train_step(self, x, y, *args, **kwargs):
|
||||
# """Legacy train step wrapper"""
|
||||
# try:
|
||||
# # Convert to BaseDataInput format if needed
|
||||
# if hasattr(x, 'get_feature_vector'):
|
||||
# # Already BaseDataInput
|
||||
# base_input = x
|
||||
# else:
|
||||
# # Create mock BaseDataInput for legacy compatibility
|
||||
# from core.data_models import BaseDataInput
|
||||
# base_input = BaseDataInput()
|
||||
# # Set mock feature vector
|
||||
# if isinstance(x, torch.Tensor):
|
||||
# feature_vector = x.flatten().cpu().numpy()
|
||||
# else:
|
||||
# feature_vector = np.array(x).flatten()
|
||||
|
||||
# Pad or truncate to expected size
|
||||
expected_size = self.model.expected_feature_dim
|
||||
if len(feature_vector) < expected_size:
|
||||
padding = np.zeros(expected_size - len(feature_vector))
|
||||
feature_vector = np.concatenate([feature_vector, padding])
|
||||
else:
|
||||
feature_vector = feature_vector[:expected_size]
|
||||
# # Pad or truncate to expected size
|
||||
# expected_size = self.model.expected_feature_dim
|
||||
# if len(feature_vector) < expected_size:
|
||||
# padding = np.zeros(expected_size - len(feature_vector))
|
||||
# feature_vector = np.concatenate([feature_vector, padding])
|
||||
# else:
|
||||
# feature_vector = feature_vector[:expected_size]
|
||||
|
||||
base_input._feature_vector = feature_vector
|
||||
# base_input._feature_vector = feature_vector
|
||||
|
||||
# Convert target to string format
|
||||
if isinstance(y, torch.Tensor):
|
||||
y_val = y.item() if y.numel() == 1 else y.argmax().item()
|
||||
else:
|
||||
y_val = int(y) if np.isscalar(y) else int(np.argmax(y))
|
||||
# # Convert target to string format
|
||||
# if isinstance(y, torch.Tensor):
|
||||
# y_val = y.item() if y.numel() == 1 else y.argmax().item()
|
||||
# else:
|
||||
# y_val = int(y) if np.isscalar(y) else int(np.argmax(y))
|
||||
|
||||
target_map = {0: 'BUY', 1: 'SELL', 2: 'HOLD'}
|
||||
target = target_map.get(y_val, 'HOLD')
|
||||
# target_map = {0: 'BUY', 1: 'SELL', 2: 'HOLD'}
|
||||
# target = target_map.get(y_val, 'HOLD')
|
||||
|
||||
# Use StandardizedCNN training
|
||||
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
|
||||
loss = self.model.train_step([base_input], [target], optimizer)
|
||||
# # Use StandardizedCNN training
|
||||
# optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
|
||||
# loss = self.model.train_step([base_input], [target], optimizer)
|
||||
|
||||
return {'total_loss': loss, 'main_loss': loss, 'accuracy': 0.5}
|
||||
# return {'total_loss': loss, 'main_loss': loss, 'accuracy': 0.5}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Legacy train_step error: {e}")
|
||||
return {'total_loss': 0.0, 'main_loss': 0.0, 'accuracy': 0.5}
|
||||
# except Exception as e:
|
||||
# logger.error(f"Legacy train_step error: {e}")
|
||||
# return {'total_loss': 0.0, 'main_loss': 0.0, 'accuracy': 0.5}
|
||||
|
||||
|
||||
class CNNModel:
|
||||
"""Legacy compatibility wrapper for CNN model interface"""
|
||||
# # class CNNModel:
|
||||
# # """Legacy compatibility wrapper for CNN model interface"""
|
||||
|
||||
def __init__(self, input_shape=(900, 50), output_size=3, model_path=None):
|
||||
warnings.warn(
|
||||
"CNNModel is deprecated. Use StandardizedCNN directly.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
self.input_shape = input_shape
|
||||
self.output_size = output_size
|
||||
self.standardized_cnn = StandardizedCNN()
|
||||
self.trainer = CNNModelTrainer(self.standardized_cnn)
|
||||
logger.warning("CNNModel compatibility wrapper created - please migrate to StandardizedCNN")
|
||||
# # def __init__(self, input_shape=(900, 50), output_size=3, model_path=None):
|
||||
# # warnings.warn(
|
||||
# # "CNNModel is deprecated. Use StandardizedCNN directly.",
|
||||
# # DeprecationWarning,
|
||||
# # stacklevel=2
|
||||
# # )
|
||||
# # self.input_shape = input_shape
|
||||
# # self.output_size = output_size
|
||||
# # self.standardized_cnn = StandardizedCNN()
|
||||
# # self.trainer = CNNModelTrainer(self.standardized_cnn)
|
||||
# # logger.warning("CNNModel compatibility wrapper created - please migrate to StandardizedCNN")
|
||||
|
||||
def build_model(self, **kwargs):
|
||||
"""Legacy build method - no-op for StandardizedCNN"""
|
||||
return self
|
||||
# # def build_model(self, **kwargs):
|
||||
# # """Legacy build method - no-op for StandardizedCNN"""
|
||||
# # return self
|
||||
|
||||
def predict(self, X):
|
||||
"""Legacy predict method"""
|
||||
try:
|
||||
# Convert input to BaseDataInput
|
||||
from core.data_models import BaseDataInput
|
||||
base_input = BaseDataInput()
|
||||
# # def predict(self, X):
|
||||
# # """Legacy predict method"""
|
||||
# # try:
|
||||
# # # Convert input to BaseDataInput
|
||||
# # from core.data_models import BaseDataInput
|
||||
# # base_input = BaseDataInput()
|
||||
|
||||
if isinstance(X, np.ndarray):
|
||||
feature_vector = X.flatten()
|
||||
else:
|
||||
feature_vector = np.array(X).flatten()
|
||||
# # if isinstance(X, np.ndarray):
|
||||
# # feature_vector = X.flatten()
|
||||
# # else:
|
||||
# # feature_vector = np.array(X).flatten()
|
||||
|
||||
# Pad or truncate to expected size
|
||||
expected_size = self.standardized_cnn.expected_feature_dim
|
||||
if len(feature_vector) < expected_size:
|
||||
padding = np.zeros(expected_size - len(feature_vector))
|
||||
feature_vector = np.concatenate([feature_vector, padding])
|
||||
else:
|
||||
feature_vector = feature_vector[:expected_size]
|
||||
# # # Pad or truncate to expected size
|
||||
# # expected_size = self.standardized_cnn.expected_feature_dim
|
||||
# # if len(feature_vector) < expected_size:
|
||||
# # padding = np.zeros(expected_size - len(feature_vector))
|
||||
# # feature_vector = np.concatenate([feature_vector, padding])
|
||||
# # else:
|
||||
# # feature_vector = feature_vector[:expected_size]
|
||||
|
||||
base_input._feature_vector = feature_vector
|
||||
# # base_input._feature_vector = feature_vector
|
||||
|
||||
# Get prediction from StandardizedCNN
|
||||
result = self.standardized_cnn.predict_from_base_input(base_input)
|
||||
# # # Get prediction from StandardizedCNN
|
||||
# # result = self.standardized_cnn.predict_from_base_input(base_input)
|
||||
|
||||
# Convert to legacy format
|
||||
action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
|
||||
pred_class = np.array([action_map.get(result.predictions['action'], 2)])
|
||||
pred_proba = np.array([result.predictions['action_probabilities']])
|
||||
# # # Convert to legacy format
|
||||
# # action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
|
||||
# # pred_class = np.array([action_map.get(result.predictions['action'], 2)])
|
||||
# # pred_proba = np.array([result.predictions['action_probabilities']])
|
||||
|
||||
return pred_class, pred_proba
|
||||
# # return pred_class, pred_proba
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Legacy predict error: {e}")
|
||||
# Return safe defaults
|
||||
pred_class = np.array([2]) # HOLD
|
||||
pred_proba = np.array([[0.33, 0.33, 0.34]])
|
||||
return pred_class, pred_proba
|
||||
# # except Exception as e:
|
||||
# # logger.error(f"Legacy predict error: {e}")
|
||||
# # # Return safe defaults
|
||||
# # pred_class = np.array([2]) # HOLD
|
||||
# # pred_proba = np.array([[0.33, 0.33, 0.34]])
|
||||
# # return pred_class, pred_proba
|
||||
|
||||
def fit(self, X, y, **kwargs):
|
||||
"""Legacy fit method"""
|
||||
try:
|
||||
return self.trainer.train_step(X, y)
|
||||
except Exception as e:
|
||||
logger.error(f"Legacy fit error: {e}")
|
||||
return self
|
||||
# # def fit(self, X, y, **kwargs):
|
||||
# # """Legacy fit method"""
|
||||
# # try:
|
||||
# # return self.trainer.train_step(X, y)
|
||||
# # except Exception as e:
|
||||
# # logger.error(f"Legacy fit error: {e}")
|
||||
# # return self
|
||||
|
||||
def save(self, filepath: str):
|
||||
"""Legacy save method"""
|
||||
try:
|
||||
torch.save(self.standardized_cnn.state_dict(), filepath)
|
||||
logger.info(f"StandardizedCNN saved to {filepath}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model: {e}")
|
||||
# # def save(self, filepath: str):
|
||||
# # """Legacy save method"""
|
||||
# # try:
|
||||
# # torch.save(self.standardized_cnn.state_dict(), filepath)
|
||||
# # logger.info(f"StandardizedCNN saved to {filepath}")
|
||||
# # except Exception as e:
|
||||
# # logger.error(f"Error saving model: {e}")
|
||||
|
||||
|
||||
def create_enhanced_cnn_model(input_size: int = 60,
|
||||
feature_dim: int = 50,
|
||||
output_size: int = 3,
|
||||
base_channels: int = 256,
|
||||
device: str = 'cuda') -> Tuple[StandardizedCNN, CNNModelTrainer]:
|
||||
"""Legacy compatibility function - returns StandardizedCNN"""
|
||||
warnings.warn(
|
||||
"create_enhanced_cnn_model is deprecated. Use StandardizedCNN() directly.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
# def create_enhanced_cnn_model(input_size: int = 60,
|
||||
# feature_dim: int = 50,
|
||||
# output_size: int = 3,
|
||||
# base_channels: int = 256,
|
||||
# device: str = 'cuda') -> Tuple[StandardizedCNN, CNNModelTrainer]:
|
||||
# """Legacy compatibility function - returns StandardizedCNN"""
|
||||
# warnings.warn(
|
||||
# "create_enhanced_cnn_model is deprecated. Use StandardizedCNN() directly.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
model = StandardizedCNN()
|
||||
trainer = CNNModelTrainer(model)
|
||||
# model = StandardizedCNN()
|
||||
# trainer = CNNModelTrainer(model)
|
||||
|
||||
logger.warning("Legacy create_enhanced_cnn_model called - please use StandardizedCNN directly")
|
||||
return model, trainer
|
||||
# logger.warning("Legacy create_enhanced_cnn_model called - please use StandardizedCNN directly")
|
||||
# return model, trainer
|
||||
|
||||
|
||||
# Export compatibility symbols
|
||||
__all__ = [
|
||||
'EnhancedCNNModel',
|
||||
'CNNModelTrainer',
|
||||
'CNNModel',
|
||||
'create_enhanced_cnn_model'
|
||||
]
|
||||
# # Export compatibility symbols
|
||||
# __all__ = [
|
||||
# 'EnhancedCNNModel',
|
||||
# 'CNNModelTrainer',
|
||||
# # 'CNNModel',
|
||||
# 'create_enhanced_cnn_model'
|
||||
# ]
|
||||
|
@ -23,11 +23,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class DQNNetwork(nn.Module):
|
||||
"""
|
||||
Massive Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
|
||||
Configurable Deep Q-Network specifically designed for RL trading with unified BaseDataInput features
|
||||
Handles 7850 input features from multi-timeframe, multi-asset data
|
||||
TARGET: 50M parameters for enhanced learning capacity
|
||||
Architecture is configurable via config.yaml
|
||||
"""
|
||||
def __init__(self, input_dim: int, n_actions: int):
|
||||
def __init__(self, input_dim: int, n_actions: int, config: dict = None):
|
||||
super(DQNNetwork, self).__init__()
|
||||
|
||||
# Handle different input dimension formats
|
||||
@ -41,59 +41,65 @@ class DQNNetwork(nn.Module):
|
||||
|
||||
self.n_actions = n_actions
|
||||
|
||||
# MASSIVE network architecture optimized for trading features
|
||||
# Target: ~50M parameters
|
||||
self.feature_extractor = nn.Sequential(
|
||||
# Initial feature extraction with massive width
|
||||
nn.Linear(self.input_size, 8192), # 7850 -> 8192 = ~64M weights
|
||||
nn.LayerNorm(8192),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
|
||||
# Deep feature processing layers
|
||||
nn.Linear(8192, 6144), # 8192 -> 6144 = ~50M weights
|
||||
nn.LayerNorm(6144),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
|
||||
nn.Linear(6144, 4096), # 6144 -> 4096 = ~25M weights
|
||||
nn.LayerNorm(4096),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
|
||||
nn.Linear(4096, 3072), # 4096 -> 3072 = ~12M weights
|
||||
nn.LayerNorm(3072),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
|
||||
nn.Linear(3072, 2048), # 3072 -> 2048 = ~6M weights
|
||||
nn.LayerNorm(2048),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
)
|
||||
# Get network architecture from config or use defaults
|
||||
if config and 'network_architecture' in config:
|
||||
arch_config = config['network_architecture']
|
||||
feature_layers = arch_config.get('feature_layers', [4096, 3072, 2048, 1536, 1024])
|
||||
regime_head = arch_config.get('regime_head', [512, 256])
|
||||
price_direction_head = arch_config.get('price_direction_head', [512, 256])
|
||||
volatility_head = arch_config.get('volatility_head', [512, 128])
|
||||
value_head = arch_config.get('value_head', [512, 256])
|
||||
advantage_head = arch_config.get('advantage_head', [512, 256])
|
||||
dropout_rate = arch_config.get('dropout_rate', 0.1)
|
||||
use_layer_norm = arch_config.get('use_layer_norm', True)
|
||||
else:
|
||||
# Default reduced architecture (half the original size)
|
||||
feature_layers = [4096, 3072, 2048, 1536, 1024]
|
||||
regime_head = [512, 256]
|
||||
price_direction_head = [512, 256]
|
||||
volatility_head = [512, 128]
|
||||
value_head = [512, 256]
|
||||
advantage_head = [512, 256]
|
||||
dropout_rate = 0.1
|
||||
use_layer_norm = True
|
||||
|
||||
# Build configurable feature extractor
|
||||
feature_layers_list = []
|
||||
prev_size = self.input_size
|
||||
|
||||
for layer_size in feature_layers:
|
||||
feature_layers_list.append(nn.Linear(prev_size, layer_size))
|
||||
if use_layer_norm:
|
||||
feature_layers_list.append(nn.LayerNorm(layer_size))
|
||||
feature_layers_list.append(nn.ReLU(inplace=True))
|
||||
feature_layers_list.append(nn.Dropout(dropout_rate))
|
||||
prev_size = layer_size
|
||||
|
||||
self.feature_extractor = nn.Sequential(*feature_layers_list)
|
||||
self.feature_size = feature_layers[-1] # Final feature size
|
||||
|
||||
# Build configurable network heads
|
||||
def build_head_layers(input_size, layer_sizes, output_size):
|
||||
layers = []
|
||||
prev_size = input_size
|
||||
for layer_size in layer_sizes:
|
||||
layers.append(nn.Linear(prev_size, layer_size))
|
||||
if use_layer_norm:
|
||||
layers.append(nn.LayerNorm(layer_size))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
layers.append(nn.Dropout(dropout_rate))
|
||||
prev_size = layer_size
|
||||
layers.append(nn.Linear(prev_size, output_size))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
# Market regime detection head
|
||||
self.regime_head = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(512, 4) # trending, ranging, volatile, mixed
|
||||
self.regime_head = build_head_layers(
|
||||
self.feature_size, regime_head, 4 # trending, ranging, volatile, mixed
|
||||
)
|
||||
|
||||
# Price direction prediction head - outputs direction and confidence
|
||||
self.price_direction_head = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(512, 2) # [direction, confidence]
|
||||
self.price_direction_head = build_head_layers(
|
||||
self.feature_size, price_direction_head, 2 # [direction, confidence]
|
||||
)
|
||||
|
||||
# Direction activation (tanh for -1 to 1)
|
||||
@ -102,38 +108,18 @@ class DQNNetwork(nn.Module):
|
||||
self.confidence_activation = nn.Sigmoid()
|
||||
|
||||
# Volatility prediction head
|
||||
self.volatility_head = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 256),
|
||||
nn.LayerNorm(256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(256, 4) # predicted volatility for 4 timeframes
|
||||
self.volatility_head = build_head_layers(
|
||||
self.feature_size, volatility_head, 4 # predicted volatility for 4 timeframes
|
||||
)
|
||||
|
||||
# Main Q-value head (dueling architecture)
|
||||
self.value_head = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(512, 1) # State value
|
||||
self.value_head = build_head_layers(
|
||||
self.feature_size, value_head, 1 # Single value for dueling architecture
|
||||
)
|
||||
|
||||
self.advantage_head = nn.Sequential(
|
||||
nn.Linear(2048, 1024),
|
||||
nn.LayerNorm(1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(1024, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(512, n_actions) # Action advantages
|
||||
# Advantage head (dueling architecture)
|
||||
self.advantage_head = build_head_layers(
|
||||
self.feature_size, advantage_head, n_actions # Action advantages
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
@ -248,7 +234,8 @@ class DQNAgent:
|
||||
priority_memory: bool = True,
|
||||
device=None,
|
||||
model_name: str = "dqn_agent",
|
||||
enable_checkpoints: bool = True):
|
||||
enable_checkpoints: bool = True,
|
||||
config: dict = None):
|
||||
|
||||
# Checkpoint management
|
||||
self.model_name = model_name
|
||||
@ -292,8 +279,8 @@ class DQNAgent:
|
||||
logger.info(f"DQN Agent using device: {self.device}")
|
||||
|
||||
# Initialize models with RL-specific network architecture
|
||||
self.policy_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device)
|
||||
self.target_net = DQNNetwork(self.state_dim, self.n_actions).to(self.device)
|
||||
self.policy_net = DQNNetwork(self.state_dim, self.n_actions, config).to(self.device)
|
||||
self.target_net = DQNNetwork(self.state_dim, self.n_actions, config).to(self.device)
|
||||
|
||||
# Ensure models are on the correct device
|
||||
self.policy_net = self.policy_net.to(self.device)
|
||||
|
Reference in New Issue
Block a user