diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index bbc532fb9..9897399e3 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -776,6 +776,9 @@ class ModelLoader: dist_dtype: torch.dtype, before_kbit_train_or_finetune: bool, ): + dest = {"dtype": dist_dtype} + if self.cfg.lora_on_cpu: + dest["device"] = "cpu" for name, module in self.model.named_modules(): if "norm" in name: module.to(dist_dtype) @@ -786,4 +789,4 @@ class ModelLoader: # don't upcast lm_head for btlm continue if any(m in name for m in embedding_modules) and hasattr(module, "weight"): - module.to(dist_dtype) + module.to(**dest)