migrate lora_ to peft_
This commit is contained in:
26
README.md
26
README.md
@@ -384,10 +384,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
- lora
|
- lora
|
||||||
```yaml
|
```yaml
|
||||||
adapter: lora # qlora or leave blank for full finetune
|
adapter: lora # qlora or leave blank for full finetune
|
||||||
lora_r: 8
|
peft_r: 8
|
||||||
lora_alpha: 16
|
peft_alpha: 16
|
||||||
lora_dropout: 0.05
|
peft_dropout: 0.05
|
||||||
lora_target_modules:
|
peft_target_modules:
|
||||||
- q_proj
|
- q_proj
|
||||||
- v_proj
|
- v_proj
|
||||||
```
|
```
|
||||||
@@ -536,10 +536,10 @@ peft_model_dir:
|
|||||||
# LoRA hyperparameters
|
# LoRA hyperparameters
|
||||||
# For more details about the following options, see:
|
# For more details about the following options, see:
|
||||||
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
|
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
|
||||||
lora_r: 8
|
peft_r: 8
|
||||||
lora_alpha: 16
|
peft_alpha: 16
|
||||||
lora_dropout: 0.05
|
peft_dropout: 0.05
|
||||||
lora_target_modules:
|
peft_target_modules:
|
||||||
- q_proj
|
- q_proj
|
||||||
- v_proj
|
- v_proj
|
||||||
# - k_proj
|
# - k_proj
|
||||||
@@ -547,13 +547,13 @@ lora_target_modules:
|
|||||||
# - gate_proj
|
# - gate_proj
|
||||||
# - down_proj
|
# - down_proj
|
||||||
# - up_proj
|
# - up_proj
|
||||||
lora_target_linear: # If true, will target all linear layers
|
peft_target_linear: # if true, will target all linear layers
|
||||||
|
|
||||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
||||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
||||||
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
|
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
|
||||||
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
|
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
|
||||||
lora_modules_to_save:
|
peft_modules_to_save:
|
||||||
# - embed_tokens
|
# - embed_tokens
|
||||||
# - lm_head
|
# - lm_head
|
||||||
|
|
||||||
@@ -561,10 +561,8 @@ lora_modules_to_save:
|
|||||||
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
|
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
|
||||||
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
|
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
|
||||||
lora_out_dir:
|
lora_out_dir:
|
||||||
lora_fan_in_fan_out: false
|
peft_fan_in_fan_out: false
|
||||||
ia3_target_modules: # target modules for IA3, for llama, k, v, and down projections
|
peft_feedforward_modules: # ffn modules for IA3, for llama down projection
|
||||||
ia3_feedforward_modules: # ffn modules for IA3, for llama down projection
|
|
||||||
ia3_fan_in_fan_out:
|
|
||||||
|
|
||||||
# ReLoRA configuration
|
# ReLoRA configuration
|
||||||
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||||
|
|||||||
@@ -21,13 +21,13 @@ pad_to_sequence_len: true
|
|||||||
|
|
||||||
adapter: ia3
|
adapter: ia3
|
||||||
peft_model_dir:
|
peft_model_dir:
|
||||||
ia3_target_modules:
|
peft_target_modules:
|
||||||
- k_proj
|
- k_proj
|
||||||
- v_proj
|
- v_proj
|
||||||
- down_proj
|
- down_proj
|
||||||
ia3_feedforward_modules:
|
peft_feedforward_modules:
|
||||||
- down_proj
|
- down_proj
|
||||||
ia3_fan_in_fan_out: false
|
peft_fan_in_fan_out: false
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
|
|||||||
@@ -121,6 +121,18 @@ def normalize_config(cfg):
|
|||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
|
if cfg.adapter is not None:
|
||||||
|
for key in list(cfg.keys()):
|
||||||
|
if key.startswith("lora_"):
|
||||||
|
new_key = key.replace("lora_", "peft_")
|
||||||
|
LOG.warning(
|
||||||
|
PendingDeprecationWarning(
|
||||||
|
f"{key} soon to be deprecated. please use {new_key}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cfg[new_key] = cfg[key]
|
||||||
|
del cfg[key]
|
||||||
|
|
||||||
|
|
||||||
def validate_config(cfg):
|
def validate_config(cfg):
|
||||||
if is_torch_bf16_gpu_available():
|
if is_torch_bf16_gpu_available():
|
||||||
|
|||||||
@@ -490,11 +490,11 @@ def load_llama_adapter(model, cfg):
|
|||||||
task_type="CAUSAL_LM",
|
task_type="CAUSAL_LM",
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.peft_model_dir or cfg.lora_model_dir:
|
if cfg.peft_model_dir:
|
||||||
LOG.debug("Loading pretained PEFT - llama_adapter")
|
LOG.debug("Loading pretained PEFT - llama_adapter")
|
||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
model,
|
model,
|
||||||
cfg.peft_model_dir or cfg.lora_model_dir,
|
cfg.peft_model_dir,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -507,7 +507,7 @@ def load_llama_adapter(model, cfg):
|
|||||||
|
|
||||||
def find_all_linear_names(model):
|
def find_all_linear_names(model):
|
||||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
||||||
lora_module_names = set()
|
peft_module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if (
|
if (
|
||||||
isinstance(module, cls)
|
isinstance(module, cls)
|
||||||
@@ -515,12 +515,12 @@ def find_all_linear_names(model):
|
|||||||
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
|
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
|
||||||
):
|
):
|
||||||
names = name.split(".")
|
names = name.split(".")
|
||||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
peft_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||||
|
|
||||||
if "lm_head" in lora_module_names: # needed for 16-bit
|
if "lm_head" in peft_module_names: # needed for 16-bit
|
||||||
lora_module_names.remove("lm_head")
|
peft_module_names.remove("lm_head")
|
||||||
|
|
||||||
return list(lora_module_names)
|
return list(peft_module_names)
|
||||||
|
|
||||||
|
|
||||||
def load_lora(model, cfg, inference=False):
|
def load_lora(model, cfg, inference=False):
|
||||||
@@ -528,20 +528,20 @@ def load_lora(model, cfg, inference=False):
|
|||||||
|
|
||||||
from peft import LoraConfig, PeftModel, get_peft_model
|
from peft import LoraConfig, PeftModel, get_peft_model
|
||||||
|
|
||||||
lora_target_modules = list(cfg.lora_target_modules or [])
|
peft_target_modules = list(cfg.peft_target_modules or [])
|
||||||
|
|
||||||
if cfg.lora_target_linear:
|
if cfg.peft_target_linear:
|
||||||
linear_names = find_all_linear_names(model)
|
linear_names = find_all_linear_names(model)
|
||||||
LOG.info(f"found linear modules: {repr(linear_names)}")
|
LOG.info(f"found linear modules: {repr(linear_names)}")
|
||||||
lora_target_modules = list(set(lora_target_modules + linear_names))
|
peft_target_modules = list(set(peft_target_modules + linear_names))
|
||||||
|
|
||||||
lora_config = LoraConfig(
|
peft_config = LoraConfig(
|
||||||
r=cfg.lora_r,
|
r=cfg.peft_r,
|
||||||
lora_alpha=cfg.lora_alpha,
|
lora_alpha=cfg.peft_alpha,
|
||||||
target_modules=lora_target_modules,
|
target_modules=peft_target_modules,
|
||||||
lora_dropout=cfg.lora_dropout,
|
lora_dropout=cfg.peft_dropout,
|
||||||
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
fan_in_fan_out=cfg.peft_fan_in_fan_out,
|
||||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
modules_to_save=cfg.peft_modules_to_save if cfg.peft_modules_to_save else None,
|
||||||
bias="none",
|
bias="none",
|
||||||
task_type="CAUSAL_LM",
|
task_type="CAUSAL_LM",
|
||||||
)
|
)
|
||||||
@@ -554,11 +554,11 @@ def load_lora(model, cfg, inference=False):
|
|||||||
is_trainable=(not inference),
|
is_trainable=(not inference),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, peft_config)
|
||||||
|
|
||||||
model.print_trainable_parameters()
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
return model, lora_config
|
return model, peft_config
|
||||||
|
|
||||||
|
|
||||||
def load_ia3(model, cfg, inference=False):
|
def load_ia3(model, cfg, inference=False):
|
||||||
@@ -566,18 +566,18 @@ def load_ia3(model, cfg, inference=False):
|
|||||||
|
|
||||||
from peft import IA3Config, PeftModel, get_peft_model
|
from peft import IA3Config, PeftModel, get_peft_model
|
||||||
|
|
||||||
ia3_config_kwargs = {}
|
peft_config_kwargs = {}
|
||||||
if cfg.ia3_init_ia3_weights is not None:
|
if cfg.peft_init_ia3_weights is not None:
|
||||||
ia3_config_kwargs["init_ia3_weights"] = cfg.ia3_init_ia3_weights
|
peft_config_kwargs["init_ia3_weights"] = cfg.peft_init_ia3_weights
|
||||||
if cfg.ia3_fan_in_fan_out is not None:
|
if cfg.peft_fan_in_fan_out is not None:
|
||||||
ia3_config_kwargs["fan_in_fan_out"] = cfg.ia3_fan_in_fan_out
|
peft_config_kwargs["fan_in_fan_out"] = cfg.peft_fan_in_fan_out
|
||||||
|
|
||||||
ia3_config = IA3Config(
|
peft_config = IA3Config(
|
||||||
target_modules=cfg.ia3_target_modules,
|
target_modules=cfg.peft_target_modules,
|
||||||
feedforward_modules=cfg.ia3_feedforward_modules,
|
feedforward_modules=cfg.peft_feedforward_modules,
|
||||||
modules_to_save=cfg.ia3_modules_to_save,
|
modules_to_save=cfg.peft_modules_to_save,
|
||||||
task_type="CAUSAL_LM",
|
task_type="CAUSAL_LM",
|
||||||
**ia3_config_kwargs,
|
**peft_config_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.peft_model_dir:
|
if cfg.peft_model_dir:
|
||||||
@@ -588,8 +588,8 @@ def load_ia3(model, cfg, inference=False):
|
|||||||
is_trainable=(not inference),
|
is_trainable=(not inference),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = get_peft_model(model, ia3_config)
|
model = get_peft_model(model, peft_config)
|
||||||
|
|
||||||
model.print_trainable_parameters()
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
return model, ia3_config
|
return model, peft_config
|
||||||
|
|||||||
48
tests/test_cfg_normalization.py
Normal file
48
tests/test_cfg_normalization.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""Module for testing the validation module"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import unittest
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizationTest(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test the cfg normalization module
|
||||||
|
"""
|
||||||
|
|
||||||
|
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def inject_fixtures(self, caplog):
|
||||||
|
self._caplog = caplog
|
||||||
|
|
||||||
|
def test_lora_to_peft(self):
|
||||||
|
base_cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"base_model": "NousResearch/Llama-2-7b-hf",
|
||||||
|
"base_model_config": "NousResearch/Llama-2-7b-hf",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
cfg = base_cfg | DictDefault(
|
||||||
|
{
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 128,
|
||||||
|
"lora_alpha": 64,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
normalize_config(cfg)
|
||||||
|
assert any(
|
||||||
|
"soon to be deprecated. please use peft_" in record.message
|
||||||
|
for record in self._caplog.records
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cfg.peft_r == 128
|
||||||
|
assert cfg.peft_alpha == 64
|
||||||
Reference in New Issue
Block a user