From ec15a7a6916f5c856a824e7b5f23ef760e8242c9 Mon Sep 17 00:00:00 2001 From: kallewoof Date: Sat, 28 Jun 2025 00:19:24 +0900 Subject: [PATCH] Support --lora-on-cpu flag for DPO model merging (#2766) [skip ci] * Support --lora-on-cpu flag for DPO model merging * fix: use device=cpu in _convert_embedding_modules_dtype when lora_on_cpu is set --- src/axolotl/loaders/model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index bbc532fb9..9897399e3 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -776,6 +776,9 @@ class ModelLoader: dist_dtype: torch.dtype, before_kbit_train_or_finetune: bool, ): + dest = {"dtype": dist_dtype} + if self.cfg.lora_on_cpu: + dest["device"] = "cpu" for name, module in self.model.named_modules(): if "norm" in name: module.to(dist_dtype) @@ -786,4 +789,4 @@ class ModelLoader: # don't upcast lm_head for btlm continue if any(m in name for m in embedding_modules) and hasattr(module, "weight"): - module.to(dist_dtype) + module.to(**dest)