linter stuff

This commit is contained in:
Maxime
2024-02-04 18:27:38 +01:00
parent eb300b6c57
commit 67c70d1954
3 changed files with 5 additions and 3 deletions

View File

@@ -1,8 +1,9 @@
"""setup.py for axolotl"""
from importlib.metadata import PackageNotFoundError, version
from packaging.version import Version, parse
import platform
from importlib.metadata import PackageNotFoundError, version
from packaging.version import Version, parse
from setuptools import find_packages, setup

View File

@@ -46,6 +46,7 @@ def gpu_memory_usage_all(device=0):
smi = gpu_memory_usage_smi(device)
return usage, reserved - usage, max(0, smi - reserved)
def mps_memory_usage_all():
usage = torch.mps.current_allocated_memory() / 1024.0**3
reserved = torch.mps.driver_allocated_memory() / 1024.0**3

View File

@@ -672,7 +672,7 @@ def load_model(
):
model.config.eos_token_id = tokenizer.eos_token_id
if hasattr(model, "device") and (model.device.type == "cuda" or model.device.type == "mps"):
if hasattr(model, "device") and model.device.type in ("cuda", "mps"):
log_gpu_memory_usage(LOG, "after model load", model.device)
# make sure these are fp32 per Ramesh et al. (2021)