From 4cb7900a567e97b278cc713ec6bd8af616d2ebf7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 28 Jan 2024 18:50:08 -0500 Subject: [PATCH] Peft lotfq (#1222) * loftq support for lora * fix loftq check * update readme for loftq * readability cleanup * use peft main for loftq fixes, remove unnecessary special tokens * remove unused test from older deprecation --- README.md | 6 +++ examples/llama-2/fft_optimized.yml | 3 -- examples/llama-2/loftq.yml | 70 ++++++++++++++++++++++++++++++ examples/llama-2/lora.yml | 3 -- examples/llama-2/qlora.yml | 3 -- requirements.txt | 2 +- src/axolotl/utils/config.py | 6 +-- src/axolotl/utils/models.py | 24 +++++++--- tests/test_validation.py | 10 ----- 9 files changed, 97 insertions(+), 30 deletions(-) create mode 100644 examples/llama-2/loftq.yml diff --git a/README.md b/README.md index 95e5f530c..422185ed6 100644 --- a/README.md +++ b/README.md @@ -696,6 +696,12 @@ lora_modules_to_save: lora_fan_in_fan_out: false +peft: + # Configuration options for loftq initialization for LoRA + # https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization + loftq_config: + loftq_bits: # typically 4 bits + # ReLoRA configuration # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed relora_steps: # Number of steps per ReLoRA restart diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml index c2639050e..a7e2a6310 100644 --- a/examples/llama-2/fft_optimized.yml +++ b/examples/llama-2/fft_optimized.yml @@ -67,6 +67,3 @@ weight_decay: 0.1 fsdp: fsdp_config: special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/examples/llama-2/loftq.yml b/examples/llama-2/loftq.yml new file mode 100644 index 000000000..2abbb7847 --- /dev/null +++ b/examples/llama-2/loftq.yml @@ -0,0 +1,70 @@ +base_model: NousResearch/Llama-2-7b-hf +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer +is_llama_derived_model: true + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./lora-out + +sequence_len: 4096 +sample_packing: true +pad_to_sequence_len: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: +peft: + loftq_config: + loftq_bits: 4 + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true +s2_attention: + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_table_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index 61d82a403..90a9cfd2c 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -65,6 +65,3 @@ weight_decay: 0.0 fsdp: fsdp_config: special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml index 41810d56d..badb67ac3 100644 --- a/examples/llama-2/qlora.yml +++ b/examples/llama-2/qlora.yml @@ -65,6 +65,3 @@ weight_decay: 0.0 fsdp: fsdp_config: special_tokens: - bos_token: "" - eos_token: "" - unk_token: "" diff --git a/requirements.txt b/requirements.txt index b23c2509b..075f63622 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 -peft==0.7.1 +peft @ git+https://github.com/huggingface/peft.git transformers==4.37.0 tokenizers==0.15.0 bitsandbytes>=0.41.1 diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 3bc01fc7f..3ea48f6bb 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -232,9 +232,6 @@ def validate_config(cfg): "eval_batch_size != micro_batch_size. This can lead to VRAM instability." ) - if cfg.load_4bit: - raise ValueError("cfg.load_4bit parameter has been deprecated") - if cfg.adapter == "qlora": if cfg.merge_lora: # can't merge qlora if loaded in 8bit or 4bit @@ -260,7 +257,8 @@ def validate_config(cfg): if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp: raise ValueError("Fused modules are not supported with QLoRA") - if not cfg.load_in_8bit and cfg.adapter == "lora": + loftq = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits + if not cfg.load_in_8bit and cfg.adapter == "lora" and not loftq: LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 72427f645..e2401b7fe 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -9,7 +9,7 @@ import bitsandbytes as bnb import torch import transformers from optimum.bettertransformer import BetterTransformer -from peft import PeftConfig, prepare_model_for_kbit_training +from peft import LoftQConfig, PeftConfig, prepare_model_for_kbit_training from peft.tuners.lora import QuantLinear from transformers import ( # noqa: F401 AddedToken, @@ -667,13 +667,17 @@ def load_model( # Qwen doesn't play nicely with LoRA if this is enabled skip_prepare_model_for_kbit_training = True - if (cfg.adapter == "lora" and load_in_8bit) or ( - cfg.adapter == "qlora" and cfg.load_in_4bit - ): - LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") + loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits + if cfg.adapter == "lora" and loftq_bits: + skip_prepare_model_for_kbit_training = True + + if cfg.adapter in ["lora", "qlora"]: if cfg.gradient_checkpointing: model.gradient_checkpointing_enable() - if not skip_prepare_model_for_kbit_training: + if ( + cfg.load_in_8bit or cfg.load_in_4bit + ) and not skip_prepare_model_for_kbit_training: + LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=cfg.gradient_checkpointing ) @@ -700,6 +704,7 @@ def load_model( model, lora_config = load_adapter(model, cfg, cfg.adapter) if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit): + # TODO revaldate this conditional model.to(f"cuda:{cfg.local_rank}") if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: @@ -797,6 +802,12 @@ def load_lora(model, cfg, inference=False, config_only=False): LOG.info(f"found linear modules: {repr(linear_names)}") lora_target_modules = list(set(lora_target_modules + linear_names)) + lora_config_kwargs = {} + loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits + if loftq_bits: + lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) + lora_config_kwargs["init_lora_weights"] = "loftq" + lora_config = LoraConfig( r=cfg.lora_r, lora_alpha=cfg.lora_alpha, @@ -807,6 +818,7 @@ def load_lora(model, cfg, inference=False, config_only=False): modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, bias="none", task_type="CAUSAL_LM", + **lora_config_kwargs, ) if config_only: diff --git a/tests/test_validation.py b/tests/test_validation.py index d73ae34eb..e5a74394c 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -32,16 +32,6 @@ class ValidationTest(BaseValidation): Test the validation module """ - def test_load_4bit_deprecate(self): - cfg = DictDefault( - { - "load_4bit": True, - } - ) - - with pytest.raises(ValueError): - validate_config(cfg) - def test_batch_size_unused_warning(self): cfg = DictDefault( {