Enable or disable bf16 support based on availability (#1116)

This commit is contained in:
Simon Hällqvist
2024-01-14 18:06:56 +01:00
committed by GitHub
parent 2202a20f60
commit 086561326f
2 changed files with 29 additions and 0 deletions

View File

@@ -61,6 +61,14 @@ def normalize_config(cfg):
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
if cfg.bf16 == "auto":
if is_torch_bf16_gpu_available():
LOG.debug("bf16 support detected, enabling for this configuration.")
cfg.bf16 = True
else:
LOG.debug("bf16 support not detected, disabling for this configuration.")
cfg.bf16 = False
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False