From 0cfdb2c90cbd915273f21cf3bff3b216f00303a0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 5 Mar 2024 21:20:15 -0500 Subject: [PATCH] support for DoRA w/ PEFT (#1363) --- requirements.txt | 4 ++-- .../utils/config/models/input/v0_4_1/__init__.py | 12 ++++++++++++ src/axolotl/utils/models.py | 2 ++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 02cde5add..cd5171ebd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 -peft @ git+https://github.com/huggingface/peft.git -transformers @ git+https://github.com/huggingface/transformers.git@ae49b218c3d718df90d8e4a109016450fb8f0632 +peft==0.9.0 +transformers==4.38.2 tokenizers==0.15.0 bitsandbytes>=0.41.1 accelerate==0.26.1 diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index b881b1605..79fffe9cd 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -178,6 +178,7 @@ class LoraConfig(BaseModel): lora_dropout: Optional[float] = None peft_layers_to_transform: Optional[List[int]] = None peft: Optional[PeftConfig] = None + peft_use_dora: Optional[bool] = None lora_on_cpu: Optional[bool] = None gptq: Optional[bool] = None @@ -233,6 +234,17 @@ class LoraConfig(BaseModel): raise ValueError("Require cfg.load_in_4bit to be True for qlora") return self + @model_validator(mode="before") + @classmethod + def validate_quantized_dora(cls, data): + if data.get("peft_use_dora") and ( + data.get("load_in_8bit") or data.get("load_in_4bit") + ): + raise ValueError( + "`peft_use_dora` is not currently compatible with quantized weights." + ) + return data + class ReLoRAConfig(BaseModel): """ReLoRA configuration subset""" diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index aa2e9539b..5407245ac 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -830,6 +830,8 @@ def load_lora(model, cfg, inference=False, config_only=False): if loftq_bits: lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) lora_config_kwargs["init_lora_weights"] = "loftq" + if cfg.peft_use_dora: + lora_config_kwargs["use_dora"] = cfg.peft_use_dora lora_config = LoraConfig( r=cfg.lora_r,