Compare commits
1 Commits
refactor-f
...
fp8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8836986a92 |
1
.github/workflows/tests.yml
vendored
1
.github/workflows/tests.yml
vendored
@@ -71,7 +71,6 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
|
||||||
pip3 uninstall -y transformers accelerate
|
pip3 uninstall -y transformers accelerate
|
||||||
pip3 install -U -e .[flash-attn]
|
pip3 install -U -e .[flash-attn]
|
||||||
pip3 install -r requirements-tests.txt
|
pip3 install -r requirements-tests.txt
|
||||||
|
|||||||
68
README.md
68
README.md
@@ -25,10 +25,8 @@ Features:
|
|||||||
- [Installation](#installation)
|
- [Installation](#installation)
|
||||||
- [Docker](#docker)
|
- [Docker](#docker)
|
||||||
- [Conda/Pip venv](#condapip-venv)
|
- [Conda/Pip venv](#condapip-venv)
|
||||||
- [Runpod](#runpod)
|
|
||||||
- [LambdaLabs](#lambdalabs)
|
- [LambdaLabs](#lambdalabs)
|
||||||
- [Windows](#windows)
|
- [Windows](#windows)
|
||||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
|
||||||
- [Dataset](#dataset)
|
- [Dataset](#dataset)
|
||||||
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
||||||
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
||||||
@@ -77,7 +75,6 @@ Features:
|
|||||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
|
||||||
|
|
||||||
|
|
||||||
## Quickstart ⚡
|
## Quickstart ⚡
|
||||||
@@ -86,19 +83,14 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
|||||||
|
|
||||||
**Requirements**: Python >=3.9 and Pytorch >=2.0.
|
**Requirements**: Python >=3.9 and Pytorch >=2.0.
|
||||||
|
|
||||||
`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
|
|
||||||
|
|
||||||
### For developers
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging
|
pip3 install packaging
|
||||||
pip3 install -e '.[flash-attn,deepspeed]'
|
pip3 install -e '.[flash-attn,deepspeed]'
|
||||||
```
|
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||||
|
|
||||||
### Usage
|
|
||||||
```bash
|
|
||||||
# finetune lora
|
# finetune lora
|
||||||
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||||
|
|
||||||
@@ -119,6 +111,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
```bash
|
```bash
|
||||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||||
```
|
```
|
||||||
|
- `winglian/axolotl-runpod:main-latest`: for runpod or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||||
|
|
||||||
Or run on the current files for development:
|
Or run on the current files for development:
|
||||||
|
|
||||||
@@ -133,15 +126,13 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
A more powerful Docker command to run would be this:
|
A more powerful Docker command to run would be this:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
docker run --gpus '"all"' --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||||
```
|
```
|
||||||
|
|
||||||
It additionally:
|
It additionally:
|
||||||
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
|
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
|
||||||
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
|
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
|
||||||
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
|
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
|
||||||
* The `--privileged` flag gives all capabilities to the container.
|
|
||||||
* The `--shm-size 10g` argument increases the shared memory size. Use this if you see `exitcode: -7` errors using deepspeed.
|
|
||||||
|
|
||||||
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
|
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
|
||||||
|
|
||||||
@@ -163,10 +154,6 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
```
|
```
|
||||||
Get the token at huggingface.co/settings/tokens
|
Get the token at huggingface.co/settings/tokens
|
||||||
|
|
||||||
#### Runpod
|
|
||||||
|
|
||||||
Use `winglian/axolotl-runpod:main-latest` or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
|
||||||
|
|
||||||
#### LambdaLabs
|
#### LambdaLabs
|
||||||
<details>
|
<details>
|
||||||
|
|
||||||
@@ -214,28 +201,6 @@ Use `winglian/axolotl-runpod:main-latest` or use this [direct link](https://runp
|
|||||||
#### Windows
|
#### Windows
|
||||||
Please use WSL or Docker!
|
Please use WSL or Docker!
|
||||||
|
|
||||||
|
|
||||||
#### Launching on public clouds via SkyPilot
|
|
||||||
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
|
|
||||||
```bash
|
|
||||||
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
|
|
||||||
sky check
|
|
||||||
```
|
|
||||||
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
|
|
||||||
```
|
|
||||||
git clone https://github.com/skypilot-org/skypilot.git
|
|
||||||
cd skypilot/llm/axolotl
|
|
||||||
```
|
|
||||||
Use one command to launch:
|
|
||||||
```bash
|
|
||||||
# On-demand
|
|
||||||
HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
|
|
||||||
|
|
||||||
# Managed spot (auto-recovery on preemption)
|
|
||||||
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### Dataset
|
### Dataset
|
||||||
|
|
||||||
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
|
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
|
||||||
@@ -432,12 +397,6 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
- path: knowrohit07/know_sql
|
- path: knowrohit07/know_sql
|
||||||
type: context_qa.load_v2
|
type: context_qa.load_v2
|
||||||
train_on_split: validation
|
train_on_split: validation
|
||||||
|
|
||||||
# loading from s3 or gcs
|
|
||||||
# s3 creds will be loaded from the system default and gcs only supports public access
|
|
||||||
dataset:
|
|
||||||
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
|
|
||||||
...
|
|
||||||
```
|
```
|
||||||
|
|
||||||
- loading
|
- loading
|
||||||
@@ -500,15 +459,6 @@ is_falcon_derived_model:
|
|||||||
is_llama_derived_model:
|
is_llama_derived_model:
|
||||||
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
||||||
is_mistral_derived_model:
|
is_mistral_derived_model:
|
||||||
is_qwen_derived_model:
|
|
||||||
|
|
||||||
# optional overrides to the base model configuration
|
|
||||||
model_config:
|
|
||||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
|
||||||
rope_scaling:
|
|
||||||
type: # linear | dynamic
|
|
||||||
factor: # float
|
|
||||||
|
|
||||||
|
|
||||||
# Whether you are training a 4-bit GPTQ quantized model
|
# Whether you are training a 4-bit GPTQ quantized model
|
||||||
gptq: true
|
gptq: true
|
||||||
@@ -533,7 +483,7 @@ float16: true
|
|||||||
|
|
||||||
# A list of one or more datasets to finetune the model with
|
# A list of one or more datasets to finetune the model with
|
||||||
datasets:
|
datasets:
|
||||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
# HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files
|
||||||
- path: vicgalle/alpaca-gpt4
|
- path: vicgalle/alpaca-gpt4
|
||||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||||
@@ -541,12 +491,9 @@ datasets:
|
|||||||
data_files: # Optional[str] path to source data files
|
data_files: # Optional[str] path to source data files
|
||||||
shards: # Optional[int] number of shards to split data into
|
shards: # Optional[int] number of shards to split data into
|
||||||
name: # Optional[str] name of dataset configuration to load
|
name: # Optional[str] name of dataset configuration to load
|
||||||
train_on_split: train # Optional[str] name of dataset split to load from
|
|
||||||
|
|
||||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
field_human: # Optional[str]. Human key to use for conversation.
|
|
||||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
|
||||||
|
|
||||||
# Custom user prompt
|
# Custom user prompt
|
||||||
- path: repo
|
- path: repo
|
||||||
@@ -677,8 +624,7 @@ gradient_accumulation_steps: 1
|
|||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
eval_batch_size:
|
eval_batch_size:
|
||||||
num_epochs: 4
|
num_epochs: 4
|
||||||
warmup_steps: 100 # cannot use with warmup_ratio
|
warmup_steps: 100
|
||||||
warmup_ratio: 0.05 # cannot use with warmup_steps
|
|
||||||
learning_rate: 0.00003
|
learning_rate: 0.00003
|
||||||
lr_quadratic_warmup:
|
lr_quadratic_warmup:
|
||||||
logging_steps:
|
logging_steps:
|
||||||
@@ -780,6 +726,10 @@ landmark_attention:
|
|||||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||||
# LLaMA only
|
# LLaMA only
|
||||||
xpos_rope:
|
xpos_rope:
|
||||||
|
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||||
|
rope_scaling:
|
||||||
|
type: # linear | dynamic
|
||||||
|
factor: # float
|
||||||
|
|
||||||
# Resume from a specific checkpoint dir
|
# Resume from a specific checkpoint dir
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T
|
base_model: PY007/TinyLlama-1.1B-step-50K-105b
|
||||||
|
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
base_model: microsoft/phi-1_5
|
base_model: microsoft/phi-1_5
|
||||||
model_type: PhiForCausalLM
|
model_type: MixFormerSequentialForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
is_llama_derived_model: false
|
is_llama_derived_model: false
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
base_model: Qwen/Qwen-7B
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
|
|
||||||
is_qwen_derived_model: true
|
|
||||||
trust_remote_code: true
|
|
||||||
|
|
||||||
load_in_8bit: true
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.05
|
|
||||||
output_dir: ./lora-out
|
|
||||||
|
|
||||||
sequence_len: 2048 # supports up to 8192
|
|
||||||
sample_packing: false
|
|
||||||
pad_to_sequence_len:
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_run_id:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16: false
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: false
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention:
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
eval_steps: 0.05
|
|
||||||
eval_table_size:
|
|
||||||
eval_table_max_new_tokens: 128
|
|
||||||
save_steps:
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
base_model: Qwen/Qwen-7B
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
|
|
||||||
is_qwen_derived_model: true
|
|
||||||
trust_remote_code: true
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path:
|
|
||||||
val_set_size: 0.05
|
|
||||||
output_dir: ./lora-out
|
|
||||||
|
|
||||||
sequence_len: 2048 # supports up to 8192
|
|
||||||
sample_packing: false
|
|
||||||
pad_to_sequence_len:
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_fan_in_fan_out:
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_run_id:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16: false
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: false
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention:
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
eval_steps: 0.05
|
|
||||||
eval_table_size:
|
|
||||||
eval_table_max_new_tokens: 128
|
|
||||||
save_steps:
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.0
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,21 +1,22 @@
|
|||||||
|
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
auto-gptq==0.5.1
|
torch==2.0.1
|
||||||
|
auto-gptq==0.4.2
|
||||||
packaging
|
packaging
|
||||||
peft==0.6.0
|
peft==0.6.0
|
||||||
transformers==4.35.2
|
transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697
|
||||||
tokenizers==0.15.0
|
|
||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate==0.24.1
|
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
|
||||||
deepspeed
|
deepspeed
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
datasets>=2.15.0
|
datasets
|
||||||
flash-attn==2.3.3
|
flash-attn>=2.3.0
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
xformers==0.0.22
|
xformers>=0.0.22
|
||||||
optimum==1.13.2
|
optimum==1.13.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
colorama
|
colorama
|
||||||
@@ -30,10 +31,4 @@ scikit-learn==1.2.2
|
|||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
fschat==0.2.29
|
fschat==0.2.29
|
||||||
gradio==3.50.2
|
gradio
|
||||||
tensorboard
|
|
||||||
|
|
||||||
# remote filesystems
|
|
||||||
s3fs
|
|
||||||
gcsfs
|
|
||||||
# adlfs
|
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.trainer import prepare_optim_env
|
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
@@ -47,7 +46,7 @@ def print_axolotl_text_art(suffix=None):
|
|||||||
ascii_text = " axolotl"
|
ascii_text = " axolotl"
|
||||||
if suffix:
|
if suffix:
|
||||||
ascii_text += f" x {suffix}"
|
ascii_text += f" x {suffix}"
|
||||||
ascii_art = text2art(ascii_text, font=font)
|
ascii_art = text2art(" axolotl", font=font)
|
||||||
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
print(ascii_art)
|
print(ascii_art)
|
||||||
@@ -72,7 +71,7 @@ def do_merge_lora(
|
|||||||
|
|
||||||
LOG.info("running merge of LoRA with base model")
|
LOG.info("running merge of LoRA with base model")
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
model.to(dtype=cfg.torch_dtype)
|
model.to(dtype=torch.float16)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
||||||
@@ -297,8 +296,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
prepare_optim_env(cfg)
|
|
||||||
|
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|
||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from abc import abstractmethod
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
@@ -31,6 +31,7 @@ from axolotl.utils.callbacks import (
|
|||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
|
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
|
||||||
|
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler
|
from axolotl.utils.samplers import MultipackBatchSampler
|
||||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||||
|
|
||||||
@@ -214,7 +215,9 @@ class AxolotlTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
return super().get_train_dataloader()
|
return super().get_train_dataloader()
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
def get_eval_dataloader(
|
||||||
|
self, eval_dataset: Optional[Dataset] = None
|
||||||
|
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||||
eval_dataset = (
|
eval_dataset = (
|
||||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
@@ -257,7 +260,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
def get_bench_dataloader(
|
def get_bench_dataloader(
|
||||||
self,
|
self,
|
||||||
bench_dataset: Dataset,
|
bench_dataset: Dataset,
|
||||||
) -> DataLoader:
|
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
dataloader_params = {
|
dataloader_params = {
|
||||||
"batch_size": self.args.eval_batch_size,
|
"batch_size": self.args.eval_batch_size,
|
||||||
"collate_fn": self.bench_data_collator,
|
"collate_fn": self.bench_data_collator,
|
||||||
@@ -461,14 +464,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
warmup_steps = None
|
warmup_steps = (
|
||||||
if self.cfg.warmup_steps is not None:
|
self.cfg.warmup_steps
|
||||||
warmup_steps = self.cfg.warmup_steps
|
if self.cfg.warmup_steps is not None
|
||||||
elif self.cfg.warmup_ratio is not None:
|
else min(int(0.03 * total_num_steps), 100)
|
||||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
)
|
||||||
else:
|
|
||||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
|
||||||
|
|
||||||
logging_steps = (
|
logging_steps = (
|
||||||
self.cfg.logging_steps
|
self.cfg.logging_steps
|
||||||
if self.cfg.logging_steps is not None
|
if self.cfg.logging_steps is not None
|
||||||
@@ -483,6 +483,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["fp16"] = (
|
training_arguments_kwargs["fp16"] = (
|
||||||
self.cfg.fp16 and not self.cfg.bf16
|
self.cfg.fp16 and not self.cfg.bf16
|
||||||
) or False
|
) or False
|
||||||
|
if self.cfg.fp8:
|
||||||
|
training_arguments_kwargs["fp16"] = False
|
||||||
|
training_arguments_kwargs["bf16"] = False
|
||||||
|
|
||||||
training_arguments_kwargs["tf32"] = self.cfg.tf32
|
training_arguments_kwargs["tf32"] = self.cfg.tf32
|
||||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||||
@@ -546,16 +550,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
"dataloader_prefetch_factor"
|
"dataloader_prefetch_factor"
|
||||||
] = self.cfg.dataloader_prefetch_factor
|
] = self.cfg.dataloader_prefetch_factor
|
||||||
|
|
||||||
if self.cfg.val_set_size == 0:
|
if self.cfg.eval_steps:
|
||||||
# no eval set, so don't eval
|
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
|
||||||
elif self.cfg.eval_steps:
|
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
elif self.cfg.evaluation_strategy:
|
elif self.cfg.evaluation_strategy:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"evaluation_strategy"
|
"evaluation_strategy"
|
||||||
] = self.cfg.evaluation_strategy
|
] = self.cfg.evaluation_strategy
|
||||||
|
elif self.cfg.val_set_size == 0:
|
||||||
|
# no eval set, so don't eval
|
||||||
|
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||||
else:
|
else:
|
||||||
# we have an eval set, but no steps defined, default to use epoch
|
# we have an eval set, but no steps defined, default to use epoch
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||||
@@ -661,9 +665,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["eval_sample_packing"] = (
|
training_arguments_kwargs["eval_sample_packing"] = (
|
||||||
self.cfg.sample_packing
|
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||||
if self.cfg.eval_sample_packing is not False
|
|
||||||
else False
|
|
||||||
)
|
)
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"sample_packing_seq_len_multiplier"
|
"sample_packing_seq_len_multiplier"
|
||||||
|
|||||||
@@ -3,6 +3,4 @@ MixFormers model architecture used for phi models
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
|
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
|
||||||
from .configuration_phi import PhiConfig # noqa
|
|
||||||
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
|
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
|
||||||
from .modeling_phi import PhiForCausalLM # noqa
|
|
||||||
|
|||||||
@@ -1,65 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
# Copyright (c) Microsoft Corporation.
|
|
||||||
# Licensed under the MIT license.
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
|
||||||
|
|
||||||
|
|
||||||
class PhiConfig(PretrainedConfig):
|
|
||||||
"""Phi configuration."""
|
|
||||||
|
|
||||||
model_type = "phi"
|
|
||||||
attribute_map = {
|
|
||||||
"max_position_embeddings": "n_positions",
|
|
||||||
"hidden_size": "n_embd",
|
|
||||||
"num_attention_heads": "n_head",
|
|
||||||
"num_hidden_layers": "n_layer",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size: int = 50304,
|
|
||||||
n_positions: int = 2048,
|
|
||||||
n_embd: int = 1024,
|
|
||||||
n_layer: int = 20,
|
|
||||||
n_inner: Optional[int] = None,
|
|
||||||
n_head: int = 16,
|
|
||||||
n_head_kv: Optional[int] = None,
|
|
||||||
rotary_dim: Optional[int] = 32,
|
|
||||||
activation_function: Optional[str] = "gelu_new",
|
|
||||||
flash_attn: bool = False,
|
|
||||||
flash_rotary: bool = False,
|
|
||||||
fused_dense: bool = False,
|
|
||||||
attn_pdrop: float = 0.0,
|
|
||||||
embd_pdrop: float = 0.0,
|
|
||||||
resid_pdrop: float = 0.0,
|
|
||||||
layer_norm_epsilon: float = 1e-5,
|
|
||||||
initializer_range: float = 0.02,
|
|
||||||
tie_word_embeddings: bool = False,
|
|
||||||
pad_vocab_size_multiple: int = 64,
|
|
||||||
**kwargs
|
|
||||||
) -> None:
|
|
||||||
self.vocab_size = int(
|
|
||||||
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
|
||||||
)
|
|
||||||
self.n_positions = n_positions
|
|
||||||
self.n_embd = n_embd
|
|
||||||
self.n_layer = n_layer
|
|
||||||
self.n_inner = n_inner
|
|
||||||
self.n_head = n_head
|
|
||||||
self.n_head_kv = n_head_kv
|
|
||||||
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
|
||||||
self.activation_function = activation_function
|
|
||||||
self.flash_attn = flash_attn
|
|
||||||
self.flash_rotary = flash_rotary
|
|
||||||
self.fused_dense = fused_dense
|
|
||||||
self.attn_pdrop = attn_pdrop
|
|
||||||
self.embd_pdrop = embd_pdrop
|
|
||||||
self.resid_pdrop = resid_pdrop
|
|
||||||
self.layer_norm_epsilon = layer_norm_epsilon
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
|
|
||||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,426 +0,0 @@
|
|||||||
import torch
|
|
||||||
import logging
|
|
||||||
import warnings
|
|
||||||
from einops import rearrange
|
|
||||||
from functools import partial
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
|
||||||
from axolotl.monkeypatch.fused_modules import FusedAttention
|
|
||||||
|
|
||||||
try:
|
|
||||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
|
||||||
flash_attn_kvpacked_func,
|
|
||||||
flash_attn_varlen_kvpacked_func,
|
|
||||||
flash_attn_varlen_qkvpacked_func,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
from flash_attn.flash_attn_interface import (
|
|
||||||
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
|
|
||||||
)
|
|
||||||
from flash_attn.flash_attn_interface import (
|
|
||||||
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
|
||||||
|
|
||||||
def flashattn_forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
"""Input shape: Batch x Time x Channel
|
|
||||||
|
|
||||||
attention_mask: [bsz, q_len]
|
|
||||||
"""
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
if not hasattr(self, "pretraining_tp"):
|
|
||||||
self.pretraining_tp = 1
|
|
||||||
|
|
||||||
if self.pretraining_tp > 1:
|
|
||||||
key_value_slicing = (
|
|
||||||
self.num_key_value_heads * self.head_dim
|
|
||||||
) // self.pretraining_tp
|
|
||||||
query_slices = self.q_proj.weight.split(
|
|
||||||
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
|
||||||
)
|
|
||||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
|
|
||||||
query_states = [
|
|
||||||
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
|
||||||
]
|
|
||||||
query_states = torch.cat(query_states, dim=-1)
|
|
||||||
|
|
||||||
key_states = [
|
|
||||||
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
|
||||||
]
|
|
||||||
key_states = torch.cat(key_states, dim=-1)
|
|
||||||
|
|
||||||
value_states = [
|
|
||||||
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
|
||||||
]
|
|
||||||
value_states = torch.cat(value_states, dim=-1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if isinstance(self, FusedAttention):
|
|
||||||
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
|
|
||||||
self.out_features, dim=-1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
query_states = query_states.view(
|
|
||||||
bsz, q_len, self.num_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
key_states = key_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
value_states = value_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
# [bsz, q_len, nh, hd]
|
|
||||||
# [bsz, nh, q_len, hd]
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = self.apply_rotary_fn(
|
|
||||||
query_states, key_states, cos, sin, position_ids
|
|
||||||
)
|
|
||||||
# [bsz, nh, t, hd]
|
|
||||||
|
|
||||||
use_sliding_windows = (
|
|
||||||
hasattr(self.config, "sliding_window") is not None
|
|
||||||
and kv_seq_len > self.config.sliding_window
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_sliding_windows:
|
|
||||||
window_size = (self.config.sliding_window, self.config.sliding_window)
|
|
||||||
else:
|
|
||||||
window_size = (-1, -1)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
|
||||||
if (
|
|
||||||
hasattr(self.config, "sliding_window")
|
|
||||||
and kv_seq_len > self.config.sliding_window
|
|
||||||
):
|
|
||||||
slicing_tokens = kv_seq_len - self.config.sliding_window
|
|
||||||
|
|
||||||
past_key = past_key_value[0]
|
|
||||||
past_value = past_key_value[1]
|
|
||||||
|
|
||||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
|
||||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
|
||||||
|
|
||||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
|
||||||
f" {past_key.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
past_key_value = (past_key, past_value) if use_cache else None
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
|
||||||
key_states = self.repeat_kv_fn(key_states, self.num_key_value_groups)
|
|
||||||
value_states = self.repeat_kv_fn(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
warnings.warn(
|
|
||||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
#
|
|
||||||
# flash-attn v2 start
|
|
||||||
#
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
# during training q,k,v always have same seqlen
|
|
||||||
assert key_states.shape == query_states.shape
|
|
||||||
is_causal = True
|
|
||||||
else:
|
|
||||||
# turn off FA causal mask after first inference autoregressive iteration
|
|
||||||
# only on first autoregressive step q,k,v have same seqlen
|
|
||||||
is_causal = key_states.shape == query_states.shape
|
|
||||||
|
|
||||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
|
||||||
|
|
||||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
|
||||||
# special handling using sample packing
|
|
||||||
qkv = torch.stack(
|
|
||||||
[query_states, key_states, value_states], dim=2
|
|
||||||
) # [bsz, nh, 3, q_len, hd]
|
|
||||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
|
||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
|
||||||
|
|
||||||
output = flash_attn_varlen_qkvpacked_func(
|
|
||||||
qkv,
|
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen,
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=True,
|
|
||||||
window_size=window_size,
|
|
||||||
)
|
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
|
||||||
elif query_states.shape == key_states.shape:
|
|
||||||
query_states = query_states.transpose(1, 2)
|
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
value_states = value_states.transpose(1, 2)
|
|
||||||
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
qkvpacked=True,
|
|
||||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
|
||||||
# the attention_mask should be the same as the key_padding_mask
|
|
||||||
key_padding_mask=attention_mask,
|
|
||||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
|
||||||
if attention_mask is not None
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
|
||||||
qkv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
max_seqlen_q,
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=is_causal,
|
|
||||||
window_size=window_size,
|
|
||||||
)
|
|
||||||
output = output_pad_fn(output_unpad)
|
|
||||||
else:
|
|
||||||
query_states = query_states.transpose(1, 2)
|
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
value_states = value_states.transpose(1, 2)
|
|
||||||
if attention_mask is None or attention_mask.all().item():
|
|
||||||
output = flash_attn_kvpacked_func(
|
|
||||||
query_states,
|
|
||||||
torch.stack([key_states, value_states], 2),
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
causal=is_causal,
|
|
||||||
window_size=window_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
( # pylint: disable=unbalanced-tuple-unpacking
|
|
||||||
q_unpad,
|
|
||||||
kv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
output_pad_fn,
|
|
||||||
) = generate_qkv(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
kvpacked=True,
|
|
||||||
key_padding_mask=attention_mask,
|
|
||||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
|
||||||
if attention_mask is not None
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
if q_unpad.dtype != kv_unpad.dtype:
|
|
||||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
|
||||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
|
||||||
q_unpad,
|
|
||||||
kv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=is_causal,
|
|
||||||
window_size=window_size,
|
|
||||||
)
|
|
||||||
output = output_pad_fn(output_unpad)
|
|
||||||
|
|
||||||
attn_output = output
|
|
||||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
|
||||||
raise ValueError(
|
|
||||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
|
||||||
f" {attn_output.size()}"
|
|
||||||
)
|
|
||||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
|
||||||
|
|
||||||
#
|
|
||||||
# flash-attn v2 end
|
|
||||||
#
|
|
||||||
|
|
||||||
if self.pretraining_tp > 1:
|
|
||||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
|
||||||
o_proj_slices = self.o_proj.weight.split(
|
|
||||||
self.hidden_size // self.pretraining_tp, dim=1
|
|
||||||
)
|
|
||||||
attn_output = sum(
|
|
||||||
F.linear(attn_output[i], o_proj_slices[i])
|
|
||||||
for i in range(self.pretraining_tp)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
|
||||||
def generate_qkv(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
query_padding_mask=None,
|
|
||||||
key_padding_mask=None,
|
|
||||||
kvpacked=False,
|
|
||||||
qkvpacked=False,
|
|
||||||
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
|
||||||
"""
|
|
||||||
Arguments:
|
|
||||||
q: (batch_size, seqlen_q, nheads, d)
|
|
||||||
k: (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
v: (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
query_padding_mask: (batch_size, seqlen), bool
|
|
||||||
key_padding_mask: (batch_size, seqlen), bool
|
|
||||||
"""
|
|
||||||
assert not (kvpacked and qkvpacked)
|
|
||||||
batch_size, seqlen_q, nheads, d = q.shape
|
|
||||||
_, seqlen_k, nheads_k, _ = k.shape
|
|
||||||
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
|
|
||||||
if query_padding_mask is not None:
|
|
||||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
|
||||||
q, query_padding_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
|
||||||
output_unpad, indices_q, batch_size, seqlen_q
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
|
||||||
cu_seqlens_q = torch.arange(
|
|
||||||
0,
|
|
||||||
(batch_size + 1) * seqlen_q,
|
|
||||||
step=seqlen_q,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=q_unpad.device,
|
|
||||||
)
|
|
||||||
max_seqlen_q = seqlen_q
|
|
||||||
|
|
||||||
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
|
||||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
|
||||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
|
||||||
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
|
||||||
else:
|
|
||||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
|
||||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
|
||||||
cu_seqlens_k = torch.arange(
|
|
||||||
0,
|
|
||||||
(batch_size + 1) * seqlen_k,
|
|
||||||
step=seqlen_k,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=k_unpad.device,
|
|
||||||
)
|
|
||||||
max_seqlen_k = seqlen_k
|
|
||||||
|
|
||||||
if qkvpacked:
|
|
||||||
assert nheads == nheads_k
|
|
||||||
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
|
||||||
qkv = torch.stack([q, k, v], dim=2)
|
|
||||||
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
|
||||||
|
|
||||||
if kvpacked:
|
|
||||||
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
|
||||||
kv = torch.stack([k, v], dim=2)
|
|
||||||
return (
|
|
||||||
q_unpad,
|
|
||||||
kv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
q,
|
|
||||||
kv,
|
|
||||||
output_pad_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
q_unpad,
|
|
||||||
k_unpad,
|
|
||||||
v_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
output_pad_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
def replace_cross_entropy(modeling_class, module_name):
|
|
||||||
"""
|
|
||||||
modeling_class: transformers.models.llama.modeling_<class>
|
|
||||||
module_name: CrossEntropyLoss
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
|
||||||
|
|
||||||
cross_entropy_loss = partial(
|
|
||||||
CrossEntropyLoss, inplace_backward=True
|
|
||||||
)
|
|
||||||
|
|
||||||
setattr(modeling_class, module_name, cross_entropy_loss)
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
LOG.info(
|
|
||||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def replace_rms_norm(modeling_class, module_name):
|
|
||||||
"""
|
|
||||||
modeling_class: transformers.models.llama.modeling_<class>
|
|
||||||
module_name: RMSNorm
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from flash_attn.ops.rms_norm import RMSNorm
|
|
||||||
|
|
||||||
class FlashRMSNorm(RMSNorm):
|
|
||||||
"""A faster RMS Norm."""
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
super().__init__(hidden_size, eps=eps)
|
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
|
||||||
setattr(modeling_class, module_name, FlashRMSNorm)
|
|
||||||
except ImportError:
|
|
||||||
LOG.info(
|
|
||||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
|
||||||
)
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
import torch
|
|
||||||
from typing import List
|
|
||||||
from xformers.ops import SwiGLU
|
|
||||||
from axolotl.monkeypatch.utils import set_module_name
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
|
||||||
LlamaAttention,
|
|
||||||
LlamaMLP,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Generalize to other attention modules
|
|
||||||
class FusedAttention(LlamaAttention):
|
|
||||||
"""
|
|
||||||
Fused QKV Attention layer for incrementally improved training efficiency
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
q: torch.nn.Linear, # pylint: disable=invalid-name
|
|
||||||
k: torch.nn.Linear, # pylint: disable=invalid-name
|
|
||||||
v: torch.nn.Linear, # pylint: disable=invalid-name
|
|
||||||
o: torch.nn.Linear, # pylint: disable=invalid-name
|
|
||||||
):
|
|
||||||
super().__init__(config)
|
|
||||||
self.config = config
|
|
||||||
self.init_device = next(iter(q.state_dict().values())).device
|
|
||||||
|
|
||||||
# define equivalent fused qkv projection
|
|
||||||
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
|
|
||||||
self.qkv_proj = torch.nn.Linear(
|
|
||||||
q.in_features, sum(self.out_features), device=self.init_device, bias=False
|
|
||||||
)
|
|
||||||
self.o_proj = o
|
|
||||||
|
|
||||||
# overwrite initialized weights with pretrained weights
|
|
||||||
self.qkv_proj.weight.data = torch.cat(
|
|
||||||
(q.weight.data, k.weight.data, v.weight.data), dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def _post_training(self, model, name):
|
|
||||||
q_proj, k_proj, v_proj = torch.split(
|
|
||||||
self.qkv_proj.weight.data, self.out_features, dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
new_attn = LlamaAttention(self.config)
|
|
||||||
new_attn.q_proj.weight.data = q_proj
|
|
||||||
new_attn.k_proj.weight.data = k_proj
|
|
||||||
new_attn.v_proj.weight.data = v_proj
|
|
||||||
new_attn.o_proj.weight.data = self.o_proj.weight.data
|
|
||||||
|
|
||||||
set_module_name(model, name, new_attn)
|
|
||||||
|
|
||||||
|
|
||||||
class FusedMLP(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Fused MLP layer for incrementally improved training efficiency
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
gate_proj: torch.nn.Linear,
|
|
||||||
up_proj: torch.nn.Linear,
|
|
||||||
down_proj: torch.nn.Linear,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.swiglu = SwiGLU(
|
|
||||||
in_features=config.hidden_size,
|
|
||||||
hidden_features=config.intermediate_size,
|
|
||||||
bias=False,
|
|
||||||
_pack_weights=True,
|
|
||||||
)
|
|
||||||
# overwrite initialized weights with pretrained weights
|
|
||||||
self.swiglu.w12.weight.data = torch.cat(
|
|
||||||
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
|
||||||
)
|
|
||||||
self.swiglu.w3.weight.data = down_proj.weight.data
|
|
||||||
|
|
||||||
def _post_training(self, model, name):
|
|
||||||
w1, w2 = torch.split( # pylint: disable=invalid-name
|
|
||||||
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assign the split weights back to the original layers
|
|
||||||
new_mlp = LlamaMLP(self.config)
|
|
||||||
new_mlp.gate_proj.weight.data = w1
|
|
||||||
new_mlp.up_proj.weight.data = w2
|
|
||||||
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
|
||||||
|
|
||||||
set_module_name(model, name, new_mlp)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
|
||||||
return self.swiglu(x)
|
|
||||||
@@ -3,10 +3,15 @@
|
|||||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import transformers
|
import transformers
|
||||||
|
from einops import rearrange
|
||||||
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
@@ -14,20 +19,27 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
)
|
)
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaMLP,
|
LlamaMLP,
|
||||||
)
|
|
||||||
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
|
||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
|
from xformers.ops import SwiGLU
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||||
from axolotl.monkeypatch.fused_modules import FusedAttention, FusedMLP
|
|
||||||
from axolotl.monkeypatch.flash_modules import (
|
try:
|
||||||
flashattn_forward,
|
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||||
replace_cross_entropy,
|
flash_attn_kvpacked_func,
|
||||||
replace_rms_norm
|
flash_attn_varlen_kvpacked_func,
|
||||||
)
|
flash_attn_varlen_qkvpacked_func,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
from flash_attn.flash_attn_interface import (
|
||||||
|
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
|
||||||
|
)
|
||||||
|
from flash_attn.flash_attn_interface import (
|
||||||
|
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -63,17 +75,129 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
_prepare_decoder_attention_mask
|
_prepare_decoder_attention_mask
|
||||||
)
|
)
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.apply_rotary_fn = apply_rotary_pos_emb
|
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.repeat_kv_fn = repeat_kv
|
|
||||||
if packed:
|
if packed:
|
||||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||||
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
||||||
llama_model_forward
|
llama_model_forward
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# skip only if explicitly disabled
|
||||||
if cross_entropy:
|
if cross_entropy:
|
||||||
replace_cross_entropy(transformers.models.llama.modeling_llama, "CrossEntropyLoss")
|
try:
|
||||||
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||||
|
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||||
|
CrossEntropyLoss, inplace_backward=True
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
LOG.info(
|
||||||
|
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# skip only if explicitly disabled
|
||||||
if rms_norm:
|
if rms_norm:
|
||||||
replace_rms_norm(transformers.models.llama.modeling_llama, "LlamaRMSNorm")
|
try:
|
||||||
|
from flash_attn.ops.rms_norm import RMSNorm
|
||||||
|
|
||||||
|
class LlamaRMSNorm(RMSNorm):
|
||||||
|
"""Patched LLamaRMSNorm"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
super().__init__(hidden_size, eps=eps)
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||||
|
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||||
|
except ImportError:
|
||||||
|
LOG.info(
|
||||||
|
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedAttention(LlamaAttention):
|
||||||
|
"""
|
||||||
|
Fused QKV Attention layer for incrementally improved training efficiency
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
q: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
k: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
v: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
o: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.init_device = next(iter(q.state_dict().values())).device
|
||||||
|
|
||||||
|
# define equivalent fused qkv projection
|
||||||
|
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
|
||||||
|
self.qkv_proj = torch.nn.Linear(
|
||||||
|
q.in_features, sum(self.out_features), device=self.init_device, bias=False
|
||||||
|
)
|
||||||
|
self.o_proj = o
|
||||||
|
|
||||||
|
# overwrite initialized weights with pretrained weights
|
||||||
|
self.qkv_proj.weight.data = torch.cat(
|
||||||
|
(q.weight.data, k.weight.data, v.weight.data), dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _post_training(self, model, name):
|
||||||
|
q_proj, k_proj, v_proj = torch.split(
|
||||||
|
self.qkv_proj.weight.data, self.out_features, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
new_attn = LlamaAttention(self.config)
|
||||||
|
new_attn.q_proj.weight.data = q_proj
|
||||||
|
new_attn.k_proj.weight.data = k_proj
|
||||||
|
new_attn.v_proj.weight.data = v_proj
|
||||||
|
new_attn.o_proj.weight.data = self.o_proj.weight.data
|
||||||
|
|
||||||
|
set_module_name(model, name, new_attn)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedMLP(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Fused MLP layer for incrementally improved training efficiency
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
gate_proj: torch.nn.Linear,
|
||||||
|
up_proj: torch.nn.Linear,
|
||||||
|
down_proj: torch.nn.Linear,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.swiglu = SwiGLU(
|
||||||
|
in_features=config.hidden_size,
|
||||||
|
hidden_features=config.intermediate_size,
|
||||||
|
bias=False,
|
||||||
|
_pack_weights=True,
|
||||||
|
)
|
||||||
|
# overwrite initialized weights with pretrained weights
|
||||||
|
self.swiglu.w12.weight.data = torch.cat(
|
||||||
|
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
||||||
|
)
|
||||||
|
self.swiglu.w3.weight.data = down_proj.weight.data
|
||||||
|
|
||||||
|
def _post_training(self, model, name):
|
||||||
|
w1, w2 = torch.split( # pylint: disable=invalid-name
|
||||||
|
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign the split weights back to the original layers
|
||||||
|
new_mlp = LlamaMLP(self.config)
|
||||||
|
new_mlp.gate_proj.weight.data = w1
|
||||||
|
new_mlp.up_proj.weight.data = w2
|
||||||
|
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
||||||
|
|
||||||
|
set_module_name(model, name, new_mlp)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
||||||
|
return self.swiglu(x)
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
@@ -89,6 +213,322 @@ def _prepare_decoder_attention_mask(
|
|||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def flashattn_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel
|
||||||
|
|
||||||
|
attention_mask: [bsz, q_len]
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if not hasattr(self, "pretraining_tp"):
|
||||||
|
self.pretraining_tp = 1
|
||||||
|
|
||||||
|
if self.pretraining_tp > 1:
|
||||||
|
key_value_slicing = (
|
||||||
|
self.num_key_value_heads * self.head_dim
|
||||||
|
) // self.pretraining_tp
|
||||||
|
query_slices = self.q_proj.weight.split(
|
||||||
|
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
||||||
|
)
|
||||||
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
|
||||||
|
query_states = [
|
||||||
|
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
||||||
|
]
|
||||||
|
query_states = torch.cat(query_states, dim=-1)
|
||||||
|
|
||||||
|
key_states = [
|
||||||
|
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
||||||
|
]
|
||||||
|
key_states = torch.cat(key_states, dim=-1)
|
||||||
|
|
||||||
|
value_states = [
|
||||||
|
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
||||||
|
]
|
||||||
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if isinstance(self, FusedAttention):
|
||||||
|
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
|
||||||
|
self.out_features, dim=-1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
# [bsz, q_len, nh, hd]
|
||||||
|
# [bsz, nh, q_len, hd]
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
warnings.warn(
|
||||||
|
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
#
|
||||||
|
# flash-attn v2 start
|
||||||
|
#
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
# during training q,k,v always have same seqlen
|
||||||
|
assert key_states.shape == query_states.shape
|
||||||
|
is_causal = True
|
||||||
|
else:
|
||||||
|
# turn off FA causal mask after first inference autoregressive iteration
|
||||||
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
|
is_causal = key_states.shape == query_states.shape
|
||||||
|
|
||||||
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||||
|
# special handling using sample packing
|
||||||
|
qkv = torch.stack(
|
||||||
|
[query_states, key_states, value_states], dim=2
|
||||||
|
) # [bsz, nh, 3, q_len, hd]
|
||||||
|
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||||
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||||
|
)
|
||||||
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
|
elif query_states.shape == key_states.shape:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
qkvpacked=True,
|
||||||
|
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||||
|
# the attention_mask should be the same as the key_padding_mask
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
max_seqlen_q,
|
||||||
|
0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
else:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
if attention_mask is None or attention_mask.all().item():
|
||||||
|
output = flash_attn_kvpacked_func(
|
||||||
|
query_states,
|
||||||
|
torch.stack([key_states, value_states], 2),
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
( # pylint: disable=unbalanced-tuple-unpacking
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
output_pad_fn,
|
||||||
|
) = generate_qkv(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
kvpacked=True,
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
if q_unpad.dtype != kv_unpad.dtype:
|
||||||
|
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||||
|
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
|
||||||
|
attn_output = output
|
||||||
|
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||||
|
|
||||||
|
#
|
||||||
|
# flash-attn v2 end
|
||||||
|
#
|
||||||
|
|
||||||
|
if self.pretraining_tp > 1:
|
||||||
|
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||||
|
o_proj_slices = self.o_proj.weight.split(
|
||||||
|
self.hidden_size // self.pretraining_tp, dim=1
|
||||||
|
)
|
||||||
|
attn_output = sum(
|
||||||
|
F.linear(attn_output[i], o_proj_slices[i])
|
||||||
|
for i in range(self.pretraining_tp)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||||
|
def generate_qkv(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
query_padding_mask=None,
|
||||||
|
key_padding_mask=None,
|
||||||
|
kvpacked=False,
|
||||||
|
qkvpacked=False,
|
||||||
|
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
q: (batch_size, seqlen_q, nheads, d)
|
||||||
|
k: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
v: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
query_padding_mask: (batch_size, seqlen), bool
|
||||||
|
key_padding_mask: (batch_size, seqlen), bool
|
||||||
|
"""
|
||||||
|
assert not (kvpacked and qkvpacked)
|
||||||
|
batch_size, seqlen_q, nheads, d = q.shape
|
||||||
|
_, seqlen_k, nheads_k, _ = k.shape
|
||||||
|
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
|
||||||
|
if query_padding_mask is not None:
|
||||||
|
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||||
|
q, query_padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
||||||
|
output_unpad, indices_q, batch_size, seqlen_q
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_q,
|
||||||
|
step=seqlen_q,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=q_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_q = seqlen_q
|
||||||
|
|
||||||
|
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
||||||
|
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||||
|
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||||
|
else:
|
||||||
|
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||||
|
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_k = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_k,
|
||||||
|
step=seqlen_k,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=k_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_k = seqlen_k
|
||||||
|
|
||||||
|
if qkvpacked:
|
||||||
|
assert nheads == nheads_k
|
||||||
|
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||||
|
qkv = torch.stack([q, k, v], dim=2)
|
||||||
|
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||||
|
|
||||||
|
if kvpacked:
|
||||||
|
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||||
|
kv = torch.stack([k, v], dim=2)
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
k_unpad,
|
||||||
|
v_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def llama_model_forward(
|
def llama_model_forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -25,8 +25,6 @@ def sdp_attention_forward(
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|||||||
@@ -29,8 +29,6 @@ def xformers_forward(
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|||||||
@@ -6,37 +6,29 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
from einops import rearrange
|
||||||
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
|
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||||
|
flash_attn_kvpacked_func,
|
||||||
|
flash_attn_varlen_kvpacked_func,
|
||||||
|
flash_attn_varlen_qkvpacked_func,
|
||||||
|
)
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
|
MistralAttention as OriginalMistralAttention,
|
||||||
|
)
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||||
MistralMLP
|
|
||||||
)
|
)
|
||||||
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
from axolotl.monkeypatch.flash_modules import (
|
|
||||||
flashattn_forward,
|
|
||||||
replace_cross_entropy,
|
|
||||||
replace_rms_norm
|
|
||||||
)
|
|
||||||
from axolotl.monkeypatch.fused_modules import FusedMLP
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||||
|
|
||||||
def replace_mistral_mlp_with_swiglu(model):
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if isinstance(module, MistralMLP):
|
|
||||||
mlp = FusedMLP(
|
|
||||||
module.config, module.gate_proj, module.up_proj, module.down_proj
|
|
||||||
)
|
|
||||||
set_module_name(model, name, mlp)
|
|
||||||
|
|
||||||
|
|
||||||
def replace_mistral_attn_with_flash_attn(
|
def replace_mistral_attn_with_flash_attn(
|
||||||
packed: Optional[bool] = False,
|
packed: Optional[bool] = False,
|
||||||
cross_entropy: Optional[bool] = False,
|
|
||||||
rms_norm: Optional[bool] = False,
|
|
||||||
):
|
):
|
||||||
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||||
_prepare_decoder_attention_mask
|
_prepare_decoder_attention_mask
|
||||||
@@ -44,8 +36,6 @@ def replace_mistral_attn_with_flash_attn(
|
|||||||
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
||||||
flashattn_forward
|
flashattn_forward
|
||||||
)
|
)
|
||||||
transformers.models.mistral.modeling_mistral.MistralAttention.apply_rotary_fn = apply_rotary_pos_emb
|
|
||||||
transformers.models.mistral.modeling_mistral.MistralAttention.repeat_kv_fn = repeat_kv
|
|
||||||
if packed:
|
if packed:
|
||||||
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
||||||
MistralDecoderLayer
|
MistralDecoderLayer
|
||||||
@@ -53,10 +43,6 @@ def replace_mistral_attn_with_flash_attn(
|
|||||||
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
|
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
|
||||||
mistral_model_forward
|
mistral_model_forward
|
||||||
)
|
)
|
||||||
if cross_entropy:
|
|
||||||
replace_cross_entropy(transformers.mistral.llama.modeling_mistral, "CrossEntropyLoss")
|
|
||||||
if rms_norm:
|
|
||||||
replace_rms_norm(transformers.mistral.llama.modeling_mistral, "MistralRMSNorm")
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
@@ -129,6 +115,299 @@ def _prepare_decoder_attention_mask(
|
|||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def flashattn_forward(
|
||||||
|
self: OriginalMistralAttention,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
use_sliding_windows = (
|
||||||
|
hasattr(self.config, "sliding_window") is not None
|
||||||
|
and kv_seq_len > self.config.sliding_window
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_sliding_windows:
|
||||||
|
window_size = (self.config.sliding_window, self.config.sliding_window)
|
||||||
|
else:
|
||||||
|
window_size = (-1, -1)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||||
|
if (
|
||||||
|
hasattr(self.config, "sliding_window")
|
||||||
|
and kv_seq_len > self.config.sliding_window
|
||||||
|
):
|
||||||
|
slicing_tokens = kv_seq_len - self.config.sliding_window
|
||||||
|
|
||||||
|
past_key = past_key_value[0]
|
||||||
|
past_value = past_key_value[1]
|
||||||
|
|
||||||
|
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||||
|
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||||
|
|
||||||
|
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||||
|
f" {past_key.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
past_key_value = (past_key, past_value) if use_cache else None
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
# during training q,k,v always have same seqlen
|
||||||
|
assert key_states.shape == query_states.shape
|
||||||
|
is_causal = True
|
||||||
|
else:
|
||||||
|
# turn off FA causal mask after first inference autoregressive iteration
|
||||||
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
|
is_causal = key_states.shape == query_states.shape
|
||||||
|
|
||||||
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||||
|
# special handling using sample packing
|
||||||
|
qkv = torch.stack(
|
||||||
|
[query_states, key_states, value_states], dim=2
|
||||||
|
) # [bsz, nh, 3, q_len, hd]
|
||||||
|
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||||
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=True,
|
||||||
|
window_size=window_size,
|
||||||
|
)
|
||||||
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
|
elif query_states.shape == key_states.shape:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
qkvpacked=True,
|
||||||
|
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||||
|
# the attention_mask should be the same as the key_padding_mask
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
max_seqlen_q,
|
||||||
|
0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
window_size=window_size,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
else:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
if attention_mask is None or attention_mask.all().item():
|
||||||
|
output = flash_attn_kvpacked_func(
|
||||||
|
query_states,
|
||||||
|
torch.stack([key_states, value_states], 2),
|
||||||
|
causal=is_causal,
|
||||||
|
window_size=window_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
( # pylint: disable=unbalanced-tuple-unpacking
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
output_pad_fn,
|
||||||
|
) = generate_qkv(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
kvpacked=True,
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
if q_unpad.dtype != kv_unpad.dtype:
|
||||||
|
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||||
|
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
window_size=window_size,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
|
||||||
|
attn_output = output
|
||||||
|
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||||
|
def generate_qkv(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
query_padding_mask=None,
|
||||||
|
key_padding_mask=None,
|
||||||
|
kvpacked=False,
|
||||||
|
qkvpacked=False,
|
||||||
|
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
q: (batch_size, seqlen_q, nheads, d)
|
||||||
|
k: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
v: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
query_padding_mask: (batch_size, seqlen), bool
|
||||||
|
key_padding_mask: (batch_size, seqlen), bool
|
||||||
|
"""
|
||||||
|
assert not (kvpacked and qkvpacked)
|
||||||
|
batch_size, seqlen_q, nheads, d = q.shape
|
||||||
|
_, seqlen_k, nheads_k, _ = k.shape
|
||||||
|
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
|
||||||
|
if query_padding_mask is not None:
|
||||||
|
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||||
|
q, query_padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
||||||
|
output_unpad, indices_q, batch_size, seqlen_q
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_q,
|
||||||
|
step=seqlen_q,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=q_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_q = seqlen_q
|
||||||
|
|
||||||
|
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
||||||
|
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||||
|
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||||
|
else:
|
||||||
|
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||||
|
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_k = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_k,
|
||||||
|
step=seqlen_k,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=k_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_k = seqlen_k
|
||||||
|
|
||||||
|
if qkvpacked:
|
||||||
|
assert nheads == nheads_k
|
||||||
|
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||||
|
qkv = torch.stack([q, k, v], dim=2)
|
||||||
|
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||||
|
|
||||||
|
if kvpacked:
|
||||||
|
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||||
|
kv = torch.stack([k, v], dim=2)
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
k_unpad,
|
||||||
|
v_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def mistral_model_forward(
|
def mistral_model_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
|||||||
@@ -22,13 +22,7 @@ class PromptStyle(Enum):
|
|||||||
CHATML = "chatml"
|
CHATML = "chatml"
|
||||||
|
|
||||||
|
|
||||||
class Prompter:
|
class AlpacaPrompter:
|
||||||
"""
|
|
||||||
Base prompter class for all prompters
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class AlpacaPrompter(Prompter):
|
|
||||||
"""
|
"""
|
||||||
Base class for alpaca prompters
|
Base class for alpaca prompters
|
||||||
"""
|
"""
|
||||||
@@ -75,7 +69,7 @@ class AlpacaPrompter(Prompter):
|
|||||||
else:
|
else:
|
||||||
res = (
|
res = (
|
||||||
self.system_format.format(system=self.system_no_input_prompt)
|
self.system_format.format(system=self.system_no_input_prompt)
|
||||||
if self.system_no_input_prompt
|
if self.system_prompt
|
||||||
else ""
|
else ""
|
||||||
) + self.turn_no_input_format.format(instruction=instruction)
|
) + self.turn_no_input_format.format(instruction=instruction)
|
||||||
if output:
|
if output:
|
||||||
@@ -165,7 +159,7 @@ class NomicGPT4AllPrompter(AlpacaPrompter):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ReflectAlpacaPrompter(Prompter):
|
class ReflectAlpacaPrompter:
|
||||||
"""
|
"""
|
||||||
Prompter for ReflectAlpaca
|
Prompter for ReflectAlpaca
|
||||||
"""
|
"""
|
||||||
@@ -260,7 +254,7 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||||
"""
|
"""
|
||||||
A prompter that generates prompts for the ShareGPT
|
A prompter that generates prompts for the ShareGPT
|
||||||
"""
|
"""
|
||||||
@@ -355,7 +349,7 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedPrompter(Prompter):
|
class UnsupportedPrompter:
|
||||||
"""
|
"""
|
||||||
A dummy class for custom prompters
|
A dummy class for custom prompters
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -70,7 +70,9 @@ def normalize_config(cfg):
|
|||||||
else:
|
else:
|
||||||
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
|
||||||
|
|
||||||
if cfg.bf16 or cfg.bfloat16:
|
if cfg.fp8:
|
||||||
|
cfg.torch_dtype = torch.bfloat16
|
||||||
|
elif cfg.bf16 or cfg.bfloat16:
|
||||||
cfg.torch_dtype = torch.bfloat16
|
cfg.torch_dtype = torch.bfloat16
|
||||||
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
||||||
cfg.torch_dtype = torch.float16
|
cfg.torch_dtype = torch.float16
|
||||||
@@ -122,19 +124,6 @@ def normalize_config(cfg):
|
|||||||
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg.is_qwen_derived_model = (
|
|
||||||
(
|
|
||||||
hasattr(model_config, "model_type")
|
|
||||||
and model_config.model_type
|
|
||||||
in [
|
|
||||||
"qwen",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
or cfg.is_qwen_derived_model
|
|
||||||
or "qwen" in cfg.base_model.lower()
|
|
||||||
or (cfg.model_type and "qwen" in cfg.model_type.lower())
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(cfg.learning_rate, str):
|
if isinstance(cfg.learning_rate, str):
|
||||||
cfg.learning_rate = float(cfg.learning_rate)
|
cfg.learning_rate = float(cfg.learning_rate)
|
||||||
|
|
||||||
@@ -178,11 +167,7 @@ def validate_config(cfg):
|
|||||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||||
)
|
)
|
||||||
if (
|
if cfg.eval_batch_size != cfg.micro_batch_size:
|
||||||
cfg.eval_batch_size
|
|
||||||
and cfg.micro_batch_size
|
|
||||||
and cfg.eval_batch_size != cfg.micro_batch_size
|
|
||||||
):
|
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||||
)
|
)
|
||||||
@@ -386,17 +371,6 @@ def validate_config(cfg):
|
|||||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.rope_scaling:
|
|
||||||
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
|
||||||
|
|
||||||
if cfg.warmup_steps and cfg.warmup_ratio:
|
|
||||||
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
|
|
||||||
|
|
||||||
if cfg.is_qwen_derived_model and cfg.gradient_checkpointing:
|
|
||||||
LOG.warning(
|
|
||||||
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import functools
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import (
|
from datasets import (
|
||||||
@@ -34,7 +34,6 @@ from axolotl.prompters import (
|
|||||||
JeopardyPrompter,
|
JeopardyPrompter,
|
||||||
MultipleChoiceConcisePrompter,
|
MultipleChoiceConcisePrompter,
|
||||||
MultipleChoiceExplainPrompter,
|
MultipleChoiceExplainPrompter,
|
||||||
Prompter,
|
|
||||||
ReflectAlpacaPrompter,
|
ReflectAlpacaPrompter,
|
||||||
SummarizeTLDRPrompter,
|
SummarizeTLDRPrompter,
|
||||||
UnsupportedPrompter,
|
UnsupportedPrompter,
|
||||||
@@ -79,14 +78,6 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||||
cfg, train_dataset, eval_dataset, tokenizer
|
cfg, train_dataset, eval_dataset, tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
|
||||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
|
||||||
if total_eval_steps == 0:
|
|
||||||
raise ValueError(
|
|
||||||
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.max_steps:
|
if cfg.max_steps:
|
||||||
total_num_steps = min(
|
total_num_steps = min(
|
||||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||||
@@ -99,7 +90,7 @@ def prepare_dataset(cfg, tokenizer):
|
|||||||
|
|
||||||
def load_tokenized_prepared_datasets(
|
def load_tokenized_prepared_datasets(
|
||||||
tokenizer, cfg, default_dataset_prepared_path
|
tokenizer, cfg, default_dataset_prepared_path
|
||||||
) -> Tuple[DatasetDict, List[Prompter]]:
|
) -> DatasetDict:
|
||||||
tokenizer_name = tokenizer.__class__.__name__
|
tokenizer_name = tokenizer.__class__.__name__
|
||||||
ds_hash = str(
|
ds_hash = str(
|
||||||
md5(
|
md5(
|
||||||
@@ -107,12 +98,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
str(cfg.sequence_len)
|
str(cfg.sequence_len)
|
||||||
+ "@"
|
+ "@"
|
||||||
+ "|".join(
|
+ "|".join(
|
||||||
sorted(
|
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||||
[
|
|
||||||
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
|
|
||||||
for d in cfg.datasets
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
+ "|"
|
+ "|"
|
||||||
+ tokenizer_name
|
+ tokenizer_name
|
||||||
@@ -178,66 +164,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
except (FileNotFoundError, ConnectionError):
|
except (FileNotFoundError, ConnectionError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
ds_from_cloud = False
|
|
||||||
storage_options = {}
|
|
||||||
remote_file_system = None
|
|
||||||
if config_dataset.path.startswith("s3://"):
|
|
||||||
try:
|
|
||||||
import aiobotocore.session # type: ignore
|
|
||||||
import s3fs # type: ignore
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(
|
|
||||||
"s3:// paths require aiobotocore and s3fs to be installed"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
# Takes credentials from ~/.aws/credentials for default profile
|
|
||||||
s3_session = aiobotocore.session.AioSession(profile="default")
|
|
||||||
storage_options = {"session": s3_session}
|
|
||||||
remote_file_system = s3fs.S3FileSystem(**storage_options)
|
|
||||||
elif config_dataset.path.startswith(
|
|
||||||
"gs://"
|
|
||||||
) or config_dataset.path.startswith("gcs://"):
|
|
||||||
try:
|
|
||||||
import gcsfs # type: ignore
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(
|
|
||||||
"gs:// or gcs:// paths require gcsfs to be installed"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
# gcsfs will use default credentials from the environment else anon
|
|
||||||
# https://gcsfs.readthedocs.io/en/latest/#credentials
|
|
||||||
storage_options = {"token": None}
|
|
||||||
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
|
|
||||||
# TODO: Figure out how to get auth creds passed
|
|
||||||
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
|
|
||||||
# try:
|
|
||||||
# import adlfs
|
|
||||||
# except ImportError as exc:
|
|
||||||
# raise ImportError(
|
|
||||||
# "adl:// or abfs:// paths require adlfs to be installed"
|
|
||||||
# ) from exc
|
|
||||||
|
|
||||||
# # Gen 1
|
|
||||||
# storage_options = {
|
|
||||||
# "tenant_id": TENANT_ID,
|
|
||||||
# "client_id": CLIENT_ID,
|
|
||||||
# "client_secret": CLIENT_SECRET,
|
|
||||||
# }
|
|
||||||
# # Gen 2
|
|
||||||
# storage_options = {
|
|
||||||
# "account_name": ACCOUNT_NAME,
|
|
||||||
# "account_key": ACCOUNT_KEY,
|
|
||||||
# }
|
|
||||||
|
|
||||||
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
|
|
||||||
try:
|
|
||||||
if remote_file_system and remote_file_system.exists(
|
|
||||||
config_dataset.path
|
|
||||||
):
|
|
||||||
ds_from_cloud = True
|
|
||||||
except (FileNotFoundError, ConnectionError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# prefer local dataset, even if hub exists
|
# prefer local dataset, even if hub exists
|
||||||
local_path = Path(config_dataset.path)
|
local_path = Path(config_dataset.path)
|
||||||
if local_path.exists():
|
if local_path.exists():
|
||||||
@@ -251,8 +177,17 @@ def load_tokenized_prepared_datasets(
|
|||||||
split=None,
|
split=None,
|
||||||
)
|
)
|
||||||
elif local_path.is_file():
|
elif local_path.is_file():
|
||||||
ds_type = get_ds_type(config_dataset)
|
ds_type = "json"
|
||||||
|
if config_dataset.ds_type:
|
||||||
|
ds_type = config_dataset.ds_type
|
||||||
|
elif ".parquet" in config_dataset.path:
|
||||||
|
ds_type = "parquet"
|
||||||
|
elif ".arrow" in config_dataset.path:
|
||||||
|
ds_type = "arrow"
|
||||||
|
elif ".csv" in config_dataset.path:
|
||||||
|
ds_type = "csv"
|
||||||
|
elif ".txt" in config_dataset.path:
|
||||||
|
ds_type = "text"
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
ds_type,
|
ds_type,
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
@@ -272,22 +207,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
data_files=config_dataset.data_files,
|
data_files=config_dataset.data_files,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
elif ds_from_cloud and remote_file_system:
|
|
||||||
if remote_file_system.isdir(config_dataset.path):
|
|
||||||
ds = load_from_disk(
|
|
||||||
config_dataset.path,
|
|
||||||
storage_options=storage_options,
|
|
||||||
)
|
|
||||||
elif remote_file_system.isfile(config_dataset.path):
|
|
||||||
ds_type = get_ds_type(config_dataset)
|
|
||||||
ds = load_dataset(
|
|
||||||
ds_type,
|
|
||||||
name=config_dataset.name,
|
|
||||||
data_files=config_dataset.path,
|
|
||||||
streaming=False,
|
|
||||||
split=None,
|
|
||||||
storage_options=storage_options,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if isinstance(config_dataset.data_files, str):
|
if isinstance(config_dataset.data_files, str):
|
||||||
fp = hf_hub_download(
|
fp = hf_hub_download(
|
||||||
@@ -379,29 +298,11 @@ def load_tokenized_prepared_datasets(
|
|||||||
return dataset, prompters
|
return dataset, prompters
|
||||||
|
|
||||||
|
|
||||||
def get_ds_type(config_dataset: DictDefault):
|
|
||||||
"""
|
|
||||||
Get the dataset type from the path if it's not specified
|
|
||||||
"""
|
|
||||||
ds_type = "json"
|
|
||||||
if config_dataset.ds_type:
|
|
||||||
ds_type = config_dataset.ds_type
|
|
||||||
elif ".parquet" in config_dataset.path:
|
|
||||||
ds_type = "parquet"
|
|
||||||
elif ".arrow" in config_dataset.path:
|
|
||||||
ds_type = "arrow"
|
|
||||||
elif ".csv" in config_dataset.path:
|
|
||||||
ds_type = "csv"
|
|
||||||
elif ".txt" in config_dataset.path:
|
|
||||||
ds_type = "text"
|
|
||||||
return ds_type
|
|
||||||
|
|
||||||
|
|
||||||
def load_prepare_datasets(
|
def load_prepare_datasets(
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
cfg,
|
cfg,
|
||||||
default_dataset_prepared_path,
|
default_dataset_prepared_path,
|
||||||
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
) -> Tuple[Dataset, Dataset, List[Any]]:
|
||||||
max_packed_sequence_len = (
|
max_packed_sequence_len = (
|
||||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||||
)
|
)
|
||||||
@@ -410,7 +311,7 @@ def load_prepare_datasets(
|
|||||||
) # make sure we don't accidentally set it larger than sequence_len
|
) # make sure we don't accidentally set it larger than sequence_len
|
||||||
|
|
||||||
tokenizer_name = tokenizer.__class__.__name__
|
tokenizer_name = tokenizer.__class__.__name__
|
||||||
prompters: List[Prompter] = []
|
prompters = []
|
||||||
if cfg.max_packed_sequence_len is not None:
|
if cfg.max_packed_sequence_len is not None:
|
||||||
# see if we can go ahead and load the stacked dataset
|
# see if we can go ahead and load the stacked dataset
|
||||||
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
||||||
@@ -544,13 +445,14 @@ def load_prepare_datasets(
|
|||||||
train_fingerprint = md5(to_hash_train)
|
train_fingerprint = md5(to_hash_train)
|
||||||
test_fingerprint = md5(to_hash_test)
|
test_fingerprint = md5(to_hash_test)
|
||||||
|
|
||||||
dataset = dataset.train_test_split(
|
with zero_first(is_main_process()):
|
||||||
test_size=cfg.val_set_size,
|
dataset = dataset.train_test_split(
|
||||||
shuffle=False,
|
test_size=cfg.val_set_size,
|
||||||
seed=cfg.seed or 42,
|
shuffle=False,
|
||||||
train_new_fingerprint=train_fingerprint,
|
seed=cfg.seed or 42,
|
||||||
test_new_fingerprint=test_fingerprint,
|
train_new_fingerprint=train_fingerprint,
|
||||||
)
|
test_new_fingerprint=test_fingerprint,
|
||||||
|
)
|
||||||
|
|
||||||
train_dataset = dataset["train"]
|
train_dataset = dataset["train"]
|
||||||
eval_dataset = dataset["test"]
|
eval_dataset = dataset["test"]
|
||||||
|
|||||||
342
src/axolotl/utils/dataloader.py
Normal file
342
src/axolotl/utils/dataloader.py
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
# pylint: skip-file
|
||||||
|
import hashlib
|
||||||
|
import itertools
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
from queue import Queue
|
||||||
|
from threading import Thread
|
||||||
|
from typing import Any, Callable, List, Union
|
||||||
|
|
||||||
|
import numba
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import DistributedSampler, Sampler
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.utils.dataloader")
|
||||||
|
|
||||||
|
|
||||||
|
@numba.njit
|
||||||
|
def ffd_check(a: np.ndarray, c: int, n: int):
|
||||||
|
# First-fit-decreasing bin packing
|
||||||
|
# Check if a[] could fit in n bins with capacity c
|
||||||
|
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
||||||
|
|
||||||
|
a = np.sort(a)[::-1]
|
||||||
|
bins = np.full((n,), c, dtype=a.dtype)
|
||||||
|
for size in a:
|
||||||
|
not_found = True
|
||||||
|
for idx in range(n):
|
||||||
|
if bins[idx] >= size:
|
||||||
|
bins[idx] -= size
|
||||||
|
not_found = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if not_found:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@numba.njit
|
||||||
|
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
||||||
|
# First-fit-decreasing bin packing (with result return)
|
||||||
|
|
||||||
|
indices = np.argsort(a)[::-1]
|
||||||
|
a = a[indices]
|
||||||
|
|
||||||
|
bins: List[Any] = []
|
||||||
|
bins_result: List[Any] = []
|
||||||
|
for a_id, size in enumerate(a):
|
||||||
|
add_new = True
|
||||||
|
for idx in range(len(bins)):
|
||||||
|
if bins[idx] >= size:
|
||||||
|
bins[idx] -= size
|
||||||
|
bins_result[idx].append(indices[a_id] + start_index)
|
||||||
|
add_new = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if add_new:
|
||||||
|
bins.append(c - size)
|
||||||
|
bins_result.append([indices[a_id] + start_index])
|
||||||
|
|
||||||
|
return bins_result, len(a)
|
||||||
|
|
||||||
|
|
||||||
|
@numba.njit
|
||||||
|
def allocate(
|
||||||
|
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param lengths: array of lengths of each sample
|
||||||
|
:param lengths_cumsum: cumulative sum of consecutive lengths
|
||||||
|
:param rank: rank for this process
|
||||||
|
:param c: length of tokens per batch
|
||||||
|
:param n: number of ranks
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# Dynamic batch allocator, similar to Multifit
|
||||||
|
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||||
|
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
||||||
|
|
||||||
|
s = 0
|
||||||
|
start_index = 0
|
||||||
|
result = []
|
||||||
|
result_totseqs = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# binary search [left, right)
|
||||||
|
left = 1
|
||||||
|
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
||||||
|
|
||||||
|
while right - left > 1:
|
||||||
|
mid = (left + right) // 2
|
||||||
|
if ffd_check(lengths[start_index : start_index + mid], c, n):
|
||||||
|
left = mid
|
||||||
|
else:
|
||||||
|
right = mid
|
||||||
|
|
||||||
|
# use length left
|
||||||
|
batch, tot_seqs = ffd_with_result(
|
||||||
|
lengths[start_index : start_index + left], c, start_index
|
||||||
|
)
|
||||||
|
if len(batch) < n:
|
||||||
|
break
|
||||||
|
|
||||||
|
start_index += left
|
||||||
|
s = lengths_cumsum[start_index - 1]
|
||||||
|
|
||||||
|
# add local rank
|
||||||
|
result.append(batch[rank])
|
||||||
|
# add total seqs for all ranks
|
||||||
|
result_totseqs.append(tot_seqs)
|
||||||
|
# yield batch[rank], tot_seqs, s, len(result) * c * n
|
||||||
|
return result, result_totseqs, s, len(result) * c * n
|
||||||
|
|
||||||
|
|
||||||
|
def chunk(iterable, n):
|
||||||
|
"""
|
||||||
|
Chunk data into tuples of length n
|
||||||
|
"""
|
||||||
|
# batched('ABCDEFG', 3) --> ABC DEF G
|
||||||
|
if n < 1:
|
||||||
|
raise ValueError("n must be at least one")
|
||||||
|
it = iter(iterable)
|
||||||
|
while batch := tuple(itertools.islice(it, n)):
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
def hash_indices(lst: List[int]) -> str:
|
||||||
|
# Convert the list of integers to a string representation
|
||||||
|
concatenated = ",".join(map(str, lst))
|
||||||
|
|
||||||
|
# Generate the hash
|
||||||
|
sha256 = hashlib.sha256()
|
||||||
|
sha256.update(concatenated.encode())
|
||||||
|
|
||||||
|
return sha256.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class MultipackDistributedDataloader:
|
||||||
|
"""Unpadded data loading using Multipack.
|
||||||
|
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
|
||||||
|
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset: Any,
|
||||||
|
collate_fn: Callable,
|
||||||
|
seq_max_length: int = 2048,
|
||||||
|
batch_size: int = 1,
|
||||||
|
sampler: Union[Sampler, DistributedSampler] = None,
|
||||||
|
packing_efficiency_estimate: float = 1.0,
|
||||||
|
sample_packing_seq_len_multiplier: int = 1,
|
||||||
|
device_count: int = 1,
|
||||||
|
prefetch_max: int = 1000,
|
||||||
|
num_epochs: int = 1,
|
||||||
|
):
|
||||||
|
# Dataset
|
||||||
|
self.dataset = dataset
|
||||||
|
self.lengths = (
|
||||||
|
dataset.data.column("position_ids")
|
||||||
|
.to_pandas()
|
||||||
|
.apply(lambda x: x[-1] + 1)
|
||||||
|
.values
|
||||||
|
)
|
||||||
|
assert isinstance(self.lengths, np.ndarray)
|
||||||
|
assert batch_size % sample_packing_seq_len_multiplier == 0
|
||||||
|
assert batch_size >= sample_packing_seq_len_multiplier
|
||||||
|
self.sampler = sampler
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
|
||||||
|
self.seq_max_length = seq_max_length
|
||||||
|
self.batch_max_length = batch_size * seq_max_length
|
||||||
|
self.collate_fn = collate_fn
|
||||||
|
self.num_epochs = num_epochs
|
||||||
|
|
||||||
|
self.num_replicas = 1
|
||||||
|
self.rank = 0
|
||||||
|
|
||||||
|
# statistics
|
||||||
|
self.eff_total_used = 0
|
||||||
|
self.eff_total_slots = 0
|
||||||
|
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||||
|
self.device_count = device_count
|
||||||
|
|
||||||
|
# maxsize is maximum number of samples in queue
|
||||||
|
self.prefetch_max = prefetch_max
|
||||||
|
self.queue: Queue = Queue(maxsize=prefetch_max)
|
||||||
|
self.thread = None
|
||||||
|
|
||||||
|
def _worker(self):
|
||||||
|
LOG.info(
|
||||||
|
f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}"
|
||||||
|
)
|
||||||
|
for epoch in range(self.num_epochs):
|
||||||
|
for sample in self._internal_batch_generator():
|
||||||
|
while True:
|
||||||
|
if self.queue.full():
|
||||||
|
time.sleep(1)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
self.queue.put(sample)
|
||||||
|
|
||||||
|
# stop the queue when epoch is done
|
||||||
|
self.queue.put(None)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if hasattr(self.sampler, "set_epoch"):
|
||||||
|
new_epoch = self.sampler.epoch + 1
|
||||||
|
self.sampler.set_epoch(new_epoch)
|
||||||
|
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||||
|
|
||||||
|
if self.thread is None:
|
||||||
|
self.thread = Thread(target=self._worker, daemon=True)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
item = self.queue.get()
|
||||||
|
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
yield item
|
||||||
|
|
||||||
|
def generate_batches(self, set_stats=False):
|
||||||
|
LOG.info("generating packed batches")
|
||||||
|
if self.sampler:
|
||||||
|
indices = [idx for idx in self.sampler]
|
||||||
|
else:
|
||||||
|
indices = range(0, len(self.dataset))
|
||||||
|
|
||||||
|
LOG.info(hash_indices(indices))
|
||||||
|
lengths = self.lengths[indices]
|
||||||
|
lengths_cumsum = np.cumsum(lengths)
|
||||||
|
|
||||||
|
batches, totseqs, total_used, total_slots = allocate(
|
||||||
|
lengths=lengths,
|
||||||
|
lengths_cumsum=lengths_cumsum,
|
||||||
|
rank=self.rank,
|
||||||
|
# c=self.batch_max_length,
|
||||||
|
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
||||||
|
n=self.num_replicas,
|
||||||
|
)
|
||||||
|
|
||||||
|
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
||||||
|
|
||||||
|
# statistics
|
||||||
|
if set_stats:
|
||||||
|
self.eff_total_used += total_used
|
||||||
|
self.eff_total_slots += total_slots
|
||||||
|
|
||||||
|
return batches, totseqs
|
||||||
|
|
||||||
|
def _internal_batch_generator(self):
|
||||||
|
all_batches, _ = self.generate_batches(set_stats=True)
|
||||||
|
features = self.dataset.features.keys()
|
||||||
|
len_remaining = self._len_est()
|
||||||
|
for batches in chunk(
|
||||||
|
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
|
||||||
|
):
|
||||||
|
chunked_data = []
|
||||||
|
attn_mask_cum_idx = 0
|
||||||
|
for batch in batches:
|
||||||
|
concatenated = {}
|
||||||
|
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||||
|
for feature in features:
|
||||||
|
if feature == "length":
|
||||||
|
continue
|
||||||
|
if feature == "attention_mask":
|
||||||
|
arrays = [
|
||||||
|
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
||||||
|
for idx, item in enumerate(batched_data)
|
||||||
|
if feature in item
|
||||||
|
]
|
||||||
|
attn_mask_cum_idx += len(batched_data)
|
||||||
|
concatenated[feature] = np.concatenate(arrays)
|
||||||
|
else:
|
||||||
|
arrays = [
|
||||||
|
np.array(item[feature])
|
||||||
|
for item in batched_data
|
||||||
|
if feature in item
|
||||||
|
]
|
||||||
|
concatenated[feature] = np.concatenate(arrays)
|
||||||
|
chunked_data.append(concatenated)
|
||||||
|
yield self.collate_fn(chunked_data)
|
||||||
|
len_remaining -= 1
|
||||||
|
if not len_remaining:
|
||||||
|
return
|
||||||
|
# yield a no-op for cases where we don't have any data left to pack
|
||||||
|
for i in range(0, len_remaining):
|
||||||
|
yield self.collate_fn(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"input_ids": [0],
|
||||||
|
"labels": [-100],
|
||||||
|
"attention_mask": [True],
|
||||||
|
"position_ids": [0],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _len_est(self):
|
||||||
|
lengths_sum = np.sum(self.lengths)
|
||||||
|
lengths_sum_per_device = lengths_sum // self.device_count
|
||||||
|
LOG.info(
|
||||||
|
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||||
|
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||||
|
return (
|
||||||
|
math.floor(
|
||||||
|
0.99
|
||||||
|
* lengths_sum_per_device
|
||||||
|
/ self.packing_efficiency_estimate
|
||||||
|
// self.seq_max_length
|
||||||
|
// self.batch_size
|
||||||
|
)
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
|
||||||
|
# the same share of total tokens
|
||||||
|
# if not self.eff_total_used:
|
||||||
|
# batches, _ = self.generate_batches(set_stats=True)
|
||||||
|
# LOG.info(
|
||||||
|
# f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||||
|
# f"actual packing efficiency: {self.efficiency()}"
|
||||||
|
# )
|
||||||
|
return max(1, self._len_est())
|
||||||
|
|
||||||
|
def len_w_stats(self):
|
||||||
|
if not self.eff_total_used:
|
||||||
|
batches, _ = self.generate_batches(set_stats=True)
|
||||||
|
LOG.info(
|
||||||
|
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||||
|
f"actual packing efficiency: {self.efficiency()}"
|
||||||
|
)
|
||||||
|
return max(1, self._len_est())
|
||||||
|
|
||||||
|
def efficiency(self):
|
||||||
|
return self.eff_total_used / self.eff_total_slots
|
||||||
@@ -17,6 +17,7 @@ from transformers import ( # noqa: F401
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
GPTQConfig,
|
GPTQConfig,
|
||||||
|
LlamaConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
@@ -31,14 +32,9 @@ LOG = logging.getLogger("axolotl")
|
|||||||
def load_model_config(cfg):
|
def load_model_config(cfg):
|
||||||
model_config_name = cfg.base_model_config or cfg.base_model
|
model_config_name = cfg.base_model_config or cfg.base_model
|
||||||
trust_remote_code = cfg.trust_remote_code is True
|
trust_remote_code = cfg.trust_remote_code is True
|
||||||
model_config = AutoConfig.from_pretrained(
|
return AutoConfig.from_pretrained(
|
||||||
model_config_name, trust_remote_code=trust_remote_code
|
model_config_name, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
if cfg.model_config:
|
|
||||||
for key, val in cfg.model_config.items():
|
|
||||||
setattr(model_config, key, val)
|
|
||||||
|
|
||||||
return model_config
|
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(cfg):
|
def load_tokenizer(cfg):
|
||||||
@@ -55,7 +51,7 @@ def load_tokenizer(cfg):
|
|||||||
if cfg.tokenizer_type:
|
if cfg.tokenizer_type:
|
||||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||||
|
|
||||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
|
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||||
tokenizer = tokenizer_cls.from_pretrained(
|
tokenizer = tokenizer_cls.from_pretrained(
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
@@ -84,18 +80,6 @@ def load_tokenizer(cfg):
|
|||||||
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
|
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
|
||||||
tokenizer.padding_side = "left"
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
# Qwen base only has single token, so we need to set the special tokens
|
|
||||||
if cfg.is_qwen_derived_model:
|
|
||||||
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
|
|
||||||
for attr_name in token_ids:
|
|
||||||
if getattr(tokenizer, attr_name) is None:
|
|
||||||
setattr(tokenizer, attr_name, tokenizer.eod_id)
|
|
||||||
|
|
||||||
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
|
|
||||||
for attr_name in token_names:
|
|
||||||
if getattr(tokenizer, attr_name) is None:
|
|
||||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
|
||||||
|
|
||||||
if cfg.special_tokens:
|
if cfg.special_tokens:
|
||||||
for k, val in cfg.special_tokens.items():
|
for k, val in cfg.special_tokens.items():
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
@@ -126,6 +110,7 @@ def load_model(
|
|||||||
Load a model for a given configuration and tokenizer.
|
Load a model for a given configuration and tokenizer.
|
||||||
"""
|
"""
|
||||||
base_model = cfg.base_model
|
base_model = cfg.base_model
|
||||||
|
base_model_config = cfg.base_model_config
|
||||||
model_type = cfg.model_type
|
model_type = cfg.model_type
|
||||||
model_config = load_model_config(cfg)
|
model_config = load_model_config(cfg)
|
||||||
|
|
||||||
@@ -193,11 +178,7 @@ def load_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LOG.info("patching with flash attention")
|
LOG.info("patching with flash attention")
|
||||||
replace_mistral_attn_with_flash_attn(
|
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||||
packed=cfg.sample_packing,
|
|
||||||
cross_entropy=cfg.flash_attn_cross_entropy,
|
|
||||||
rms_norm=cfg.flash_attn_rms_norm,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||||
@@ -257,9 +238,16 @@ def load_model(
|
|||||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||||
from transformers import LlamaForCausalLM
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
|
config_kwargs = {}
|
||||||
|
if cfg.rope_scaling:
|
||||||
|
config_kwargs["rope_scaling"] = cfg.rope_scaling
|
||||||
|
config = LlamaConfig.from_pretrained(
|
||||||
|
base_model_config,
|
||||||
|
**config_kwargs,
|
||||||
|
)
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=config,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -278,15 +266,6 @@ def load_model(
|
|||||||
if cfg.flash_attn_fuse_qkv:
|
if cfg.flash_attn_fuse_qkv:
|
||||||
LOG.info("patching with fused QKV")
|
LOG.info("patching with fused QKV")
|
||||||
replace_llama_qkv_with_fused(model)
|
replace_llama_qkv_with_fused(model)
|
||||||
elif cfg.is_mistral_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
|
||||||
if cfg.flash_attention and not inference:
|
|
||||||
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
|
||||||
replace_mistral_mlp_with_swiglu,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.flash_attn_fuse_mlp:
|
|
||||||
LOG.info("patching with SwiGLU")
|
|
||||||
replace_mistral_mlp_with_swiglu(model)
|
|
||||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||||
# This is a WIP, still an issue with the backward pass
|
# This is a WIP, still an issue with the backward pass
|
||||||
# RuntimeError: grad can be implicitly created only for scalar outputs
|
# RuntimeError: grad can be implicitly created only for scalar outputs
|
||||||
@@ -313,10 +292,10 @@ def load_model(
|
|||||||
# device=cfg.device,
|
# device=cfg.device,
|
||||||
# )
|
# )
|
||||||
# model.train() # sets to train instead of eval mode
|
# model.train() # sets to train instead of eval mode
|
||||||
elif model_type == "PhiForCausalLM":
|
elif model_type == "MixFormerSequentialForCausalLM":
|
||||||
from axolotl.models.phi import PhiForCausalLM
|
from axolotl.models.phi import MixFormerSequentialForCausalLM
|
||||||
|
|
||||||
model = PhiForCausalLM.from_pretrained(
|
model = MixFormerSequentialForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
@@ -326,55 +305,66 @@ def load_model(
|
|||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = getattr(transformers, model_type).from_pretrained(
|
model = getattr(transformers, model_type).from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
)
|
||||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||||
# when training starts
|
# when training starts
|
||||||
if (
|
if (
|
||||||
hasattr(model_config, "max_seq_len")
|
hasattr(config, "max_seq_len")
|
||||||
and model_config.max_seq_len
|
and config.max_seq_len
|
||||||
and cfg.sequence_len > model_config.max_seq_len
|
and cfg.sequence_len > config.max_seq_len
|
||||||
):
|
):
|
||||||
model_config.max_seq_len = cfg.sequence_len
|
config.max_seq_len = cfg.sequence_len
|
||||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||||
elif (
|
elif (
|
||||||
hasattr(model_config, "max_sequence_length")
|
hasattr(config, "max_sequence_length")
|
||||||
and model_config.max_sequence_length
|
and config.max_sequence_length
|
||||||
and cfg.sequence_len > model_config.max_sequence_length
|
and cfg.sequence_len > config.max_sequence_length
|
||||||
):
|
):
|
||||||
model_config.max_sequence_length = cfg.sequence_len
|
config.max_sequence_length = cfg.sequence_len
|
||||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=config,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=model_config,
|
config=config,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
except Exception as err: # pylint: disable=broad-exception-caught
|
except Exception as err: # pylint: disable=broad-exception-caught
|
||||||
|
LOG.error(
|
||||||
|
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
||||||
|
)
|
||||||
LOG.exception(err)
|
LOG.exception(err)
|
||||||
raise err
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
embeddings_len = (
|
embeddings_len = (
|
||||||
math.ceil(len(tokenizer) / 32) * 32
|
math.ceil(len(tokenizer) / 32) * 32
|
||||||
@@ -425,22 +415,15 @@ def load_model(
|
|||||||
module.to(torch.float32)
|
module.to(torch.float32)
|
||||||
|
|
||||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||||
skip_prepare_model_for_kbit_training = False
|
|
||||||
|
|
||||||
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
|
||||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
|
||||||
skip_prepare_model_for_kbit_training = True
|
|
||||||
|
|
||||||
if (cfg.adapter == "lora" and load_in_8bit) or (
|
if (cfg.adapter == "lora" and load_in_8bit) or (
|
||||||
cfg.adapter == "qlora" and cfg.load_in_4bit
|
cfg.adapter == "qlora" and cfg.load_in_4bit
|
||||||
):
|
):
|
||||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||||
if cfg.gradient_checkpointing:
|
if cfg.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
if not skip_prepare_model_for_kbit_training:
|
model = prepare_model_for_kbit_training(
|
||||||
model = prepare_model_for_kbit_training(
|
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
)
|
||||||
)
|
|
||||||
needs_fa2_dtype = True
|
needs_fa2_dtype = True
|
||||||
|
|
||||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||||
|
|||||||
@@ -181,16 +181,13 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||||
return max(
|
return (
|
||||||
0,
|
world_size
|
||||||
(
|
* math.floor(
|
||||||
world_size
|
0.99
|
||||||
* math.floor(
|
* lengths_sum_per_device
|
||||||
0.99
|
/ self.packing_efficiency_estimate
|
||||||
* lengths_sum_per_device
|
// self.batch_max_len
|
||||||
/ self.packing_efficiency_estimate
|
)
|
||||||
// self.batch_max_len
|
- 1
|
||||||
)
|
|
||||||
- 1
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -141,35 +141,32 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
|
|
||||||
def calculate_total_num_steps(cfg, train_dataset, update=True):
|
def calculate_total_num_steps(cfg, train_dataset):
|
||||||
if not cfg.total_num_tokens:
|
|
||||||
total_num_tokens = np.sum(
|
|
||||||
train_dataset.data.column("input_ids")
|
|
||||||
.to_pandas()
|
|
||||||
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
|
||||||
.values
|
|
||||||
)
|
|
||||||
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
|
|
||||||
if update:
|
|
||||||
cfg.total_num_tokens = total_num_tokens
|
|
||||||
|
|
||||||
if not cfg.total_supervised_tokens:
|
|
||||||
total_supervised_tokens = (
|
|
||||||
train_dataset.data.column("labels")
|
|
||||||
.to_pandas()
|
|
||||||
.apply(lambda x: np.sum(np.array(x) != -100))
|
|
||||||
.sum()
|
|
||||||
)
|
|
||||||
LOG.debug(
|
|
||||||
f"`total_supervised_tokens: {total_supervised_tokens}`",
|
|
||||||
main_process_only=True,
|
|
||||||
)
|
|
||||||
if update:
|
|
||||||
cfg.total_supervised_tokens = total_supervised_tokens
|
|
||||||
|
|
||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
# we have to drop anything longer then sequence len otherwise
|
# we have to drop anything longer then sequence len otherwise
|
||||||
# flash attention with position ids fails
|
# flash attention with position ids fails
|
||||||
|
if not cfg.total_num_tokens:
|
||||||
|
total_num_tokens = np.sum(
|
||||||
|
train_dataset.data.column("input_ids")
|
||||||
|
.to_pandas()
|
||||||
|
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
||||||
|
.values
|
||||||
|
)
|
||||||
|
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
|
||||||
|
cfg.total_num_tokens = total_num_tokens
|
||||||
|
|
||||||
|
if not cfg.total_supervised_tokens:
|
||||||
|
total_supervised_tokens = (
|
||||||
|
train_dataset.data.column("labels")
|
||||||
|
.to_pandas()
|
||||||
|
.apply(lambda x: np.sum(np.array(x) != -100))
|
||||||
|
.sum()
|
||||||
|
)
|
||||||
|
LOG.debug(
|
||||||
|
f"`total_supervised_tokens: {total_supervised_tokens}`",
|
||||||
|
main_process_only=True,
|
||||||
|
)
|
||||||
|
cfg.total_supervised_tokens = total_supervised_tokens
|
||||||
|
|
||||||
if cfg.sample_packing_eff_est:
|
if cfg.sample_packing_eff_est:
|
||||||
total_num_steps = (
|
total_num_steps = (
|
||||||
@@ -234,8 +231,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
sample_packing_eff_est = (
|
sample_packing_eff_est = (
|
||||||
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
||||||
)
|
)
|
||||||
if update:
|
cfg.sample_packing_eff_est = sample_packing_eff_est
|
||||||
cfg.sample_packing_eff_est = sample_packing_eff_est
|
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
|
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
|
||||||
main_process_only=True,
|
main_process_only=True,
|
||||||
@@ -267,14 +263,14 @@ def setup_fsdp_envs(cfg):
|
|||||||
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
||||||
|
|
||||||
|
|
||||||
def prepare_optim_env(cfg):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
setup_fsdp_envs(cfg)
|
setup_fsdp_envs(cfg)
|
||||||
elif cfg.deepspeed:
|
elif cfg.deepspeed:
|
||||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
|
if cfg.fp8:
|
||||||
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
|
||||||
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
||||||
trainer_builder.train_dataset = train_dataset
|
trainer_builder.train_dataset = train_dataset
|
||||||
trainer_builder.eval_dataset = eval_dataset
|
trainer_builder.eval_dataset = eval_dataset
|
||||||
|
|||||||
@@ -101,7 +101,6 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"bf16": True,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class TestPhi(unittest.TestCase):
|
|||||||
{
|
{
|
||||||
"base_model": "microsoft/phi-1_5",
|
"base_model": "microsoft/phi-1_5",
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"model_type": "PhiForCausalLM",
|
"model_type": "MixFormerSequentialForCausalLM",
|
||||||
"tokenizer_type": "AutoTokenizer",
|
"tokenizer_type": "AutoTokenizer",
|
||||||
"sequence_len": 512,
|
"sequence_len": 512,
|
||||||
"sample_packing": False,
|
"sample_packing": False,
|
||||||
@@ -76,7 +76,7 @@ class TestPhi(unittest.TestCase):
|
|||||||
{
|
{
|
||||||
"base_model": "microsoft/phi-1_5",
|
"base_model": "microsoft/phi-1_5",
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"model_type": "PhiForCausalLM",
|
"model_type": "MixFormerSequentialForCausalLM",
|
||||||
"tokenizer_type": "AutoTokenizer",
|
"tokenizer_type": "AutoTokenizer",
|
||||||
"sequence_len": 512,
|
"sequence_len": 512,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
|
|||||||
@@ -1,95 +0,0 @@
|
|||||||
"""
|
|
||||||
E2E tests for resuming training
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import subprocess
|
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
|
||||||
from axolotl.train import train
|
|
||||||
from axolotl.utils.config import normalize_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from .utils import most_recent_subdir, with_temp_dir
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
class TestResumeLlama(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
Test case for resuming training of llama models
|
|
||||||
"""
|
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_resume_qlora(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "JackFram/llama-68m",
|
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"sample_packing": True,
|
|
||||||
"flash_attention": True,
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"adapter": "qlora",
|
|
||||||
"lora_r": 32,
|
|
||||||
"lora_alpha": 64,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0.1,
|
|
||||||
"special_tokens": {},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "vicgalle/alpaca-gpt4",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 2,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"save_steps": 10,
|
|
||||||
"save_total_limit": 5,
|
|
||||||
"max_steps": 40,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if is_torch_bf16_gpu_available():
|
|
||||||
cfg.bf16 = True
|
|
||||||
else:
|
|
||||||
cfg.fp16 = True
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
||||||
|
|
||||||
resume_cfg = cfg | DictDefault(
|
|
||||||
{
|
|
||||||
"resume_from_checkpoint": f"{temp_dir}/checkpoint-30/",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
normalize_config(resume_cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
|
|
||||||
train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
|
||||||
|
|
||||||
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
|
|
||||||
cmd = f"tensorboard --inspect --logdir {tb_log_path_1}"
|
|
||||||
res = subprocess.run(
|
|
||||||
cmd, shell=True, text=True, capture_output=True, check=True
|
|
||||||
)
|
|
||||||
pattern = r"first_step\s+(\d+)"
|
|
||||||
first_steps = int(re.findall(pattern, res.stdout)[0])
|
|
||||||
assert first_steps == 31
|
|
||||||
@@ -1,11 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
helper utils for tests
|
helper utils for tests
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
def with_temp_dir(test_func):
|
def with_temp_dir(test_func):
|
||||||
@@ -21,13 +20,3 @@ def with_temp_dir(test_func):
|
|||||||
shutil.rmtree(temp_dir)
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def most_recent_subdir(path):
|
|
||||||
base_path = Path(path)
|
|
||||||
subdirectories = [d for d in base_path.iterdir() if d.is_dir()]
|
|
||||||
if not subdirectories:
|
|
||||||
return None
|
|
||||||
subdir = max(subdirectories, key=os.path.getctime)
|
|
||||||
|
|
||||||
return subdir
|
|
||||||
|
|||||||
@@ -649,33 +649,3 @@ class ValidationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
def test_warmup_step_no_conflict(self):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"warmup_steps": 10,
|
|
||||||
"warmup_ratio": 0.1,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match=r".*warmup_steps and warmup_ratio are mutually exclusive*",
|
|
||||||
):
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"warmup_steps": 10,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"warmup_ratio": 0.1,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user