CNN training first working
This commit is contained in:
@ -51,10 +51,68 @@ class EnhancedCNNAdapter:
|
||||
# Create checkpoint directory if it doesn't exist
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Initialize model
|
||||
# Initialize the model
|
||||
self._initialize_model()
|
||||
|
||||
logger.info(f"EnhancedCNNAdapter initialized with device: {self.device}")
|
||||
# Load checkpoint if available
|
||||
if model_path and os.path.exists(model_path):
|
||||
self._load_checkpoint(model_path)
|
||||
else:
|
||||
self._load_best_checkpoint()
|
||||
|
||||
logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
|
||||
|
||||
def _load_checkpoint(self, checkpoint_path: str) -> bool:
|
||||
"""Load model from checkpoint path"""
|
||||
try:
|
||||
if self.model and os.path.exists(checkpoint_path):
|
||||
success = self.model.load(checkpoint_path)
|
||||
if success:
|
||||
logger.info(f"Loaded model from {checkpoint_path}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Failed to load model from {checkpoint_path}")
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"Checkpoint path does not exist: {checkpoint_path}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def _load_best_checkpoint(self) -> bool:
|
||||
"""Load the best available checkpoint"""
|
||||
try:
|
||||
return self.load_best_checkpoint()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint: {e}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def _create_default_output(self, symbol: str) -> ModelOutput:
|
||||
"""Create default output when prediction fails"""
|
||||
return create_model_output(
|
||||
model_type='cnn',
|
||||
model_name=self.model_name,
|
||||
symbol=symbol,
|
||||
action='HOLD',
|
||||
confidence=0.0,
|
||||
metadata={'error': 'Prediction failed, using default output'}
|
||||
)
|
||||
|
||||
def _process_hidden_states(self, hidden_states: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Process hidden states for cross-model feeding"""
|
||||
processed_states = {}
|
||||
|
||||
for key, value in hidden_states.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Convert tensor to numpy array
|
||||
processed_states[key] = value.cpu().numpy().tolist()
|
||||
else:
|
||||
processed_states[key] = value
|
||||
|
||||
return processed_states
|
||||
|
||||
def _initialize_model(self):
|
||||
"""Initialize the EnhancedCNN model"""
|
||||
|
Reference in New Issue
Block a user