Compare commits

...

10 Commits

Author SHA1 Message Date
Wing Lian
d0b534292f Add e2e test for ia3 ft 2023-10-19 09:27:55 -04:00
Wing Lian
0bd89b38c6 migrate lora_ to peft_ 2023-10-18 22:22:54 -04:00
Wing Lian
481ef187a5 update README for IA3 peft 2023-10-18 22:18:39 -04:00
Wing Lian
d645b19fcf include task type for ia3 config 2023-10-18 22:18:39 -04:00
Wing Lian
203369411e consolidate as peft_model_dir 2023-10-18 22:18:37 -04:00
Wing Lian
ba85308720 Update src/axolotl/utils/models.py
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-10-18 22:17:38 -04:00
Wing Lian
998763bade ia3 keeps casting to float32, handle it here for now 2023-10-18 22:17:38 -04:00
Wing Lian
c8e42a0f4f fix load_in_8bit check 2023-10-18 22:17:38 -04:00
Wing Lian
1da328eb9a prepare ia3 for 8bit 2023-10-18 22:17:38 -04:00
Wing Lian
2d7cccfc8e add ia3 peft support 2023-10-18 22:17:38 -04:00
37 changed files with 364 additions and 74 deletions

View File

@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error, disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods, too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation, too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-boolean-expressions,

View File

@@ -96,7 +96,7 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference # inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --peft_model_dir="./lora-out"
``` ```
## Installation ## Installation
@@ -384,10 +384,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- lora - lora
```yaml ```yaml
adapter: lora # qlora or leave blank for full finetune adapter: lora # qlora or leave blank for full finetune
lora_r: 8 peft_r: 8
lora_alpha: 16 peft_alpha: 16
lora_dropout: 0.05 peft_dropout: 0.05
lora_target_modules: peft_target_modules:
- q_proj - q_proj
- v_proj - v_proj
``` ```
@@ -531,15 +531,15 @@ total_num_tokens:
adapter: lora adapter: lora
# If you already have a lora model trained that you want to load, put that here. # If you already have a lora model trained that you want to load, put that here.
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`. # This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
lora_model_dir: peft_model_dir:
# LoRA hyperparameters # LoRA hyperparameters
# For more details about the following options, see: # For more details about the following options, see:
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2 # https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
lora_r: 8 peft_r: 8
lora_alpha: 16 peft_alpha: 16
lora_dropout: 0.05 peft_dropout: 0.05
lora_target_modules: peft_target_modules:
- q_proj - q_proj
- v_proj - v_proj
# - k_proj # - k_proj
@@ -547,13 +547,13 @@ lora_target_modules:
# - gate_proj # - gate_proj
# - down_proj # - down_proj
# - up_proj # - up_proj
lora_target_linear: # If true, will target all linear layers peft_target_linear: # if true, will target all linear layers
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens. # If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models. # For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities. # `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994 # https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
lora_modules_to_save: peft_modules_to_save:
# - embed_tokens # - embed_tokens
# - lm_head # - lm_head
@@ -561,7 +561,8 @@ lora_modules_to_save:
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory. # If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
# Make sure `lora_model_dir` points to this directory if you want to use the trained model. # Make sure `lora_model_dir` points to this directory if you want to use the trained model.
lora_out_dir: lora_out_dir:
lora_fan_in_fan_out: false peft_fan_in_fan_out: false
peft_feedforward_modules: # ffn modules for IA3, for llama down projection
# ReLoRA configuration # ReLoRA configuration
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
@@ -869,7 +870,7 @@ Pass the appropriate flag to the train command:
- Pretrained LORA: - Pretrained LORA:
```bash ```bash
python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir" python -m axolotl.cli.inference examples/your_config.yml --peft_model_dir="./lora-output-dir"
``` ```
- Full weights finetune: - Full weights finetune:
```bash ```bash
@@ -890,7 +891,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi
Add below flag to train command above Add below flag to train command above
```bash ```bash
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False python3 -m axolotl.cli.merge_lora examples/your_config.yml --peft_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
``` ```
If you run out of CUDA memory, you can try to merge in system RAM with If you run out of CUDA memory, you can try to merge in system RAM with

View File

@@ -18,7 +18,7 @@ dataset_prepared_path: last_prepared_run
val_set_size: 0.01 val_set_size: 0.01
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
sample_packing: false sample_packing: false

View File

@@ -10,7 +10,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: 2048 max_packed_sequence_len: 2048
lora_r: 16 lora_r: 16

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -15,7 +15,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 16 lora_r: 16

View File

@@ -22,7 +22,7 @@ dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
# enable QLoRA # enable QLoRA
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:

View File

@@ -15,7 +15,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 64 lora_r: 64

View File

@@ -10,7 +10,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 8 lora_r: 8

View File

@@ -9,7 +9,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 512 sequence_len: 512
max_packed_sequence_len: max_packed_sequence_len:
lora_r: lora_r:

View File

@@ -18,7 +18,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: sample_packing:
lora_r: 8 lora_r: 8

72
examples/llama-2/ia3.yml Normal file
View File

@@ -0,0 +1,72 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./ia3-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: ia3
peft_model_dir:
peft_target_modules:
- k_proj
- v_proj
- down_proj
peft_feedforward_modules:
- down_proj
peft_fan_in_fan_out: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 5
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens:
save_steps:
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./relora-out output_dir: ./relora-out
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -20,7 +20,7 @@ sequence_len: 4096
sample_packing: true sample_packing: true
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -9,7 +9,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 8 lora_r: 8

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 1024 sequence_len: 1024
sample_packing: true sample_packing: true
lora_r: lora_r:

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
sequence_len: 1024 sequence_len: 1024
sample_packing: true sample_packing: true
lora_r: 8 lora_r: 8

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 1024 sequence_len: 1024
sample_packing: true sample_packing: true
lora_r: 8 lora_r: 8

View File

@@ -22,7 +22,7 @@ sample_packing: true
pad_to_sequence_len: pad_to_sequence_len:
adapter: adapter:
lora_model_dir: peft_model_dir:
lora_r: lora_r:
lora_alpha: lora_alpha:
lora_dropout: lora_dropout:

View File

@@ -22,7 +22,7 @@ sample_packing: false # not CURRENTLY compatible with LoRAs
pad_to_sequence_len: pad_to_sequence_len:
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
lora_r: 64 lora_r: 64
lora_alpha: 32 lora_alpha: 32
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -13,7 +13,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: 2048 max_packed_sequence_len: 2048
lora_r: 64 lora_r: 64

View File

@@ -7,7 +7,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
sequence_len: 512 sequence_len: 512
lora_r: 16 lora_r: 16
lora_alpha: 32 lora_alpha: 32

View File

@@ -10,7 +10,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 8 lora_r: 8

View File

@@ -8,7 +8,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 8 lora_r: 8

View File

@@ -20,7 +20,7 @@ dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
# enable QLoRA # enable QLoRA
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 8192 sequence_len: 8192
max_packed_sequence_len: max_packed_sequence_len:

View File

@@ -116,6 +116,8 @@ def flashattn_forward(
attention_mask: [bsz, q_len] attention_mask: [bsz, q_len]
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
original_dtype = hidden_states.dtype
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"): if not hasattr(self, "pretraining_tp"):
@@ -151,6 +153,13 @@ def flashattn_forward(
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
if query_states.dtype == torch.float32:
query_states = query_states.to(dtype=original_dtype)
if key_states.dtype == torch.float32:
key_states = key_states.to(dtype=original_dtype)
if value_states.dtype == torch.float32:
value_states = value_states.to(dtype=original_dtype)
query_states = query_states.view( query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2) ).transpose(1, 2)
@@ -309,6 +318,10 @@ def flashattn_forward(
else: else:
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
# handle conversion back for IA3
if attn_output.dtype == torch.float32:
attn_output = attn_output.to(dtype=original_dtype)
return attn_output, None, past_key_value return attn_output, None, past_key_value
@@ -502,6 +515,7 @@ def llama_model_forward(
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
original_dtype = hidden_states.dtype
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
@@ -559,6 +573,10 @@ def llama_model_forward(
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
# handle conversion back for IA3
if hidden_states.dtype == torch.float32:
hidden_states = hidden_states.to(dtype=original_dtype)
if use_cache: if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

View File

@@ -121,6 +121,18 @@ def normalize_config(cfg):
log_gpu_memory_usage(LOG, "baseline", cfg.device) log_gpu_memory_usage(LOG, "baseline", cfg.device)
if cfg.adapter is not None:
for key in list(cfg.keys()):
if key.startswith("lora_"):
new_key = key.replace("lora_", "peft_")
LOG.warning(
PendingDeprecationWarning(
f"{key} soon to be deprecated. please use {new_key}"
)
)
cfg[new_key] = cfg[key]
del cfg[key]
def validate_config(cfg): def validate_config(cfg):
if is_torch_bf16_gpu_available(): if is_torch_bf16_gpu_available():
@@ -190,7 +202,10 @@ def validate_config(cfg):
raise ValueError("Require cfg.load_in_4bit to be True for qlora") raise ValueError("Require cfg.load_in_4bit to be True for qlora")
if not cfg.load_in_8bit and cfg.adapter == "lora": if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") LOG.warning("We recommend setting `load_in_8bit: true` for LoRA finetuning")
if not cfg.load_in_8bit and cfg.adapter == "ia3":
LOG.warning("We recommend setting `load_in_8bit: true` for IA3 finetuning")
if cfg.relora_steps: if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"): if cfg.adapter not in ("lora", "qlora"):

View File

@@ -406,21 +406,21 @@ def load_model(
if hasattr(module, "weight"): if hasattr(module, "weight"):
module.to(torch.float32) module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp require_peft: bool = False
if (cfg.adapter == "lora" and load_in_8bit) or ( if cfg.adapter in ["lora", "qlora", "ia3"]:
cfg.adapter == "qlora" and cfg.load_in_4bit require_peft = True
):
if require_peft:
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing: if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training( model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing model, use_gradient_checkpointing=cfg.gradient_checkpointing
) )
needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility. # convert them back to fp16/bf16 for flash-attn compatibility.
if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model): if require_peft or cfg.fsdp or (cfg.flash_attention and cfg.is_llama_derived_model):
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules(): for name, module in model.named_modules():
if "norm" in name: if "norm" in name:
@@ -429,7 +429,7 @@ def load_model(
if hasattr(module, "weight"): if hasattr(module, "weight"):
module.to(cfg.torch_dtype) module.to(cfg.torch_dtype)
model, lora_config = load_adapter(model, cfg, cfg.adapter) model, peft_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit: if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")
@@ -460,7 +460,7 @@ def load_model(
log_gpu_memory_usage(LOG, "after adapters", model.device) log_gpu_memory_usage(LOG, "after adapters", model.device)
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
return model, lora_config return model, peft_config
def load_adapter(model, cfg, adapter, inference=False): def load_adapter(model, cfg, adapter, inference=False):
@@ -470,6 +470,8 @@ def load_adapter(model, cfg, adapter, inference=False):
return model, None return model, None
if hasattr(model, "enable_input_require_grads"): if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads() model.enable_input_require_grads()
if adapter == "ia3":
return load_ia3(model, cfg, inference=inference)
if adapter in ["lora", "qlora"]: if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference) return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter": if adapter == "llama-adapter":
@@ -488,11 +490,11 @@ def load_llama_adapter(model, cfg):
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
) )
if cfg.lora_model_dir: if cfg.peft_model_dir:
LOG.debug("Loading pretained PEFT - llama_adapter") LOG.debug("Loading pretained PEFT - llama_adapter")
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.lora_model_dir, cfg.peft_model_dir,
torch_dtype=torch.float16, torch_dtype=torch.float16,
) )
else: else:
@@ -505,7 +507,7 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model): def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set() peft_module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if ( if (
isinstance(module, cls) isinstance(module, cls)
@@ -513,12 +515,12 @@ def find_all_linear_names(model):
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",) and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
): ):
names = name.split(".") names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) peft_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in lora_module_names: # needed for 16-bit if "lm_head" in peft_module_names: # needed for 16-bit
lora_module_names.remove("lm_head") peft_module_names.remove("lm_head")
return list(lora_module_names) return list(peft_module_names)
def load_lora(model, cfg, inference=False): def load_lora(model, cfg, inference=False):
@@ -526,34 +528,68 @@ def load_lora(model, cfg, inference=False):
from peft import LoraConfig, PeftModel, get_peft_model from peft import LoraConfig, PeftModel, get_peft_model
lora_target_modules = list(cfg.lora_target_modules or []) peft_target_modules = list(cfg.peft_target_modules or [])
if cfg.lora_target_linear: if cfg.peft_target_linear:
linear_names = find_all_linear_names(model) linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(linear_names)}") LOG.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(set(lora_target_modules + linear_names)) peft_target_modules = list(set(peft_target_modules + linear_names))
lora_config = LoraConfig( peft_config = LoraConfig(
r=cfg.lora_r, r=cfg.peft_r,
lora_alpha=cfg.lora_alpha, lora_alpha=cfg.peft_alpha,
target_modules=lora_target_modules, target_modules=peft_target_modules,
lora_dropout=cfg.lora_dropout, lora_dropout=cfg.peft_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out, fan_in_fan_out=cfg.peft_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, modules_to_save=cfg.peft_modules_to_save if cfg.peft_modules_to_save else None,
bias="none", bias="none",
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
) )
if cfg.lora_model_dir: if cfg.peft_model_dir:
LOG.debug("Loading pretained PEFT - LoRA") LOG.debug("Loading pretained PEFT - LoRA")
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.lora_model_dir, cfg.peft_model_dir,
is_trainable=(not inference), is_trainable=(not inference),
) )
else: else:
model = get_peft_model(model, lora_config) model = get_peft_model(model, peft_config)
model.print_trainable_parameters() model.print_trainable_parameters()
return model, lora_config return model, peft_config
def load_ia3(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import IA3Config, PeftModel, get_peft_model
peft_config_kwargs = {}
if cfg.peft_init_ia3_weights is not None:
peft_config_kwargs["init_ia3_weights"] = cfg.peft_init_ia3_weights
if cfg.peft_fan_in_fan_out is not None:
peft_config_kwargs["fan_in_fan_out"] = cfg.peft_fan_in_fan_out
peft_config = IA3Config(
target_modules=cfg.peft_target_modules,
feedforward_modules=cfg.peft_feedforward_modules,
modules_to_save=cfg.peft_modules_to_save,
task_type="CAUSAL_LM",
**peft_config_kwargs,
)
if cfg.peft_model_dir:
LOG.debug("Loading pretained PEFT - IA3")
model = PeftModel.from_pretrained(
model,
cfg.peft_model_dir,
is_trainable=(not inference),
)
else:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model, peft_config

View File

@@ -24,6 +24,10 @@ class TestLoraLlama(unittest.TestCase):
""" """
def test_lora(self): def test_lora(self):
"""
support for legacy lora_ configs
:return:
"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp() output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
@@ -66,6 +70,101 @@ class TestLoraLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists() assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_peft(self):
"""
support for legacy lora_ configs
:return:
"""
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"peft_r": 32,
"peft_alpha": 64,
"peft_dropout": 0.05,
"peft_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_ia3_peft(self):
"""
support for IA3 peft
:return:
"""
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "ia3",
"peft_r": 32,
"peft_alpha": 64,
"peft_dropout": 0.05,
"peft_target_modules": ["k_proj", "v_proj", "down_proj"],
"peft_feedforward_modules": ["down_proj"],
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_packing(self): def test_lora_packing(self):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp() output_dir = tempfile.mkdtemp()

View File

@@ -0,0 +1,48 @@
"""Module for testing the validation module"""
import logging
import unittest
from typing import Optional
import pytest
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
class NormalizationTest(unittest.TestCase):
"""
Test the cfg normalization module
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def test_lora_to_peft(self):
base_cfg = DictDefault(
{
"gradient_accumulation_steps": 1,
"micro_batch_size": 1,
"base_model": "NousResearch/Llama-2-7b-hf",
"base_model_config": "NousResearch/Llama-2-7b-hf",
}
)
cfg = base_cfg | DictDefault(
{
"adapter": "lora",
"lora_r": 128,
"lora_alpha": 64,
}
)
with self._caplog.at_level(logging.WARNING):
normalize_config(cfg)
assert any(
"soon to be deprecated. please use peft_" in record.message
for record in self._caplog.records
)
assert cfg.peft_r == 128
assert cfg.peft_alpha == 64