Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
080612219b | ||
|
|
f95858d369 |
@@ -12,4 +12,3 @@ generated-members=numpy.*, torch.*
|
|||||||
disable=missing-function-docstring, line-too-long, import-error,
|
disable=missing-function-docstring, line-too-long, import-error,
|
||||||
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
||||||
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
||||||
too-many-boolean-expressions,
|
|
||||||
|
|||||||
54
README.md
54
README.md
@@ -96,7 +96,7 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
|||||||
|
|
||||||
# inference
|
# inference
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
--peft_model_dir="./lora-out"
|
--lora_model_dir="./lora-out"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
@@ -297,24 +297,25 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
|
|
||||||
#### How to add custom prompts
|
#### How to add custom prompts
|
||||||
|
|
||||||
For a dataset that is preprocessed for instruction purposes:
|
Using yaml. Example:
|
||||||
|
|
||||||
```json
|
|
||||||
{"instruction": "...", "output": "..."}
|
|
||||||
```
|
|
||||||
|
|
||||||
You can use this example in your YAML config:
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
datasets:
|
datasets:
|
||||||
- path: repo
|
- path: repo
|
||||||
type:
|
type:
|
||||||
system_prompt: ""
|
system_prompt: ""
|
||||||
field_system: system
|
no_input_format: |-
|
||||||
format: "[INST] {instruction} [/INST]"
|
User: {instruction}<|end_of_turn|>
|
||||||
no_input_format: "[INST] {instruction} [/INST]"
|
Assistant:
|
||||||
|
format: |-
|
||||||
|
User: {instruction}
|
||||||
|
{input}<|end_of_turn|>
|
||||||
|
Assistant:
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Using file:
|
||||||
|
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
||||||
|
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
|
||||||
|
|
||||||
#### How to use your custom pretokenized dataset
|
#### How to use your custom pretokenized dataset
|
||||||
|
|
||||||
- Do not pass a `type:`
|
- Do not pass a `type:`
|
||||||
@@ -384,10 +385,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
- lora
|
- lora
|
||||||
```yaml
|
```yaml
|
||||||
adapter: lora # qlora or leave blank for full finetune
|
adapter: lora # qlora or leave blank for full finetune
|
||||||
peft_r: 8
|
lora_r: 8
|
||||||
peft_alpha: 16
|
lora_alpha: 16
|
||||||
peft_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
peft_target_modules:
|
lora_target_modules:
|
||||||
- q_proj
|
- q_proj
|
||||||
- v_proj
|
- v_proj
|
||||||
```
|
```
|
||||||
@@ -531,15 +532,15 @@ total_num_tokens:
|
|||||||
adapter: lora
|
adapter: lora
|
||||||
# If you already have a lora model trained that you want to load, put that here.
|
# If you already have a lora model trained that you want to load, put that here.
|
||||||
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
|
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
# LoRA hyperparameters
|
# LoRA hyperparameters
|
||||||
# For more details about the following options, see:
|
# For more details about the following options, see:
|
||||||
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
|
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
|
||||||
peft_r: 8
|
lora_r: 8
|
||||||
peft_alpha: 16
|
lora_alpha: 16
|
||||||
peft_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
peft_target_modules:
|
lora_target_modules:
|
||||||
- q_proj
|
- q_proj
|
||||||
- v_proj
|
- v_proj
|
||||||
# - k_proj
|
# - k_proj
|
||||||
@@ -547,13 +548,13 @@ peft_target_modules:
|
|||||||
# - gate_proj
|
# - gate_proj
|
||||||
# - down_proj
|
# - down_proj
|
||||||
# - up_proj
|
# - up_proj
|
||||||
peft_target_linear: # if true, will target all linear layers
|
lora_target_linear: # If true, will target all linear layers
|
||||||
|
|
||||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
||||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
||||||
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
|
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
|
||||||
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
|
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
|
||||||
peft_modules_to_save:
|
lora_modules_to_save:
|
||||||
# - embed_tokens
|
# - embed_tokens
|
||||||
# - lm_head
|
# - lm_head
|
||||||
|
|
||||||
@@ -561,8 +562,7 @@ peft_modules_to_save:
|
|||||||
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
|
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
|
||||||
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
|
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
|
||||||
lora_out_dir:
|
lora_out_dir:
|
||||||
peft_fan_in_fan_out: false
|
lora_fan_in_fan_out: false
|
||||||
peft_feedforward_modules: # ffn modules for IA3, for llama down projection
|
|
||||||
|
|
||||||
# ReLoRA configuration
|
# ReLoRA configuration
|
||||||
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||||
@@ -870,7 +870,7 @@ Pass the appropriate flag to the train command:
|
|||||||
|
|
||||||
- Pretrained LORA:
|
- Pretrained LORA:
|
||||||
```bash
|
```bash
|
||||||
python -m axolotl.cli.inference examples/your_config.yml --peft_model_dir="./lora-output-dir"
|
python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir"
|
||||||
```
|
```
|
||||||
- Full weights finetune:
|
- Full weights finetune:
|
||||||
```bash
|
```bash
|
||||||
@@ -891,7 +891,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi
|
|||||||
Add below flag to train command above
|
Add below flag to train command above
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python3 -m axolotl.cli.merge_lora examples/your_config.yml --peft_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||||
```
|
```
|
||||||
|
|
||||||
If you run out of CUDA memory, you can try to merge in system RAM with
|
If you run out of CUDA memory, you can try to merge in system RAM with
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ dataset_prepared_path: last_prepared_run
|
|||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
|
|
||||||
adapter:
|
adapter:
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len: 2048
|
max_packed_sequence_len: 2048
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ sample_packing: true
|
|||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
|||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ sample_packing: true
|
|||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
|||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ sample_packing: true
|
|||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
|||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ dataset_prepared_path:
|
|||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
# enable QLoRA
|
# enable QLoRA
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter:
|
adapter:
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
lora_r: 64
|
lora_r: 64
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
adapter:
|
adapter:
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 512
|
sequence_len: 512
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
lora_r:
|
lora_r:
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing:
|
sample_packing:
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
base_model: meta-llama/Llama-2-7b-hf
|
|
||||||
base_model_config: meta-llama/Llama-2-7b-hf
|
|
||||||
model_type: LlamaForCausalLM
|
|
||||||
tokenizer_type: LlamaTokenizer
|
|
||||||
is_llama_derived_model: true
|
|
||||||
|
|
||||||
load_in_8bit: true
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.01
|
|
||||||
output_dir: ./ia3-out
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
adapter: ia3
|
|
||||||
peft_model_dir:
|
|
||||||
peft_target_modules:
|
|
||||||
- k_proj
|
|
||||||
- v_proj
|
|
||||||
- down_proj
|
|
||||||
peft_feedforward_modules:
|
|
||||||
- down_proj
|
|
||||||
peft_fan_in_fan_out: false
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_run_id:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 5
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16: false
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
eval_steps: 0.05
|
|
||||||
eval_table_size:
|
|
||||||
eval_table_max_new_tokens:
|
|
||||||
save_steps:
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.1
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
bos_token: "<s>"
|
|
||||||
eos_token: "</s>"
|
|
||||||
unk_token: "<unk>"
|
|
||||||
@@ -20,7 +20,7 @@ sample_packing: true
|
|||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
|||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
|||||||
output_dir: ./relora-out
|
output_dir: ./relora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ sequence_len: 4096
|
|||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
adapter:
|
adapter:
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
adapter:
|
adapter:
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 1024
|
sequence_len: 1024
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
lora_r:
|
lora_r:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 1024
|
sequence_len: 1024
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 1024
|
sequence_len: 1024
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ sample_packing: true
|
|||||||
pad_to_sequence_len:
|
pad_to_sequence_len:
|
||||||
|
|
||||||
adapter:
|
adapter:
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
lora_r:
|
lora_r:
|
||||||
lora_alpha:
|
lora_alpha:
|
||||||
lora_dropout:
|
lora_dropout:
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ sample_packing: false # not CURRENTLY compatible with LoRAs
|
|||||||
pad_to_sequence_len:
|
pad_to_sequence_len:
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 64
|
lora_r: 64
|
||||||
lora_alpha: 32
|
lora_alpha: 32
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
adapter:
|
adapter:
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len: 2048
|
max_packed_sequence_len: 2048
|
||||||
lora_r: 64
|
lora_r: 64
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 512
|
sequence_len: 512
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
lora_alpha: 32
|
lora_alpha: 32
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
adapter:
|
adapter:
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ dataset_prepared_path:
|
|||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
# enable QLoRA
|
# enable QLoRA
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 370 KiB |
2
setup.py
2
setup.py
@@ -46,7 +46,7 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn>=2.3.0",
|
"flash-attn>=2.2.1",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed",
|
"deepspeed",
|
||||||
|
|||||||
@@ -42,11 +42,21 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
packed: Optional[bool] = False,
|
packed: Optional[bool] = False,
|
||||||
cross_entropy: Optional[bool] = False,
|
cross_entropy: Optional[bool] = False,
|
||||||
rms_norm: Optional[bool] = False,
|
rms_norm: Optional[bool] = False,
|
||||||
|
noisy_embeddings_alpha: Optional[int] = False,
|
||||||
):
|
):
|
||||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||||
_prepare_decoder_attention_mask
|
_prepare_decoder_attention_mask
|
||||||
)
|
)
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
||||||
|
if noisy_embeddings_alpha:
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = partial(
|
||||||
|
llama_model_get_inputs_embeds, noisy_embeddings_alpha=noisy_embeddings_alpha
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = (
|
||||||
|
llama_model_get_inputs_embeds
|
||||||
|
)
|
||||||
|
|
||||||
if packed:
|
if packed:
|
||||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||||
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
||||||
@@ -116,8 +126,6 @@ def flashattn_forward(
|
|||||||
attention_mask: [bsz, q_len]
|
attention_mask: [bsz, q_len]
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
original_dtype = hidden_states.dtype
|
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
if not hasattr(self, "pretraining_tp"):
|
if not hasattr(self, "pretraining_tp"):
|
||||||
@@ -153,13 +161,6 @@ def flashattn_forward(
|
|||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
if query_states.dtype == torch.float32:
|
|
||||||
query_states = query_states.to(dtype=original_dtype)
|
|
||||||
if key_states.dtype == torch.float32:
|
|
||||||
key_states = key_states.to(dtype=original_dtype)
|
|
||||||
if value_states.dtype == torch.float32:
|
|
||||||
value_states = value_states.to(dtype=original_dtype)
|
|
||||||
|
|
||||||
query_states = query_states.view(
|
query_states = query_states.view(
|
||||||
bsz, q_len, self.num_heads, self.head_dim
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
@@ -318,10 +319,6 @@ def flashattn_forward(
|
|||||||
else:
|
else:
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
# handle conversion back for IA3
|
|
||||||
if attn_output.dtype == torch.float32:
|
|
||||||
attn_output = attn_output.to(dtype=original_dtype)
|
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
@@ -424,6 +421,28 @@ def generate_qkv(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def llama_model_get_inputs_embeds(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
noisy_embeddings_alpha: Optional[int] = None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if noisy_embeddings_alpha:
|
||||||
|
input_mask = attention_mask.to(inputs_embeds) # B x L
|
||||||
|
input_lengths = torch.sum(input_mask, 1) # B
|
||||||
|
|
||||||
|
noise_ = torch.zeros_like(inputs_embeds).uniform_(-1, 1)
|
||||||
|
delta = noise_ * input_mask.unsqueeze(2)
|
||||||
|
dims = input_lengths * inputs_embeds.size(-1)
|
||||||
|
mag = noisy_embeddings_alpha / torch.sqrt(dims)
|
||||||
|
delta = (delta * mag.view(-1, 1, 1)).detach()
|
||||||
|
inputs_embeds += delta
|
||||||
|
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
|
||||||
def llama_model_forward(
|
def llama_model_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
@@ -490,7 +509,8 @@ def llama_model_forward(
|
|||||||
cu_seqlens = cu_seqlens.squeeze()
|
cu_seqlens = cu_seqlens.squeeze()
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.get_inputs_embeds(input_ids, attention_mask)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(
|
attention_mask = torch.ones(
|
||||||
@@ -515,7 +535,6 @@ def llama_model_forward(
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
original_dtype = hidden_states.dtype
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
@@ -573,10 +592,6 @@ def llama_model_forward(
|
|||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
# handle conversion back for IA3
|
|
||||||
if hidden_states.dtype == torch.float32:
|
|
||||||
hidden_states = hidden_states.to(dtype=original_dtype)
|
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
|||||||
@@ -14,9 +14,6 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
|
|||||||
flash_attn_varlen_qkvpacked_func,
|
flash_attn_varlen_qkvpacked_func,
|
||||||
)
|
)
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
|
||||||
MistralAttention as OriginalMistralAttention,
|
|
||||||
)
|
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||||
)
|
)
|
||||||
@@ -45,44 +42,6 @@ def replace_mistral_attn_with_flash_attn(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def _make_sliding_window_causal_mask(
|
|
||||||
bsz: int,
|
|
||||||
tgt_len: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
past_key_values_length: int = 0,
|
|
||||||
sliding_window: int = 4096,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Make causal mask used for sliding window attention
|
|
||||||
"""
|
|
||||||
tensor = torch.full(
|
|
||||||
(tgt_len, tgt_len),
|
|
||||||
fill_value=1,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
mask = torch.tril(tensor, diagonal=0)
|
|
||||||
# make the mask banded to account for sliding window
|
|
||||||
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
|
|
||||||
mask = torch.triu(mask, diagonal=-sliding_window + 1)
|
|
||||||
mask = torch.log(mask).to(dtype)
|
|
||||||
|
|
||||||
if past_key_values_length > 0:
|
|
||||||
mask = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(
|
|
||||||
tgt_len, past_key_values_length, dtype=dtype, device=device
|
|
||||||
),
|
|
||||||
mask,
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
return mask[None, None, :, :].expand(
|
|
||||||
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
# requires the attention mask to be the same as the key_padding_mask
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
def _prepare_decoder_attention_mask(
|
def _prepare_decoder_attention_mask(
|
||||||
@@ -94,29 +53,11 @@ def _prepare_decoder_attention_mask(
|
|||||||
sliding_window,
|
sliding_window,
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
# [bsz, seq_len]
|
# [bsz, seq_len]
|
||||||
if attention_mask is None:
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
|
||||||
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
|
|
||||||
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
|
|
||||||
sliding_window_mask = _make_sliding_window_causal_mask(
|
|
||||||
bsz=input_shape[0],
|
|
||||||
tgt_len=input_shape[1],
|
|
||||||
dtype=inputs_embeds.dtype,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
past_key_values_length=past_key_values_length,
|
|
||||||
sliding_window=sliding_window,
|
|
||||||
)
|
|
||||||
attention_mask = attention_mask + sliding_window_mask
|
|
||||||
else:
|
|
||||||
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
|
|
||||||
|
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
def flashattn_forward(
|
def flashattn_forward(
|
||||||
self: OriginalMistralAttention,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
@@ -150,41 +91,10 @@ def flashattn_forward(
|
|||||||
query_states, key_states, cos, sin, position_ids
|
query_states, key_states, cos, sin, position_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
use_sliding_windows = (
|
|
||||||
hasattr(self.config, "sliding_window") is not None
|
|
||||||
and kv_seq_len > self.config.sliding_window
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_sliding_windows:
|
|
||||||
window_size = (self.config.sliding_window, self.config.sliding_window)
|
|
||||||
else:
|
|
||||||
window_size = (-1, -1)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
# reuse k, v, self_attention
|
||||||
if (
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
hasattr(self.config, "sliding_window")
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
and kv_seq_len > self.config.sliding_window
|
|
||||||
):
|
|
||||||
slicing_tokens = kv_seq_len - self.config.sliding_window
|
|
||||||
|
|
||||||
past_key = past_key_value[0]
|
|
||||||
past_value = past_key_value[1]
|
|
||||||
|
|
||||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
|
||||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
|
||||||
|
|
||||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
|
||||||
f" {past_key.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
past_key_value = (past_key, past_value) if use_cache else None
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
@@ -210,13 +120,7 @@ def flashattn_forward(
|
|||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
output = flash_attn_varlen_qkvpacked_func(
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv,
|
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen,
|
|
||||||
0.0,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=True,
|
|
||||||
window_size=window_size,
|
|
||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
elif query_states.shape == key_states.shape:
|
elif query_states.shape == key_states.shape:
|
||||||
@@ -242,7 +146,6 @@ def flashattn_forward(
|
|||||||
0.0,
|
0.0,
|
||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
window_size=window_size,
|
|
||||||
)
|
)
|
||||||
output = output_pad_fn(output_unpad)
|
output = output_pad_fn(output_unpad)
|
||||||
else:
|
else:
|
||||||
@@ -254,7 +157,6 @@ def flashattn_forward(
|
|||||||
query_states,
|
query_states,
|
||||||
torch.stack([key_states, value_states], 2),
|
torch.stack([key_states, value_states], 2),
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
window_size=window_size,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
( # pylint: disable=unbalanced-tuple-unpacking
|
( # pylint: disable=unbalanced-tuple-unpacking
|
||||||
@@ -289,7 +191,6 @@ def flashattn_forward(
|
|||||||
0.0,
|
0.0,
|
||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
window_size=window_size,
|
|
||||||
)
|
)
|
||||||
output = output_pad_fn(output_unpad)
|
output = output_pad_fn(output_unpad)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Module for Alpaca prompt strategy classes"""
|
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import (
|
from axolotl.prompt_tokenizers import (
|
||||||
AlpacaPromptTokenizingStrategy,
|
AlpacaPromptTokenizingStrategy,
|
||||||
@@ -9,13 +9,9 @@ from axolotl.prompt_tokenizers import (
|
|||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def load(tokenizer, cfg):
|
||||||
prompt_style = PromptStyle.CHAT.value
|
|
||||||
if ds_cfg and "conversation" in ds_cfg:
|
|
||||||
prompt_style = ds_cfg["conversation"]
|
|
||||||
|
|
||||||
return AlpacaPromptTokenizingStrategy(
|
return AlpacaPromptTokenizingStrategy(
|
||||||
AlpacaPrompter(prompt_style),
|
AlpacaPrompter(PromptStyle.CHAT.value),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
|
|||||||
@@ -121,18 +121,6 @@ def normalize_config(cfg):
|
|||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
if cfg.adapter is not None:
|
|
||||||
for key in list(cfg.keys()):
|
|
||||||
if key.startswith("lora_"):
|
|
||||||
new_key = key.replace("lora_", "peft_")
|
|
||||||
LOG.warning(
|
|
||||||
PendingDeprecationWarning(
|
|
||||||
f"{key} soon to be deprecated. please use {new_key}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cfg[new_key] = cfg[key]
|
|
||||||
del cfg[key]
|
|
||||||
|
|
||||||
|
|
||||||
def validate_config(cfg):
|
def validate_config(cfg):
|
||||||
if is_torch_bf16_gpu_available():
|
if is_torch_bf16_gpu_available():
|
||||||
@@ -202,10 +190,7 @@ def validate_config(cfg):
|
|||||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||||
|
|
||||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||||
LOG.warning("We recommend setting `load_in_8bit: true` for LoRA finetuning")
|
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||||
|
|
||||||
if not cfg.load_in_8bit and cfg.adapter == "ia3":
|
|
||||||
LOG.warning("We recommend setting `load_in_8bit: true` for IA3 finetuning")
|
|
||||||
|
|
||||||
if cfg.relora_steps:
|
if cfg.relora_steps:
|
||||||
if cfg.adapter not in ("lora", "qlora"):
|
if cfg.adapter not in ("lora", "qlora"):
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except (FileNotFoundError, ConnectionError):
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# prefer local dataset, even if hub exists
|
# prefer local dataset, even if hub exists
|
||||||
|
|||||||
@@ -136,7 +136,11 @@ def load_model(
|
|||||||
|
|
||||||
replace_stablelm_attn_with_flash_attn(cfg.base_model)
|
replace_stablelm_attn_with_flash_attn(cfg.base_model)
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
|
if (
|
||||||
|
cfg.is_llama_derived_model
|
||||||
|
and cfg.flash_attention
|
||||||
|
and (cfg.noisy_embeddings_alpha or cfg.sample_packing)
|
||||||
|
):
|
||||||
if cfg.device not in ["mps", "cpu"] and not inference:
|
if cfg.device not in ["mps", "cpu"] and not inference:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
replace_llama_attn_with_flash_attn,
|
replace_llama_attn_with_flash_attn,
|
||||||
@@ -147,6 +151,7 @@ def load_model(
|
|||||||
packed=cfg.sample_packing,
|
packed=cfg.sample_packing,
|
||||||
cross_entropy=cfg.flash_attn_cross_entropy,
|
cross_entropy=cfg.flash_attn_cross_entropy,
|
||||||
rms_norm=cfg.flash_attn_rms_norm,
|
rms_norm=cfg.flash_attn_rms_norm,
|
||||||
|
noisy_embeddings_alpha=cfg.noisy_embeddings_alpha,
|
||||||
)
|
)
|
||||||
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||||
@@ -180,16 +185,16 @@ def load_model(
|
|||||||
LOG.info("patching with flash attention")
|
LOG.info("patching with flash attention")
|
||||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
|
# if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
|
||||||
from axolotl.monkeypatch.llama_embeddings_hijack import (
|
# from axolotl.monkeypatch.llama_embeddings_hijack import (
|
||||||
replace_llama_embeddings_with_uniform_distribution,
|
# replace_llama_embeddings_with_uniform_distribution,
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
LOG.info("patching with noisy embeddings")
|
# LOG.info("patching with noisy embeddings")
|
||||||
replace_llama_embeddings_with_uniform_distribution(
|
# replace_llama_embeddings_with_uniform_distribution(
|
||||||
noise_alpha=cfg.noisy_embedding_alpha
|
# noise_alpha=cfg.noisy_embedding_alpha
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
|
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
|
||||||
from axolotl.monkeypatch.mistral_embeddings_hijack import (
|
from axolotl.monkeypatch.mistral_embeddings_hijack import (
|
||||||
replace_mistral_embeddings_with_uniform_distribution,
|
replace_mistral_embeddings_with_uniform_distribution,
|
||||||
@@ -406,21 +411,21 @@ def load_model(
|
|||||||
if hasattr(module, "weight"):
|
if hasattr(module, "weight"):
|
||||||
module.to(torch.float32)
|
module.to(torch.float32)
|
||||||
|
|
||||||
require_peft: bool = False
|
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||||
if cfg.adapter in ["lora", "qlora", "ia3"]:
|
if (cfg.adapter == "lora" and load_in_8bit) or (
|
||||||
require_peft = True
|
cfg.adapter == "qlora" and cfg.load_in_4bit
|
||||||
|
):
|
||||||
if require_peft:
|
|
||||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||||
if cfg.gradient_checkpointing:
|
if cfg.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
model = prepare_model_for_kbit_training(
|
model = prepare_model_for_kbit_training(
|
||||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||||
)
|
)
|
||||||
|
needs_fa2_dtype = True
|
||||||
|
|
||||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||||
if require_peft or cfg.fsdp or (cfg.flash_attention and cfg.is_llama_derived_model):
|
if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model):
|
||||||
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if "norm" in name:
|
if "norm" in name:
|
||||||
@@ -429,7 +434,7 @@ def load_model(
|
|||||||
if hasattr(module, "weight"):
|
if hasattr(module, "weight"):
|
||||||
module.to(cfg.torch_dtype)
|
module.to(cfg.torch_dtype)
|
||||||
|
|
||||||
model, peft_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|
||||||
if cfg.ddp and not load_in_8bit:
|
if cfg.ddp and not load_in_8bit:
|
||||||
model.to(f"cuda:{cfg.local_rank}")
|
model.to(f"cuda:{cfg.local_rank}")
|
||||||
@@ -460,7 +465,7 @@ def load_model(
|
|||||||
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
return model, peft_config
|
return model, lora_config
|
||||||
|
|
||||||
|
|
||||||
def load_adapter(model, cfg, adapter, inference=False):
|
def load_adapter(model, cfg, adapter, inference=False):
|
||||||
@@ -470,8 +475,6 @@ def load_adapter(model, cfg, adapter, inference=False):
|
|||||||
return model, None
|
return model, None
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
model.enable_input_require_grads()
|
model.enable_input_require_grads()
|
||||||
if adapter == "ia3":
|
|
||||||
return load_ia3(model, cfg, inference=inference)
|
|
||||||
if adapter in ["lora", "qlora"]:
|
if adapter in ["lora", "qlora"]:
|
||||||
return load_lora(model, cfg, inference=inference)
|
return load_lora(model, cfg, inference=inference)
|
||||||
if adapter == "llama-adapter":
|
if adapter == "llama-adapter":
|
||||||
@@ -490,11 +493,11 @@ def load_llama_adapter(model, cfg):
|
|||||||
task_type="CAUSAL_LM",
|
task_type="CAUSAL_LM",
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.peft_model_dir:
|
if cfg.lora_model_dir:
|
||||||
LOG.debug("Loading pretained PEFT - llama_adapter")
|
LOG.debug("Loading pretained PEFT - llama_adapter")
|
||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
model,
|
model,
|
||||||
cfg.peft_model_dir,
|
cfg.lora_model_dir,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -507,20 +510,16 @@ def load_llama_adapter(model, cfg):
|
|||||||
|
|
||||||
def find_all_linear_names(model):
|
def find_all_linear_names(model):
|
||||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
||||||
peft_module_names = set()
|
lora_module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if (
|
if isinstance(module, cls) or "Linear" in module.__class__.__name__:
|
||||||
isinstance(module, cls)
|
|
||||||
or "Linear" in module.__class__.__name__
|
|
||||||
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
|
|
||||||
):
|
|
||||||
names = name.split(".")
|
names = name.split(".")
|
||||||
peft_module_names.add(names[0] if len(names) == 1 else names[-1])
|
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||||
|
|
||||||
if "lm_head" in peft_module_names: # needed for 16-bit
|
if "lm_head" in lora_module_names: # needed for 16-bit
|
||||||
peft_module_names.remove("lm_head")
|
lora_module_names.remove("lm_head")
|
||||||
|
|
||||||
return list(peft_module_names)
|
return list(lora_module_names)
|
||||||
|
|
||||||
|
|
||||||
def load_lora(model, cfg, inference=False):
|
def load_lora(model, cfg, inference=False):
|
||||||
@@ -528,68 +527,34 @@ def load_lora(model, cfg, inference=False):
|
|||||||
|
|
||||||
from peft import LoraConfig, PeftModel, get_peft_model
|
from peft import LoraConfig, PeftModel, get_peft_model
|
||||||
|
|
||||||
peft_target_modules = list(cfg.peft_target_modules or [])
|
lora_target_modules = list(cfg.lora_target_modules or [])
|
||||||
|
|
||||||
if cfg.peft_target_linear:
|
if cfg.lora_target_linear:
|
||||||
linear_names = find_all_linear_names(model)
|
linear_names = find_all_linear_names(model)
|
||||||
LOG.info(f"found linear modules: {repr(linear_names)}")
|
LOG.info(f"found linear modules: {repr(linear_names)}")
|
||||||
peft_target_modules = list(set(peft_target_modules + linear_names))
|
lora_target_modules = list(set(lora_target_modules + linear_names))
|
||||||
|
|
||||||
peft_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
r=cfg.peft_r,
|
r=cfg.lora_r,
|
||||||
lora_alpha=cfg.peft_alpha,
|
lora_alpha=cfg.lora_alpha,
|
||||||
target_modules=peft_target_modules,
|
target_modules=lora_target_modules,
|
||||||
lora_dropout=cfg.peft_dropout,
|
lora_dropout=cfg.lora_dropout,
|
||||||
fan_in_fan_out=cfg.peft_fan_in_fan_out,
|
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||||
modules_to_save=cfg.peft_modules_to_save if cfg.peft_modules_to_save else None,
|
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||||
bias="none",
|
bias="none",
|
||||||
task_type="CAUSAL_LM",
|
task_type="CAUSAL_LM",
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.peft_model_dir:
|
if cfg.lora_model_dir:
|
||||||
LOG.debug("Loading pretained PEFT - LoRA")
|
LOG.debug("Loading pretained PEFT - LoRA")
|
||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
model,
|
model,
|
||||||
cfg.peft_model_dir,
|
cfg.lora_model_dir,
|
||||||
is_trainable=(not inference),
|
is_trainable=(not inference),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = get_peft_model(model, peft_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
model.print_trainable_parameters()
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
return model, peft_config
|
return model, lora_config
|
||||||
|
|
||||||
|
|
||||||
def load_ia3(model, cfg, inference=False):
|
|
||||||
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
|
||||||
|
|
||||||
from peft import IA3Config, PeftModel, get_peft_model
|
|
||||||
|
|
||||||
peft_config_kwargs = {}
|
|
||||||
if cfg.peft_init_ia3_weights is not None:
|
|
||||||
peft_config_kwargs["init_ia3_weights"] = cfg.peft_init_ia3_weights
|
|
||||||
if cfg.peft_fan_in_fan_out is not None:
|
|
||||||
peft_config_kwargs["fan_in_fan_out"] = cfg.peft_fan_in_fan_out
|
|
||||||
|
|
||||||
peft_config = IA3Config(
|
|
||||||
target_modules=cfg.peft_target_modules,
|
|
||||||
feedforward_modules=cfg.peft_feedforward_modules,
|
|
||||||
modules_to_save=cfg.peft_modules_to_save,
|
|
||||||
task_type="CAUSAL_LM",
|
|
||||||
**peft_config_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.peft_model_dir:
|
|
||||||
LOG.debug("Loading pretained PEFT - IA3")
|
|
||||||
model = PeftModel.from_pretrained(
|
|
||||||
model,
|
|
||||||
cfg.peft_model_dir,
|
|
||||||
is_trainable=(not inference),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model = get_peft_model(model, peft_config)
|
|
||||||
|
|
||||||
model.print_trainable_parameters()
|
|
||||||
|
|
||||||
return model, peft_config
|
|
||||||
|
|||||||
@@ -423,9 +423,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Phi doesn't want the attention_mask feature when training
|
# Phi doesn't want the attention_mask feature when training
|
||||||
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
|
if "CodeGenTokenizer" in tokenizer.__class__.__name__:
|
||||||
cfg.is_mistral_derived_model and cfg.flash_attention
|
|
||||||
):
|
|
||||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
||||||
|
|||||||
@@ -24,10 +24,6 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def test_lora(self):
|
def test_lora(self):
|
||||||
"""
|
|
||||||
support for legacy lora_ configs
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
output_dir = tempfile.mkdtemp()
|
output_dir = tempfile.mkdtemp()
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -70,101 +66,6 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
def test_lora_peft(self):
|
|
||||||
"""
|
|
||||||
support for legacy lora_ configs
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
output_dir = tempfile.mkdtemp()
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "JackFram/llama-68m",
|
|
||||||
"base_model_config": "JackFram/llama-68m",
|
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"load_in_8bit": True,
|
|
||||||
"adapter": "lora",
|
|
||||||
"peft_r": 32,
|
|
||||||
"peft_alpha": 64,
|
|
||||||
"peft_dropout": 0.05,
|
|
||||||
"peft_target_linear": True,
|
|
||||||
"val_set_size": 0.1,
|
|
||||||
"special_tokens": {
|
|
||||||
"unk_token": "<unk>",
|
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 2,
|
|
||||||
"micro_batch_size": 8,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": output_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
|
||||||
|
|
||||||
def test_ia3_peft(self):
|
|
||||||
"""
|
|
||||||
support for IA3 peft
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
output_dir = tempfile.mkdtemp()
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "JackFram/llama-68m",
|
|
||||||
"base_model_config": "JackFram/llama-68m",
|
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"load_in_8bit": True,
|
|
||||||
"adapter": "ia3",
|
|
||||||
"peft_r": 32,
|
|
||||||
"peft_alpha": 64,
|
|
||||||
"peft_dropout": 0.05,
|
|
||||||
"peft_target_modules": ["k_proj", "v_proj", "down_proj"],
|
|
||||||
"peft_feedforward_modules": ["down_proj"],
|
|
||||||
"val_set_size": 0.1,
|
|
||||||
"special_tokens": {
|
|
||||||
"unk_token": "<unk>",
|
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 2,
|
|
||||||
"micro_batch_size": 8,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": output_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
|
||||||
|
|
||||||
def test_lora_packing(self):
|
def test_lora_packing(self):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
output_dir = tempfile.mkdtemp()
|
output_dir = tempfile.mkdtemp()
|
||||||
|
|||||||
@@ -1,48 +0,0 @@
|
|||||||
"""Module for testing the validation module"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import unittest
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.utils.config import normalize_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
|
|
||||||
class NormalizationTest(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
Test the cfg normalization module
|
|
||||||
"""
|
|
||||||
|
|
||||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def inject_fixtures(self, caplog):
|
|
||||||
self._caplog = caplog
|
|
||||||
|
|
||||||
def test_lora_to_peft(self):
|
|
||||||
base_cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"base_model": "NousResearch/Llama-2-7b-hf",
|
|
||||||
"base_model_config": "NousResearch/Llama-2-7b-hf",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
cfg = base_cfg | DictDefault(
|
|
||||||
{
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 128,
|
|
||||||
"lora_alpha": 64,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
normalize_config(cfg)
|
|
||||||
assert any(
|
|
||||||
"soon to be deprecated. please use peft_" in record.message
|
|
||||||
for record in self._caplog.records
|
|
||||||
)
|
|
||||||
|
|
||||||
assert cfg.peft_r == 128
|
|
||||||
assert cfg.peft_alpha == 64
|
|
||||||
Reference in New Issue
Block a user