From 67c70d19543228f5cfc0718e7e4166441b040656 Mon Sep 17 00:00:00 2001 From: Maxime <672982+maximegmd@users.noreply.github.com> Date: Sun, 4 Feb 2024 18:27:38 +0100 Subject: [PATCH] linter stuff --- setup.py | 5 +++-- src/axolotl/utils/bench.py | 1 + src/axolotl/utils/models.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 39d705711..f6a530dfa 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,9 @@ """setup.py for axolotl""" -from importlib.metadata import PackageNotFoundError, version -from packaging.version import Version, parse import platform +from importlib.metadata import PackageNotFoundError, version + +from packaging.version import Version, parse from setuptools import find_packages, setup diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py index 8e3bc7264..c039e790a 100644 --- a/src/axolotl/utils/bench.py +++ b/src/axolotl/utils/bench.py @@ -46,6 +46,7 @@ def gpu_memory_usage_all(device=0): smi = gpu_memory_usage_smi(device) return usage, reserved - usage, max(0, smi - reserved) + def mps_memory_usage_all(): usage = torch.mps.current_allocated_memory() / 1024.0**3 reserved = torch.mps.driver_allocated_memory() / 1024.0**3 diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6efad3843..1df6228ab 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -672,7 +672,7 @@ def load_model( ): model.config.eos_token_id = tokenizer.eos_token_id - if hasattr(model, "device") and (model.device.type == "cuda" or model.device.type == "mps"): + if hasattr(model, "device") and model.device.type in ("cuda", "mps"): log_gpu_memory_usage(LOG, "after model load", model.device) # make sure these are fp32 per Ramesh et al. (2021)