diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 224e3a258..52a81ea2c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -676,7 +676,7 @@ def load_model( if not cfg.fsdp: # FSDP doesn't like mixed Float and BFloat16 for name, module in model.named_modules(): - if any(m in name for m in ["norm", "gate"]): + if "norm" in name or name.endswith(".gate"): module.to(torch.float32) if model_config.model_type == "btlm": # don't upcast lm_head for btlm