Compare commits
10 Commits
feat/soap-
...
ia3-peft
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d0b534292f | ||
|
|
0bd89b38c6 | ||
|
|
481ef187a5 | ||
|
|
d645b19fcf | ||
|
|
203369411e | ||
|
|
ba85308720 | ||
|
|
998763bade | ||
|
|
c8e42a0f4f | ||
|
|
1da328eb9a | ||
|
|
2d7cccfc8e |
@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
|
||||
disable=missing-function-docstring, line-too-long, import-error,
|
||||
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
||||
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
||||
too-many-boolean-expressions,
|
||||
|
||||
31
README.md
31
README.md
@@ -96,7 +96,7 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||
|
||||
# inference
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--lora_model_dir="./lora-out"
|
||||
--peft_model_dir="./lora-out"
|
||||
```
|
||||
|
||||
## Installation
|
||||
@@ -384,10 +384,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- lora
|
||||
```yaml
|
||||
adapter: lora # qlora or leave blank for full finetune
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
peft_r: 8
|
||||
peft_alpha: 16
|
||||
peft_dropout: 0.05
|
||||
peft_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
```
|
||||
@@ -531,15 +531,15 @@ total_num_tokens:
|
||||
adapter: lora
|
||||
# If you already have a lora model trained that you want to load, put that here.
|
||||
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
|
||||
# LoRA hyperparameters
|
||||
# For more details about the following options, see:
|
||||
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
peft_r: 8
|
||||
peft_alpha: 16
|
||||
peft_dropout: 0.05
|
||||
peft_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
@@ -547,13 +547,13 @@ lora_target_modules:
|
||||
# - gate_proj
|
||||
# - down_proj
|
||||
# - up_proj
|
||||
lora_target_linear: # If true, will target all linear layers
|
||||
peft_target_linear: # if true, will target all linear layers
|
||||
|
||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
||||
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
|
||||
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
|
||||
lora_modules_to_save:
|
||||
peft_modules_to_save:
|
||||
# - embed_tokens
|
||||
# - lm_head
|
||||
|
||||
@@ -561,7 +561,8 @@ lora_modules_to_save:
|
||||
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
|
||||
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
|
||||
lora_out_dir:
|
||||
lora_fan_in_fan_out: false
|
||||
peft_fan_in_fan_out: false
|
||||
peft_feedforward_modules: # ffn modules for IA3, for llama down projection
|
||||
|
||||
# ReLoRA configuration
|
||||
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||
@@ -869,7 +870,7 @@ Pass the appropriate flag to the train command:
|
||||
|
||||
- Pretrained LORA:
|
||||
```bash
|
||||
python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir"
|
||||
python -m axolotl.cli.inference examples/your_config.yml --peft_model_dir="./lora-output-dir"
|
||||
```
|
||||
- Full weights finetune:
|
||||
```bash
|
||||
@@ -890,7 +891,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi
|
||||
Add below flag to train command above
|
||||
|
||||
```bash
|
||||
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||
python3 -m axolotl.cli.merge_lora examples/your_config.yml --peft_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||
```
|
||||
|
||||
If you run out of CUDA memory, you can try to merge in system RAM with
|
||||
|
||||
@@ -18,7 +18,7 @@ dataset_prepared_path: last_prepared_run
|
||||
val_set_size: 0.01
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
sample_packing: false
|
||||
|
||||
@@ -10,7 +10,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 16
|
||||
|
||||
@@ -20,7 +20,7 @@ sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
@@ -20,7 +20,7 @@ sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
@@ -20,7 +20,7 @@ sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
@@ -15,7 +15,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 16
|
||||
|
||||
@@ -22,7 +22,7 @@ dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 64
|
||||
|
||||
@@ -10,7 +10,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
|
||||
@@ -9,7 +9,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 512
|
||||
max_packed_sequence_len:
|
||||
lora_r:
|
||||
|
||||
@@ -18,7 +18,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 4096
|
||||
sample_packing:
|
||||
lora_r: 8
|
||||
|
||||
72
examples/llama-2/ia3.yml
Normal file
72
examples/llama-2/ia3.yml
Normal file
@@ -0,0 +1,72 @@
|
||||
base_model: meta-llama/Llama-2-7b-hf
|
||||
base_model_config: meta-llama/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./ia3-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: ia3
|
||||
peft_model_dir:
|
||||
peft_target_modules:
|
||||
- k_proj
|
||||
- v_proj
|
||||
- down_proj
|
||||
peft_feedforward_modules:
|
||||
- down_proj
|
||||
peft_fan_in_fan_out: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 5
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens:
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -20,7 +20,7 @@ sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
||||
output_dir: ./relora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
@@ -20,7 +20,7 @@ sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -9,7 +9,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
|
||||
@@ -12,7 +12,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 1024
|
||||
sample_packing: true
|
||||
lora_r:
|
||||
|
||||
@@ -12,7 +12,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.02
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 1024
|
||||
sample_packing: true
|
||||
lora_r: 8
|
||||
|
||||
@@ -12,7 +12,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 1024
|
||||
sample_packing: true
|
||||
lora_r: 8
|
||||
|
||||
@@ -22,7 +22,7 @@ sample_packing: true
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
|
||||
@@ -22,7 +22,7 @@ sample_packing: false # not CURRENTLY compatible with LoRAs
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
lora_r: 64
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -13,7 +13,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 64
|
||||
|
||||
@@ -7,7 +7,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 512
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
|
||||
@@ -10,7 +10,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
|
||||
@@ -8,7 +8,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
|
||||
@@ -20,7 +20,7 @@ dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
peft_model_dir:
|
||||
sequence_len: 8192
|
||||
max_packed_sequence_len:
|
||||
|
||||
|
||||
@@ -116,6 +116,8 @@ def flashattn_forward(
|
||||
attention_mask: [bsz, q_len]
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if not hasattr(self, "pretraining_tp"):
|
||||
@@ -151,6 +153,13 @@ def flashattn_forward(
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
if query_states.dtype == torch.float32:
|
||||
query_states = query_states.to(dtype=original_dtype)
|
||||
if key_states.dtype == torch.float32:
|
||||
key_states = key_states.to(dtype=original_dtype)
|
||||
if value_states.dtype == torch.float32:
|
||||
value_states = value_states.to(dtype=original_dtype)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
@@ -309,6 +318,10 @@ def flashattn_forward(
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
# handle conversion back for IA3
|
||||
if attn_output.dtype == torch.float32:
|
||||
attn_output = attn_output.to(dtype=original_dtype)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
@@ -502,6 +515,7 @@ def llama_model_forward(
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
@@ -559,6 +573,10 @@ def llama_model_forward(
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
# handle conversion back for IA3
|
||||
if hidden_states.dtype == torch.float32:
|
||||
hidden_states = hidden_states.to(dtype=original_dtype)
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
|
||||
@@ -121,6 +121,18 @@ def normalize_config(cfg):
|
||||
|
||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||
|
||||
if cfg.adapter is not None:
|
||||
for key in list(cfg.keys()):
|
||||
if key.startswith("lora_"):
|
||||
new_key = key.replace("lora_", "peft_")
|
||||
LOG.warning(
|
||||
PendingDeprecationWarning(
|
||||
f"{key} soon to be deprecated. please use {new_key}"
|
||||
)
|
||||
)
|
||||
cfg[new_key] = cfg[key]
|
||||
del cfg[key]
|
||||
|
||||
|
||||
def validate_config(cfg):
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -190,7 +202,10 @@ def validate_config(cfg):
|
||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||
|
||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LoRA finetuning")
|
||||
|
||||
if not cfg.load_in_8bit and cfg.adapter == "ia3":
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for IA3 finetuning")
|
||||
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter not in ("lora", "qlora"):
|
||||
|
||||
@@ -406,21 +406,21 @@ def load_model(
|
||||
if hasattr(module, "weight"):
|
||||
module.to(torch.float32)
|
||||
|
||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||
if (cfg.adapter == "lora" and load_in_8bit) or (
|
||||
cfg.adapter == "qlora" and cfg.load_in_4bit
|
||||
):
|
||||
require_peft: bool = False
|
||||
if cfg.adapter in ["lora", "qlora", "ia3"]:
|
||||
require_peft = True
|
||||
|
||||
if require_peft:
|
||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||
if cfg.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||
)
|
||||
needs_fa2_dtype = True
|
||||
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model):
|
||||
if require_peft or cfg.fsdp or (cfg.flash_attention and cfg.is_llama_derived_model):
|
||||
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
@@ -429,7 +429,7 @@ def load_model(
|
||||
if hasattr(module, "weight"):
|
||||
module.to(cfg.torch_dtype)
|
||||
|
||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||
model, peft_config = load_adapter(model, cfg, cfg.adapter)
|
||||
|
||||
if cfg.ddp and not load_in_8bit:
|
||||
model.to(f"cuda:{cfg.local_rank}")
|
||||
@@ -460,7 +460,7 @@ def load_model(
|
||||
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
||||
|
||||
# TODO resume_from_checkpoint handling
|
||||
return model, lora_config
|
||||
return model, peft_config
|
||||
|
||||
|
||||
def load_adapter(model, cfg, adapter, inference=False):
|
||||
@@ -470,6 +470,8 @@ def load_adapter(model, cfg, adapter, inference=False):
|
||||
return model, None
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
if adapter == "ia3":
|
||||
return load_ia3(model, cfg, inference=inference)
|
||||
if adapter in ["lora", "qlora"]:
|
||||
return load_lora(model, cfg, inference=inference)
|
||||
if adapter == "llama-adapter":
|
||||
@@ -488,11 +490,11 @@ def load_llama_adapter(model, cfg):
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
if cfg.peft_model_dir:
|
||||
LOG.debug("Loading pretained PEFT - llama_adapter")
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
cfg.lora_model_dir,
|
||||
cfg.peft_model_dir,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
else:
|
||||
@@ -505,7 +507,7 @@ def load_llama_adapter(model, cfg):
|
||||
|
||||
def find_all_linear_names(model):
|
||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
||||
lora_module_names = set()
|
||||
peft_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if (
|
||||
isinstance(module, cls)
|
||||
@@ -513,12 +515,12 @@ def find_all_linear_names(model):
|
||||
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
|
||||
):
|
||||
names = name.split(".")
|
||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||
peft_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||
|
||||
if "lm_head" in lora_module_names: # needed for 16-bit
|
||||
lora_module_names.remove("lm_head")
|
||||
if "lm_head" in peft_module_names: # needed for 16-bit
|
||||
peft_module_names.remove("lm_head")
|
||||
|
||||
return list(lora_module_names)
|
||||
return list(peft_module_names)
|
||||
|
||||
|
||||
def load_lora(model, cfg, inference=False):
|
||||
@@ -526,34 +528,68 @@ def load_lora(model, cfg, inference=False):
|
||||
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
lora_target_modules = list(cfg.lora_target_modules or [])
|
||||
peft_target_modules = list(cfg.peft_target_modules or [])
|
||||
|
||||
if cfg.lora_target_linear:
|
||||
if cfg.peft_target_linear:
|
||||
linear_names = find_all_linear_names(model)
|
||||
LOG.info(f"found linear modules: {repr(linear_names)}")
|
||||
lora_target_modules = list(set(lora_target_modules + linear_names))
|
||||
peft_target_modules = list(set(peft_target_modules + linear_names))
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=cfg.lora_r,
|
||||
lora_alpha=cfg.lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
lora_dropout=cfg.lora_dropout,
|
||||
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||
peft_config = LoraConfig(
|
||||
r=cfg.peft_r,
|
||||
lora_alpha=cfg.peft_alpha,
|
||||
target_modules=peft_target_modules,
|
||||
lora_dropout=cfg.peft_dropout,
|
||||
fan_in_fan_out=cfg.peft_fan_in_fan_out,
|
||||
modules_to_save=cfg.peft_modules_to_save if cfg.peft_modules_to_save else None,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
if cfg.peft_model_dir:
|
||||
LOG.debug("Loading pretained PEFT - LoRA")
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
cfg.lora_model_dir,
|
||||
cfg.peft_model_dir,
|
||||
is_trainable=(not inference),
|
||||
)
|
||||
else:
|
||||
model = get_peft_model(model, lora_config)
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
model.print_trainable_parameters()
|
||||
|
||||
return model, lora_config
|
||||
return model, peft_config
|
||||
|
||||
|
||||
def load_ia3(model, cfg, inference=False):
|
||||
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
|
||||
from peft import IA3Config, PeftModel, get_peft_model
|
||||
|
||||
peft_config_kwargs = {}
|
||||
if cfg.peft_init_ia3_weights is not None:
|
||||
peft_config_kwargs["init_ia3_weights"] = cfg.peft_init_ia3_weights
|
||||
if cfg.peft_fan_in_fan_out is not None:
|
||||
peft_config_kwargs["fan_in_fan_out"] = cfg.peft_fan_in_fan_out
|
||||
|
||||
peft_config = IA3Config(
|
||||
target_modules=cfg.peft_target_modules,
|
||||
feedforward_modules=cfg.peft_feedforward_modules,
|
||||
modules_to_save=cfg.peft_modules_to_save,
|
||||
task_type="CAUSAL_LM",
|
||||
**peft_config_kwargs,
|
||||
)
|
||||
|
||||
if cfg.peft_model_dir:
|
||||
LOG.debug("Loading pretained PEFT - IA3")
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
cfg.peft_model_dir,
|
||||
is_trainable=(not inference),
|
||||
)
|
||||
else:
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
model.print_trainable_parameters()
|
||||
|
||||
return model, peft_config
|
||||
|
||||
@@ -24,6 +24,10 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def test_lora(self):
|
||||
"""
|
||||
support for legacy lora_ configs
|
||||
:return:
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
@@ -66,6 +70,101 @@ class TestLoraLlama(unittest.TestCase):
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
|
||||
def test_lora_peft(self):
|
||||
"""
|
||||
support for legacy lora_ configs
|
||||
:return:
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"base_model_config": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"peft_r": 32,
|
||||
"peft_alpha": 64,
|
||||
"peft_dropout": 0.05,
|
||||
"peft_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
|
||||
def test_ia3_peft(self):
|
||||
"""
|
||||
support for IA3 peft
|
||||
:return:
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"base_model_config": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "ia3",
|
||||
"peft_r": 32,
|
||||
"peft_alpha": 64,
|
||||
"peft_dropout": 0.05,
|
||||
"peft_target_modules": ["k_proj", "v_proj", "down_proj"],
|
||||
"peft_feedforward_modules": ["down_proj"],
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
|
||||
def test_lora_packing(self):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
|
||||
48
tests/test_cfg_normalization.py
Normal file
48
tests/test_cfg_normalization.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Module for testing the validation module"""
|
||||
|
||||
import logging
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class NormalizationTest(unittest.TestCase):
|
||||
"""
|
||||
Test the cfg normalization module
|
||||
"""
|
||||
|
||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
|
||||
def test_lora_to_peft(self):
|
||||
base_cfg = DictDefault(
|
||||
{
|
||||
"gradient_accumulation_steps": 1,
|
||||
"micro_batch_size": 1,
|
||||
"base_model": "NousResearch/Llama-2-7b-hf",
|
||||
"base_model_config": "NousResearch/Llama-2-7b-hf",
|
||||
}
|
||||
)
|
||||
cfg = base_cfg | DictDefault(
|
||||
{
|
||||
"adapter": "lora",
|
||||
"lora_r": 128,
|
||||
"lora_alpha": 64,
|
||||
}
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
normalize_config(cfg)
|
||||
assert any(
|
||||
"soon to be deprecated. please use peft_" in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
assert cfg.peft_r == 128
|
||||
assert cfg.peft_alpha == 64
|
||||
Reference in New Issue
Block a user