try to fix input dimentions
This commit is contained in:
@ -250,6 +250,12 @@ class COBRLModelInterface(ModelInterface):
|
||||
|
||||
logger.info(f"COB RL Model Interface initialized on {self.device}")
|
||||
|
||||
def to(self, device):
|
||||
"""PyTorch-style device movement method"""
|
||||
self.device = device
|
||||
self.model = self.model.to(device)
|
||||
return self
|
||||
|
||||
def predict(self, cob_features: np.ndarray) -> Dict[str, Any]:
|
||||
"""Make prediction using the model"""
|
||||
self.model.eval()
|
||||
|
Reference in New Issue
Block a user