diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 0de87fa5c..a96cc1286 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -70,6 +70,9 @@ def resolve_dtype(cfg): if cfg.fp16 is None and not cfg.float16: cfg.fp16 = True + if cfg.fp16 and cfg.bf16 == "auto": + cfg.bf16 = False + if cfg.device == "mps": cfg.load_in_8bit = False cfg.tf32 = False