model metadata

This commit is contained in:
Dobromir Popov
2025-11-12 14:37:23 +02:00
parent 4c04503f3e
commit 0c987c3557
4 changed files with 88 additions and 9 deletions

View File

@@ -1723,15 +1723,9 @@ class RealTrainingAdapter:
else:
logger.warning(f" Batch {i + 1} returned None result - skipping")
# CRITICAL FIX: Delete batch tensors immediately to free GPU memory
# This prevents memory accumulation during gradient accumulation
for key in list(batch.keys()):
if isinstance(batch[key], torch.Tensor):
del batch[key]
del batch
# CRITICAL: Clear CUDA cache after EVERY batch to prevent memory accumulation
# This is essential with large models and limited GPU memory
# NOTE: We don't delete the batch dict itself because it's reused across epochs
if torch.cuda.is_available():
torch.cuda.empty_cache()
@@ -1786,6 +1780,41 @@ class RealTrainingAdapter:
logger.info(f" Saved checkpoint: {checkpoint_path}")
# Save metadata to database for easy retrieval
try:
from utils.database_manager import DatabaseManager
db_manager = DatabaseManager()
checkpoint_id = f"transformer_e{epoch+1}_{timestamp}"
# Create metadata object
from utils.database_manager import CheckpointMetadata
metadata = CheckpointMetadata(
checkpoint_id=checkpoint_id,
model_name="transformer",
model_type="transformer",
timestamp=datetime.now(),
performance_metrics={
'loss': float(avg_loss),
'accuracy': float(avg_accuracy),
'epoch': epoch + 1,
'learning_rate': float(trainer.scheduler.get_last_lr()[0])
},
training_metadata={
'num_samples': len(training_data),
'num_batches': num_batches,
'training_id': training_id
},
file_path=checkpoint_path,
performance_score=float(avg_accuracy), # Use accuracy as score
is_active=True
)
if db_manager.save_checkpoint_metadata(metadata):
logger.info(f" Saved checkpoint metadata to database: {checkpoint_id}")
except Exception as meta_error:
logger.warning(f" Could not save checkpoint metadata: {meta_error}")
# Keep only best 5 checkpoints based on accuracy
self._cleanup_old_checkpoints(checkpoint_dir, keep_best=5, metric='accuracy')