CNN training first working

This commit is contained in:
Dobromir Popov
2025-07-23 22:39:00 +03:00
parent 26e6ba2e1d
commit 94ee7389c4
2 changed files with 456 additions and 8 deletions

View File

@ -51,10 +51,68 @@ class EnhancedCNNAdapter:
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
# Initialize model
# Initialize the model
self._initialize_model()
logger.info(f"EnhancedCNNAdapter initialized with device: {self.device}")
# Load checkpoint if available
if model_path and os.path.exists(model_path):
self._load_checkpoint(model_path)
else:
self._load_best_checkpoint()
logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
def _load_checkpoint(self, checkpoint_path: str) -> bool:
"""Load model from checkpoint path"""
try:
if self.model and os.path.exists(checkpoint_path):
success = self.model.load(checkpoint_path)
if success:
logger.info(f"Loaded model from {checkpoint_path}")
return True
else:
logger.warning(f"Failed to load model from {checkpoint_path}")
return False
else:
logger.warning(f"Checkpoint path does not exist: {checkpoint_path}")
return False
except Exception as e:
logger.error(f"Error loading checkpoint: {e}")
return False
def _load_best_checkpoint(self) -> bool:
"""Load the best available checkpoint"""
try:
return self.load_best_checkpoint()
except Exception as e:
logger.error(f"Error loading best checkpoint: {e}")
return False
def _create_default_output(self, symbol: str) -> ModelOutput:
"""Create default output when prediction fails"""
return create_model_output(
model_type='cnn',
model_name=self.model_name,
symbol=symbol,
action='HOLD',
confidence=0.0,
metadata={'error': 'Prediction failed, using default output'}
)
def _process_hidden_states(self, hidden_states: Dict[str, Any]) -> Dict[str, Any]:
"""Process hidden states for cross-model feeding"""
processed_states = {}
for key, value in hidden_states.items():
if isinstance(value, torch.Tensor):
# Convert tensor to numpy array
processed_states[key] = value.cpu().numpy().tolist()
else:
processed_states[key] = value
return processed_states
def _initialize_model(self):
"""Initialize the EnhancedCNN model"""