unify model names
This commit is contained in:
@@ -37,16 +37,23 @@ import traceback
|
||||
import gc
|
||||
import time
|
||||
import psutil
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
# Try to import torch
|
||||
try:
|
||||
import torch
|
||||
HAS_TORCH = True
|
||||
except ImportError:
|
||||
torch = None
|
||||
HAS_TORCH = False
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def clear_gpu_memory():
|
||||
"""Clear GPU memory cache"""
|
||||
if torch.cuda.is_available():
|
||||
if HAS_TORCH and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user