Enable or disable bf16 support based on availability (#1116)
This commit is contained in:
@@ -61,6 +61,14 @@ def normalize_config(cfg):
|
||||
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
|
||||
cfg.batch_size = cfg.batch_size * cfg.world_size
|
||||
|
||||
if cfg.bf16 == "auto":
|
||||
if is_torch_bf16_gpu_available():
|
||||
LOG.debug("bf16 support detected, enabling for this configuration.")
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
LOG.debug("bf16 support not detected, disabling for this configuration.")
|
||||
cfg.bf16 = False
|
||||
|
||||
if cfg.device == "mps":
|
||||
cfg.load_in_8bit = False
|
||||
cfg.tf32 = False
|
||||
|
||||
Reference in New Issue
Block a user