From c8f7213bc6c14021ac7d30fae198a81991beee84 Mon Sep 17 00:00:00 2001 From: Charles Goddard Date: Mon, 24 Jul 2023 21:07:36 -0700 Subject: [PATCH] Add CPU offload --- src/axolotl/monkeypatch/relora.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index cee90c48f..4eb86b6a2 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -42,8 +42,9 @@ def reset_optimizer(optimizer: torch.optim.Optimizer): class ReLoRACallback(TrainerCallback): def __init__(self, cfg: DictDefault): self.relora_steps = cfg.relora_steps - self.last_full_model = cfg.base_model + self.cpu_offload = cfg.relora_cpu_offload self.quantised = cfg.load_in_4bit or cfg.load_in_8bit + self.last_full_model = cfg.base_model assert os.path.exists( self.last_full_model @@ -185,6 +186,7 @@ def merge_and_save( model_dst: str, reinit: bool = False, quantized: bool = False, + cpu_offload: bool = False, ): key_list = [key for key, _ in model.model.named_modules() if "lora" not in key] @@ -235,9 +237,10 @@ def merge_and_save( if isinstance(target, peft.tuners.lora.LoraLayer): orig_weight = in_tensors[key + ".weight"] old_dev = target.weight.device + math_dev = "cpu" if cpu_offload else old_dev - update = lora_delta_weight(target).detach() - new_weight = (orig_weight.to(old_dev) + update.to(old_dev)).cpu() + update = lora_delta_weight(target).detach().to(math_dev) + new_weight = orig_weight.to(math_dev) + update out_tensors[key + ".weight"] = new_weight if reinit: