Fix: Update model loading logic to conditionally upcast based on lm_head presence for btlm models
This commit is contained in:
@@ -798,7 +798,7 @@ class ModelLoader:
|
|||||||
if before_kbit_train_or_finetune:
|
if before_kbit_train_or_finetune:
|
||||||
if name.endswith(".gate"):
|
if name.endswith(".gate"):
|
||||||
module.to(dist_dtype)
|
module.to(dist_dtype)
|
||||||
if self.model_config.model_type == "btlm":
|
if self.model_config.model_type == "btlm" and "lm_head" in name:
|
||||||
# don't upcast lm_head for btlm
|
# don't upcast lm_head for btlm
|
||||||
continue
|
continue
|
||||||
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
||||||
|
|||||||
Reference in New Issue
Block a user