Add CPU offload
This commit is contained in:
committed by
Wing Lian
parent
b57238ecec
commit
c8f7213bc6
@@ -42,8 +42,9 @@ def reset_optimizer(optimizer: torch.optim.Optimizer):
|
|||||||
class ReLoRACallback(TrainerCallback):
|
class ReLoRACallback(TrainerCallback):
|
||||||
def __init__(self, cfg: DictDefault):
|
def __init__(self, cfg: DictDefault):
|
||||||
self.relora_steps = cfg.relora_steps
|
self.relora_steps = cfg.relora_steps
|
||||||
self.last_full_model = cfg.base_model
|
self.cpu_offload = cfg.relora_cpu_offload
|
||||||
self.quantised = cfg.load_in_4bit or cfg.load_in_8bit
|
self.quantised = cfg.load_in_4bit or cfg.load_in_8bit
|
||||||
|
self.last_full_model = cfg.base_model
|
||||||
|
|
||||||
assert os.path.exists(
|
assert os.path.exists(
|
||||||
self.last_full_model
|
self.last_full_model
|
||||||
@@ -185,6 +186,7 @@ def merge_and_save(
|
|||||||
model_dst: str,
|
model_dst: str,
|
||||||
reinit: bool = False,
|
reinit: bool = False,
|
||||||
quantized: bool = False,
|
quantized: bool = False,
|
||||||
|
cpu_offload: bool = False,
|
||||||
):
|
):
|
||||||
key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
|
key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
|
||||||
|
|
||||||
@@ -235,9 +237,10 @@ def merge_and_save(
|
|||||||
if isinstance(target, peft.tuners.lora.LoraLayer):
|
if isinstance(target, peft.tuners.lora.LoraLayer):
|
||||||
orig_weight = in_tensors[key + ".weight"]
|
orig_weight = in_tensors[key + ".weight"]
|
||||||
old_dev = target.weight.device
|
old_dev = target.weight.device
|
||||||
|
math_dev = "cpu" if cpu_offload else old_dev
|
||||||
|
|
||||||
update = lora_delta_weight(target).detach()
|
update = lora_delta_weight(target).detach().to(math_dev)
|
||||||
new_weight = (orig_weight.to(old_dev) + update.to(old_dev)).cpu()
|
new_weight = orig_weight.to(math_dev) + update
|
||||||
out_tensors[key + ".weight"] = new_weight
|
out_tensors[key + ".weight"] = new_weight
|
||||||
|
|
||||||
if reinit:
|
if reinit:
|
||||||
|
|||||||
Reference in New Issue
Block a user