Support --lora-on-cpu flag for DPO model merging (#2766) [skip ci]
* Support --lora-on-cpu flag for DPO model merging * fix: use device=cpu in _convert_embedding_modules_dtype when lora_on_cpu is set
This commit is contained in:
@@ -776,6 +776,9 @@ class ModelLoader:
|
|||||||
dist_dtype: torch.dtype,
|
dist_dtype: torch.dtype,
|
||||||
before_kbit_train_or_finetune: bool,
|
before_kbit_train_or_finetune: bool,
|
||||||
):
|
):
|
||||||
|
dest = {"dtype": dist_dtype}
|
||||||
|
if self.cfg.lora_on_cpu:
|
||||||
|
dest["device"] = "cpu"
|
||||||
for name, module in self.model.named_modules():
|
for name, module in self.model.named_modules():
|
||||||
if "norm" in name:
|
if "norm" in name:
|
||||||
module.to(dist_dtype)
|
module.to(dist_dtype)
|
||||||
@@ -786,4 +789,4 @@ class ModelLoader:
|
|||||||
# don't upcast lm_head for btlm
|
# don't upcast lm_head for btlm
|
||||||
continue
|
continue
|
||||||
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
||||||
module.to(dist_dtype)
|
module.to(**dest)
|
||||||
|
|||||||
Reference in New Issue
Block a user