checkbox manager and handling

This commit is contained in:
Dobromir Popov
2025-06-24 21:59:23 +03:00
parent 706eb13912
commit ab8c94d735
8 changed files with 1170 additions and 29 deletions

View File

@ -19,6 +19,10 @@ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_sc
import torch.nn.functional as F
from typing import Dict, Any, Optional, Tuple
# Import checkpoint management
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
from utils.training_integration import get_training_integration
# Configure logging
logger = logging.getLogger(__name__)
@ -507,38 +511,140 @@ class EnhancedCNNModel(nn.Module):
return self.to(torch.device(device))
class CNNModelTrainer:
"""Enhanced trainer for the beefed-up CNN model"""
"""Enhanced CNN trainer with checkpoint management integration"""
def __init__(self, model: EnhancedCNNModel, learning_rate: float = 0.0001, device: str = 'cuda'):
self.model = model.to(device)
self.device = device
self.learning_rate = learning_rate
def __init__(self, model: EnhancedCNNModel, learning_rate: float = 0.0001, device: str = 'cuda',
model_name: str = "enhanced_cnn", enable_checkpoints: bool = True):
self.model = model
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
# Use AdamW optimizer with weight decay
self.optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
# Checkpoint management
self.model_name = model_name
self.enable_checkpoints = enable_checkpoints
self.training_integration = get_training_integration() if enable_checkpoints else None
self.epoch_count = 0
self.best_val_accuracy = 0.0
self.best_val_loss = float('inf')
self.checkpoint_frequency = 10 # Save checkpoint every 10 epochs
# Optimizers and criteria
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=learning_rate,
weight_decay=0.01,
betas=(0.9, 0.999)
)
# Learning rate scheduler
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
self.scheduler = optim.lr_scheduler.OneCycleLR(
self.optimizer,
max_lr=learning_rate * 10,
total_steps=10000, # Will be updated based on actual training
total_steps=1000,
pct_start=0.1,
anneal_strategy='cos'
)
# Multi-task loss functions
# Loss functions
self.main_criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
self.confidence_criterion = nn.BCELoss()
self.confidence_criterion = nn.MSELoss()
self.regime_criterion = nn.CrossEntropyLoss()
self.volatility_criterion = nn.MSELoss()
self.training_history = []
# Training history
self.training_history = {
'train_loss': [],
'val_loss': [],
'train_accuracy': [],
'val_accuracy': [],
'learning_rates': []
}
# Load best checkpoint if available
if self.enable_checkpoints:
self.load_best_checkpoint()
logger.info(f"CNN Trainer initialized with checkpoint management: {enable_checkpoints}")
if enable_checkpoints:
logger.info(f"Model name: {model_name}, Checkpoint frequency: {self.checkpoint_frequency}")
def load_best_checkpoint(self):
"""Load the best checkpoint for this CNN model"""
try:
if not self.enable_checkpoints:
return
result = load_best_checkpoint(self.model_name)
if result:
file_path, metadata = result
checkpoint = torch.load(file_path, map_location=self.device)
# Load model state
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
if 'optimizer_state_dict' in checkpoint:
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# Load training state
if 'epoch_count' in checkpoint:
self.epoch_count = checkpoint['epoch_count']
if 'best_val_accuracy' in checkpoint:
self.best_val_accuracy = checkpoint['best_val_accuracy']
if 'best_val_loss' in checkpoint:
self.best_val_loss = checkpoint['best_val_loss']
if 'training_history' in checkpoint:
self.training_history = checkpoint['training_history']
logger.info(f"Loaded CNN checkpoint: {metadata.checkpoint_id}")
logger.info(f"Epoch: {self.epoch_count}, Best val accuracy: {self.best_val_accuracy:.4f}")
except Exception as e:
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
def save_checkpoint(self, train_accuracy: float, val_accuracy: float,
train_loss: float, val_loss: float, force_save: bool = False):
"""Save checkpoint if performance improved or forced"""
try:
if not self.enable_checkpoints:
return False
self.epoch_count += 1
# Update best metrics
improved = False
if val_accuracy > self.best_val_accuracy:
self.best_val_accuracy = val_accuracy
improved = True
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
improved = True
# Save checkpoint if improved, forced, or at regular intervals
should_save = (
force_save or
improved or
self.epoch_count % self.checkpoint_frequency == 0
)
if should_save and self.training_integration:
return self.training_integration.save_cnn_checkpoint(
cnn_model=self.model,
model_name=self.model_name,
epoch=self.epoch_count,
train_accuracy=train_accuracy,
val_accuracy=val_accuracy,
train_loss=train_loss,
val_loss=val_loss,
training_time_hours=0.0 # Can be calculated by calling code
)
return False
except Exception as e:
logger.error(f"Error saving CNN checkpoint: {e}")
return False
def reset_computational_graph(self):
"""Reset the computational graph to prevent in-place operation issues"""
try:
@ -648,6 +754,13 @@ class CNNModelTrainer:
accuracy = (predictions == y_train).float().mean().item()
losses['accuracy'] = accuracy
# Update training history
if 'train_loss' in self.training_history:
self.training_history['train_loss'].append(losses['total_loss'])
self.training_history['train_accuracy'].append(accuracy)
current_lr = self.optimizer.param_groups[0]['lr']
self.training_history['learning_rates'].append(current_lr)
return losses
except Exception as e: