Compare commits

...

16 Commits

Author SHA1 Message Date
Wing Lian
d0b534292f Add e2e test for ia3 ft 2023-10-19 09:27:55 -04:00
Wing Lian
0bd89b38c6 migrate lora_ to peft_ 2023-10-18 22:22:54 -04:00
Wing Lian
481ef187a5 update README for IA3 peft 2023-10-18 22:18:39 -04:00
Wing Lian
d645b19fcf include task type for ia3 config 2023-10-18 22:18:39 -04:00
Wing Lian
203369411e consolidate as peft_model_dir 2023-10-18 22:18:37 -04:00
Wing Lian
ba85308720 Update src/axolotl/utils/models.py
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-10-18 22:17:38 -04:00
Wing Lian
998763bade ia3 keeps casting to float32, handle it here for now 2023-10-18 22:17:38 -04:00
Wing Lian
c8e42a0f4f fix load_in_8bit check 2023-10-18 22:17:38 -04:00
Wing Lian
1da328eb9a prepare ia3 for 8bit 2023-10-18 22:17:38 -04:00
Wing Lian
2d7cccfc8e add ia3 peft support 2023-10-18 22:17:38 -04:00
NanoCode012
440c3ab527 Fix(model): Linear detected and added to target module with rope linear (#738)
* Fix(model): Linear detected and added to target module with rope linear

* fix: exclude layer instead
2023-10-18 22:13:20 -04:00
Napuh
992d57f20a catch ConnectionError when checking dataset from HuggingFace (#743) 2023-10-18 22:11:54 -04:00
mhenrichsen
91a016f410 badge (#739)
* badge

* fixed text
2023-10-18 10:21:34 -04:00
Casper
a045db0214 Mistral: Sliding Window Attention with Flash Attention and Sample Packing (#732)
* Implement Mistral FA + SWA + Sample Packing

* Handle unbroadcastable tensor

* chore: lint

* Simplify _prepare_decoder_attention_mask

* Uncomment window size

* Upgrade flash-attn to minimum of 2.3.0 to support SWA

* Add original condition to avoid error during inference

* chore: lint

* use torchscript to prevent oom

* chore: pylint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-10-16 15:13:46 -04:00
Casper
e1b214c62b Clarify custom format example (#729)
* Clarify custom prompt format

* Simplify format
2023-10-14 09:28:12 -04:00
Wing Lian
3553172e3c fixes for alpaca w chatml, and don't include attention_mask w mistral for flash attention (#728) 2023-10-14 09:27:07 -04:00
43 changed files with 497 additions and 99 deletions

View File

@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error, disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods, too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation, too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-boolean-expressions,

View File

@@ -96,7 +96,7 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference # inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --peft_model_dir="./lora-out"
``` ```
## Installation ## Installation
@@ -297,25 +297,24 @@ Have dataset(s) in one of the following format (JSONL recommended):
#### How to add custom prompts #### How to add custom prompts
Using yaml. Example: For a dataset that is preprocessed for instruction purposes:
```json
{"instruction": "...", "output": "..."}
```
You can use this example in your YAML config:
```yaml ```yaml
datasets: datasets:
- path: repo - path: repo
type: type:
system_prompt: "" system_prompt: ""
no_input_format: |- field_system: system
User: {instruction}<|end_of_turn|> format: "[INST] {instruction} [/INST]"
Assistant: no_input_format: "[INST] {instruction} [/INST]"
format: |-
User: {instruction}
{input}<|end_of_turn|>
Assistant:
``` ```
Using file:
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
#### How to use your custom pretokenized dataset #### How to use your custom pretokenized dataset
- Do not pass a `type:` - Do not pass a `type:`
@@ -385,10 +384,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- lora - lora
```yaml ```yaml
adapter: lora # qlora or leave blank for full finetune adapter: lora # qlora or leave blank for full finetune
lora_r: 8 peft_r: 8
lora_alpha: 16 peft_alpha: 16
lora_dropout: 0.05 peft_dropout: 0.05
lora_target_modules: peft_target_modules:
- q_proj - q_proj
- v_proj - v_proj
``` ```
@@ -532,15 +531,15 @@ total_num_tokens:
adapter: lora adapter: lora
# If you already have a lora model trained that you want to load, put that here. # If you already have a lora model trained that you want to load, put that here.
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`. # This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
lora_model_dir: peft_model_dir:
# LoRA hyperparameters # LoRA hyperparameters
# For more details about the following options, see: # For more details about the following options, see:
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2 # https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
lora_r: 8 peft_r: 8
lora_alpha: 16 peft_alpha: 16
lora_dropout: 0.05 peft_dropout: 0.05
lora_target_modules: peft_target_modules:
- q_proj - q_proj
- v_proj - v_proj
# - k_proj # - k_proj
@@ -548,13 +547,13 @@ lora_target_modules:
# - gate_proj # - gate_proj
# - down_proj # - down_proj
# - up_proj # - up_proj
lora_target_linear: # If true, will target all linear layers peft_target_linear: # if true, will target all linear layers
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens. # If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models. # For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities. # `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994 # https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
lora_modules_to_save: peft_modules_to_save:
# - embed_tokens # - embed_tokens
# - lm_head # - lm_head
@@ -562,7 +561,8 @@ lora_modules_to_save:
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory. # If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
# Make sure `lora_model_dir` points to this directory if you want to use the trained model. # Make sure `lora_model_dir` points to this directory if you want to use the trained model.
lora_out_dir: lora_out_dir:
lora_fan_in_fan_out: false peft_fan_in_fan_out: false
peft_feedforward_modules: # ffn modules for IA3, for llama down projection
# ReLoRA configuration # ReLoRA configuration
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
@@ -870,7 +870,7 @@ Pass the appropriate flag to the train command:
- Pretrained LORA: - Pretrained LORA:
```bash ```bash
python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir" python -m axolotl.cli.inference examples/your_config.yml --peft_model_dir="./lora-output-dir"
``` ```
- Full weights finetune: - Full weights finetune:
```bash ```bash
@@ -891,7 +891,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi
Add below flag to train command above Add below flag to train command above
```bash ```bash
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False python3 -m axolotl.cli.merge_lora examples/your_config.yml --peft_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
``` ```
If you run out of CUDA memory, you can try to merge in system RAM with If you run out of CUDA memory, you can try to merge in system RAM with

View File

@@ -18,7 +18,7 @@ dataset_prepared_path: last_prepared_run
val_set_size: 0.01 val_set_size: 0.01
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
sample_packing: false sample_packing: false

View File

@@ -10,7 +10,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: 2048 max_packed_sequence_len: 2048
lora_r: 16 lora_r: 16

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -15,7 +15,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 16 lora_r: 16

View File

@@ -22,7 +22,7 @@ dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
# enable QLoRA # enable QLoRA
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:

View File

@@ -15,7 +15,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 64 lora_r: 64

View File

@@ -10,7 +10,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 8 lora_r: 8

View File

@@ -9,7 +9,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 512 sequence_len: 512
max_packed_sequence_len: max_packed_sequence_len:
lora_r: lora_r:

View File

@@ -18,7 +18,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: sample_packing:
lora_r: 8 lora_r: 8

72
examples/llama-2/ia3.yml Normal file
View File

@@ -0,0 +1,72 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./ia3-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: ia3
peft_model_dir:
peft_target_modules:
- k_proj
- v_proj
- down_proj
peft_feedforward_modules:
- down_proj
peft_fan_in_fan_out: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 5
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens:
save_steps:
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -20,7 +20,7 @@ sample_packing: true
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -16,7 +16,7 @@ val_set_size: 0.01
output_dir: ./relora-out output_dir: ./relora-out
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true

View File

@@ -20,7 +20,7 @@ sequence_len: 4096
sample_packing: true sample_packing: true
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -9,7 +9,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 8 lora_r: 8

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 1024 sequence_len: 1024
sample_packing: true sample_packing: true
lora_r: lora_r:

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
sequence_len: 1024 sequence_len: 1024
sample_packing: true sample_packing: true
lora_r: 8 lora_r: 8

View File

@@ -12,7 +12,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 1024 sequence_len: 1024
sample_packing: true sample_packing: true
lora_r: 8 lora_r: 8

View File

@@ -22,7 +22,7 @@ sample_packing: true
pad_to_sequence_len: pad_to_sequence_len:
adapter: adapter:
lora_model_dir: peft_model_dir:
lora_r: lora_r:
lora_alpha: lora_alpha:
lora_dropout: lora_dropout:

View File

@@ -22,7 +22,7 @@ sample_packing: false # not CURRENTLY compatible with LoRAs
pad_to_sequence_len: pad_to_sequence_len:
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
lora_r: 64 lora_r: 64
lora_alpha: 32 lora_alpha: 32
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -13,7 +13,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: 2048 max_packed_sequence_len: 2048
lora_r: 64 lora_r: 64

View File

@@ -7,7 +7,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
sequence_len: 512 sequence_len: 512
lora_r: 16 lora_r: 16
lora_alpha: 32 lora_alpha: 32

View File

@@ -10,7 +10,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 8 lora_r: 8

View File

@@ -8,7 +8,7 @@ datasets:
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.05 val_set_size: 0.05
adapter: lora adapter: lora
lora_model_dir: peft_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: max_packed_sequence_len:
lora_r: 8 lora_r: 8

View File

@@ -20,7 +20,7 @@ dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.01
# enable QLoRA # enable QLoRA
adapter: qlora adapter: qlora
lora_model_dir: peft_model_dir:
sequence_len: 8192 sequence_len: 8192
max_packed_sequence_len: max_packed_sequence_len:

BIN
image/sticker_fixed.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 370 KiB

View File

@@ -46,7 +46,7 @@ setup(
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "flash-attn": [
"flash-attn>=2.2.1", "flash-attn>=2.3.0",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed", "deepspeed",

View File

@@ -116,6 +116,8 @@ def flashattn_forward(
attention_mask: [bsz, q_len] attention_mask: [bsz, q_len]
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
original_dtype = hidden_states.dtype
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"): if not hasattr(self, "pretraining_tp"):
@@ -151,6 +153,13 @@ def flashattn_forward(
key_states = self.k_proj(hidden_states) key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states) value_states = self.v_proj(hidden_states)
if query_states.dtype == torch.float32:
query_states = query_states.to(dtype=original_dtype)
if key_states.dtype == torch.float32:
key_states = key_states.to(dtype=original_dtype)
if value_states.dtype == torch.float32:
value_states = value_states.to(dtype=original_dtype)
query_states = query_states.view( query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2) ).transpose(1, 2)
@@ -309,6 +318,10 @@ def flashattn_forward(
else: else:
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
# handle conversion back for IA3
if attn_output.dtype == torch.float32:
attn_output = attn_output.to(dtype=original_dtype)
return attn_output, None, past_key_value return attn_output, None, past_key_value
@@ -502,6 +515,7 @@ def llama_model_forward(
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds
original_dtype = hidden_states.dtype
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
@@ -559,6 +573,10 @@ def llama_model_forward(
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
# handle conversion back for IA3
if hidden_states.dtype == torch.float32:
hidden_states = hidden_states.to(dtype=original_dtype)
if use_cache: if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

View File

@@ -14,6 +14,9 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_func,
) )
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import (
MistralAttention as OriginalMistralAttention,
)
from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer, MistralDecoderLayer as OriginalMistralDecoderLayer,
) )
@@ -42,6 +45,44 @@ def replace_mistral_attn_with_flash_attn(
) )
@torch.jit.script
def _make_sliding_window_causal_mask(
bsz: int,
tgt_len: int,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: int = 4096,
):
"""
Make causal mask used for sliding window attention
"""
tensor = torch.full(
(tgt_len, tgt_len),
fill_value=1,
device=device,
)
mask = torch.tril(tensor, diagonal=0)
# make the mask banded to account for sliding window
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
mask = torch.triu(mask, diagonal=-sliding_window + 1)
mask = torch.log(mask).to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention # Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask # requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask( def _prepare_decoder_attention_mask(
@@ -53,11 +94,29 @@ def _prepare_decoder_attention_mask(
sliding_window, sliding_window,
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
# [bsz, seq_len] # [bsz, seq_len]
if attention_mask is None:
return attention_mask
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
sliding_window_mask = _make_sliding_window_causal_mask(
bsz=input_shape[0],
tgt_len=input_shape[1],
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
sliding_window=sliding_window,
)
attention_mask = attention_mask + sliding_window_mask
else:
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
return attention_mask return attention_mask
def flashattn_forward( def flashattn_forward(
self, self: OriginalMistralAttention,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
@@ -91,10 +150,41 @@ def flashattn_forward(
query_states, key_states, cos, sin, position_ids query_states, key_states, cos, sin, position_ids
) )
use_sliding_windows = (
hasattr(self.config, "sliding_window") is not None
and kv_seq_len > self.config.sliding_window
)
if use_sliding_windows:
window_size = (self.config.sliding_window, self.config.sliding_window)
else:
window_size = (-1, -1)
if past_key_value is not None: if past_key_value is not None:
# reuse k, v, self_attention # Activate slicing cache only if the config has a value `sliding_windows` attribute
key_states = torch.cat([past_key_value[0], key_states], dim=2) if (
value_states = torch.cat([past_key_value[1], value_states], dim=2) hasattr(self.config, "sliding_window")
and kv_seq_len > self.config.sliding_window
):
slicing_tokens = kv_seq_len - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
past_key_value = (past_key, past_value) if use_cache else None
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
@@ -120,7 +210,13 @@ def flashattn_forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...") qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func( output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True qkv,
cu_seqlens,
max_seqlen,
0.0,
softmax_scale=None,
causal=True,
window_size=window_size,
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape: elif query_states.shape == key_states.shape:
@@ -146,6 +242,7 @@ def flashattn_forward(
0.0, 0.0,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
window_size=window_size,
) )
output = output_pad_fn(output_unpad) output = output_pad_fn(output_unpad)
else: else:
@@ -157,6 +254,7 @@ def flashattn_forward(
query_states, query_states,
torch.stack([key_states, value_states], 2), torch.stack([key_states, value_states], 2),
causal=is_causal, causal=is_causal,
window_size=window_size,
) )
else: else:
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
@@ -191,6 +289,7 @@ def flashattn_forward(
0.0, 0.0,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
window_size=window_size,
) )
output = output_pad_fn(output_unpad) output = output_pad_fn(output_unpad)

View File

@@ -1,6 +1,6 @@
"""Module containing the AlpacaQAPromptTokenizingStrategy class""" """Module for Alpaca prompt strategy classes"""
from typing import Tuple from typing import Any, Dict, Optional, Tuple
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy, AlpacaPromptTokenizingStrategy,
@@ -9,9 +9,13 @@ from axolotl.prompt_tokenizers import (
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
def load(tokenizer, cfg): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
prompt_style = PromptStyle.CHAT.value
if ds_cfg and "conversation" in ds_cfg:
prompt_style = ds_cfg["conversation"]
return AlpacaPromptTokenizingStrategy( return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value), AlpacaPrompter(prompt_style),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,

View File

@@ -121,6 +121,18 @@ def normalize_config(cfg):
log_gpu_memory_usage(LOG, "baseline", cfg.device) log_gpu_memory_usage(LOG, "baseline", cfg.device)
if cfg.adapter is not None:
for key in list(cfg.keys()):
if key.startswith("lora_"):
new_key = key.replace("lora_", "peft_")
LOG.warning(
PendingDeprecationWarning(
f"{key} soon to be deprecated. please use {new_key}"
)
)
cfg[new_key] = cfg[key]
del cfg[key]
def validate_config(cfg): def validate_config(cfg):
if is_torch_bf16_gpu_available(): if is_torch_bf16_gpu_available():
@@ -190,7 +202,10 @@ def validate_config(cfg):
raise ValueError("Require cfg.load_in_4bit to be True for qlora") raise ValueError("Require cfg.load_in_4bit to be True for qlora")
if not cfg.load_in_8bit and cfg.adapter == "lora": if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") LOG.warning("We recommend setting `load_in_8bit: true` for LoRA finetuning")
if not cfg.load_in_8bit and cfg.adapter == "ia3":
LOG.warning("We recommend setting `load_in_8bit: true` for IA3 finetuning")
if cfg.relora_steps: if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"): if cfg.adapter not in ("lora", "qlora"):

View File

@@ -158,7 +158,7 @@ def load_tokenized_prepared_datasets(
token=use_auth_token, token=use_auth_token,
) )
ds_from_hub = True ds_from_hub = True
except FileNotFoundError: except (FileNotFoundError, ConnectionError):
pass pass
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists

View File

@@ -406,21 +406,21 @@ def load_model(
if hasattr(module, "weight"): if hasattr(module, "weight"):
module.to(torch.float32) module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp require_peft: bool = False
if (cfg.adapter == "lora" and load_in_8bit) or ( if cfg.adapter in ["lora", "qlora", "ia3"]:
cfg.adapter == "qlora" and cfg.load_in_4bit require_peft = True
):
if require_peft:
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing: if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training( model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing model, use_gradient_checkpointing=cfg.gradient_checkpointing
) )
needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility. # convert them back to fp16/bf16 for flash-attn compatibility.
if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model): if require_peft or cfg.fsdp or (cfg.flash_attention and cfg.is_llama_derived_model):
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules(): for name, module in model.named_modules():
if "norm" in name: if "norm" in name:
@@ -429,7 +429,7 @@ def load_model(
if hasattr(module, "weight"): if hasattr(module, "weight"):
module.to(cfg.torch_dtype) module.to(cfg.torch_dtype)
model, lora_config = load_adapter(model, cfg, cfg.adapter) model, peft_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit: if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")
@@ -460,7 +460,7 @@ def load_model(
log_gpu_memory_usage(LOG, "after adapters", model.device) log_gpu_memory_usage(LOG, "after adapters", model.device)
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
return model, lora_config return model, peft_config
def load_adapter(model, cfg, adapter, inference=False): def load_adapter(model, cfg, adapter, inference=False):
@@ -470,6 +470,8 @@ def load_adapter(model, cfg, adapter, inference=False):
return model, None return model, None
if hasattr(model, "enable_input_require_grads"): if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads() model.enable_input_require_grads()
if adapter == "ia3":
return load_ia3(model, cfg, inference=inference)
if adapter in ["lora", "qlora"]: if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference) return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter": if adapter == "llama-adapter":
@@ -488,11 +490,11 @@ def load_llama_adapter(model, cfg):
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
) )
if cfg.lora_model_dir: if cfg.peft_model_dir:
LOG.debug("Loading pretained PEFT - llama_adapter") LOG.debug("Loading pretained PEFT - llama_adapter")
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.lora_model_dir, cfg.peft_model_dir,
torch_dtype=torch.float16, torch_dtype=torch.float16,
) )
else: else:
@@ -505,16 +507,20 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model): def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set() peft_module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, cls) or "Linear" in module.__class__.__name__: if (
isinstance(module, cls)
or "Linear" in module.__class__.__name__
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
):
names = name.split(".") names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) peft_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in lora_module_names: # needed for 16-bit if "lm_head" in peft_module_names: # needed for 16-bit
lora_module_names.remove("lm_head") peft_module_names.remove("lm_head")
return list(lora_module_names) return list(peft_module_names)
def load_lora(model, cfg, inference=False): def load_lora(model, cfg, inference=False):
@@ -522,34 +528,68 @@ def load_lora(model, cfg, inference=False):
from peft import LoraConfig, PeftModel, get_peft_model from peft import LoraConfig, PeftModel, get_peft_model
lora_target_modules = list(cfg.lora_target_modules or []) peft_target_modules = list(cfg.peft_target_modules or [])
if cfg.lora_target_linear: if cfg.peft_target_linear:
linear_names = find_all_linear_names(model) linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(linear_names)}") LOG.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(set(lora_target_modules + linear_names)) peft_target_modules = list(set(peft_target_modules + linear_names))
lora_config = LoraConfig( peft_config = LoraConfig(
r=cfg.lora_r, r=cfg.peft_r,
lora_alpha=cfg.lora_alpha, lora_alpha=cfg.peft_alpha,
target_modules=lora_target_modules, target_modules=peft_target_modules,
lora_dropout=cfg.lora_dropout, lora_dropout=cfg.peft_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out, fan_in_fan_out=cfg.peft_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, modules_to_save=cfg.peft_modules_to_save if cfg.peft_modules_to_save else None,
bias="none", bias="none",
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
) )
if cfg.lora_model_dir: if cfg.peft_model_dir:
LOG.debug("Loading pretained PEFT - LoRA") LOG.debug("Loading pretained PEFT - LoRA")
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.lora_model_dir, cfg.peft_model_dir,
is_trainable=(not inference), is_trainable=(not inference),
) )
else: else:
model = get_peft_model(model, lora_config) model = get_peft_model(model, peft_config)
model.print_trainable_parameters() model.print_trainable_parameters()
return model, lora_config return model, peft_config
def load_ia3(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import IA3Config, PeftModel, get_peft_model
peft_config_kwargs = {}
if cfg.peft_init_ia3_weights is not None:
peft_config_kwargs["init_ia3_weights"] = cfg.peft_init_ia3_weights
if cfg.peft_fan_in_fan_out is not None:
peft_config_kwargs["fan_in_fan_out"] = cfg.peft_fan_in_fan_out
peft_config = IA3Config(
target_modules=cfg.peft_target_modules,
feedforward_modules=cfg.peft_feedforward_modules,
modules_to_save=cfg.peft_modules_to_save,
task_type="CAUSAL_LM",
**peft_config_kwargs,
)
if cfg.peft_model_dir:
LOG.debug("Loading pretained PEFT - IA3")
model = PeftModel.from_pretrained(
model,
cfg.peft_model_dir,
is_trainable=(not inference),
)
else:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model, peft_config

View File

@@ -423,7 +423,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
) )
# Phi doesn't want the attention_mask feature when training # Phi doesn't want the attention_mask feature when training
if "CodeGenTokenizer" in tokenizer.__class__.__name__: if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
cfg.is_mistral_derived_model and cfg.flash_attention
):
train_dataset = train_dataset.remove_columns("attention_mask") train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask") eval_dataset = eval_dataset.remove_columns("attention_mask")

View File

@@ -24,6 +24,10 @@ class TestLoraLlama(unittest.TestCase):
""" """
def test_lora(self): def test_lora(self):
"""
support for legacy lora_ configs
:return:
"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp() output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
@@ -66,6 +70,101 @@ class TestLoraLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists() assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_peft(self):
"""
support for legacy lora_ configs
:return:
"""
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"peft_r": 32,
"peft_alpha": 64,
"peft_dropout": 0.05,
"peft_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_ia3_peft(self):
"""
support for IA3 peft
:return:
"""
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "ia3",
"peft_r": 32,
"peft_alpha": 64,
"peft_dropout": 0.05,
"peft_target_modules": ["k_proj", "v_proj", "down_proj"],
"peft_feedforward_modules": ["down_proj"],
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_packing(self): def test_lora_packing(self):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp() output_dir = tempfile.mkdtemp()

View File

@@ -0,0 +1,48 @@
"""Module for testing the validation module"""
import logging
import unittest
from typing import Optional
import pytest
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
class NormalizationTest(unittest.TestCase):
"""
Test the cfg normalization module
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def test_lora_to_peft(self):
base_cfg = DictDefault(
{
"gradient_accumulation_steps": 1,
"micro_batch_size": 1,
"base_model": "NousResearch/Llama-2-7b-hf",
"base_model_config": "NousResearch/Llama-2-7b-hf",
}
)
cfg = base_cfg | DictDefault(
{
"adapter": "lora",
"lora_r": 128,
"lora_alpha": 64,
}
)
with self._caplog.at_level(logging.WARNING):
normalize_config(cfg)
assert any(
"soon to be deprecated. please use peft_" in record.message
for record in self._caplog.records
)
assert cfg.peft_r == 128
assert cfg.peft_alpha == 64