model metadata
This commit is contained in:
@@ -1723,15 +1723,9 @@ class RealTrainingAdapter:
|
||||
else:
|
||||
logger.warning(f" Batch {i + 1} returned None result - skipping")
|
||||
|
||||
# CRITICAL FIX: Delete batch tensors immediately to free GPU memory
|
||||
# This prevents memory accumulation during gradient accumulation
|
||||
for key in list(batch.keys()):
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
del batch[key]
|
||||
del batch
|
||||
|
||||
# CRITICAL: Clear CUDA cache after EVERY batch to prevent memory accumulation
|
||||
# This is essential with large models and limited GPU memory
|
||||
# NOTE: We don't delete the batch dict itself because it's reused across epochs
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -1786,6 +1780,41 @@ class RealTrainingAdapter:
|
||||
|
||||
logger.info(f" Saved checkpoint: {checkpoint_path}")
|
||||
|
||||
# Save metadata to database for easy retrieval
|
||||
try:
|
||||
from utils.database_manager import DatabaseManager
|
||||
|
||||
db_manager = DatabaseManager()
|
||||
checkpoint_id = f"transformer_e{epoch+1}_{timestamp}"
|
||||
|
||||
# Create metadata object
|
||||
from utils.database_manager import CheckpointMetadata
|
||||
metadata = CheckpointMetadata(
|
||||
checkpoint_id=checkpoint_id,
|
||||
model_name="transformer",
|
||||
model_type="transformer",
|
||||
timestamp=datetime.now(),
|
||||
performance_metrics={
|
||||
'loss': float(avg_loss),
|
||||
'accuracy': float(avg_accuracy),
|
||||
'epoch': epoch + 1,
|
||||
'learning_rate': float(trainer.scheduler.get_last_lr()[0])
|
||||
},
|
||||
training_metadata={
|
||||
'num_samples': len(training_data),
|
||||
'num_batches': num_batches,
|
||||
'training_id': training_id
|
||||
},
|
||||
file_path=checkpoint_path,
|
||||
performance_score=float(avg_accuracy), # Use accuracy as score
|
||||
is_active=True
|
||||
)
|
||||
|
||||
if db_manager.save_checkpoint_metadata(metadata):
|
||||
logger.info(f" Saved checkpoint metadata to database: {checkpoint_id}")
|
||||
except Exception as meta_error:
|
||||
logger.warning(f" Could not save checkpoint metadata: {meta_error}")
|
||||
|
||||
# Keep only best 5 checkpoints based on accuracy
|
||||
self._cleanup_old_checkpoints(checkpoint_dir, keep_best=5, metric='accuracy')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user