Add layers_to_transform for lora_config (#1118)
This commit is contained in:
@@ -677,7 +677,8 @@ lora_target_modules:
|
||||
# - gate_proj
|
||||
# - down_proj
|
||||
# - up_proj
|
||||
lora_target_linear: # If true, will target all linear layers
|
||||
lora_target_linear: # If true, will target all linear modules
|
||||
peft_layers_to_transform: # The layer indices to transform, otherwise, apply to all layers
|
||||
|
||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
||||
|
||||
@@ -257,6 +257,11 @@ def validate_config(cfg):
|
||||
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
||||
raise ValueError("Fused modules are not supported with LoRA")
|
||||
|
||||
if cfg.adapter and cfg.peft_layers_to_transform and cfg.unfrozen_parameters:
|
||||
raise ValueError(
|
||||
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
|
||||
)
|
||||
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
|
||||
@@ -733,6 +733,7 @@ def load_lora(model, cfg, inference=False):
|
||||
r=cfg.lora_r,
|
||||
lora_alpha=cfg.lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
layers_to_transform=cfg.peft_layers_to_transform,
|
||||
lora_dropout=cfg.lora_dropout,
|
||||
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||
|
||||
@@ -694,6 +694,21 @@ class ValidationTest(BaseValidation):
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_unfrozen_parameters_w_peft_layers_to_transform(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "lora",
|
||||
"unfrozen_parameters": ["model.layers.2[0-9]+.block_sparse_moe.gate.*"],
|
||||
"peft_layers_to_transform": [0, 1],
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*can have unexpected behavior*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
|
||||
class ValidationCheckModelConfig(BaseValidation):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user