From 0bd89b38c633fd117a0d48c25c5012992bd38df7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 28 Sep 2023 11:58:23 -0400 Subject: [PATCH] migrate lora_ to peft_ --- README.md | 26 +++++++------- examples/llama-2/ia3.yml | 6 ++-- src/axolotl/utils/config.py | 12 +++++++ src/axolotl/utils/models.py | 62 ++++++++++++++++----------------- tests/test_cfg_normalization.py | 48 +++++++++++++++++++++++++ 5 files changed, 106 insertions(+), 48 deletions(-) create mode 100644 tests/test_cfg_normalization.py diff --git a/README.md b/README.md index bd0426b4c..3d5c2488f 100644 --- a/README.md +++ b/README.md @@ -384,10 +384,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod - lora ```yaml adapter: lora # qlora or leave blank for full finetune - lora_r: 8 - lora_alpha: 16 - lora_dropout: 0.05 - lora_target_modules: + peft_r: 8 + peft_alpha: 16 + peft_dropout: 0.05 + peft_target_modules: - q_proj - v_proj ``` @@ -536,10 +536,10 @@ peft_model_dir: # LoRA hyperparameters # For more details about the following options, see: # https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2 -lora_r: 8 -lora_alpha: 16 -lora_dropout: 0.05 -lora_target_modules: +peft_r: 8 +peft_alpha: 16 +peft_dropout: 0.05 +peft_target_modules: - q_proj - v_proj # - k_proj @@ -547,13 +547,13 @@ lora_target_modules: # - gate_proj # - down_proj # - up_proj -lora_target_linear: # If true, will target all linear layers +peft_target_linear: # if true, will target all linear layers # If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens. # For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models. # `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities. # https://github.com/huggingface/peft/issues/334#issuecomment-1561727994 -lora_modules_to_save: +peft_modules_to_save: # - embed_tokens # - lm_head @@ -561,10 +561,8 @@ lora_modules_to_save: # If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory. # Make sure `lora_model_dir` points to this directory if you want to use the trained model. lora_out_dir: -lora_fan_in_fan_out: false -ia3_target_modules: # target modules for IA3, for llama, k, v, and down projections -ia3_feedforward_modules: # ffn modules for IA3, for llama down projection -ia3_fan_in_fan_out: +peft_fan_in_fan_out: false +peft_feedforward_modules: # ffn modules for IA3, for llama down projection # ReLoRA configuration # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed diff --git a/examples/llama-2/ia3.yml b/examples/llama-2/ia3.yml index 48d81d5f0..409761669 100644 --- a/examples/llama-2/ia3.yml +++ b/examples/llama-2/ia3.yml @@ -21,13 +21,13 @@ pad_to_sequence_len: true adapter: ia3 peft_model_dir: -ia3_target_modules: +peft_target_modules: - k_proj - v_proj - down_proj -ia3_feedforward_modules: +peft_feedforward_modules: - down_proj -ia3_fan_in_fan_out: false +peft_fan_in_fan_out: false wandb_project: wandb_entity: diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 1c0a15d67..10e987d2b 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -121,6 +121,18 @@ def normalize_config(cfg): log_gpu_memory_usage(LOG, "baseline", cfg.device) + if cfg.adapter is not None: + for key in list(cfg.keys()): + if key.startswith("lora_"): + new_key = key.replace("lora_", "peft_") + LOG.warning( + PendingDeprecationWarning( + f"{key} soon to be deprecated. please use {new_key}" + ) + ) + cfg[new_key] = cfg[key] + del cfg[key] + def validate_config(cfg): if is_torch_bf16_gpu_available(): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 67d0aaac2..56c7008ad 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -490,11 +490,11 @@ def load_llama_adapter(model, cfg): task_type="CAUSAL_LM", ) - if cfg.peft_model_dir or cfg.lora_model_dir: + if cfg.peft_model_dir: LOG.debug("Loading pretained PEFT - llama_adapter") model = PeftModel.from_pretrained( model, - cfg.peft_model_dir or cfg.lora_model_dir, + cfg.peft_model_dir, torch_dtype=torch.float16, ) else: @@ -507,7 +507,7 @@ def load_llama_adapter(model, cfg): def find_all_linear_names(model): cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) - lora_module_names = set() + peft_module_names = set() for name, module in model.named_modules(): if ( isinstance(module, cls) @@ -515,12 +515,12 @@ def find_all_linear_names(model): and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",) ): names = name.split(".") - lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + peft_module_names.add(names[0] if len(names) == 1 else names[-1]) - if "lm_head" in lora_module_names: # needed for 16-bit - lora_module_names.remove("lm_head") + if "lm_head" in peft_module_names: # needed for 16-bit + peft_module_names.remove("lm_head") - return list(lora_module_names) + return list(peft_module_names) def load_lora(model, cfg, inference=False): @@ -528,20 +528,20 @@ def load_lora(model, cfg, inference=False): from peft import LoraConfig, PeftModel, get_peft_model - lora_target_modules = list(cfg.lora_target_modules or []) + peft_target_modules = list(cfg.peft_target_modules or []) - if cfg.lora_target_linear: + if cfg.peft_target_linear: linear_names = find_all_linear_names(model) LOG.info(f"found linear modules: {repr(linear_names)}") - lora_target_modules = list(set(lora_target_modules + linear_names)) + peft_target_modules = list(set(peft_target_modules + linear_names)) - lora_config = LoraConfig( - r=cfg.lora_r, - lora_alpha=cfg.lora_alpha, - target_modules=lora_target_modules, - lora_dropout=cfg.lora_dropout, - fan_in_fan_out=cfg.lora_fan_in_fan_out, - modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, + peft_config = LoraConfig( + r=cfg.peft_r, + lora_alpha=cfg.peft_alpha, + target_modules=peft_target_modules, + lora_dropout=cfg.peft_dropout, + fan_in_fan_out=cfg.peft_fan_in_fan_out, + modules_to_save=cfg.peft_modules_to_save if cfg.peft_modules_to_save else None, bias="none", task_type="CAUSAL_LM", ) @@ -554,11 +554,11 @@ def load_lora(model, cfg, inference=False): is_trainable=(not inference), ) else: - model = get_peft_model(model, lora_config) + model = get_peft_model(model, peft_config) model.print_trainable_parameters() - return model, lora_config + return model, peft_config def load_ia3(model, cfg, inference=False): @@ -566,18 +566,18 @@ def load_ia3(model, cfg, inference=False): from peft import IA3Config, PeftModel, get_peft_model - ia3_config_kwargs = {} - if cfg.ia3_init_ia3_weights is not None: - ia3_config_kwargs["init_ia3_weights"] = cfg.ia3_init_ia3_weights - if cfg.ia3_fan_in_fan_out is not None: - ia3_config_kwargs["fan_in_fan_out"] = cfg.ia3_fan_in_fan_out + peft_config_kwargs = {} + if cfg.peft_init_ia3_weights is not None: + peft_config_kwargs["init_ia3_weights"] = cfg.peft_init_ia3_weights + if cfg.peft_fan_in_fan_out is not None: + peft_config_kwargs["fan_in_fan_out"] = cfg.peft_fan_in_fan_out - ia3_config = IA3Config( - target_modules=cfg.ia3_target_modules, - feedforward_modules=cfg.ia3_feedforward_modules, - modules_to_save=cfg.ia3_modules_to_save, + peft_config = IA3Config( + target_modules=cfg.peft_target_modules, + feedforward_modules=cfg.peft_feedforward_modules, + modules_to_save=cfg.peft_modules_to_save, task_type="CAUSAL_LM", - **ia3_config_kwargs, + **peft_config_kwargs, ) if cfg.peft_model_dir: @@ -588,8 +588,8 @@ def load_ia3(model, cfg, inference=False): is_trainable=(not inference), ) else: - model = get_peft_model(model, ia3_config) + model = get_peft_model(model, peft_config) model.print_trainable_parameters() - return model, ia3_config + return model, peft_config diff --git a/tests/test_cfg_normalization.py b/tests/test_cfg_normalization.py new file mode 100644 index 000000000..90faae05d --- /dev/null +++ b/tests/test_cfg_normalization.py @@ -0,0 +1,48 @@ +"""Module for testing the validation module""" + +import logging +import unittest +from typing import Optional + +import pytest + +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + + +class NormalizationTest(unittest.TestCase): + """ + Test the cfg normalization module + """ + + _caplog: Optional[pytest.LogCaptureFixture] = None + + @pytest.fixture(autouse=True) + def inject_fixtures(self, caplog): + self._caplog = caplog + + def test_lora_to_peft(self): + base_cfg = DictDefault( + { + "gradient_accumulation_steps": 1, + "micro_batch_size": 1, + "base_model": "NousResearch/Llama-2-7b-hf", + "base_model_config": "NousResearch/Llama-2-7b-hf", + } + ) + cfg = base_cfg | DictDefault( + { + "adapter": "lora", + "lora_r": 128, + "lora_alpha": 64, + } + ) + with self._caplog.at_level(logging.WARNING): + normalize_config(cfg) + assert any( + "soon to be deprecated. please use peft_" in record.message + for record in self._caplog.records + ) + + assert cfg.peft_r == 128 + assert cfg.peft_alpha == 64