diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 7fce928f0..3d11601ba 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -798,7 +798,7 @@ class ModelLoader: if before_kbit_train_or_finetune: if name.endswith(".gate"): module.to(dist_dtype) - if self.model_config.model_type == "btlm": + if self.model_config.model_type == "btlm" and "lm_head" in name: # don't upcast lm_head for btlm continue if any(m in name for m in embedding_modules) and hasattr(module, "weight"):