integrating new CNN model
This commit is contained in:
@@ -289,11 +289,9 @@ class TradingOrchestrator:
|
||||
|
||||
# Initialize CNN Model
|
||||
try:
|
||||
from NN.models.enhanced_cnn import EnhancedCNN
|
||||
from NN.models.standardized_cnn import StandardizedCNN
|
||||
|
||||
cnn_input_shape = self.config.cnn.get('input_shape', 100)
|
||||
cnn_n_actions = self.config.cnn.get('n_actions', 3)
|
||||
self.cnn_model = EnhancedCNN(input_shape=cnn_input_shape, n_actions=cnn_n_actions)
|
||||
self.cnn_model = StandardizedCNN()
|
||||
self.cnn_model.to(self.device) # Move CNN model to the determined device
|
||||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN
|
||||
|
||||
@@ -325,8 +323,8 @@ class TradingOrchestrator:
|
||||
logger.info("Enhanced CNN model initialized")
|
||||
except ImportError:
|
||||
try:
|
||||
from NN.models.cnn_model import CNNModel
|
||||
self.cnn_model = CNNModel()
|
||||
from NN.models.standardized_cnn import StandardizedCNN
|
||||
self.cnn_model = StandardizedCNN()
|
||||
self.cnn_model.to(self.device) # Move basic CNN model to the determined device
|
||||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for basic CNN
|
||||
|
||||
|
Reference in New Issue
Block a user