checkbox manager and handling
This commit is contained in:
@ -19,6 +19,10 @@ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_sc
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
# Import checkpoint management
|
||||
from utils.checkpoint_manager import save_checkpoint, load_best_checkpoint
|
||||
from utils.training_integration import get_training_integration
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -507,38 +511,140 @@ class EnhancedCNNModel(nn.Module):
|
||||
return self.to(torch.device(device))
|
||||
|
||||
class CNNModelTrainer:
|
||||
"""Enhanced trainer for the beefed-up CNN model"""
|
||||
"""Enhanced CNN trainer with checkpoint management integration"""
|
||||
|
||||
def __init__(self, model: EnhancedCNNModel, learning_rate: float = 0.0001, device: str = 'cuda'):
|
||||
self.model = model.to(device)
|
||||
self.device = device
|
||||
self.learning_rate = learning_rate
|
||||
def __init__(self, model: EnhancedCNNModel, learning_rate: float = 0.0001, device: str = 'cuda',
|
||||
model_name: str = "enhanced_cnn", enable_checkpoints: bool = True):
|
||||
self.model = model
|
||||
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
||||
self.model.to(self.device)
|
||||
|
||||
# Use AdamW optimizer with weight decay
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=learning_rate,
|
||||
# Checkpoint management
|
||||
self.model_name = model_name
|
||||
self.enable_checkpoints = enable_checkpoints
|
||||
self.training_integration = get_training_integration() if enable_checkpoints else None
|
||||
self.epoch_count = 0
|
||||
self.best_val_accuracy = 0.0
|
||||
self.best_val_loss = float('inf')
|
||||
self.checkpoint_frequency = 10 # Save checkpoint every 10 epochs
|
||||
|
||||
# Optimizers and criteria
|
||||
self.optimizer = optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=learning_rate,
|
||||
weight_decay=0.01,
|
||||
betas=(0.9, 0.999)
|
||||
)
|
||||
|
||||
# Learning rate scheduler
|
||||
self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
||||
self.scheduler = optim.lr_scheduler.OneCycleLR(
|
||||
self.optimizer,
|
||||
max_lr=learning_rate * 10,
|
||||
total_steps=10000, # Will be updated based on actual training
|
||||
total_steps=1000,
|
||||
pct_start=0.1,
|
||||
anneal_strategy='cos'
|
||||
)
|
||||
|
||||
# Multi-task loss functions
|
||||
# Loss functions
|
||||
self.main_criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
||||
self.confidence_criterion = nn.BCELoss()
|
||||
self.confidence_criterion = nn.MSELoss()
|
||||
self.regime_criterion = nn.CrossEntropyLoss()
|
||||
self.volatility_criterion = nn.MSELoss()
|
||||
|
||||
self.training_history = []
|
||||
# Training history
|
||||
self.training_history = {
|
||||
'train_loss': [],
|
||||
'val_loss': [],
|
||||
'train_accuracy': [],
|
||||
'val_accuracy': [],
|
||||
'learning_rates': []
|
||||
}
|
||||
|
||||
# Load best checkpoint if available
|
||||
if self.enable_checkpoints:
|
||||
self.load_best_checkpoint()
|
||||
|
||||
logger.info(f"CNN Trainer initialized with checkpoint management: {enable_checkpoints}")
|
||||
if enable_checkpoints:
|
||||
logger.info(f"Model name: {model_name}, Checkpoint frequency: {self.checkpoint_frequency}")
|
||||
|
||||
def load_best_checkpoint(self):
|
||||
"""Load the best checkpoint for this CNN model"""
|
||||
try:
|
||||
if not self.enable_checkpoints:
|
||||
return
|
||||
|
||||
result = load_best_checkpoint(self.model_name)
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
checkpoint = torch.load(file_path, map_location=self.device)
|
||||
|
||||
# Load model state
|
||||
if 'model_state_dict' in checkpoint:
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
if 'optimizer_state_dict' in checkpoint:
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
if 'scheduler_state_dict' in checkpoint:
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
# Load training state
|
||||
if 'epoch_count' in checkpoint:
|
||||
self.epoch_count = checkpoint['epoch_count']
|
||||
if 'best_val_accuracy' in checkpoint:
|
||||
self.best_val_accuracy = checkpoint['best_val_accuracy']
|
||||
if 'best_val_loss' in checkpoint:
|
||||
self.best_val_loss = checkpoint['best_val_loss']
|
||||
if 'training_history' in checkpoint:
|
||||
self.training_history = checkpoint['training_history']
|
||||
|
||||
logger.info(f"Loaded CNN checkpoint: {metadata.checkpoint_id}")
|
||||
logger.info(f"Epoch: {self.epoch_count}, Best val accuracy: {self.best_val_accuracy:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load checkpoint for {self.model_name}: {e}")
|
||||
|
||||
def save_checkpoint(self, train_accuracy: float, val_accuracy: float,
|
||||
train_loss: float, val_loss: float, force_save: bool = False):
|
||||
"""Save checkpoint if performance improved or forced"""
|
||||
try:
|
||||
if not self.enable_checkpoints:
|
||||
return False
|
||||
|
||||
self.epoch_count += 1
|
||||
|
||||
# Update best metrics
|
||||
improved = False
|
||||
if val_accuracy > self.best_val_accuracy:
|
||||
self.best_val_accuracy = val_accuracy
|
||||
improved = True
|
||||
if val_loss < self.best_val_loss:
|
||||
self.best_val_loss = val_loss
|
||||
improved = True
|
||||
|
||||
# Save checkpoint if improved, forced, or at regular intervals
|
||||
should_save = (
|
||||
force_save or
|
||||
improved or
|
||||
self.epoch_count % self.checkpoint_frequency == 0
|
||||
)
|
||||
|
||||
if should_save and self.training_integration:
|
||||
return self.training_integration.save_cnn_checkpoint(
|
||||
cnn_model=self.model,
|
||||
model_name=self.model_name,
|
||||
epoch=self.epoch_count,
|
||||
train_accuracy=train_accuracy,
|
||||
val_accuracy=val_accuracy,
|
||||
train_loss=train_loss,
|
||||
val_loss=val_loss,
|
||||
training_time_hours=0.0 # Can be calculated by calling code
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving CNN checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def reset_computational_graph(self):
|
||||
"""Reset the computational graph to prevent in-place operation issues"""
|
||||
try:
|
||||
@ -648,6 +754,13 @@ class CNNModelTrainer:
|
||||
accuracy = (predictions == y_train).float().mean().item()
|
||||
losses['accuracy'] = accuracy
|
||||
|
||||
# Update training history
|
||||
if 'train_loss' in self.training_history:
|
||||
self.training_history['train_loss'].append(losses['total_loss'])
|
||||
self.training_history['train_accuracy'].append(accuracy)
|
||||
current_lr = self.optimizer.param_groups[0]['lr']
|
||||
self.training_history['learning_rates'].append(current_lr)
|
||||
|
||||
return losses
|
||||
|
||||
except Exception as e:
|
||||
|
Reference in New Issue
Block a user