Compare commits

..

18 Commits

Author SHA1 Message Date
Wing Lian
b885169229 handle load_model splat 2023-10-23 21:55:05 -04:00
Wing Lian
ab9d12ce34 handle dataset loading for multimodal 2023-10-23 21:44:07 -04:00
Wing Lian
866774737b WIP llaval support 2023-10-23 20:29:49 -04:00
Wing Lian
6c81c61bc4 refactor setup trainer so we can add more hooks (#773)
* refactor setup trainer so we can add more hooks

* Remove stray comma
2023-10-23 17:38:41 -04:00
Wing Lian
9b43e7ea15 disable eval table w sample packing in examples (#778) 2023-10-23 09:18:44 -04:00
Wing Lian
2d8def68dc simplify by removing duplicate base_model_config (#772) 2023-10-23 01:42:38 -04:00
NanoCode012
44c9d0151a Fix: Warn when fullfinetune without adapter (#770) 2023-10-22 15:41:43 -04:00
Wing Lian
ca84cca2c0 convert exponential notation lr to floats (#771) 2023-10-22 15:37:03 -04:00
Casper
32eeeb5b64 Hotfix for not saving correctly (#762) 2023-10-22 13:22:32 -04:00
NanoCode012
afedc470bd Fix: Cannot tokenize with bf16 and on cpu (#766) 2023-10-23 01:32:26 +09:00
NanoCode012
9923b72649 Fix: eval table conflict with eval_sample_packing (#769) 2023-10-23 01:18:12 +09:00
Wing Lian
21cf09b608 remove lora fused packing test (#758) 2023-10-21 22:59:35 -04:00
Casper
15d3a654bf Implement fused modules (#747)
* MLP: Memory saving

* Remove RMSNorm restrictions

* Map packed weights to original

* FusedAttention module

* Simplify code

* Move fused modules

* Fix critical typo

* Split inplace

* Add FFT config

* Add validation of fused arguments

* Add fused arguments to config

* Update docs

* Fix validation logic

* Add fused modules to flash attn

* Only fuse during training

* Remove timing

* Formatting

* Formatting

* Formatting

* chore: lint

* chore: lint

* add e2e tests for fused llama

* no lora for tests

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-10-21 16:08:25 -04:00
Wing Lian
a21935f07a add to docs (#703) 2023-10-19 21:32:30 -04:00
NanoCode012
8966a6f566 chore: bump transformers to v4.34.1 to fix tokenizer issue (#745) 2023-10-19 20:18:22 -04:00
Motoki Wu
e4d1585c4e Fix DeepSpeed Zero 3 Saving (#709)
* Update train.py

* add zero3 check

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-10-19 19:18:24 -04:00
Wing Lian
70157ccb8f add a latest tag for regular axolotl image, cleanup extraneous print statement (#746) 2023-10-19 12:28:29 -04:00
seungduk.kim.2304
3a99495b05 improve: Enhance code readability of prompt_tokenizers.py (#707) 2023-10-19 08:12:17 -04:00
63 changed files with 1819 additions and 1010 deletions

View File

@@ -23,6 +23,7 @@ jobs:
python_version: "3.10"
pytorch: 2.0.1
axolotl_extras:
is_latest: true
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
@@ -54,7 +55,9 @@ jobs:
PYTORCH_VERSION=${{ matrix.pytorch }}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-runpod:
needs: build-axolotl

View File

@@ -12,4 +12,3 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-boolean-expressions,

View File

@@ -96,7 +96,7 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--peft_model_dir="./lora-out"
--lora_model_dir="./lora-out"
```
## Installation
@@ -384,10 +384,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- lora
```yaml
adapter: lora # qlora or leave blank for full finetune
peft_r: 8
peft_alpha: 16
peft_dropout: 0.05
peft_target_modules:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
```
@@ -531,15 +531,15 @@ total_num_tokens:
adapter: lora
# If you already have a lora model trained that you want to load, put that here.
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
peft_model_dir:
lora_model_dir:
# LoRA hyperparameters
# For more details about the following options, see:
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
peft_r: 8
peft_alpha: 16
peft_dropout: 0.05
peft_target_modules:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
# - k_proj
@@ -547,13 +547,13 @@ peft_target_modules:
# - gate_proj
# - down_proj
# - up_proj
peft_target_linear: # if true, will target all linear layers
lora_target_linear: # If true, will target all linear layers
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
peft_modules_to_save:
lora_modules_to_save:
# - embed_tokens
# - lm_head
@@ -561,8 +561,7 @@ peft_modules_to_save:
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
lora_out_dir:
peft_fan_in_fan_out: false
peft_feedforward_modules: # ffn modules for IA3, for llama down projection
lora_fan_in_fan_out: false
# ReLoRA configuration
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
@@ -685,6 +684,8 @@ xformers_attention:
flash_attention:
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# Whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
@@ -818,7 +819,7 @@ accelerate launch -m axolotl.cli.train your_config.yml
You can optionally pre-tokenize dataset with the following before finetuning:
```bash
CUDA_VISIBLE_DEVICES="" accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
```
##### Config
@@ -870,7 +871,7 @@ Pass the appropriate flag to the train command:
- Pretrained LORA:
```bash
python -m axolotl.cli.inference examples/your_config.yml --peft_model_dir="./lora-output-dir"
python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir"
```
- Full weights finetune:
```bash
@@ -891,7 +892,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi
Add below flag to train command above
```bash
python3 -m axolotl.cli.merge_lora examples/your_config.yml --peft_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
```
If you run out of CUDA memory, you can try to merge in system RAM with
@@ -902,6 +903,8 @@ CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
## Common Errors 🧰
See also the [FAQ's](./docs/faq.md).
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
Please reduce any below

14
docs/faq.md Normal file
View File

@@ -0,0 +1,14 @@
# Axolotl FAQ's
> The trainer stopped and hasn't progressed in several minutes.
Usually an issue with the GPU's communicating with each other. See the [NCCL doc](../docs/nccl.md)
> Exitcode -9
This usually happens when you run out of system RAM.
> Exitcode -7 while using deepspeed
Try upgrading deepspeed w: `pip install -U deepspeed`

View File

@@ -1,5 +1,4 @@
base_model: cerebras/btlm-3b-8k-base
base_model_config: cerebras/btlm-3b-8k-base
model_type: AutoModelForCausalLM
tokenizer_type: GPT2Tokenizer
trust_remote_code: true
@@ -18,7 +17,7 @@ dataset_prepared_path: last_prepared_run
val_set_size: 0.01
adapter:
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
sample_packing: false

View File

@@ -1,5 +1,4 @@
base_model: cerebras/Cerebras-GPT-1.3B
base_model_config: cerebras/Cerebras-GPT-1.3B
load_in_8bit: false
load_in_4bit: true
strict: false
@@ -10,7 +9,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.01
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 16

View File

@@ -1,5 +1,4 @@
base_model: codellama/CodeLlama-13b-hf
base_model_config: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
@@ -20,7 +19,7 @@ sample_packing: true
pad_to_sequence_len: true
adapter: lora
peft_model_dir:
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -1,5 +1,4 @@
base_model: codellama/CodeLlama-13b-hf
base_model_config: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
@@ -16,7 +15,7 @@ val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 4096
sample_packing: true

View File

@@ -1,5 +1,4 @@
base_model: codellama/CodeLlama-34b-hf
base_model_config: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
@@ -20,7 +19,7 @@ sample_packing: true
pad_to_sequence_len: true
adapter: lora
peft_model_dir:
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -1,5 +1,4 @@
base_model: codellama/CodeLlama-34b-hf
base_model_config: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
@@ -16,7 +15,7 @@ val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 4096
sample_packing: true

View File

@@ -1,5 +1,4 @@
base_model: codellama/CodeLlama-7b-hf
base_model_config: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
@@ -20,7 +19,7 @@ sample_packing: true
pad_to_sequence_len: true
adapter: lora
peft_model_dir:
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -1,5 +1,4 @@
base_model: codellama/CodeLlama-7b-hf
base_model_config: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
@@ -16,7 +15,7 @@ val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 4096
sample_packing: true

View File

@@ -1,5 +1,4 @@
base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
@@ -15,7 +14,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.01
adapter: lora
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 16

View File

@@ -1,7 +1,6 @@
# 1b: tiiuae/falcon-rw-1b
# 40b: tiiuae/falcon-40b
base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
trust_remote_code: true
model_type: AutoModelForCausalLM
@@ -22,7 +21,7 @@ dataset_prepared_path:
val_set_size: 0.01
# enable QLoRA
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:

View File

@@ -1,5 +1,4 @@
base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
@@ -15,7 +14,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.01
adapter:
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 64

View File

@@ -1,5 +1,4 @@
base_model: EleutherAI/gpt-j-6b
base_model_config: EleutherAI/gpt-j-6b
load_in_8bit: false
load_in_4bit: true
strict: false
@@ -10,7 +9,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.01
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 8

View File

@@ -1,5 +1,4 @@
base_model: huggyllama/llama-7b
base_model_config: huggyllama/llama-7b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
@@ -9,7 +8,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.02
adapter:
peft_model_dir:
lora_model_dir:
sequence_len: 512
max_packed_sequence_len:
lora_r:

View File

@@ -9,12 +9,16 @@ gradient_accumulation_steps: 2
micro_batch_size: 1
```shell
accelerate launch scripts/finetune.py examples/llama-2/qlora.yml
accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml
```
or
```shell
accelerate launch scripts/finetune.py examples/llama-2/lora.yml
accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml
```
To launch a full finetuning with 16-bit precision:
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml
```

View File

@@ -1,10 +1,9 @@
base_model: meta-llama/Llama-2-7b-hf
base_model_config: meta-llama/Llama-2-7b-hf
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_8bit: false
load_in_4bit: false
strict: false
@@ -13,21 +12,19 @@ datasets:
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./ia3-out
output_dir: ./out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: ia3
peft_model_dir:
peft_target_modules:
- k_proj
- v_proj
- down_proj
peft_feedforward_modules:
- down_proj
peft_fan_in_fan_out: false
adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
@@ -36,8 +33,8 @@ wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 5
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
@@ -55,14 +52,17 @@ local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_steps: 10
warmup_steps: 100
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens:
save_steps:
debug:
deepspeed:
deepspeed: #deepspeed/zero2.json # multi-gpu only
weight_decay: 0.1
fsdp:
fsdp_config:

View File

@@ -1,5 +1,4 @@
base_model: TheBloke/Llama-2-7B-GPTQ
base_model_config: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false
gptq: true
gptq_disable_exllama: true
@@ -18,7 +17,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.01
adapter: lora
peft_model_dir:
lora_model_dir:
sequence_len: 4096
sample_packing:
lora_r: 8

View File

@@ -1,5 +1,4 @@
base_model: NousResearch/Llama-2-7b-hf
base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
@@ -20,7 +19,7 @@ sample_packing: true
pad_to_sequence_len: true
adapter: lora
peft_model_dir:
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -1,5 +1,4 @@
base_model: NousResearch/Llama-2-7b-hf
base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
@@ -16,7 +15,7 @@ val_set_size: 0.01
output_dir: ./qlora-out
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 4096
sample_packing: true

View File

@@ -1,5 +1,4 @@
base_model: NousResearch/Llama-2-7b-hf
base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
@@ -16,7 +15,7 @@ val_set_size: 0.01
output_dir: ./relora-out
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 4096
sample_packing: true

View File

@@ -1,5 +1,4 @@
base_model: PY007/TinyLlama-1.1B-step-50K-105b
base_model_config: PY007/TinyLlama-1.1B-step-50K-105b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
@@ -20,7 +19,7 @@ sequence_len: 4096
sample_packing: true
adapter: lora
peft_model_dir:
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -1,5 +1,4 @@
base_model: mistralai/Mistral-7B-v0.1
base_model_config: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
@@ -48,7 +47,7 @@ flash_attention: true
warmup_steps: 10
eval_steps: 20
eval_table_size: 5
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
debug:

View File

@@ -1,5 +1,4 @@
base_model: mistralai/Mistral-7B-v0.1
base_model_config: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
@@ -65,7 +64,7 @@ flash_attention: true
warmup_steps: 10
eval_steps: 20
eval_table_size: 5
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
debug:

View File

@@ -1,5 +1,4 @@
base_model: mosaicml/mpt-7b
base_model_config: mosaicml/mpt-7b
tokenizer_type: AutoTokenizer
trust_remote_code: true # required for mpt as their model class is not merged into transformers yet
load_in_8bit: false
@@ -9,7 +8,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.02
adapter:
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 8

View File

@@ -0,0 +1,63 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
multimodal: true
vision_tower: openai/clip-vit-large-patch14
tune_mm_mlp_adapter: true
mm_vision_select_layer: -2
mm_projector_type: mlp2x_gelu
mm_image_folder: ./llava/
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: liuhaotian/LLaVA-CC3M-Pretrain-595K
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<unk>"

View File

@@ -1,5 +1,4 @@
base_model: openlm-research/open_llama_3b_v2
base_model_config: openlm-research/open_llama_3b_v2
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
@@ -12,7 +11,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.02
adapter:
peft_model_dir:
lora_model_dir:
sequence_len: 1024
sample_packing: true
lora_r:

View File

@@ -1,5 +1,4 @@
base_model: openlm-research/open_llama_3b_v2
base_model_config: openlm-research/open_llama_3b_v2
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
@@ -12,7 +11,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.02
adapter: lora
peft_model_dir:
lora_model_dir:
sequence_len: 1024
sample_packing: true
lora_r: 8

View File

@@ -1,5 +1,4 @@
base_model: openlm-research/open_llama_3b_v2
base_model_config: openlm-research/open_llama_3b_v2
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
@@ -12,7 +11,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.01
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 1024
sample_packing: true
lora_r: 8

View File

@@ -1,5 +1,4 @@
base_model: microsoft/phi-1_5
base_model_config: microsoft/phi-1_5
model_type: MixFormerSequentialForCausalLM
tokenizer_type: AutoTokenizer
is_llama_derived_model: false
@@ -22,7 +21,7 @@ sample_packing: true
pad_to_sequence_len:
adapter:
peft_model_dir:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:

View File

@@ -1,5 +1,4 @@
base_model: microsoft/phi-1_5
base_model_config: microsoft/phi-1_5
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_llama_derived_model: false
@@ -22,7 +21,7 @@ sample_packing: false # not CURRENTLY compatible with LoRAs
pad_to_sequence_len:
adapter: qlora
peft_model_dir:
lora_model_dir:
lora_r: 64
lora_alpha: 32
lora_dropout: 0.05

View File

@@ -1,5 +1,4 @@
base_model: EleutherAI/pythia-12b-deduped
base_model_config: EleutherAI/pythia-12b-deduped
base_model_ignore_patterns: pytorch* # prefer safetensors
model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
@@ -13,7 +12,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.05
adapter:
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 64

View File

@@ -1,5 +1,4 @@
base_model: EleutherAI/pythia-1.4b-deduped
base_model_config: EleutherAI/pythia-1.4b-deduped
load_in_8bit: true
datasets:
- path: teknium/GPT4-LLM-Cleaned
@@ -7,7 +6,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.05
adapter: lora
peft_model_dir:
lora_model_dir:
sequence_len: 512
lora_r: 16
lora_alpha: 32

View File

@@ -1,5 +1,4 @@
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code:
@@ -10,7 +9,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.02
adapter:
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 8

View File

@@ -1,5 +1,4 @@
base_model: replit/replit-code-v1-3b
base_model_config: replit/replit-code-v1-3b
trust_remote_code: true
load_in_8bit: false
datasets:
@@ -8,7 +7,7 @@ datasets:
dataset_prepared_path:
val_set_size: 0.05
adapter: lora
peft_model_dir:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 8

View File

@@ -1,7 +1,6 @@
# An example finetuning Saleforce's XGen-7b model with 8k context using qlora
# on Tim Dettmer's Guanaco dataset.
base_model: Salesforce/xgen-7b-8k-base
base_model_config: Salesforce/xgen-7b-8k-base
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
@@ -20,7 +19,7 @@ dataset_prepared_path:
val_set_size: 0.01
# enable QLoRA
adapter: qlora
peft_model_dir:
lora_model_dir:
sequence_len: 8192
max_packed_sequence_len:

View File

@@ -4,7 +4,7 @@ torch==2.0.1
auto-gptq
packaging
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git@bd6205919aad4d3a2300a39a98a642f1cc3a5348
transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697
bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
deepspeed

View File

@@ -2,6 +2,7 @@
import importlib
import logging
import math
import os
import random
import sys
@@ -215,6 +216,45 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
return cfg
def load_mm_dataset(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
model,
):
# pylint: disable=duplicate-code
from llava.train.train import DataArguments, LazySupervisedDataset
vision_tower = model.get_vision_tower()
data_args = DataArguments(
data_path=cfg.datasets[0]["path"],
lazy_preprocess=cfg.mm_lazy_preprocess
if cfg.mm_lazy_preprocess is not None
else True,
is_multimodal=True,
image_folder=cfg.mm_image_folder or None,
image_aspect_ratio=cfg.mm_image_aspect_ratio or "square",
image_grid_pinpoints=cfg.mm_image_grid_pinpoints or None,
)
data_args.image_processor = vision_tower.image_processor
tokenizer = load_tokenizer(cfg)
train_dataset = LazySupervisedDataset(
tokenizer=tokenizer,
data_path=data_args["data_path"],
data_args=data_args,
)
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=None,
total_num_steps=total_num_steps,
)
def load_datasets(
*,
cfg: DictDefault,

View File

@@ -0,0 +1,56 @@
"""
CLI to run training on a model
"""
import logging
from pathlib import Path
import fire
import transformers
from colorama import Fore
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_mm_dataset,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.train import train
from axolotl.utils.models import load_model, load_tokenizer
LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "--prepare_ds_only called without dataset_prepared_path set."
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
tokenizer = load_tokenizer(parsed_cfg)
model, _ = load_model(parsed_cfg, tokenizer)
dataset_meta = load_mm_dataset(
cfg=parsed_cfg, cli_args=parsed_cli_args, model=model
)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

View File

@@ -0,0 +1,746 @@
"""
Builder for the training args and trainer
"""
import abc
import importlib
import logging
import math
import os
import sys
from abc import abstractmethod
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional, Union
import torch
import transformers
from datasets import Dataset
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import SequentialDistributedSampler
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
try:
import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError:
pass
try:
from llava.train.llava_trainer import get_mm_adapter_state_maybe_zero_3
except ImportError:
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
raise ImportError("missing LLaVA package")
LOG = logging.getLogger("axolotl.core.trainer_builder")
@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
Extend the base TrainingArguments for axolotl helpers
"""
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
sample_packing_seq_len_multiplier: int = field(
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
class AxolotlTrainer(Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments
def __init__(self, *args, bench_data_collator=None, **kwargs):
self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs)
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
):
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing:
return DistributedSampler(
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if (
self.args.world_size > 1
and self.args.sample_packing
and self.args.eval_sample_packing is not False
):
return SequentialDistributedSampler(
eval_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
batch_size=self.args.per_device_eval_batch_size,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
return self.accelerator.prepare(
MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=train_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
)
return super().get_train_dataloader()
def get_eval_dataloader(
self, eval_dataset: Optional[Dataset] = None
) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
return self.accelerator.prepare(
MultipackDistributedDataloader(
eval_dataset,
batch_size=self.args.eval_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> Union[DataLoader, MultipackDistributedDataloader]:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)
def _save_checkpoint(self, model, trial, metrics=None):
if getattr(self.args, "tune_mm_mlp_adapter", False):
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
# Only save Adapter
keys_to_match = ["mm_projector", "vision_resampler"]
if getattr(self.args, "use_im_start_end", False):
keys_to_match.extend(["embed_tokens", "embed_in"])
weight_to_save = get_mm_adapter_state_maybe_zero_3(
self.model.named_parameters(), keys_to_match
)
if self.args.local_rank in (0, -1):
self.model.config.save_pretrained(output_dir)
torch.save(weight_to_save, os.path.join(output_dir, "mm_projector.bin"))
else:
super()._save_checkpoint(model, trial, metrics)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if getattr(self.args, "tune_mm_mlp_adapter", False):
pass
else:
super()._save(output_dir, state_dict)
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
pct_start=pct_start,
div_factor=6,
)
return self.lr_scheduler
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
"""
_train_dataset = None
_eval_dataset = None
def __init__(self, cfg, model, tokenizer):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
@property
def train_dataset(self):
return self._train_dataset
@train_dataset.setter
def train_dataset(self, dataset):
self._train_dataset = dataset
@property
def eval_dataset(self):
return self._eval_dataset
@eval_dataset.setter
def eval_dataset(self, dataset):
self._eval_dataset = dataset
@abstractmethod
def build(self, total_num_steps):
pass
@abstractmethod
def get_callbacks(self):
pass
@abstractmethod
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for Causal models
"""
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
return training_arguments_kwargs
def hook_post_create_training_args(self, training_arguments):
# TODO
return training_arguments
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
# TODO
return trainer_kwargs, trainer_cls
def hook_post_create_trainer(self, trainer):
# TODO
return trainer
def get_callbacks(self):
callbacks = []
callbacks.append(GPUStatsCallback(self.cfg))
callbacks.append(EvalFirstStepCallback)
if self.cfg.relora_steps:
callbacks.append(ReLoRACallback(self.cfg))
if (
hasattr(self.model, "use_bettertransformer")
and self.model.use_bettertransformer is True
):
callbacks.append(SaveBetterTransformerModelCallback)
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
self.cfg.early_stopping_patience,
)
callbacks.append(early_stop_cb)
return callbacks
def _get_trainer_cls(self):
if self.cfg.lr_scheduler == "one_cycle" and (
self.cfg.fsdp or self.cfg.adapter == "qlora"
):
return OneCycleLRSchedulerTrainer
if self.cfg.relora_steps:
return ReLoRATrainer
return AxolotlTrainer
def build(self, total_num_steps):
warmup_steps = (
self.cfg.warmup_steps
if self.cfg.warmup_steps is not None
else min(int(0.03 * total_num_steps), 100)
)
logging_steps = (
self.cfg.logging_steps
if self.cfg.logging_steps is not None
else max(min(int(0.005 * total_num_steps), 10), 1)
)
training_arguments_kwargs = {}
if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
else:
training_arguments_kwargs["bf16"] = self.cfg.bf16
training_arguments_kwargs["fp16"] = (
self.cfg.fp16 and not self.cfg.bf16
) or False
training_arguments_kwargs["tf32"] = self.cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
if self.cfg.seed:
training_arguments_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing:
training_arguments_kwargs[
"gradient_checkpointing"
] = self.cfg.gradient_checkpointing
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs[
"lr_quadratic_warmup"
] = self.cfg.lr_quadratic_warmup
if self.cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
if self.cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
if self.cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
if self.cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
if self.cfg.hub_model_id:
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if self.cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = self.cfg.sample_packing_eff_est
if self.cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.evaluation_strategy:
training_arguments_kwargs[
"evaluation_strategy"
] = self.cfg.evaluation_strategy
elif self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
if self.cfg.save_steps:
training_arguments_kwargs["save_strategy"] = "steps"
training_arguments_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_arguments_kwargs["save_strategy"] = "epoch"
if self.cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.metric_for_best_model:
training_arguments_kwargs[
"metric_for_best_model"
] = self.cfg.metric_for_best_model
if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
if self.cfg.torch_compile:
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
elif torch._dynamo: # pylint: disable=protected-access
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = self.cfg.torch_compile_backend
# DDP Config
if self.cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs[
"ddp_broadcast_buffers"
] = self.cfg.ddp_broadcast_buffers
# these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1
)
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs[
"per_device_train_batch_size"
] = self.cfg.micro_batch_size
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs[
"eval_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
training_arguments_kwargs["save_total_limit"] = (
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
)
training_arguments_kwargs["load_best_model_at_end"] = (
(
self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience
)
and self.cfg.val_set_size > 0
and self.cfg.save_steps
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
) or False
training_arguments_kwargs["ddp_find_unused_parameters"] = (
False if self.cfg.ddp else None
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_run_id if self.cfg.use_wandb else None
)
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler
if self.cfg.lr_scheduler
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine"
)
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
training_arguments_kwargs["sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs["eval_sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs[
"sample_packing_seq_len_multiplier"
] = self.cfg.micro_batch_size
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
)
training_args = self.hook_post_create_training_args(training_args)
trainer_kwargs = {}
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import (
add_mem_tokens,
get_mem_id,
set_model_mem_id,
)
set_model_mem_id(self.model, self.tokenizer)
LOG.info("Adding landmark attention tokens to dataset")
for dataset in [self.train_dataset, self.eval_dataset]:
dataset = dataset.map(
partial(
add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer)
),
batched=False,
num_proc=32,
)
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
trainer_collator_kwargs = self.build_data_collator()
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
callbacks=self.get_callbacks(),
**trainer_collator_kwargs,
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
return trainer
def build_data_collator(self):
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
if self.cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
self.cfg.sequence_len / 64
)
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
collator_kwargs = {}
if self.cfg.multimodal:
from llava.train.train import DataCollatorForSupervisedDataset
collator_kwargs["data_collator"] = DataCollatorForSupervisedDataset(
tokenizer=self.tokenizer,
)
else:
collator_kwargs["data_collator"] = DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
)
if self.cfg.do_bench_eval:
collator_kwargs[
"bench_data_collator"
] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
)
return collator_kwargs

View File

View File

@@ -0,0 +1,167 @@
"""
LLaVA Mistral classes
"""
from typing import List, Optional, Tuple, Union
import torch
from llava.model.llava_arch import LlavaMetaForCausalLM, LlavaMetaModel
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import (
AutoConfig,
AutoModelForCausalLM,
MistralConfig,
MistralForCausalLM,
MistralModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
class LlavaMistralConfig(MistralConfig):
"""
HF Transformers Config for Mistral w LLaVA
"""
model_type = "llava_mistral"
class LlavaMistralModel(LlavaMetaModel, MistralModel):
"""
HF Transformers Model for Mistral w LLaVA
"""
config_class = LlavaMistralConfig
def __init__(
self, config: LlavaMistralConfig
): # pylint: disable=useless-parent-delegation
super().__init__(config)
class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
"""
HF Transformers Causal Model for Mistral w LLaVA
"""
config_class = LlavaMistralConfig
def __init__(self, config: LlavaMistralConfig):
super().__init__(config)
self.model = LlavaMistralModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
(
input_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
) = self.prepare_inputs_labels_for_multimodal(
input_ids, attention_mask, past_key_values, labels, images
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model/pipeline parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"images": kwargs.get("images", None),
}
)
return model_inputs
AutoConfig.register("llava_mistral", LlavaMistralConfig)
AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)

View File

@@ -13,12 +13,18 @@ import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
)
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from transformers.models.llama.modeling_llama import (
LlamaMLP,
apply_rotary_pos_emb,
repeat_kv,
)
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
@@ -38,6 +44,28 @@ except ImportError:
LOG = logging.getLogger("axolotl")
def replace_llama_mlp_with_swiglu(model):
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
mlp = FusedMLP(
module.config, module.gate_proj, module.up_proj, module.down_proj
)
set_module_name(model, name, mlp)
def replace_llama_qkv_with_fused(model):
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
qkv = FusedAttention(
module.config,
module.q_proj,
module.k_proj,
module.v_proj,
module.o_proj,
)
set_module_name(model, name, qkv)
def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
@@ -86,6 +114,92 @@ def replace_llama_attn_with_flash_attn(
)
class FusedAttention(LlamaAttention):
"""
Fused QKV Attention layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
q: torch.nn.Linear, # pylint: disable=invalid-name
k: torch.nn.Linear, # pylint: disable=invalid-name
v: torch.nn.Linear, # pylint: disable=invalid-name
o: torch.nn.Linear, # pylint: disable=invalid-name
):
super().__init__(config)
self.config = config
self.init_device = next(iter(q.state_dict().values())).device
# define equivalent fused qkv projection
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
self.qkv_proj = torch.nn.Linear(
q.in_features, sum(self.out_features), device=self.init_device, bias=False
)
self.o_proj = o
# overwrite initialized weights with pretrained weights
self.qkv_proj.weight.data = torch.cat(
(q.weight.data, k.weight.data, v.weight.data), dim=0
)
def _post_training(self, model, name):
q_proj, k_proj, v_proj = torch.split(
self.qkv_proj.weight.data, self.out_features, dim=0
)
new_attn = LlamaAttention(self.config)
new_attn.q_proj.weight.data = q_proj
new_attn.k_proj.weight.data = k_proj
new_attn.v_proj.weight.data = v_proj
new_attn.o_proj.weight.data = self.o_proj.weight.data
set_module_name(model, name, new_attn)
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
@@ -116,8 +230,6 @@ def flashattn_forward(
attention_mask: [bsz, q_len]
"""
# pylint: disable=duplicate-code
original_dtype = hidden_states.dtype
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
@@ -149,16 +261,14 @@ def flashattn_forward(
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if query_states.dtype == torch.float32:
query_states = query_states.to(dtype=original_dtype)
if key_states.dtype == torch.float32:
key_states = key_states.to(dtype=original_dtype)
if value_states.dtype == torch.float32:
value_states = value_states.to(dtype=original_dtype)
if isinstance(self, FusedAttention):
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
self.out_features, dim=-1
)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
@@ -318,10 +428,6 @@ def flashattn_forward(
else:
attn_output = self.o_proj(attn_output)
# handle conversion back for IA3
if attn_output.dtype == torch.float32:
attn_output = attn_output.to(dtype=original_dtype)
return attn_output, None, past_key_value
@@ -515,7 +621,6 @@ def llama_model_forward(
)
hidden_states = inputs_embeds
original_dtype = hidden_states.dtype
if self.gradient_checkpointing and self.training:
if use_cache:
@@ -573,10 +678,6 @@ def llama_model_forward(
hidden_states = layer_outputs[0]
# handle conversion back for IA3
if hidden_states.dtype == torch.float32:
hidden_states = hidden_states.to(dtype=original_dtype)
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

View File

@@ -101,3 +101,16 @@ def get_cu_seqlens_from_pos_ids(position_ids):
max_seq_lens.append(max_seq_len)
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def set_module_name(model, name, value):
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model
child_name = name
setattr(parent, child_name, value)

View File

@@ -45,6 +45,8 @@ class PromptTokenizingStrategy(abc.ABC):
self.prompter = prompter
self.tokenizer: PreTrainedTokenizer = tokenizer
self.train_on_inputs = train_on_inputs
# sequence_len and max_length can be different for CompletionPromptTokenizingStrategy.
# TODO: Document how they are different.
self.sequence_len = sequence_len
self.max_length = sequence_len
@@ -59,34 +61,31 @@ class PromptTokenizingStrategy(abc.ABC):
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:
result: BatchEncoding
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
if not prompt:
LOG.warning("Empty text requested for tokenization.")
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
else:
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.max_length,
padding=False,
return_tensors=None,
)
return empty
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.max_length,
padding=False,
return_tensors=None,
)
if len(result["input_ids"]) == 0:
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
return empty
if (
len(result["input_ids"]) > 0
and result["input_ids"][-1] != self.tokenizer.eos_token_id
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.max_length
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if (
len(result["input_ids"]) > 0
and result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
@@ -122,7 +121,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing
tokenized_prompt["labels"] = [-100] * user_prompt_len
tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len
tokenized_res_prompt = self._tokenize(
response, strip_bos_token=True, add_eos_token=True
)
@@ -270,7 +269,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
user_prompt_len = len(tokenized_user_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing
tokenized_full_prompt["labels"] = [
-100
IGNORE_INDEX
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
return tokenized_full_prompt
@@ -334,6 +333,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
return prompt["conversations"]
def tokenize_prompt(self, prompt):
# Initial values. We will append to these as we go through the conversation.
result, current_len = tokenize_prompt_default()
conversation: Conversation = (
self.prompter._conversation.copy() # pylint: disable=protected-access
@@ -355,62 +355,67 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if isinstance(part, tuple):
if conversation.roles[0] in part[0]:
role = (
part[0].replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
else part[0]
)
turn = role + part[1]
# this is still the user query, we should
if not part[1].strip():
LOG.warning(f"user turn has empty text: {prompt}")
res = self._tokenize(
turn,
add_eos_token=False,
strip_bos_token=True,
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif conversation.roles[1] in part[0]:
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
role = (
part[0].replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
else part[0]
)
turn = role + part[1]
# this should be the assistant response, should end with an eos token
if not part[1].strip():
LOG.warning(f"assistant turn has empty text: {prompt}")
res = self._tokenize(
turn,
add_eos_token=True,
strip_bos_token=True,
)
role_res = self._tokenize(
role.rstrip(),
add_eos_token=False,
strip_bos_token=True,
)
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
len_role = len(role_res["input_ids"])
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
elif part[0] == "":
turn = part[1]
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {part[0]}")
continue
if not isinstance(part, tuple):
LOG.warning(f"expected tuple, got {part}")
continue
user, assistant = conversation.roles
role, content = part
# Uses "in" because role contains extra characters
if user in role:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
else role
)
turn = role + content
# this is still the user query, we should
if not content.strip():
LOG.warning(f"user turn has empty text: {prompt}")
res = self._tokenize(
turn,
add_eos_token=False,
strip_bos_token=True,
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif assistant in role:
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
else role
)
turn = role + content
# this should be the assistant response, should end with an eos token
if not content.strip():
LOG.warning(f"assistant turn has empty text: {prompt}")
res = self._tokenize(
turn,
add_eos_token=True,
strip_bos_token=True,
)
role_res = self._tokenize(
role.rstrip(),
add_eos_token=False,
strip_bos_token=True,
)
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
len_role = len(role_res["input_ids"])
labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels))
elif role == "":
turn = content
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
@@ -424,38 +429,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
if not prompt.strip():
LOG.warning("Empty text requested for tokenization.")
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
else:
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
len(result["input_ids"]) > 0
and result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if (
len(result["input_ids"]) > 0
and result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
result["labels"] = result["input_ids"].copy()
return result
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
"""

View File

@@ -12,6 +12,7 @@ import torch
import transformers.modelcard
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
from transformers.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
@@ -19,6 +20,14 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
try:
from llava.train.train import safe_save_model_for_hf_trainer
except ImportError:
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
raise ImportError("missing LLaVA package")
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
@@ -39,10 +48,7 @@ class TrainDatasetMeta:
def train(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
dataset_meta: TrainDatasetMeta,
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
):
# load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
@@ -119,6 +125,11 @@ def train(
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# post training
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
@@ -134,6 +145,24 @@ def train(
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
trainer.save_model(cfg.output_dir)
elif cfg.multimodal:
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=cfg.output_dir)
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone()
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
# The model name saved is `pytorch_model.bin`
unwrapped_model.save_pretrained(
cfg.output_dir,
is_main_process=trainer.accelerator.is_main_process,
save_function=trainer.accelerator.save,
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
)
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)

View File

@@ -37,7 +37,7 @@ from axolotl.utils.distributed import (
)
if TYPE_CHECKING:
from axolotl.utils.trainer import AxolotlTrainingArguments
from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100

View File

@@ -79,6 +79,9 @@ def normalize_config(cfg):
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
if not cfg.base_model_config:
cfg.base_model_config = cfg.base_model
model_config = load_model_config(cfg)
cfg.model_config_type = model_config.model_type
@@ -119,19 +122,10 @@ def normalize_config(cfg):
or (cfg.model_type and "mistral" in cfg.model_type.lower())
)
log_gpu_memory_usage(LOG, "baseline", cfg.device)
if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate)
if cfg.adapter is not None:
for key in list(cfg.keys()):
if key.startswith("lora_"):
new_key = key.replace("lora_", "peft_")
LOG.warning(
PendingDeprecationWarning(
f"{key} soon to be deprecated. please use {new_key}"
)
)
cfg[new_key] = cfg[key]
del cfg[key]
log_gpu_memory_usage(LOG, "baseline", cfg.device)
def validate_config(cfg):
@@ -201,11 +195,14 @@ def validate_config(cfg):
if not cfg.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LoRA finetuning")
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with QLoRA")
if not cfg.load_in_8bit and cfg.adapter == "ia3":
LOG.warning("We recommend setting `load_in_8bit: true` for IA3 finetuning")
if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
raise ValueError("Fused modules are not supported with LoRA")
if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"):
@@ -220,6 +217,9 @@ def validate_config(cfg):
if cfg.lr_scheduler == "one_cycle":
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with ReLoRA")
if cfg.trust_remote_code:
LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
@@ -354,6 +354,30 @@ def validate_config(cfg):
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
)
if (
cfg.sample_packing
and cfg.eval_table_size
and cfg.eval_sample_packing is not False
):
raise ValueError(
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
)
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
if cfg.multimodal:
try:
import llava # noqa: F401 # pylint:disable=unused-import
except ImportError as exc:
LOG.warning(
"LLaVA package required for multimodal training. See docs/llava.md for more information."
)
raise exc
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -54,8 +54,19 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
def prepare_dataset(cfg, tokenizer):
if not cfg.pretraining_dataset:
def prepare_dataset(cfg, tokenizer, model=None):
if cfg.multimodal:
if not model:
raise ValueError("missing model argument")
from llava.train.train import LazySupervisedDataset
with zero_first(is_main_process()):
eval_dataset = None
train_dataset = LazySupervisedDataset(
tokenizer=tokenizer,
)
elif not cfg.pretraining_dataset:
with zero_first(is_main_process()):
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH

View File

@@ -255,7 +255,93 @@ def load_model(
model_kwargs["use_flash_attention_2"] = True
try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
if cfg.multimodal:
from llava.train.train import DataArguments, ModelArguments
if cfg.is_llama_derived_model:
from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM
model = LlavaLlamaForCausalLM.from_pretrained(
cfg.base_model,
)
elif cfg.is_mistral_derived_model:
from axolotl.models.llava.llava_mistral import LlavaMistralForCausalLM
model = LlavaMistralForCausalLM.from_pretrained(
cfg.base_model,
)
else:
raise NotImplementedError(
"unhandled model architecture for multimodal training"
)
if cfg.mm_freeze_backbone:
model.model.requires_grad_(False)
def make_inputs_require_grad(
module, input, output
): # pylint: disable=redefined-builtin,unused-argument
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model_args = ModelArguments(
model_name_or_path=cfg.base_model,
version="v0",
freeze_backbone=cfg.mm_freeze_backbone or False,
tune_mm_mlp_adapter=cfg.tune_mm_mlp_adapter or False,
vision_tower=cfg.mm_vision_tower,
mm_vision_select_layer=cfg.mm_vision_select_layer or -1,
pretrain_mm_mlp_adapter=cfg.pretrain_mm_mlp_adapter,
mm_projector_type=cfg.mm_projector_type or "linear",
mm_use_im_start_end=cfg.mm_use_im_start_end or False,
mm_use_im_patch_token=cfg.mm_use_im_patch_token or True,
mm_vision_select_feature=cfg.mm_vision_select_feature or "patch",
)
if cfg.mm_vision_tower:
model.get_model().initialize_vision_modules(
model_args=model_args, fsdp=cfg.fsdp
)
vision_tower = model.get_vision_tower()
vision_tower.to(dtype=cfg.torch_dtype)
# pylint: disable=duplicate-code
data_args = DataArguments(
data_path=cfg.datasets[0]["path"],
lazy_preprocess=cfg.mm_lazy_preprocess
if cfg.mm_lazy_preprocess is not None
else True,
is_multimodal=True,
image_folder=cfg.mm_image_folder or None,
image_aspect_ratio=cfg.mm_image_aspect_ratio or "square",
image_grid_pinpoints=cfg.mm_image_grid_pinpoints or None,
)
data_args.image_processor = vision_tower.image_processor
model.config.image_aspect_ratio = data_args.image_aspect_ratio
model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
model.config.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
if model_args.tune_mm_mlp_adapter:
model.requires_grad_(False)
for (
p # pylint: disable=invalid-name
) in model.get_model().mm_projector.parameters():
p.requires_grad = True
model.config.freeze_mm_mlp_adapter = cfg.freeze_mm_mlp_adapter
if cfg.freeze_mm_mlp_adapter:
for (
p # pylint: disable=invalid-name
) in model.get_model().mm_projector.parameters():
p.requires_grad = False
model.config.mm_use_im_start_end = (
data_args.mm_use_im_start_end
) = model_args.mm_use_im_start_end
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
elif cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM
config_kwargs = {}
@@ -272,6 +358,20 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
)
if cfg.flash_attention and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused,
)
if cfg.flash_attn_fuse_mlp:
LOG.info("patching with SwiGLU")
replace_llama_mlp_with_swiglu(model)
if cfg.flash_attn_fuse_qkv:
LOG.info("patching with fused QKV")
replace_llama_qkv_with_fused(model)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
# This is a WIP, still an issue with the backward pass
# RuntimeError: grad can be implicitly created only for scalar outputs
@@ -406,21 +506,21 @@ def load_model(
if hasattr(module, "weight"):
module.to(torch.float32)
require_peft: bool = False
if cfg.adapter in ["lora", "qlora", "ia3"]:
require_peft = True
if require_peft:
needs_fa2_dtype = cfg.adapter or cfg.fsdp
if (cfg.adapter == "lora" and load_in_8bit) or (
cfg.adapter == "qlora" and cfg.load_in_4bit
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
if require_peft or cfg.fsdp or (cfg.flash_attention and cfg.is_llama_derived_model):
if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model):
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
for name, module in model.named_modules():
if "norm" in name:
@@ -429,7 +529,7 @@ def load_model(
if hasattr(module, "weight"):
module.to(cfg.torch_dtype)
model, peft_config = load_adapter(model, cfg, cfg.adapter)
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}")
@@ -460,7 +560,7 @@ def load_model(
log_gpu_memory_usage(LOG, "after adapters", model.device)
# TODO resume_from_checkpoint handling
return model, peft_config
return model, lora_config
def load_adapter(model, cfg, adapter, inference=False):
@@ -470,8 +570,6 @@ def load_adapter(model, cfg, adapter, inference=False):
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter == "ia3":
return load_ia3(model, cfg, inference=inference)
if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter":
@@ -490,11 +588,11 @@ def load_llama_adapter(model, cfg):
task_type="CAUSAL_LM",
)
if cfg.peft_model_dir:
if cfg.lora_model_dir:
LOG.debug("Loading pretained PEFT - llama_adapter")
model = PeftModel.from_pretrained(
model,
cfg.peft_model_dir,
cfg.lora_model_dir,
torch_dtype=torch.float16,
)
else:
@@ -507,20 +605,27 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
peft_module_names = set()
lora_module_names = set()
multimodal_keywords = [
"mm_projector",
"vision_tower",
"vision_resampler",
] # for LLaVA
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if (
isinstance(module, cls)
or "Linear" in module.__class__.__name__
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
):
names = name.split(".")
peft_module_names.add(names[0] if len(names) == 1 else names[-1])
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in peft_module_names: # needed for 16-bit
peft_module_names.remove("lm_head")
if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(peft_module_names)
return list(lora_module_names)
def load_lora(model, cfg, inference=False):
@@ -528,68 +633,34 @@ def load_lora(model, cfg, inference=False):
from peft import LoraConfig, PeftModel, get_peft_model
peft_target_modules = list(cfg.peft_target_modules or [])
lora_target_modules = list(cfg.lora_target_modules or [])
if cfg.peft_target_linear:
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(linear_names)}")
peft_target_modules = list(set(peft_target_modules + linear_names))
lora_target_modules = list(set(lora_target_modules + linear_names))
peft_config = LoraConfig(
r=cfg.peft_r,
lora_alpha=cfg.peft_alpha,
target_modules=peft_target_modules,
lora_dropout=cfg.peft_dropout,
fan_in_fan_out=cfg.peft_fan_in_fan_out,
modules_to_save=cfg.peft_modules_to_save if cfg.peft_modules_to_save else None,
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
bias="none",
task_type="CAUSAL_LM",
)
if cfg.peft_model_dir:
if cfg.lora_model_dir:
LOG.debug("Loading pretained PEFT - LoRA")
model = PeftModel.from_pretrained(
model,
cfg.peft_model_dir,
cfg.lora_model_dir,
is_trainable=(not inference),
)
else:
model = get_peft_model(model, peft_config)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model, peft_config
def load_ia3(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import IA3Config, PeftModel, get_peft_model
peft_config_kwargs = {}
if cfg.peft_init_ia3_weights is not None:
peft_config_kwargs["init_ia3_weights"] = cfg.peft_init_ia3_weights
if cfg.peft_fan_in_fan_out is not None:
peft_config_kwargs["fan_in_fan_out"] = cfg.peft_fan_in_fan_out
peft_config = IA3Config(
target_modules=cfg.peft_target_modules,
feedforward_modules=cfg.peft_feedforward_modules,
modules_to_save=cfg.peft_modules_to_save,
task_type="CAUSAL_LM",
**peft_config_kwargs,
)
if cfg.peft_model_dir:
LOG.debug("Loading pretained PEFT - IA3")
model = PeftModel.from_pretrained(
model,
cfg.peft_model_dir,
is_trainable=(not inference),
)
else:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model, peft_config
return model, lora_config

View File

@@ -34,6 +34,5 @@ def check_example_labels(example, tokenizer, text_only=False):
delimiter = "" if text_only else " "
LOG.info(delimiter.join(colored_tokens))
LOG.info("\n\n\n")
print(" ".join(colored_tokens))
return " ".join(colored_tokens)

View File

@@ -1,40 +1,19 @@
"""Module containing the Trainer class and related functions"""
import importlib
import logging
import math
import os
import sys
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import List, Optional, Union
from typing import List
import numpy as np
import torch
import torch.cuda
import torch.distributed as dist
import transformers
from datasets import Dataset, set_caching_enabled
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import (
DataLoader,
DistributedSampler,
RandomSampler,
SequentialSampler,
)
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import SequentialDistributedSampler
from datasets import set_caching_enabled
from torch.utils.data import DistributedSampler, RandomSampler
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.distributed import (
@@ -43,7 +22,6 @@ from axolotl.utils.distributed import (
reduce_and_broadcast,
zero_first,
)
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
LOG = logging.getLogger("axolotl")
@@ -110,269 +88,6 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
return weighted_cross_entropy(logits, labels, weights)
@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
Extend the base TrainingArguments for axolotl helpers
"""
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
sample_packing_seq_len_multiplier: int = field(
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
class AxolotlTrainer(Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments
def __init__(self, *args, bench_data_collator=None, **kwargs):
self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs)
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
):
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing:
return DistributedSampler(
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if (
self.args.world_size > 1
and self.args.sample_packing
and self.args.eval_sample_packing is not False
):
return SequentialDistributedSampler(
eval_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
batch_size=self.args.per_device_eval_batch_size,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
return self.accelerator.prepare(
MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=train_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
)
return super().get_train_dataloader()
def get_eval_dataloader(
self, eval_dataset: Optional[Dataset] = None
) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
return self.accelerator.prepare(
MultipackDistributedDataloader(
eval_dataset,
batch_size=self.args.eval_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> Union[DataLoader, MultipackDistributedDataloader]:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
pct_start=pct_start,
div_factor=6,
)
return self.lr_scheduler
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler
def add_position_ids(sample):
sample_len = len(sample["input_ids"])
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
@@ -550,245 +265,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
warmup_steps = (
cfg.warmup_steps
if cfg.warmup_steps is not None
else min(int(0.03 * total_num_steps), 100)
)
logging_steps = (
cfg.logging_steps
if cfg.logging_steps is not None
else max(min(int(0.005 * total_num_steps), 10), 1)
)
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset
training_arguments_kwargs = {}
if cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
else:
training_arguments_kwargs["bf16"] = cfg.bf16
training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False
training_arguments_kwargs["tf32"] = cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
if cfg.seed:
training_arguments_kwargs["seed"] = cfg.seed
if cfg.gradient_checkpointing:
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
if cfg.fsdp:
training_arguments_kwargs["fsdp"] = cfg.fsdp
if cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
# deepspeed
if cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = cfg.deepspeed
if cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
if cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
if cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
if cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
if cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
if cfg.hub_model_id:
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
if cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = cfg.sample_packing_eff_est
if cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
elif cfg.evaluation_strategy:
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
elif cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
if cfg.save_steps:
training_arguments_kwargs["save_strategy"] = "steps"
training_arguments_kwargs["save_steps"] = cfg.save_steps
elif cfg.save_strategy:
training_arguments_kwargs["save_strategy"] = cfg.save_strategy
else:
# default to saving each epoch if not defined
training_arguments_kwargs["save_strategy"] = "epoch"
if cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
if cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
if cfg.metric_for_best_model:
training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model
if cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
if cfg.torch_compile:
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
else:
import torch._dynamo # pylint: disable=redefined-outer-name
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_arguments_kwargs["torch_compile"] = cfg.torch_compile
if cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = cfg.torch_compile_backend
# DDP Config
if cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
if cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = cfg.ddp_bucket_cap_mb
if cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs["ddp_broadcast_buffers"] = cfg.ddp_broadcast_buffers
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
output_dir=cfg.output_dir,
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
load_best_model_at_end=(
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
and cfg.val_set_size > 0
and cfg.save_steps
and cfg.eval_steps
and cfg.save_steps % cfg.eval_steps == 0
)
or False,
ddp_find_unused_parameters=False if cfg.ddp else None,
group_by_length=cfg.group_by_length,
report_to="wandb" if cfg.use_wandb else None,
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
optim=cfg.optimizer if cfg.optimizer else "adamw_hf",
lr_scheduler_type=cfg.lr_scheduler
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
eval_sample_packing=cfg.eval_sample_packing,
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
relora_steps=cfg.relora_steps,
relora_warmup_steps=cfg.relora_warmup_steps,
**training_arguments_kwargs,
)
trainer_kwargs = {}
if cfg.optimizer == "adamw_anyprecision":
if Path(cfg.torchdistx_path).exists():
sys.path.append(cfg.torchdistx_path)
importlib.import_module("torchdistx")
callbacks = []
callbacks.append(GPUStatsCallback(cfg))
callbacks.append(EvalFirstStepCallback)
if cfg.relora_steps:
callbacks.append(ReLoRACallback(cfg))
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
if cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
cfg.sequence_len / 64
)
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
if cfg.is_llama_derived_model and cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import (
add_mem_tokens,
get_mem_id,
set_model_mem_id,
)
set_model_mem_id(model, tokenizer)
LOG.info("Adding landmark attention tokens to dataset")
for dataset in [train_dataset, eval_dataset]:
dataset = dataset.map(
partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
batched=False,
num_proc=32,
)
trainer_cls = AxolotlTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora"):
trainer_cls = OneCycleLRSchedulerTrainer
elif cfg.relora_steps:
trainer_cls = ReLoRATrainer
trainer = trainer_cls(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
data_collator=DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=callbacks,
**trainer_kwargs,
)
if cfg.use_wandb and cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
trainer.add_callback(LogPredictionCallback(cfg))
if cfg.use_wandb:
trainer.add_callback(SaveAxolotlConfigtoWandBCallback(cfg.axolotl_config_path))
if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
cfg.early_stopping_patience,
)
trainer.add_callback(early_stop_cb)
return trainer
return trainer_builder.build(total_num_steps)

View File

@@ -0,0 +1,72 @@
"""
E2E tests for lora llama
"""
import logging
import os
import tempfile
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestFusedLlama(unittest.TestCase):
"""
Test case for Llama models using Fused layers
"""
def test_fft_packing(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"flash_attention": True,
"flash_attn_fuse_qkv": True,
"flash_attn_fuse_mlp": True,
"sample_packing": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "pytorch_model.bin").exists()

View File

@@ -24,16 +24,11 @@ class TestLoraLlama(unittest.TestCase):
"""
def test_lora(self):
"""
support for legacy lora_ configs
:return:
"""
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
@@ -70,108 +65,12 @@ class TestLoraLlama(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_peft(self):
"""
support for legacy lora_ configs
:return:
"""
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"peft_r": 32,
"peft_alpha": 64,
"peft_dropout": 0.05,
"peft_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_ia3_peft(self):
"""
support for IA3 peft
:return:
"""
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "ia3",
"peft_r": 32,
"peft_alpha": 64,
"peft_dropout": 0.05,
"peft_target_modules": ["k_proj", "v_proj", "down_proj"],
"peft_feedforward_modules": ["down_proj"],
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists()
def test_lora_packing(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
@@ -216,7 +115,6 @@ class TestLoraLlama(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"base_model_config": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,

View File

@@ -31,7 +31,6 @@ class TestMistral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True,
"sequence_len": 1024,
"load_in_8bit": True,
@@ -77,7 +76,6 @@ class TestMistral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True,
"sequence_len": 1024,
"val_set_size": 0.1,

View File

@@ -31,7 +31,6 @@ class TestMistral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
@@ -78,7 +77,6 @@ class TestMistral(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,

View File

@@ -27,7 +27,6 @@ class TestPhi(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "microsoft/phi-1_5",
"base_model_config": "microsoft/phi-1_5",
"trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer",
@@ -71,7 +70,6 @@ class TestPhi(unittest.TestCase):
cfg = DictDefault(
{
"base_model": "microsoft/phi-1_5",
"base_model_config": "microsoft/phi-1_5",
"trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer",

View File

@@ -1,48 +0,0 @@
"""Module for testing the validation module"""
import logging
import unittest
from typing import Optional
import pytest
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
class NormalizationTest(unittest.TestCase):
"""
Test the cfg normalization module
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def test_lora_to_peft(self):
base_cfg = DictDefault(
{
"gradient_accumulation_steps": 1,
"micro_batch_size": 1,
"base_model": "NousResearch/Llama-2-7b-hf",
"base_model_config": "NousResearch/Llama-2-7b-hf",
}
)
cfg = base_cfg | DictDefault(
{
"adapter": "lora",
"lora_r": 128,
"lora_alpha": 64,
}
)
with self._caplog.at_level(logging.WARNING):
normalize_config(cfg)
assert any(
"soon to be deprecated. please use peft_" in record.message
for record in self._caplog.records
)
assert cfg.peft_r == 128
assert cfg.peft_alpha == 64

View File

@@ -0,0 +1,46 @@
"""
Test classes for checking functionality of the cfg normalization
"""
import unittest
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
class NormalizeConfigTestCase(unittest.TestCase):
"""
test class for normalize_config checks
"""
def _get_base_cfg(self):
return DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)
def test_lr_as_float(self):
cfg = (
self._get_base_cfg()
| DictDefault( # pylint: disable=unsupported-binary-operation
{
"learning_rate": "5e-5",
}
)
)
normalize_config(cfg)
assert cfg.learning_rate == 0.00005
def test_base_model_config_set_when_empty(self):
cfg = self._get_base_cfg()
del cfg.base_model_config
normalize_config(cfg)
assert cfg.base_model_config == cfg.base_model

View File

@@ -565,3 +565,87 @@ class ValidationTest(unittest.TestCase):
)
validate_config(cfg)
def test_eval_table_size_conflict_eval_packing(self):
cfg = DictDefault(
{
"sample_packing": True,
"eval_table_size": 100,
}
)
with pytest.raises(
ValueError, match=r".*Please set 'eval_sample_packing' to false.*"
):
validate_config(cfg)
cfg = DictDefault(
{
"sample_packing": True,
"eval_sample_packing": False,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"sample_packing": False,
"eval_table_size": 100,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"sample_packing": True,
"eval_table_size": 100,
"eval_sample_packing": False,
}
)
validate_config(cfg)
def test_load_in_x_bit_without_adapter(self):
cfg = DictDefault(
{
"load_in_4bit": True,
}
)
with pytest.raises(
ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
):
validate_config(cfg)
cfg = DictDefault(
{
"load_in_8bit": True,
}
)
with pytest.raises(
ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
):
validate_config(cfg)
cfg = DictDefault(
{
"load_in_4bit": True,
"adapter": "qlora",
}
)
validate_config(cfg)
cfg = DictDefault(
{
"load_in_8bit": True,
"adapter": "lora",
}
)
validate_config(cfg)