diff --git a/FAQS.md b/FAQS.md
index bdf056be7..f3c9dd525 100644
--- a/FAQS.md
+++ b/FAQS.md
@@ -2,3 +2,6 @@
- Can you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this [PR](https://github.com/huggingface/transformers/pull/22874)
- Will this work with Deepspeed? That's still a WIP, but setting `export ACCELERATE_USE_DEEPSPEED=true` should work in some cases
+- `Error invalid argument at line 359 in file /workspace/bitsandbytes/csrc/pythonInterface.c`
+`/arrow/cpp/src/arrow/filesystem/s3fs.cc:2598: arrow::fs::FinalizeS3 was not called even though S3 was initialized.`
+This could lead to a segmentation fault at exit. Try reinstalling bitsandbytes and transformers from source.
diff --git a/README.md b/README.md
index 2bc55732d..32bba7490 100644
--- a/README.md
+++ b/README.md
@@ -16,13 +16,14 @@
## Axolotl supports
-| | fp16/fp32 | fp16/fp32 w/ lora | qlora | 4bit-quant | 4bit-quant w/flash attention | flash attention | xformers attention |
-|---------|:----------|:------------------|------|------------|------------------------------|-----------------|--------------------|
-| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| Pythia | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
-| cerebras | ✅ | ✅ | ❓ | ❌ | ❌ | ❌ | ❓ |
-| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
-| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
+| | fp16/fp32 | lora | qlora | gptq | gptq w/ lora | gptq w/flash attn | flash attn | xformers attn |
+|----------|:----------|:-----|-------|------|:-------------|-------------------|------------|---------------|
+| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| Pythia | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ❓ |
+| cerebras | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ✅ |
+| mpt | ✅ | ❌ | ❓ | ❌ | ❓ | ❌ | ❌ | ❓ |
+| falcon | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❌ | ✅ |
+| gpt-j | ✅ | ✅ | ✅ | ❌ | ❓ | ❌ | ❓ | ✅ |
## Quickstart ⚡
@@ -38,10 +39,10 @@ pip3 install -U git+https://github.com/huggingface/peft.git
accelerate config
# finetune lora
-accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml
+accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
# inference
-accelerate launch scripts/finetune.py examples/lora-openllama-3b/config.yml \
+accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
--inference --lora_model_dir="./lora-out"
```
@@ -218,6 +219,14 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
+- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
+ ```json
+ {"conversations": [{"role": "...", "value": "..."}]}
+ ```
+- `sharegpt_jokes`: creates a chat where bot is asked to tell a joke, then explain why the joke is funny
+ ```json
+ {"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
+ ```
@@ -381,6 +390,8 @@ num_epochs: 3
warmup_steps: 100
learning_rate: 0.00003
logging_steps:
+save_steps:
+eval_steps:
# whether to mask out or include the human's prompt from the training labels
train_on_inputs: false
@@ -497,6 +508,11 @@ Pass the appropriate flag to the train command:
```bash
--inference --base_model ./completed-model
```
+- Full weights finetune w/ a prompt from a text file:
+ ```bash
+ cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
+ --base_model ./completed-model --inference --prompter=None --load_in_8bit=True
+ ```
### Merge LORA to base
@@ -524,7 +540,7 @@ Try set `fp16: true`
Try to turn off xformers.
-## Need help? 🙋♂️
+## Need help? 🙋♂️
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
diff --git a/configs/accelerate/default_config.yaml b/configs/accelerate/default_config.yaml
deleted file mode 100644
index 9759703af..000000000
--- a/configs/accelerate/default_config.yaml
+++ /dev/null
@@ -1,15 +0,0 @@
-compute_environment: LOCAL_MACHINE
-distributed_type: 'NO'
-downcast_bf16: 'no'
-gpu_ids: all
-machine_rank: 0
-main_training_function: main
-mixed_precision: bf16
-num_machines: 1
-num_processes: 1
-rdzv_backend: static
-same_network: true
-tpu_env: []
-tpu_use_cluster: false
-tpu_use_sudo: false
-use_cpu: false
diff --git a/configs/cerebras_1_3B_alpaca.yml b/configs/cerebras_1_3B_alpaca.yml
deleted file mode 100644
index 958bf4c5a..000000000
--- a/configs/cerebras_1_3B_alpaca.yml
+++ /dev/null
@@ -1,40 +0,0 @@
-base_model: cerebras/Cerebras-GPT-1.3B
-model_type: AutoModelForCausalLM
-tokenizer_type: AutoTokenizer
-load_in_8bit: true
-datasets:
- - path: data/alpaca_data_gpt4.jsonl
- type: alpaca
- - path: data/vicuna_cleaned.jsonl
- type: sharegpt
- - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
- type: gpteacher
- - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
- type: gpteacher
-dataset_prepared_path: last_run_prepared
-val_set_size: 0.05
-adapter: lora
-sequence_len: 2048
-lora_r: 8
-lora_alpha: 16
-lora_dropout: 0.05
-lora_target_modules:
- - c_attn
-lora_fan_in_fan_out: false
-wandb_project: pythia-1.4b-lora
-wandb_watch:
-wandb_run_id:
-wandb_log_model:
-output_dir: ./lora-alpaca
-gradient_accumulation_steps: 1
-micro_batch_size: 4
-num_epochs: 5
-learning_rate: 0.0003
-train_on_inputs: false
-group_by_length: false
-bf16: True
-tf32: True
-gradient_checkpointing:
-early_stopping_patience:
-resume_from_checkpoint:
-local_rank:
diff --git a/configs/galactica_1_3B.yml b/configs/galactica_1_3B.yml
deleted file mode 100644
index 2abb4c6b4..000000000
--- a/configs/galactica_1_3B.yml
+++ /dev/null
@@ -1,41 +0,0 @@
-base_model: facebook/galactica-1.3b
-model_type: AutoModelForCausalLM
-tokenizer_type: AutoTokenizer
-load_in_8bit: false
-datasets:
- - path: tatsu-lab/alpaca
- type: alpaca
-dataset_prepared_path: last_run_prepared
-val_set_size: 0.1
-adapter:
-lora_model_dir:
-sequence_len: 1024
-max_packed_sequence_len: 1024
-lora_r: 8
-lora_alpha: 16
-lora_dropout: 0.05
-lora_target_modules:
- - q_proj
- - v_proj
-lora_fan_in_fan_out: false
-wandb_project:
-wandb_watch:
-wandb_run_id:
-wandb_log_model:
-output_dir: ./lora-llama-alpaca
-gradient_accumulation_steps: 1
-micro_batch_size: 16
-num_epochs: 3
-learning_rate: 0.00003
-train_on_inputs: false
-group_by_length: false
-bf16: false
-tf32: false
-early_stopping_patience:
-resume_from_checkpoint:
-local_rank:
-tokens:
- pad_token: "[PAD]"
- bos_token: ""
- eos_token: ""
- unk_token: ""
diff --git a/configs/llama_13B_alpaca.yml b/configs/llama_13B_alpaca.yml
deleted file mode 100644
index 99c9883fe..000000000
--- a/configs/llama_13B_alpaca.yml
+++ /dev/null
@@ -1,39 +0,0 @@
-base_model: huggyllama/llama-13b
-model_type: LlamaForCausalLM
-tokenizer_type: LlamaTokenizer
-load_in_8bit: true
-datasets:
- - path: anon8231489123/ShareGPT_Vicuna_unfiltered
- data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
- type: sharegpt
-dataset_prepared_path: last_run_prepared
-val_set_size: 0.002
-adapter:
-lora_model_dir:
-sequence_len: 2048
-lora_r: 8
-lora_alpha: 16
-lora_dropout: 0.05
-lora_target_modules:
- - q_proj
- - v_proj
-lora_fan_in_fan_out: false
-wandb_project:
-wandb_watch:
-wandb_run_id:
-wandb_log_model:
-output_dir: ./llama-13b-sharegpt
-gradient_accumulation_steps: 1
-micro_batch_size: 2
-warmup_steps: 1000
-save_steps:
-eval_steps:
-num_epochs: 5
-learning_rate: 0.00003
-train_on_inputs: false
-group_by_length: false
-bf16: true
-tf32: true
-early_stopping_patience: 5
-resume_from_checkpoint:
-local_rank:
diff --git a/configs/llama_65B_alpaca.yml b/configs/llama_65B_alpaca.yml
deleted file mode 100644
index e7d2c211c..000000000
--- a/configs/llama_65B_alpaca.yml
+++ /dev/null
@@ -1,44 +0,0 @@
-base_model: huggyllama/llama-65b
-model_type: LlamaForCausalLM
-tokenizer_type: LlamaTokenizer
-load_in_8bit: true
-datasets:
- - path: data/alpaca_data_gpt4.jsonl
- type: alpaca
- - path: anon8231489123/ShareGPT_Vicuna_unfiltered
- data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
- type: sharegpt
- - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
- type: gpteacher
- - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
- type: gpteacher
-dataset_prepared_path: last_run_prepared
-val_set_size: 0.04
-adapter: lora
-lora_model_dir:
-sequence_len: 2048
-lora_r: 8
-lora_alpha: 16
-lora_dropout: 0.05
-lora_target_modules:
- - q_proj
- - v_proj
-lora_fan_in_fan_out: false
-wandb_project: llama-65b-lora
-wandb_watch:
-wandb_run_id:
-wandb_log_model:
-output_dir: ./lora-llama-alpaca
-gradient_accumulation_steps: 1
-micro_batch_size: 16
-warmup_steps: 1000
-save_steps:
-num_epochs: 5
-learning_rate: 0.00003
-train_on_inputs: false
-group_by_length: false
-bf16: true
-tf32: true
-early_stopping_patience:
-resume_from_checkpoint:
-local_rank:
diff --git a/configs/llama_7B_4bit.yml b/configs/llama_7B_4bit.yml
deleted file mode 100644
index a7451516c..000000000
--- a/configs/llama_7B_4bit.yml
+++ /dev/null
@@ -1,45 +0,0 @@
-base_model: decapoda-research/llama-7b-hf-int4
-base_model_config: decapoda-research/llama-7b-hf
-model_type: LlamaForCausalLM
-tokenizer_type: LlamaTokenizer
-load_in_8bit: true
-datasets:
- - path: tatsu-lab/alpaca # original alpaca dataset
- type: alpaca
-dataset_prepared_path: data/last_run_prepared
-val_set_size: 0.04
-adapter: lora
-lora_model_dir:
-sequence_len: 2048
-max_packed_sequence_len: 1024
-lora_r: 8
-lora_alpha: 16
-lora_dropout: 0.05
-lora_target_modules:
- - q_proj
- - v_proj
-# - k_proj
-# - o_proj
-lora_fan_in_fan_out: false
-wandb_project:
-wandb_watch:
-wandb_run_id:
-wandb_log_model:
-output_dir: ./lora-test
-gradient_accumulation_steps: 1
-micro_batch_size: 2
-num_epochs: 3
-warmup_steps: 100
-learning_rate: 0.00003
-train_on_inputs: false
-group_by_length: false
-bf16: true
-tf32: true
-gradient_checkpointing: false
-early_stopping_patience: 3
-resume_from_checkpoint:
-auto_resume_from_checkpoints: true
-local_rank:
-load_4bit: true
-xformers_attention: true
-flash_attention:
diff --git a/configs/llama_7B_alpaca.yml b/configs/llama_7B_alpaca.yml
deleted file mode 100644
index 7db2f65aa..000000000
--- a/configs/llama_7B_alpaca.yml
+++ /dev/null
@@ -1,41 +0,0 @@
-base_model: huggyllama/llama-7b
-model_type: LlamaForCausalLM
-tokenizer_type: LlamaTokenizer
-load_in_8bit: true
-datasets:
- - path: data/alpaca_data_gpt4.jsonl
- type: alpaca
- - path: data/vicuna_cleaned.jsonl
- type: sharegpt
- - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
- type: gpteacher
- - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
- type: gpteacher
-dataset_prepared_path: last_run_prepared
-val_set_size: 0.04
-adapter: lora
-lora_model_dir:
-sequence_len: 2048
-lora_r: 8
-lora_alpha: 16
-lora_dropout: 0.05
-lora_target_modules:
- - q_proj
- - v_proj
-lora_fan_in_fan_out: false
-wandb_project: llama-7b-lora
-wandb_watch:
-wandb_run_id:
-wandb_log_model:
-output_dir: ./lora-llama-alpaca
-gradient_accumulation_steps: 1
-micro_batch_size: 16
-num_epochs: 5
-learning_rate: 0.00003
-train_on_inputs: false
-group_by_length: false
-bf16: true
-tf32: true
-early_stopping_patience:
-resume_from_checkpoint:
-local_rank:
diff --git a/configs/quickstart.yml b/configs/quickstart.yml
deleted file mode 100644
index 2362916fc..000000000
--- a/configs/quickstart.yml
+++ /dev/null
@@ -1,45 +0,0 @@
-base_model: decapoda-research/llama-7b-hf-int4
-base_model_config: decapoda-research/llama-7b-hf
-model_type: LlamaForCausalLM
-tokenizer_type: LlamaTokenizer
-load_in_8bit: true
-datasets:
- - path: tatsu-lab/alpaca # original alpaca dataset
- type: alpaca
-dataset_prepared_path: data/last_run_prepared
-val_set_size: 0.04
-adapter: lora
-lora_model_dir:
-sequence_len: 1024
-max_packed_sequence_len: 1024
-lora_r: 8
-lora_alpha: 16
-lora_dropout: 0.05
-lora_target_modules:
- - q_proj
- - v_proj
-# - k_proj
-# - o_proj
-lora_fan_in_fan_out: false
-wandb_project:
-wandb_watch:
-wandb_run_id:
-wandb_log_model:
-output_dir: ./lora-test
-gradient_accumulation_steps: 1
-micro_batch_size: 1
-num_epochs: 3
-warmup_steps: 100
-learning_rate: 0.00003
-train_on_inputs: false
-group_by_length: false
-bf16: true
-tf32: true
-gradient_checkpointing: false
-early_stopping_patience: 3
-resume_from_checkpoint:
-auto_resume_from_checkpoints: true
-local_rank:
-gptq: true
-xformers_attention: true
-flash_attention:
diff --git a/configs/sample.yml b/configs/sample.yml
deleted file mode 100644
index ddd95cb55..000000000
--- a/configs/sample.yml
+++ /dev/null
@@ -1,87 +0,0 @@
-# this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
-# this can also be a relative path to a model on disk
-base_model: decapoda-research/llama-7b-hf-int4
-# you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
-base_model_ignore_patterns:
-# if the base_model repo on hf hub doesn't include configuration .json files,
-# you can set that here, or leave this empty to default to base_model
-base_model_config: decapoda-research/llama-7b-hf
-# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
-model_type: AutoModelForCausalLM
-# Corresponding tokenizer for the model AutoTokenizer is a good choice
-tokenizer_type: AutoTokenizer
-# whether you are training a 4-bit quantized model
-load_4bit: true
-# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
-load_in_8bit: true
-# a list of one or more datasets to finetune the model with
-datasets:
- # this can be either a hf dataset, or relative path
- - path: vicgalle/alpaca-gpt4
- # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
- type: alpaca
-# axolotl attempts to save the dataset as an arrow after packing the data together so
-# subsequent training attempts load faster, relative path
-dataset_prepared_path: data/last_run_prepared
-# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc
-val_set_size: 0.04
-# if you want to use lora, leave blank to train all parameters in original model
-adapter: lora
-# if you already have a lora model trained that you want to load, put that here
-lora_model_dir:
-# the maximum length of an input to train with, this should typically be less than 2048
-# as most models have a token/context limit of 2048
-sequence_len: 2048
-# max sequence length to concatenate training samples together up to
-# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
-max_packed_sequence_len: 1024
-# lora hyperparameters
-lora_r: 8
-lora_alpha: 16
-lora_dropout: 0.05
-lora_target_modules:
- - q_proj
- - v_proj
-# - k_proj
-# - o_proj
-lora_fan_in_fan_out: false
-# wandb configuration if your're using it
-wandb_project:
-wandb_watch:
-wandb_run_id:
-wandb_log_model:
-# where to save the finsihed model to
-output_dir: ./completed-model
-# training hyperparameters
-gradient_accumulation_steps: 1
-batch_size:
-micro_batch_size: 2
-num_epochs: 3
-warmup_steps: 100
-learning_rate: 0.00003
-# whether to mask out or include the human's prompt from the training labels
-train_on_inputs: false
-# don't use this, leads to wonky training (according to someone on the internet)
-group_by_length: false
-# Use CUDA bf16
-bf16: true
-# Use CUDA tf32
-tf32: true
-# does not work with current implementation of 4-bit LoRA
-gradient_checkpointing: false
-# stop training after this many evaluation losses have increased in a row
-# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
-early_stopping_patience: 3
-# specify a scheduler to use with the optimizer. only one_cycle is supported currently
-lr_scheduler:
-# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
-xformers_attention:
-# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
-flash_attention:
-# resume from a specific checkpoint dir
-resume_from_checkpoint:
-# if resume_from_checkpoint isn't set and you simply want it to start where it left off
-# be careful with this being turned on between different models
-auto_resume_from_checkpoints: false
-# don't mess with this, it's here for accelerate and torchrun
-local_rank:
diff --git a/configs/stability_3b.yml b/configs/stability_3b.yml
deleted file mode 100644
index 83516a20a..000000000
--- a/configs/stability_3b.yml
+++ /dev/null
@@ -1,56 +0,0 @@
-base_model: stabilityai/stablelm-base-alpha-3b
-base_model_config: stabilityai/stablelm-base-alpha-3b
-load_in_8bit: false
-datasets:
- - path: vicgalle/alpaca-gpt4
- type: alpaca
-dataset_prepared_path: last_run_prepared
-val_set_size: 0.04
-adapter:
-lora_model_dir:
-sequence_len: 4096
-max_packed_sequence_len: 4096
-lora_r: 8
-lora_alpha: 16
-lora_dropout: 0.05
-lora_target_modules:
- - q_proj
- - v_proj
-lora_fan_in_fan_out: false
-wandb_project: stable-alpaca-3b
-wandb_watch:
-wandb_run_id:
-wandb_log_model:
-output_dir: ./stable-alpaca-3b
-gradient_accumulation_steps: 1
-micro_batch_size: 1
-num_epochs: 1
-optimizer: adamw_bnb_8bit
-torchdistx_path:
-lr_scheduler: cosine
-learning_rate: 0.0000002
-train_on_inputs: false
-group_by_length: false
-bf16: true
-tf32: true
-early_stopping_patience:
-resume_from_checkpoint:
-local_rank:
-logging_steps: 1
-xformers_attention: true
-flash_attention:
-gptq_groupsize:
-gptq_model_v1:
-warmup_steps: 100
-eval_steps: 50
-save_steps: 200
-debug:
-deepspeed:
-weight_decay: 0.01
-fsdp:
-fsdp_config:
-#tokens:
-# pad_token: "[PAD]"
-# bos_token: ""
-# eos_token: ""
-# unk_token: ""
diff --git a/configs/vicuna_13B_4bit_reflect.yml b/configs/vicuna_13B_4bit_reflect.yml
deleted file mode 100644
index 3e37f5334..000000000
--- a/configs/vicuna_13B_4bit_reflect.yml
+++ /dev/null
@@ -1,45 +0,0 @@
-base_model: anon8231489123/vicuna-13b-GPTQ-4bit-128g
-base_model_config: anon8231489123/vicuna-13b-GPTQ-4bit-128g
-model_type: LlamaForCausalLM
-tokenizer_type: LlamaTokenizer
-load_in_8bit: false
-load_4bit: true
-gptq_groupsize: 128
-gptq_model_v1: false
-datasets:
-# https://github.com/vaguenebula/AlpacaDataReflect/blob/main/alpaca_reflect_pruned.json
- - path: data/alpaca_reflect_pruned.jsonl
- type: reflection
-dataset_prepared_path: data/last_run_prepared
-val_set_size: 0.04
-adapter: lora
-lora_model_dir:
-sequence_len: 2048
-max_packed_sequence_len: 2048
-lora_r: 8
-lora_alpha: 16
-lora_dropout: 0.05
-lora_target_modules:
- - q_proj
- - v_proj
-# - k_proj
-# - o_proj
-lora_fan_in_fan_out: false
-wandb_project:
-wandb_watch:
-wandb_run_id:
-wandb_log_model:
-output_dir: ./lora-reflect
-gradient_accumulation_steps: 1
-micro_batch_size: 2
-num_epochs: 3
-learning_rate: 0.00003
-train_on_inputs: false
-group_by_length: false
-bf16: true
-tf32: true
-gradient_checkpointing: false
-early_stopping_patience: 3
-resume_from_checkpoint:
-local_rank:
-flash_attention: true
diff --git a/examples/cerebras/qlora.yml b/examples/cerebras/qlora.yml
new file mode 100644
index 000000000..9340299b9
--- /dev/null
+++ b/examples/cerebras/qlora.yml
@@ -0,0 +1,60 @@
+base_model: cerebras/Cerebras-GPT-1.3B
+base_model_config: cerebras/Cerebras-GPT-1.3B
+load_in_8bit: false
+load_in_4bit: true
+strict: false
+push_dataset_to_hub:
+datasets:
+ - path: teknium/GPT4-LLM-Cleaned
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.01
+adapter: qlora
+lora_model_dir:
+sequence_len: 2048
+max_packed_sequence_len: 2048
+lora_r: 16
+lora_alpha: 32
+lora_dropout: 0.05
+lora_target_modules:
+ - c_fc
+ - c_attn
+ - c_proj
+lora_target_linear:
+lora_fan_in_fan_out:
+wandb_project:
+wandb_watch:
+wandb_run_id:
+wandb_log_model:
+output_dir: ./qlora-out
+batch_size: 4
+micro_batch_size: 4
+num_epochs: 2
+optimizer: paged_adamw_8bit
+torchdistx_path:
+lr_scheduler: cosine
+learning_rate: 0.0002
+train_on_inputs: false
+group_by_length: true
+bf16: true
+fp16: false
+tf32: true
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention: true
+flash_attention:
+gptq_groupsize:
+gptq_model_v1:
+warmup_steps: 10
+eval_steps: 20
+save_steps:
+debug:
+deepspeed:
+weight_decay: 0.1
+fsdp:
+fsdp_config:
+special_tokens:
+ pad_token: "<|endoftext|>"
diff --git a/examples/falcon/config-7b-lora.yml b/examples/falcon/config-7b-lora.yml
index 090cc6bcf..8aa585851 100644
--- a/examples/falcon/config-7b-lora.yml
+++ b/examples/falcon/config-7b-lora.yml
@@ -23,7 +23,7 @@ lora_dropout: 0.0
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
-wandb_project: falcon-7b
+wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
diff --git a/examples/falcon/config-7b.yml b/examples/falcon/config-7b.yml
index dc67d6125..b267566ce 100644
--- a/examples/falcon/config-7b.yml
+++ b/examples/falcon/config-7b.yml
@@ -23,7 +23,7 @@ lora_dropout: 0.0
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
-wandb_project: falcon-7b
+wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
diff --git a/examples/gptj/qlora.yml b/examples/gptj/qlora.yml
new file mode 100644
index 000000000..858c14862
--- /dev/null
+++ b/examples/gptj/qlora.yml
@@ -0,0 +1,57 @@
+base_model: EleutherAI/gpt-j-6b
+base_model_config: EleutherAI/gpt-j-6b
+load_in_8bit: false
+load_in_4bit: true
+strict: false
+push_dataset_to_hub:
+datasets:
+ - path: teknium/GPT4-LLM-Cleaned
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.01
+adapter: qlora
+lora_model_dir:
+sequence_len: 2048
+max_packed_sequence_len:
+lora_r: 8
+lora_alpha: 32
+lora_dropout: 0.05
+lora_target_modules:
+lora_target_linear: true
+lora_fan_in_fan_out:
+wandb_project:
+wandb_watch:
+wandb_run_id:
+wandb_log_model:
+output_dir: ./qlora-out
+gradient_accumulation_steps: 2
+micro_batch_size: 2
+num_epochs: 2
+optimizer: paged_adamw_8bit
+torchdistx_path:
+lr_scheduler: cosine
+learning_rate: 0.0001
+train_on_inputs: false
+group_by_length: true
+bf16: true
+fp16: false
+tf32: true
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention: true
+flash_attention:
+gptq_groupsize:
+gptq_model_v1:
+warmup_steps: 10
+eval_steps: 20
+save_steps:
+debug:
+deepspeed:
+weight_decay: 0.1
+fsdp:
+fsdp_config:
+special_tokens:
+ pad_token: "<|endoftext|>"
diff --git a/examples/gptq-lora-7b/README.md b/examples/gptq-lora-7b/README.md
index eefe98d3f..0bde51b06 100644
--- a/examples/gptq-lora-7b/README.md
+++ b/examples/gptq-lora-7b/README.md
@@ -3,6 +3,6 @@
This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
```shell
-accelerate launch scripts/finetune.py examples/4bit-lora-7b/config.yml
+accelerate launch scripts/finetune.py examples/gptq-lora-7b/config.yml
```
diff --git a/configs/llama_7B_jeopardy.yml b/examples/jeopardy-bot/config.yml
similarity index 75%
rename from configs/llama_7B_jeopardy.yml
rename to examples/jeopardy-bot/config.yml
index 287d6d6ab..b803c6074 100644
--- a/configs/llama_7B_jeopardy.yml
+++ b/examples/jeopardy-bot/config.yml
@@ -7,30 +7,28 @@ datasets:
- path: openaccess-ai-collective/jeopardy
type: jeopardy
dataset_prepared_path: last_run_prepared
-val_set_size: 0.01
+val_set_size: 0.02
adapter:
lora_model_dir:
-sequence_len: 2048
-max_packed_sequence_len: 2048
-lora_r: 8
-lora_alpha: 16
-lora_dropout: 0.05
+sequence_len: 512
+max_packed_sequence_len:
+lora_r:
+lora_alpha:
+lora_dropout:
lora_target_modules:
- - q_proj
- - v_proj
lora_fan_in_fan_out: false
-wandb_project: jeopardy-bot-7b
+wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./jeopardy-bot-7b
-gradient_accumulation_steps: 2
+gradient_accumulation_steps: 1
micro_batch_size: 1
-num_epochs: 2
+num_epochs: 3
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine
-learning_rate: 0.0000002
+learning_rate: 0.00003
train_on_inputs: false
group_by_length: false
bf16: true
@@ -48,11 +46,10 @@ eval_steps: 110
save_steps: 660
debug:
deepspeed:
-weight_decay: 0.0001
+weight_decay: 0.1
fsdp:
fsdp_config:
tokens:
- pad_token: "[PAD]"
bos_token: ""
eos_token: ""
unk_token: ""
diff --git a/examples/openllama-3b/README.md b/examples/openllama-3b/README.md
new file mode 100644
index 000000000..3e9501a54
--- /dev/null
+++ b/examples/openllama-3b/README.md
@@ -0,0 +1,16 @@
+# openllama-3b
+
+Basic full tune
+```shell
+accelerate launch scripts/finetune.py examples/openllama-3b/config.yml
+```
+
+LoRA
+```shell
+accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
+```
+
+QLoRA
+```shell
+accelerate launch scripts/finetune.py examples/openllama-3b/qlora.yml
+```
diff --git a/examples/openllama-3b/config.yml b/examples/openllama-3b/config.yml
new file mode 100644
index 000000000..6fd704ffc
--- /dev/null
+++ b/examples/openllama-3b/config.yml
@@ -0,0 +1,61 @@
+base_model: openlm-research/open_llama_3b
+base_model_config: openlm-research/open_llama_3b
+model_type: LlamaForCausalLM
+tokenizer_type: LlamaTokenizer
+load_in_8bit: false
+load_in_4bit: false
+strict: false
+push_dataset_to_hub:
+datasets:
+ - path: teknium/GPT4-LLM-Cleaned
+ type: alpaca
+dataset_prepared_path: last_run_prepared
+val_set_size: 0.02
+adapter:
+lora_model_dir:
+sequence_len: 256
+max_packed_sequence_len:
+lora_r:
+lora_alpha:
+lora_dropout:
+lora_target_modules:
+lora_target_linear:
+lora_fan_in_fan_out:
+wandb_project:
+wandb_watch:
+wandb_run_id:
+wandb_log_model:
+output_dir: ./openllama-out
+batch_size: 16
+micro_batch_size: 4
+num_epochs: 3
+optimizer: adamw_bnb_8bit
+torchdistx_path:
+lr_scheduler: cosine
+learning_rate: 0.0002
+train_on_inputs: false
+group_by_length: false
+bf16: false
+fp16: true
+tf32: false
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention: true
+flash_attention:
+gptq_groupsize:
+gptq_model_v1:
+warmup_steps: 10
+eval_steps: 50
+save_steps:
+debug:
+deepspeed:
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+ bos_token: ""
+ eos_token: ""
+ unk_token: ""
diff --git a/examples/lora-openllama-3b/config.yml b/examples/openllama-3b/lora.yml
similarity index 89%
rename from examples/lora-openllama-3b/config.yml
rename to examples/openllama-3b/lora.yml
index 2e1644546..d1f252455 100644
--- a/examples/lora-openllama-3b/config.yml
+++ b/examples/openllama-3b/lora.yml
@@ -1,5 +1,5 @@
-base_model: openlm-research/open_llama_3b_600bt_preview
-base_model_config: openlm-research/open_llama_3b_600bt_preview
+base_model: openlm-research/open_llama_3b
+base_model_config: openlm-research/open_llama_3b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
@@ -49,7 +49,7 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
-xformers_attention:
+xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
diff --git a/examples/qlora-openllama-3b/config.yml b/examples/openllama-3b/qlora.yml
similarity index 90%
rename from examples/qlora-openllama-3b/config.yml
rename to examples/openllama-3b/qlora.yml
index 87e1dfd94..83ae31f91 100644
--- a/examples/qlora-openllama-3b/config.yml
+++ b/examples/openllama-3b/qlora.yml
@@ -1,5 +1,5 @@
-base_model: openlm-research/open_llama_3b_600bt_preview
-base_model_config: openlm-research/open_llama_3b_600bt_preview
+base_model: openlm-research/open_llama_3b
+base_model_config: openlm-research/open_llama_3b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
diff --git a/configs/pythia_1_2B_alpaca.yml b/examples/pythia/lora.yml
similarity index 56%
rename from configs/pythia_1_2B_alpaca.yml
rename to examples/pythia/lora.yml
index 52ed58cb5..e2b28f218 100644
--- a/configs/pythia_1_2B_alpaca.yml
+++ b/examples/pythia/lora.yml
@@ -1,36 +1,29 @@
base_model: EleutherAI/pythia-1.4b-deduped
-model_type: GPTNeoXForCausalLM
-tokenizer_type: AutoTokenizer
+base_model_config: EleutherAI/pythia-1.4b-deduped
load_in_8bit: true
datasets:
- - path: data/alpaca_data_gpt4.jsonl
+ - path: teknium/GPT4-LLM-Cleaned
type: alpaca
- - path: data/vicuna_cleaned.jsonl
- type: sharegpt
- - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
- type: gpteacher
- - path: data/roleplay-similarity_0.6-instruct-dataset.jsonl
- type: gpteacher
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
adapter: lora
lora_model_dir:
-sequence_len: 2048
-lora_r: 8
+sequence_len: 512
+lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- query_key_value
-# - xxx
+lora_target_linear:
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
-wandb_project: pythia-1.4b-lora
+wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
-output_dir: ./lora-alpaca
+output_dir: ./lora-alpaca-pythia
gradient_accumulation_steps: 1
micro_batch_size: 4
-num_epochs: 5
+num_epochs: 3
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
@@ -39,3 +32,6 @@ tf32: True
early_stopping_patience:
resume_from_checkpoint:
local_rank:
+weight_decay: 0.1
+eval_steps: 20
+logging_steps: 1
diff --git a/examples/qlora-openllama-3b/README.md b/examples/qlora-openllama-3b/README.md
deleted file mode 100644
index d79ea7f3f..000000000
--- a/examples/qlora-openllama-3b/README.md
+++ /dev/null
@@ -1,6 +0,0 @@
-# qlora-openllama-3b
-
-```shell
-accelerate launch scripts/finetune.py examples/qlora-openllama-3b/config.yml
-
-```
diff --git a/scripts/finetune.py b/scripts/finetune.py
index 2f6bef3ef..8f49cfba5 100644
--- a/scripts/finetune.py
+++ b/scripts/finetune.py
@@ -72,7 +72,19 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
- prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
+ prompter_module = None
+ if prompter:
+ prompter_module = getattr(
+ importlib.import_module("axolotl.prompters"), prompter
+ )
+
+ if cfg.landmark_attention:
+ from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
+
+ set_model_mem_id(model, tokenizer)
+ model.set_mem_cache_args(
+ max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
+ )
while True:
print("=" * 80)
@@ -80,10 +92,14 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
instruction = get_multi_line_input()
if not instruction:
return
- prompt: str = next(
- prompter_module().build_prompt(instruction=instruction.strip("\n"))
- )
+ if prompter_module:
+ prompt: str = next(
+ prompter_module().build_prompt(instruction=instruction.strip("\n"))
+ )
+ else:
+ prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
+
print("=" * 40)
model.eval()
with torch.no_grad():
@@ -159,7 +175,7 @@ def train(
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
- if k in cfg_keys or cfg.strict is False:
+ if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
@@ -199,8 +215,8 @@ def train(
logging.info(f"loading tokenizer... {tokenizer_config}")
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
- if check_not_in(
- ["inference", "shard", "merge_lora"], kwargs
+ if (
+ check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
): # don't need to load dataset for these
if not cfg.pretraining_dataset:
train_dataset, eval_dataset = load_prepare_datasets(
@@ -239,7 +255,6 @@ def train(
tokenizer,
cfg,
adapter=cfg.adapter,
- inference=("inference" in kwargs),
)
if "merge_lora" in kwargs and cfg.adapter is not None:
@@ -252,9 +267,15 @@ def train(
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return
- if "inference" in kwargs:
+ if cfg.inference:
logging.info("calling do_inference function")
- do_inference(cfg, model, tokenizer)
+ inf_kwargs: Dict[str, Any] = {}
+ if "prompter" in kwargs:
+ if kwargs["prompter"] == "None":
+ inf_kwargs["prompter"] = None
+ else:
+ inf_kwargs["prompter"] = kwargs["prompter"]
+ do_inference(cfg, model, tokenizer, **inf_kwargs)
return
if "shard" in kwargs:
diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py
index d6367ce7c..40c58bc9c 100644
--- a/src/axolotl/datasets.py
+++ b/src/axolotl/datasets.py
@@ -33,12 +33,16 @@ class TokenizedPromptDataset(IterableDataset):
def __iter__(self):
iterator = iter(self.dataset)
+ count = 0
# Loop through the entire dataset
for example in iterator:
try:
yield self.prompt_tokenizer.tokenize_prompt(example)
+ count += 1
except InvalidDataException:
pass
+ if count == 0:
+ raise RuntimeError("Expected at least one datapoint in dataset.")
# TODO this isn't the best since it can't interleave datasets
diff --git a/src/axolotl/monkeypatch/llama_landmark_attn.py b/src/axolotl/monkeypatch/llama_landmark_attn.py
index 18e913f09..2a4cdbc36 100644
--- a/src/axolotl/monkeypatch/llama_landmark_attn.py
+++ b/src/axolotl/monkeypatch/llama_landmark_attn.py
@@ -28,15 +28,24 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
-from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
-from transformers.activations import ACT2FN
+from torch.nn import CrossEntropyLoss
+from transformers import LlamaTokenizer
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
- SequenceClassifierOutputWithPast,
)
-from transformers.modeling_utils import PreTrainedModel
from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import (
+ LLAMA_INPUTS_DOCSTRING,
+ LLAMA_START_DOCSTRING,
+ LlamaMLP,
+ LlamaPreTrainedModel,
+ LlamaRMSNorm,
+ LlamaRotaryEmbedding,
+ _expand_mask,
+ _make_causal_mask,
+ rotate_half,
+)
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
@@ -51,131 +60,6 @@ _CONFIG_FOR_DOC = "LlamaConfig"
MEM_TOKEN = "" # nosec
-# Copied from transformers.models.bart.modeling_bart._make_causal_mask
-def _make_causal_mask(
- input_ids_shape: torch.Size,
- dtype: torch.dtype,
- device: torch.device,
- past_key_values_length: int = 0,
-):
- """
- Make causal mask used for bi-directional self-attention.
- """
- bsz, tgt_len = input_ids_shape
- mask = torch.full(
- (tgt_len, tgt_len),
- torch.tensor(torch.finfo(dtype).min, device=device),
- device=device,
- )
- mask_cond = torch.arange(mask.size(-1), device=device)
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
- mask = mask.to(dtype)
-
- if past_key_values_length > 0:
- mask = torch.cat(
- [
- torch.zeros(
- tgt_len, past_key_values_length, dtype=dtype, device=device
- ),
- mask,
- ],
- dim=-1,
- )
- return mask[None, None, :, :].expand(
- bsz, 1, tgt_len, tgt_len + past_key_values_length
- )
-
-
-# Copied from transformers.models.bart.modeling_bart._expand_mask
-def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
- """
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
- """
- bsz, src_len = mask.size()
- tgt_len = tgt_len if tgt_len is not None else src_len
-
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
-
- inverted_mask = 1.0 - expanded_mask
-
- return inverted_mask.masked_fill(
- inverted_mask.to(torch.bool), torch.finfo(dtype).min
- )
-
-
-class LlamaRMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
-
- # convert into half-precision if necessary
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
- hidden_states = hidden_states.to(self.weight.dtype)
-
- return self.weight * hidden_states
-
-
-class LlamaRotaryEmbedding(torch.nn.Module):
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
- super().__init__()
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
- self.register_buffer("inv_freq", inv_freq)
-
- # Build here to make `torch.jit.trace` work.
- self.max_seq_len_cached = max_position_embeddings
- t = torch.arange(
- self.max_seq_len_cached,
- device=self.inv_freq.device,
- dtype=self.inv_freq.dtype,
- )
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1)
- self.register_buffer(
- "cos_cached", emb.cos()[None, None, :, :], persistent=False
- )
- self.register_buffer(
- "sin_cached", emb.sin()[None, None, :, :], persistent=False
- )
-
- def forward(self, x, seq_len=None):
- # x: [bs, num_attention_heads, seq_len, head_size]
- # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
- if seq_len > self.max_seq_len_cached:
- self.max_seq_len_cached = seq_len
- t = torch.arange(
- self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
- )
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
- self.register_buffer(
- "cos_cached", emb.cos()[None, None, :, :], persistent=False
- )
- self.register_buffer(
- "sin_cached", emb.sin()[None, None, :, :], persistent=False
- )
- return (
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
- )
-
-
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
@@ -190,24 +74,11 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
return q_embed, k_embed
-class LlamaMLP(nn.Module):
- def __init__(
- self,
- hidden_size: int,
- intermediate_size: int,
- hidden_act: str,
- ):
- super().__init__()
- self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
- self.act_fn = ACT2FN[hidden_act]
-
- def forward(self, x):
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
-
-
class LandmarkGroupedSoftmaxFunction(torch.autograd.Function):
+ """
+ Landmark grouped softmax function.
+ """
+
# Note that forward, setup_context, and backward are @staticmethods
@staticmethod
def forward(ctx, x, dim, mem_cnt, resp_mem_idx):
@@ -682,16 +553,14 @@ class LlamaAttention(nn.Module):
# upcast attention to fp32
if is_mem is None:
raise ValueError("Don't use this without landmarks")
- # attn_weights = nn.functional.softmax(
- # attn_weights, dim=-1, dtype=torch.float32
- # ).to(query_states.dtype)
- else:
- attn_weights = landmark_grouped_softmax(
- attn_weights,
- dim=-1,
- is_mem=is_mem.expand(-1, self.num_heads, -1, -1),
- last_section_mask=last_section_mask,
- ).to(query_states.dtype)
+
+ attn_weights = landmark_grouped_softmax(
+ attn_weights,
+ dim=-1,
+ is_mem=is_mem.expand(-1, self.num_heads, -1, -1),
+ last_section_mask=last_section_mask,
+ ).to(query_states.dtype)
+
if attn_prefix is not None:
attn_prefix, attn_weights = torch.split(
attn_weights,
@@ -722,6 +591,10 @@ class LlamaAttention(nn.Module):
class LlamaDecoderLayer(nn.Module):
+ """
+ Llama Decoder layer
+ """
+
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
@@ -802,114 +675,6 @@ class LlamaDecoderLayer(nn.Module):
return outputs
-LLAMA_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
-
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
-
- Parameters:
- config ([`LlamaConfig`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-
-@add_start_docstrings(
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
- LLAMA_START_DOCSTRING,
-)
-class LlamaPreTrainedModel(PreTrainedModel):
- config_class = LlamaConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["LlamaDecoderLayer"]
- _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
-
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, LlamaModel):
- module.gradient_checkpointing = value
-
-
-LLAMA_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
- `past_key_values`).
-
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
-
- [What are position IDs?](../glossary#position-ids)
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
-
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
-
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
-"""
-
-
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
@@ -1178,6 +943,10 @@ class LlamaModel(LlamaPreTrainedModel):
class LlamaForCausalLM(LlamaPreTrainedModel):
+ """
+ Llama model with a causal language modeling head.
+ """
+
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
@@ -1448,148 +1217,33 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
return reordered_past
-@add_start_docstrings(
- """
- The LLaMa Model transformer with a sequence classification head on top (linear layer).
-
- [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
- (e.g. GPT-2) do.
-
- Since it does classification on the last token, it requires to know the position of the last token. If a
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
- each row of the batch).
- """,
- LLAMA_START_DOCSTRING,
-)
-class LlamaForSequenceClassification(LlamaPreTrainedModel):
- _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
-
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.model = LlamaModel(config)
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.model.embed_tokens
-
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
-
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = (
- return_dict if return_dict is not None else self.config.use_return_dict
- )
-
- transformer_outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = transformer_outputs[0]
- logits = self.score(hidden_states)
-
- if input_ids is not None:
- batch_size = input_ids.shape[0]
- else:
- batch_size = inputs_embeds.shape[0]
-
- if self.config.pad_token_id is None and batch_size != 1:
- raise ValueError(
- "Cannot handle batch sizes > 1 if no padding token is defined."
- )
- if self.config.pad_token_id is None:
- sequence_lengths = -1
- else:
- if input_ids is not None:
- sequence_lengths = (
- torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
- ).to(logits.device)
- else:
- sequence_lengths = -1
-
- pooled_logits = logits[
- torch.arange(batch_size, device=logits.device), sequence_lengths
- ]
-
- loss = None
- if labels is not None:
- labels = labels.to(logits.device)
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (
- labels.dtype == torch.long or labels.dtype == torch.int
- ):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
-
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(pooled_logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(
- pooled_logits.view(-1, self.num_labels), labels.view(-1)
- )
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(pooled_logits, labels)
- if not return_dict:
- output = (pooled_logits,) + transformer_outputs[1:]
- return ((loss,) + output) if loss is not None else output
-
- return SequenceClassifierOutputWithPast(
- loss=loss,
- logits=pooled_logits,
- past_key_values=transformer_outputs.past_key_values,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
-
-
def add_mem_tokens(example, mem_freq, mem_id):
- x = example["input_ids"]
+ ids = example["input_ids"]
ret = []
prev_idx = 0
- for t_idx in range(mem_freq, len(x), mem_freq):
- ret.extend(x[prev_idx:t_idx])
+ for t_idx in range(mem_freq, len(ids), mem_freq):
+ ret.extend(ids[prev_idx:t_idx])
ret.append(mem_id)
prev_idx = t_idx
- ret.extend(x[prev_idx:])
+ ret.extend(ids[prev_idx:])
# drop attention_mask
return {"input_ids": ret}
+
+
+def patch_llama_with_landmark_attn():
+ import transformers
+
+ transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
+ transformers.models.llama.modeling_llama.LlamaModel = LlamaModel
+ transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
+ transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
+ transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
+
+
+def set_model_mem_id(model: LlamaForCausalLM, tokenizer: LlamaTokenizer):
+ mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
+ model.set_mem_id(mem_id)
+
+
+def get_mem_id(tokenizer: LlamaTokenizer):
+ return tokenizer.convert_tokens_to_ids(MEM_TOKEN)
diff --git a/src/axolotl/prompt_strategies/sharegpt_jokes.py b/src/axolotl/prompt_strategies/sharegpt_jokes.py
new file mode 100644
index 000000000..ac424bf7c
--- /dev/null
+++ b/src/axolotl/prompt_strategies/sharegpt_jokes.py
@@ -0,0 +1,28 @@
+"""Module for Jokes prompts using sharegpt style """
+from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
+from axolotl.prompters import PromptStyle, ShareGPTPrompter
+
+
+def load(tokenizer, cfg):
+ return SimpleJokesShareGPTPromptTokenizingStrategy(
+ ShareGPTPrompter(PromptStyle.CHAT.value),
+ tokenizer,
+ cfg.train_on_inputs,
+ cfg.sequence_len,
+ )
+
+
+class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
+ """
+ Tokenization strategy for asking bot to tell a joke and then explain why its funny
+ """
+
+ # title, text, explanation
+ def get_conversation_thread(self, prompt):
+ title = "" if not prompt["title"] else prompt["title"] + " "
+ return [
+ {"from": "human", "value": "Tell me a joke."},
+ {"from": "gpt", "value": title + prompt["text"]},
+ {"from": "human", "value": "Why is that joke funny?"},
+ {"from": "gpt", "value": prompt["explanation"]},
+ ]
diff --git a/src/axolotl/prompt_strategies/sharegpt_simple.py b/src/axolotl/prompt_strategies/sharegpt_simple.py
index 4346663f2..bfe0d164b 100644
--- a/src/axolotl/prompt_strategies/sharegpt_simple.py
+++ b/src/axolotl/prompt_strategies/sharegpt_simple.py
@@ -13,6 +13,15 @@ def load(tokenizer, cfg):
)
+def load_role(tokenizer, cfg):
+ return SimpleRoleShareGPTPromptTokenizingStrategy(
+ ShareGPTPrompter(PromptStyle.CHAT.value),
+ tokenizer,
+ cfg.train_on_inputs,
+ cfg.sequence_len,
+ )
+
+
def load_guanaco(tokenizer, cfg):
return GuanacoShareGPTPromptTokenizingStrategy(
ShareGPTPrompter(PromptStyle.CHAT.value),
@@ -31,6 +40,18 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
return prompt["conversations"]
+class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
+ """
+ basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
+ """
+
+ def get_conversation_thread(self, prompt):
+ conversations = prompt["conversations"]
+ # remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
+ turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
+ return turns
+
+
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps oasst data to sharegpt format
diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py
index 39c74023b..29cc4446b 100644
--- a/src/axolotl/prompters.py
+++ b/src/axolotl/prompters.py
@@ -261,28 +261,33 @@ class Conversation:
self.messages.append([role, message])
-conv_vicuna_v1_1 = Conversation(
- system="A chat between a curious user and an artificial intelligence assistant. "
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
- roles=["USER", "ASSISTANT"],
- messages=[],
- offset=0,
- sep_style=SeparatorStyle.TWO,
- sep=" ",
- sep2=" ",
-)
-
-
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
"""
A prompter that generates prompts for the ShareGPT
"""
- def __init__(self, prompt_style=None):
+ def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
if prompt_style != PromptStyle.CHAT.value:
raise ValueError(
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
)
+ system: str = (
+ system_prompt
+ if system_prompt
+ else (
+ "A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
+ )
+ )
+ self._conversation = Conversation(
+ system=system,
+ roles=["USER", "ASSISTANT"],
+ messages=[],
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2=" ",
+ )
# def match_prompt_style(self):
# if self.prompt_style == PromptStyle.chat.value:
@@ -300,7 +305,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
# also happens on the data splitting leaving empty conversations
raise IndexError
- conv = conv_vicuna_v1_1.copy()
+ conv = self._conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
try:
diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py
index 058c24bcd..c36bfcee9 100644
--- a/src/axolotl/utils/data.py
+++ b/src/axolotl/utils/data.py
@@ -240,8 +240,15 @@ def load_tokenized_prepared_datasets(
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
else:
- logging.error(f"unhandled prompt tokenization strategy: {d.type}")
- raise ValueError(f"unhandled prompt tokenization strategy: {d.type}")
+ suffix = ""
+ if ":load_" in d.type:
+ suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
+ logging.error(
+ f"unhandled prompt tokenization strategy: {d.type}. {suffix}"
+ )
+ raise ValueError(
+ f"unhandled prompt tokenization strategy: {d.type} {suffix}"
+ )
logging.info("tokenizing, merging, and shuffling master dataset")
samples: List[int] = []
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index 532fa5518..05acfce93 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -20,15 +20,6 @@ from transformers import (
LlamaConfig,
)
-try:
- from transformers import ( # pylint: disable=unused-import # noqa: F401
- LlamaForCausalLM,
- )
-except ImportError:
- logging.warning(
- "This version of transformers does not support Llama. Consider upgrading."
- )
-
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
if TYPE_CHECKING:
@@ -78,15 +69,9 @@ def load_tokenizer(
def load_model(
- base_model,
- base_model_config,
- model_type,
- tokenizer,
- cfg,
- adapter="lora",
- inference=False,
+ base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
):
- # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
+ # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
"""
Load a model from a base model and a model type.
"""
@@ -98,7 +83,7 @@ def load_model(
)
if cfg.is_llama_derived_model and cfg.flash_attention:
- if cfg.device not in ["mps", "cpu"] and inference is False:
+ if cfg.device not in ["mps", "cpu"] and not cfg.inference:
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
logging.info("patching with flash attention")
@@ -118,14 +103,15 @@ def load_model(
logging.info("patching with sdp attention")
hijack_llama_sdp_attention()
elif cfg.is_llama_derived_model and cfg.landmark_attention:
- from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
+ from axolotl.monkeypatch.llama_landmark_attn import (
MEM_TOKEN,
- LlamaForCausalLM,
+ patch_llama_with_landmark_attn,
)
logging.info("patching with landmark attention")
+ patch_llama_with_landmark_attn()
- # TODO: Check if this would overwrite previous additional_special_tokens
+ # Note: This might overwrite previous additional_special_tokens
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
if cfg.is_llama_derived_model and cfg.xpos_rope:
@@ -210,7 +196,9 @@ def load_model(
else True,
)
load_in_8bit = False
- elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
+ elif cfg.is_llama_derived_model:
+ from transformers import LlamaForCausalLM
+
config = LlamaConfig.from_pretrained(base_model_config)
model = LlamaForCausalLM.from_pretrained(
base_model,
@@ -314,7 +302,9 @@ def load_model(
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
):
logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
- model = prepare_model_for_kbit_training(model)
+ model = prepare_model_for_kbit_training(
+ model, use_gradient_checkpointing=cfg.gradient_checkpointing
+ )
model, lora_config = load_adapter(model, cfg, adapter)
@@ -387,7 +377,6 @@ def load_llama_adapter(model, cfg):
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
- device_map=cfg.device_map,
torch_dtype=torch.float16,
)
else:
@@ -449,8 +438,7 @@ def load_lora(model, cfg):
model = PeftModel.from_pretrained(
model,
cfg.lora_model_dir,
- device_map=cfg.device_map,
- # torch_dtype=torch.float16,
+ is_trainable=not cfg.inference,
)
else:
model = get_peft_model(model, lora_config)
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index b7823fea4..5152e649b 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -245,16 +245,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.is_llama_derived_model and cfg.landmark_attention:
from functools import partial
- from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens
+ from axolotl.monkeypatch.llama_landmark_attn import (
+ add_mem_tokens,
+ get_mem_id,
+ set_model_mem_id,
+ )
- mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
- model.set_mem_id(mem_id)
+ set_model_mem_id(model, tokenizer)
logging.info("Adding landmark attention tokens to dataset")
for dataset in [train_dataset, eval_dataset]:
dataset = dataset.map(
- partial(add_mem_tokens, mem_freq=50, mem_id=mem_id),
+ partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
batched=False,
num_proc=32,
)
diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py
index 603afbfee..298d36c4e 100644
--- a/src/axolotl/utils/validation.py
+++ b/src/axolotl/utils/validation.py
@@ -59,6 +59,11 @@ def validate_config(cfg):
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
raise ValueError("FSDP is not supported for falcon models")
+ if (
+ cfg.base_model and "mpt" in cfg.base_model.lower()
+ ) and cfg.gradient_checkpointing:
+ raise ValueError("gradient_checkpointing is not supported for MPT models")
+
if cfg.flash_optimum is True:
if cfg.adapter:
logging.warning(
diff --git a/tests/test_validation.py b/tests/test_validation.py
index 575392ab4..dba54586e 100644
--- a/tests/test_validation.py
+++ b/tests/test_validation.py
@@ -199,6 +199,20 @@ class ValidationTest(unittest.TestCase):
validate_config(cfg)
+ def test_mpt_gradient_checkpointing(self):
+ regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
+
+ # Check for lower-case
+ cfg = DictDefault(
+ {
+ "base_model": "mosaicml/mpt-7b",
+ "gradient_checkpointing": True,
+ }
+ )
+
+ with pytest.raises(ValueError, match=regex_exp):
+ validate_config(cfg)
+
def test_flash_optimum(self):
cfg = DictDefault(
{