try to fix input dimentions

This commit is contained in:
Dobromir Popov
2025-07-13 23:41:47 +03:00
parent bcc13a5db3
commit ebf65494a8
7 changed files with 358 additions and 145 deletions

View File

@ -250,6 +250,12 @@ class COBRLModelInterface(ModelInterface):
logger.info(f"COB RL Model Interface initialized on {self.device}")
def to(self, device):
"""PyTorch-style device movement method"""
self.device = device
self.model = self.model.to(device)
return self
def predict(self, cob_features: np.ndarray) -> Dict[str, Any]:
"""Make prediction using the model"""
self.model.eval()