Fix: Update model loading logic to conditionally upcast based on lm_head presence for btlm models

This commit is contained in:
mhenrhcsen
2025-07-16 21:16:47 +02:00
parent 84ad69afad
commit 2f670a5988

View File

@@ -798,7 +798,7 @@ class ModelLoader:
if before_kbit_train_or_finetune:
if name.endswith(".gate"):
module.to(dist_dtype)
if self.model_config.model_type == "btlm":
if self.model_config.model_type == "btlm" and "lm_head" in name:
# don't upcast lm_head for btlm
continue
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):