allow the optimizer prune ratio for ReLoRA to be configurable (#1287)
* allow the optimizer prune ration for relora to be configurable * update docs for relora * prevent circular imports
This commit is contained in:
@@ -734,6 +734,8 @@ peft:
|
|||||||
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||||
relora_steps: # Number of steps per ReLoRA restart
|
relora_steps: # Number of steps per ReLoRA restart
|
||||||
relora_warmup_steps: # Number of per-restart warmup steps
|
relora_warmup_steps: # Number of per-restart warmup steps
|
||||||
|
relora_anneal_steps: # Number of anneal steps for each relora cycle
|
||||||
|
relora_prune_ratio: # threshold for optimizer magnitude when pruning
|
||||||
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||||
|
|
||||||
# wandb configuration if you're using it
|
# wandb configuration if you're using it
|
||||||
|
|||||||
@@ -131,6 +131,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||||
)
|
)
|
||||||
|
relora_prune_ratio: Optional[float] = field(
|
||||||
|
default=0.9,
|
||||||
|
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||||
|
)
|
||||||
bench_split: Optional[str] = field(
|
bench_split: Optional[str] = field(
|
||||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||||
)
|
)
|
||||||
@@ -900,9 +904,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"sample_packing_seq_len_multiplier"
|
"sample_packing_seq_len_multiplier"
|
||||||
] = self.cfg.micro_batch_size
|
] = self.cfg.micro_batch_size
|
||||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
if self.cfg.relora_steps:
|
||||||
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||||
training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps
|
training_arguments_kwargs[
|
||||||
|
"relora_warmup_steps"
|
||||||
|
] = self.cfg.relora_warmup_steps
|
||||||
|
if self.cfg.relora_anneal_steps:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"relora_anneal_steps"
|
||||||
|
] = self.cfg.relora_anneal_steps
|
||||||
|
if self.cfg.relora_prune_ratio:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"relora_prune_ratio"
|
||||||
|
] = self.cfg.relora_prune_ratio
|
||||||
|
|
||||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||||
training_arguments_kwargs
|
training_arguments_kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -46,8 +46,9 @@ def reset_optimizer(
|
|||||||
*,
|
*,
|
||||||
reset_params: list[str], # where str is the key to a torch.nn.Parameter
|
reset_params: list[str], # where str is the key to a torch.nn.Parameter
|
||||||
optimizer_state_keys: list[str],
|
optimizer_state_keys: list[str],
|
||||||
|
prune_ratio: float = 0.9,
|
||||||
):
|
):
|
||||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9)
|
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
||||||
n_zeros = 0
|
n_zeros = 0
|
||||||
n_total = 0
|
n_total = 0
|
||||||
|
|
||||||
@@ -159,6 +160,7 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
optimizer,
|
optimizer,
|
||||||
reset_params=lora_params,
|
reset_params=lora_params,
|
||||||
optimizer_state_keys=optimizer_state_keys,
|
optimizer_state_keys=optimizer_state_keys,
|
||||||
|
prune_ratio=args.relora_prune_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.quantized:
|
if self.quantized:
|
||||||
|
|||||||
Reference in New Issue
Block a user