diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 71e27a2bc..ed917d963 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -355,7 +355,7 @@ def load_model( if hasattr(module, "weight"): module.to(torch.float32) - fix_dtype = False + fix_dtype = not cfg.adapter if not cfg.gptq and ( (cfg.adapter == "lora" and load_in_8bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit)