From 4b997c3e1af2d71f6c032db84aedabaf372f63fb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 12 Feb 2024 11:39:51 -0800 Subject: [PATCH] allow the optimizer prune ratio for ReLoRA to be configurable (#1287) * allow the optimizer prune ration for relora to be configurable * update docs for relora * prevent circular imports --- README.md | 2 ++ src/axolotl/core/trainer_builder.py | 21 ++++++++++++++++++--- src/axolotl/monkeypatch/relora.py | 4 +++- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 47888862b..cf3865064 100644 --- a/README.md +++ b/README.md @@ -734,6 +734,8 @@ peft: # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed relora_steps: # Number of steps per ReLoRA restart relora_warmup_steps: # Number of per-restart warmup steps +relora_anneal_steps: # Number of anneal steps for each relora cycle +relora_prune_ratio: # threshold for optimizer magnitude when pruning relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings # wandb configuration if you're using it diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b1c37d8ba..fd7aeef53 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -131,6 +131,10 @@ class AxolotlTrainingArguments(TrainingArguments): default=None, metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, ) + relora_prune_ratio: Optional[float] = field( + default=0.9, + metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, + ) bench_split: Optional[str] = field( default="eval", metadata={"help": "The benchmark split to run on"} ) @@ -900,9 +904,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "sample_packing_seq_len_multiplier" ] = self.cfg.micro_batch_size - training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps - training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps - training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps + if self.cfg.relora_steps: + training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps + training_arguments_kwargs[ + "relora_warmup_steps" + ] = self.cfg.relora_warmup_steps + if self.cfg.relora_anneal_steps: + training_arguments_kwargs[ + "relora_anneal_steps" + ] = self.cfg.relora_anneal_steps + if self.cfg.relora_prune_ratio: + training_arguments_kwargs[ + "relora_prune_ratio" + ] = self.cfg.relora_prune_ratio + training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs ) diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index 2d396e080..f9f861ba5 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -46,8 +46,9 @@ def reset_optimizer( *, reset_params: list[str], # where str is the key to a torch.nn.Parameter optimizer_state_keys: list[str], + prune_ratio: float = 0.9, ): - pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9) + pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio) n_zeros = 0 n_total = 0 @@ -159,6 +160,7 @@ class ReLoRACallback(TrainerCallback): optimizer, reset_params=lora_params, optimizer_state_keys=optimizer_state_keys, + prune_ratio=args.relora_prune_ratio, ) if self.quantized: