Compare commits
26 Commits
20240307-u
...
4bit-optim
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e6b78c1fca | ||
|
|
a236f5eab5 | ||
|
|
dd449c5cd8 | ||
|
|
40a88e8c4a | ||
|
|
43bdc5d3de | ||
|
|
b1e3e1b25f | ||
|
|
2ea70ebbd8 | ||
|
|
e8c8ea64b3 | ||
|
|
d485a08393 | ||
|
|
f083aed2c7 | ||
|
|
868c33954d | ||
|
|
8df7b888ff | ||
|
|
6366b0c212 | ||
|
|
05bcc9ea56 | ||
|
|
3bd8203c35 | ||
|
|
8b12468230 | ||
|
|
0976781e15 | ||
|
|
8a82d2e0a4 | ||
|
|
4326520829 | ||
|
|
b7d8a7dc4d | ||
|
|
b0ee9ec734 | ||
|
|
0bc114d2e1 | ||
|
|
7659c001aa | ||
|
|
3fd8093717 | ||
|
|
9b6ee83a73 | ||
|
|
638c2dafb5 |
76
README.md
76
README.md
@@ -13,6 +13,9 @@ Features:
|
||||
- Log results and optionally checkpoints to wandb or mlflow
|
||||
- And more!
|
||||
|
||||
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
|
||||
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
|
||||
</a>
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
@@ -25,9 +28,10 @@ Features:
|
||||
- [Environment](#environment)
|
||||
- [Docker](#docker)
|
||||
- [Conda/Pip venv](#condapip-venv)
|
||||
- [Cloud GPU](#cloud-gpu) - Latitude.sh, RunPod
|
||||
- [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod
|
||||
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
|
||||
- [Windows](#windows)
|
||||
- [Mac](#mac)
|
||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||
- [Dataset](#dataset)
|
||||
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
||||
@@ -99,24 +103,14 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
||||
|
||||
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
|
||||
|
||||
### For developers
|
||||
```bash
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging
|
||||
```
|
||||
|
||||
General case:
|
||||
```
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
Mac: see https://github.com/OpenAccess-AI-Collective/axolotl/blob/13199f678b9aab39e92961323bdbce3234ee4b2b/docs/mac.md
|
||||
```
|
||||
pip3 install -e '.'
|
||||
```
|
||||
|
||||
### Usage
|
||||
```bash
|
||||
# preprocess datasets - optional but recommended
|
||||
@@ -199,6 +193,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
||||
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
|
||||
|
||||
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
|
||||
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
|
||||
#### Bare Metal Cloud GPU
|
||||
@@ -248,9 +243,31 @@ For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud
|
||||
```
|
||||
</details>
|
||||
|
||||
##### GCP
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Click to Expand</summary>
|
||||
|
||||
Use a Deeplearning linux OS with cuda and pytorch installed. Then follow instructions on quickstart.
|
||||
|
||||
Make sure to run the below to uninstall xla.
|
||||
```bash
|
||||
pip uninstall -y torch_xla[tpu]
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### Windows
|
||||
Please use WSL or Docker!
|
||||
|
||||
#### Mac
|
||||
|
||||
Use the below instead of the install method in QuickStart.
|
||||
```
|
||||
pip3 install -e '.'
|
||||
```
|
||||
More info: [mac.md](/docs/mac.md)
|
||||
|
||||
#### Launching on public clouds via SkyPilot
|
||||
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
|
||||
@@ -634,9 +651,13 @@ datasets:
|
||||
train_on_split: train # Optional[str] name of dataset split to load from
|
||||
|
||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
field_human: # Optional[str]. Human key to use for conversation.
|
||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
||||
# Add additional keys from your dataset as input or output roles
|
||||
roles:
|
||||
input: # Optional[List[str]]. These will be masked based on train_on_input
|
||||
output: # Optional[List[str]].
|
||||
|
||||
# Custom user instruction prompt
|
||||
- path: repo
|
||||
@@ -661,6 +682,10 @@ datasets:
|
||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||
field:
|
||||
|
||||
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
|
||||
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
|
||||
shuffle_merged_datasets: true
|
||||
|
||||
# A list of one or more datasets to eval the model with.
|
||||
# You can use either test_datasets, or val_set_size, but not both.
|
||||
test_datasets:
|
||||
@@ -842,7 +867,7 @@ group_by_length: false
|
||||
gradient_checkpointing: false
|
||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||
# gradient_checkpointing_kwargs:
|
||||
# use_reentrant: false
|
||||
# use_reentrant: true
|
||||
|
||||
# Stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
@@ -882,7 +907,26 @@ lr_div_factor: # Learning rate div factor
|
||||
# - paged_adamw_8bit
|
||||
# - paged_lion_32bit
|
||||
# - paged_lion_8bit
|
||||
# - galore_adamw
|
||||
# - galore_adamw_8bit
|
||||
# - galore_adafactor
|
||||
# - galore_adamw_layerwise
|
||||
# - galore_adamw_8bit_layerwise
|
||||
# - galore_adafactor_layerwise
|
||||
optimizer:
|
||||
# Dictionary of arguments to pass to the optimizer
|
||||
optim_args:
|
||||
# For Galore Optimizers the following optim_args are available
|
||||
# rank: # type: int
|
||||
# update_proj_gap # type: int
|
||||
# scale # type: float
|
||||
# proj_type: # type: str, default = std
|
||||
|
||||
# The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
|
||||
optim_target_modules:
|
||||
# - self_attn # for llama
|
||||
# - mlp
|
||||
|
||||
# Specify weight decay
|
||||
weight_decay:
|
||||
# adamw hyperparams
|
||||
@@ -1079,6 +1123,10 @@ fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
##### FSDP + QLoRA
|
||||
|
||||
Axolotl supports training with FSDP and QLoRA, see [these docs](docs/fsdp_qlora.md) for more information.
|
||||
|
||||
##### Weights & Biases Logging
|
||||
|
||||
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
||||
@@ -1298,4 +1346,6 @@ consider sponsoring the project via [GitHub Sponsors](https://github.com/sponsor
|
||||
|
||||
#### 🥉 Bronze Sponsors - $500/mo
|
||||
|
||||
- [JarvisLabs.ai](https://jarvislabs.ai)
|
||||
|
||||
---
|
||||
|
||||
@@ -23,9 +23,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
# So we can test the Docker image
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
|
||||
@@ -21,9 +21,9 @@ WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
# So we can test the Docker image
|
||||
|
||||
37
docs/fsdp_qlora.md
Normal file
37
docs/fsdp_qlora.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# FDSP + QLoRA
|
||||
|
||||
## Background
|
||||
|
||||
Using FSDP with QLoRA is essential for **fine-tuning larger (70b+ parameter) LLMs on consumer GPUs.** For example, you can use FSDP + QLoRA to train a 70b model on two 24GB GPUs[^1].
|
||||
|
||||
Below, we describe how to use this feature in Axolotl.
|
||||
|
||||
## Usage
|
||||
|
||||
To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
||||
|
||||
> ![Tip]
|
||||
> See the [example config](#example-config) file in addition to reading these instructions.
|
||||
|
||||
1. Set `adapter: qlora` in your axolotl config file.
|
||||
2. Enable FSDP in your axolotl config, as [described here](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#fsdp).
|
||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||
|
||||
## Example Config
|
||||
|
||||
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.
|
||||
|
||||
## References
|
||||
|
||||
- [PR #1378](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1378) enabling QLoRA in FSDP in Axolotl.
|
||||
- [Blog Post](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the [Answer.AI](https://www.answer.ai/) team describing the work that enabled QLoRA in FSDP.
|
||||
- Related HuggingFace PRs Enabling FDSP + QLoRA:
|
||||
- Accelerate [PR#2544](https://github.com/huggingface/accelerate/pull/2544 )
|
||||
- Transformers [PR#29587](https://github.com/huggingface/transformers/pull/29587)
|
||||
- TRL [PR#1416](https://github.com/huggingface/trl/pull/1416)
|
||||
- PEFT [PR#1550](https://github.com/huggingface/peft/pull/1550)
|
||||
|
||||
|
||||
|
||||
|
||||
[^1]: This was enabled by [this work](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) from the Answer.AI team.
|
||||
29
docs/optimizers.md
Normal file
29
docs/optimizers.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# Optimizers
|
||||
|
||||
Optimizers are an important component when training LLMs. Optimizers are responsible for updating the model's weights (parameters) based on the gradients computed during backpropagation.
|
||||
The goal of an optimizer is to minimize the loss function.
|
||||
|
||||
### Adam/AdamW Optimizers
|
||||
|
||||
```yaml
|
||||
adam_beta1: 0.9
|
||||
adam_beta2: 0.999
|
||||
adam_epsilon: 1e-8
|
||||
weight_decay: 0.0
|
||||
```
|
||||
|
||||
### GaLore Optimizer
|
||||
|
||||
https://huggingface.co/papers/2403.03507
|
||||
|
||||
```yaml
|
||||
optimizer: galore_adamw | galore_adamw_8bit | galore_adafactor
|
||||
optim_args:
|
||||
rank: 128
|
||||
update_proj_gap: 200
|
||||
scale: 0.25
|
||||
proj_type: std
|
||||
optim_target_modules:
|
||||
- mlp
|
||||
- attn
|
||||
```
|
||||
15
docs/rlhf.md
15
docs/rlhf.md
@@ -34,6 +34,21 @@ datasets:
|
||||
rl: ipo
|
||||
```
|
||||
|
||||
#### ORPO
|
||||
|
||||
Paper: https://arxiv.org/abs/2403.07691
|
||||
|
||||
```yaml
|
||||
rl: orpo
|
||||
orpo_alpha: 0.1
|
||||
remove_unused_columns: false
|
||||
|
||||
chat_template: chatml
|
||||
datasets:
|
||||
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
||||
type: orpo.chat_template
|
||||
```
|
||||
|
||||
#### Using local dataset files
|
||||
```yaml
|
||||
datasets:
|
||||
|
||||
@@ -21,7 +21,7 @@ lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
|
||||
70
examples/llama-2/qlora-fsdp.yml
Normal file
70
examples/llama-2/qlora-fsdp.yml
Normal file
@@ -0,0 +1,70 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: yahma/alpaca-cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 512
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 4
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.00001
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
special_tokens:
|
||||
74
examples/mistral/mixtral-qlora-fsdp.yml
Normal file
74
examples/mistral/mixtral-qlora-fsdp.yml
Normal file
@@ -0,0 +1,74 @@
|
||||
base_model: mistralai/Mixtral-8x7B-v0.1
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
output_dir: ./qlora-out
|
||||
|
||||
model_config:
|
||||
output_router_logits: true
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 1024
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
|
||||
special_tokens:
|
||||
@@ -16,12 +16,12 @@ output_dir: ./qlora-out
|
||||
|
||||
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
||||
unfrozen_parameters:
|
||||
# - lm_head.*
|
||||
# - model.embed_tokens.*
|
||||
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
|
||||
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
|
||||
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
|
||||
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
|
||||
# - ^lm_head.weight$
|
||||
# - ^model.embed_tokens.weight$[:32000]
|
||||
# - model.layers.2[0-9]+.block_sparse_moe.gate
|
||||
# - model.layers.2[0-9]+.block_sparse_moe.experts
|
||||
# - model.layers.3[0-9]+.block_sparse_moe.gate
|
||||
# - model.layers.3[0-9]+.block_sparse_moe.experts
|
||||
|
||||
model_config:
|
||||
output_router_logits: true
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft==0.9.0
|
||||
transformers==4.38.2
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@f6261d7d81edd036fc53bfede65fe91f01a661aa
|
||||
tokenizers==0.15.0
|
||||
bitsandbytes>=0.41.1
|
||||
bitsandbytes>=0.43.0
|
||||
accelerate==0.26.1
|
||||
deepspeed==0.13.1
|
||||
pydantic==2.6.3
|
||||
@@ -39,4 +39,8 @@ s3fs
|
||||
gcsfs
|
||||
# adlfs
|
||||
|
||||
trl>=0.7.9
|
||||
trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
|
||||
fastcore>=1.5.29
|
||||
|
||||
lpmm @ git+https://github.com/thu-ml/low-bit-optimizers.git@main
|
||||
yacs
|
||||
|
||||
3
setup.py
3
setup.py
@@ -89,5 +89,8 @@ setup(
|
||||
"lion-pytorch": [
|
||||
"lion-pytorch==0.1.2",
|
||||
],
|
||||
"galore": [
|
||||
"galore_torch",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
if parsed_cfg.rl:
|
||||
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
|
||||
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
else:
|
||||
register_chatml_template()
|
||||
|
||||
if cfg.rl:
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
0
src/axolotl/core/policies/__init__.py
Normal file
0
src/axolotl/core/policies/__init__.py
Normal file
55
src/axolotl/core/policies/auto_wrap.py
Normal file
55
src/axolotl/core/policies/auto_wrap.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""module for building the auto wrap policy for FSDP"""
|
||||
import functools
|
||||
|
||||
from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
_or_policy,
|
||||
lambda_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
||||
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
|
||||
|
||||
SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
|
||||
"llama",
|
||||
"mistral",
|
||||
"mixtral",
|
||||
]
|
||||
|
||||
|
||||
def get_wrapping_policy_factory(model_type):
|
||||
if model_type == "llama":
|
||||
layer_to_wrap = LlamaDecoderLayer
|
||||
elif model_type == "mistral":
|
||||
layer_to_wrap = MistralDecoderLayer
|
||||
elif model_type == "mixtral":
|
||||
layer_to_wrap = MixtralDecoderLayer
|
||||
|
||||
def get_wrapping_policy():
|
||||
"""This checks for lora layers (has weight and requires_grad)"""
|
||||
|
||||
def lambda_policy_fn(module):
|
||||
return (
|
||||
len(list(module.named_children())) == 0
|
||||
and getattr(module, "weight", None) is not None
|
||||
and module.weight.requires_grad
|
||||
)
|
||||
|
||||
lambda_policy = functools.partial(
|
||||
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
|
||||
)
|
||||
transformer_layer_name = layer_to_wrap
|
||||
transformer_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls=(
|
||||
PrefixEncoder,
|
||||
PromptEncoder,
|
||||
PromptEmbedding,
|
||||
transformer_layer_name,
|
||||
),
|
||||
)
|
||||
policies = [lambda_policy, transformer_wrap_policy]
|
||||
return functools.partial(_or_policy, policies=policies)
|
||||
|
||||
return get_wrapping_policy
|
||||
@@ -8,20 +8,28 @@ import importlib
|
||||
import importlib.util
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Type, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
import lpmm
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate import FullyShardedDataParallelPlugin
|
||||
from accelerate.utils import str_to_bool
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from torch.distributed.fsdp import MixedPrecision
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import (
|
||||
EarlyStoppingCallback,
|
||||
PreTrainedModel,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
@@ -30,6 +38,8 @@ from transformers.trainer_utils import seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
|
||||
from axolotl.core.trainers import OptimizerNames
|
||||
from axolotl.loraplus import create_loraplus_optimizer
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
@@ -56,6 +66,9 @@ from axolotl.utils.schedulers import (
|
||||
get_cosine_schedule_with_warmup_decay_constant,
|
||||
)
|
||||
|
||||
# monkeypatch so it accepts our custom optimizers
|
||||
transformers.training_args.OptimizerNames = OptimizerNames
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
@@ -191,6 +204,13 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
)
|
||||
qlora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "whether this is a qlora training"},
|
||||
)
|
||||
orpo_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -207,33 +227,115 @@ class AxolotlTrainer(Trainer):
|
||||
num_epochs=1,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self.num_epochs = num_epochs
|
||||
self.bench_data_collator = bench_data_collator
|
||||
self.eval_data_collator = eval_data_collator
|
||||
super().__init__(*_args, **kwargs)
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
@staticmethod
|
||||
def get_optimizer_cls_and_kwargs(
|
||||
args: TrainingArguments, model: Optional[PreTrainedModel] = None
|
||||
) -> Tuple[Any, Any]:
|
||||
optim_args = {}
|
||||
if args.optim_args:
|
||||
for mapping in args.optim_args.replace(" ", "").split(","):
|
||||
key, value = mapping.split("=")
|
||||
optim_args[key] = value
|
||||
|
||||
optimizer_kwargs = {"lr": args.learning_rate}
|
||||
|
||||
adam_kwargs = {
|
||||
"betas": (args.adam_beta1, args.adam_beta2),
|
||||
"eps": args.adam_epsilon,
|
||||
}
|
||||
|
||||
if args.optim in [
|
||||
OptimizerNames.LPMM_ADAMW_4BIT,
|
||||
OptimizerNames.LPMM_ADAMW_4BIT_FUSED,
|
||||
]:
|
||||
optimizer_cls = lpmm.optim.AdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
if args.optim == OptimizerNames.LPMM_ADAMW_4BIT_FUSED:
|
||||
optimizer_kwargs.update({"fused": True})
|
||||
return optimizer_cls, optimizer_kwargs
|
||||
|
||||
return Trainer.get_optimizer_cls_and_kwargs(
|
||||
args,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def create_optimizer(self):
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in opt_model.named_parameters()
|
||||
if (n in decay_parameters and p.requires_grad)
|
||||
],
|
||||
"weight_decay": self.args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in opt_model.named_parameters()
|
||||
if (n not in decay_parameters and p.requires_grad)
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
(
|
||||
optimizer_cls,
|
||||
optimizer_kwargs,
|
||||
loraplus_lr_ratio,
|
||||
loraplus_lr_embedding,
|
||||
)
|
||||
) = AxolotlTrainer.get_optimizer_cls_and_kwargs(self.args)
|
||||
|
||||
if self.args.loraplus_lr_ratio:
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
loraplus_lr_embedding = getattr(
|
||||
self.args, "loraplus_lr_embedding", None
|
||||
)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
optimizer_kwargs,
|
||||
loraplus_lr_ratio,
|
||||
loraplus_lr_embedding,
|
||||
)
|
||||
|
||||
else:
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
|
||||
if optimizer_cls.__name__ == "Adam8bit":
|
||||
import bitsandbytes
|
||||
|
||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||
|
||||
skipped = 0
|
||||
for module in opt_model.modules():
|
||||
if isinstance(module, nn.Embedding):
|
||||
skipped += sum(
|
||||
{
|
||||
p.data_ptr(): p.numel() for p in module.parameters()
|
||||
}.values()
|
||||
)
|
||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||
manager.register_module_override(
|
||||
module, "weight", {"optim_bits": 32}
|
||||
)
|
||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
@@ -456,8 +558,112 @@ class AxolotlTrainer(Trainer):
|
||||
# outputs = model(**inputs)
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
if self.args.orpo_alpha:
|
||||
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
def orpo_compute_custom_loss(self, logits, labels):
|
||||
logits = logits.contiguous()
|
||||
loss = 0.0
|
||||
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
# Flatten the tokens
|
||||
loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(
|
||||
dim=-1
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def orpo_compute_logps(
|
||||
self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits
|
||||
):
|
||||
# Get the shape of chosen_attention_mask[:, :-1]
|
||||
chosen_shape = chosen_attention_mask[:, :-1].shape
|
||||
|
||||
# Calculate the padding size
|
||||
pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1)
|
||||
|
||||
# Pad prompt_attention_mask with zeros to match the desired shape
|
||||
prompt_attention_mask_padded = torch.nn.functional.pad(
|
||||
prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0
|
||||
)
|
||||
|
||||
# Perform the subtraction operation
|
||||
mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded
|
||||
|
||||
per_token_logps = torch.gather(
|
||||
logits[:, :-1, :].log_softmax(-1),
|
||||
dim=2,
|
||||
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
|
||||
).squeeze(2)
|
||||
return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(
|
||||
dtype=torch.float64
|
||||
) / mask.sum(dim=1).to(dtype=torch.float64)
|
||||
|
||||
def orpo_compute_loss(self, model, inputs, return_outputs=False):
|
||||
outputs_neg = model(
|
||||
**{
|
||||
"input_ids": inputs["rejected_input_ids"],
|
||||
"attention_mask": inputs["rejected_attention_mask"],
|
||||
"labels": inputs["rejected_labels"],
|
||||
},
|
||||
output_hidden_states=True,
|
||||
)
|
||||
outputs_pos = model(
|
||||
**{
|
||||
"input_ids": inputs["input_ids"],
|
||||
"attention_mask": inputs["attention_mask"],
|
||||
"labels": inputs["labels"],
|
||||
},
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# Calculate NLL loss
|
||||
pos_loss = self.orpo_compute_custom_loss(
|
||||
logits=outputs_pos.logits, labels=inputs["input_ids"]
|
||||
)
|
||||
|
||||
# Calculate Log Probability
|
||||
pos_prob = self.orpo_compute_logps(
|
||||
prompt_attention_mask=inputs["prompt_attention_mask"],
|
||||
chosen_inputs=inputs["input_ids"],
|
||||
chosen_attention_mask=inputs["attention_mask"],
|
||||
logits=outputs_pos.logits,
|
||||
)
|
||||
neg_prob = self.orpo_compute_logps(
|
||||
prompt_attention_mask=inputs["prompt_attention_mask"],
|
||||
chosen_inputs=inputs["rejected_input_ids"],
|
||||
chosen_attention_mask=inputs["rejected_attention_mask"],
|
||||
logits=outputs_neg.logits,
|
||||
)
|
||||
|
||||
# Calculate log odds
|
||||
log_odds = (pos_prob - neg_prob) - (
|
||||
torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob))
|
||||
)
|
||||
sig_ratio = torch.nn.functional.sigmoid(log_odds)
|
||||
ratio = torch.log(sig_ratio)
|
||||
|
||||
# Calculate the Final Loss
|
||||
loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to(
|
||||
dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
metrics = {}
|
||||
metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item()
|
||||
metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item()
|
||||
metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item()
|
||||
metrics["log_odds"] = torch.mean(log_odds).cpu().item()
|
||||
self.store_metrics(metrics, train_eval="train")
|
||||
|
||||
return (loss, outputs_pos) if return_outputs else loss
|
||||
|
||||
@wraps(Trainer.push_to_hub)
|
||||
def push_to_hub(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
@@ -468,6 +674,78 @@ class AxolotlTrainer(Trainer):
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@wraps(Trainer.create_accelerator_and_postprocess)
|
||||
def create_accelerator_and_postprocess(self):
|
||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
res = super().create_accelerator_and_postprocess()
|
||||
|
||||
if self.args.qlora is False:
|
||||
return res
|
||||
|
||||
# the rest of this method override is specific to fsdp + qlora (for now)
|
||||
sync_module_states = (
|
||||
str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
|
||||
)
|
||||
|
||||
mp_policy = None
|
||||
amp = os.environ["ACCELERATE_MIXED_PRECISION"]
|
||||
if amp == "fp16":
|
||||
mp_policy = MixedPrecision(
|
||||
param_dtype=torch.float32,
|
||||
reduce_dtype=torch.float32,
|
||||
buffer_dtype=torch.float32,
|
||||
)
|
||||
elif amp == "bf16":
|
||||
mp_policy = MixedPrecision(
|
||||
param_dtype=torch.float32,
|
||||
reduce_dtype=torch.float32,
|
||||
buffer_dtype=torch.float32,
|
||||
)
|
||||
|
||||
# If somehow we figure out how we want to parameterize we want to autocast buffers...
|
||||
# mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
|
||||
# load_param_skip_names = ['inv_freq']
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
auto_wrap_policy=wrapping_policy(),
|
||||
cpu_offload=False,
|
||||
use_orig_params=False,
|
||||
limit_all_gathers=True,
|
||||
param_init_fn=lambda module: module.to_empty(
|
||||
device=torch.device("cuda"), recurse=False
|
||||
)
|
||||
if (rank != 0 and sync_module_states)
|
||||
else None,
|
||||
mixed_precision_policy=mp_policy,
|
||||
)
|
||||
self.accelerator.state.fsdp_plugin = fsdp_plugin
|
||||
|
||||
return res
|
||||
|
||||
def log(self, logs: Dict[str, float]) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
|
||||
Args:
|
||||
logs (`Dict[str, float]`):
|
||||
The values to log.
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
return super().log(logs)
|
||||
|
||||
def store_metrics(
|
||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
) -> None:
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""
|
||||
@@ -741,6 +1019,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
warmup_steps = None
|
||||
if self.cfg.warmup_steps is not None:
|
||||
warmup_steps = self.cfg.warmup_steps
|
||||
elif self.cfg.warmup_ratio is not None:
|
||||
@@ -777,15 +1056,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs[
|
||||
"gradient_checkpointing_kwargs"
|
||||
] = self.cfg.gradient_checkpointing_kwargs
|
||||
else:
|
||||
training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
|
||||
"use_reentrant": False
|
||||
}
|
||||
if self.cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||
if self.cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
|
||||
|
||||
if self.cfg.adapter == "qlora":
|
||||
training_arguments_kwargs["qlora"] = True
|
||||
|
||||
# deepspeed
|
||||
if self.cfg.deepspeed:
|
||||
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
||||
@@ -840,6 +1118,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
||||
training_arguments_kwargs["dataloader_drop_last"] = True
|
||||
|
||||
if self.cfg.remove_unused_columns is not None:
|
||||
training_arguments_kwargs[
|
||||
"remove_unused_columns"
|
||||
] = self.cfg.remove_unused_columns
|
||||
|
||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||
@@ -953,6 +1236,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["optim"] = (
|
||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||
)
|
||||
if self.cfg.optim_args:
|
||||
if isinstance(self.cfg.optim_args, dict):
|
||||
optim_args = ",".join(
|
||||
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
||||
)
|
||||
else:
|
||||
optim_args = self.cfg.optim_args
|
||||
training_arguments_kwargs["optim_args"] = optim_args
|
||||
if self.cfg.optim_target_modules:
|
||||
training_arguments_kwargs[
|
||||
"optim_target_modules"
|
||||
] = self.cfg.optim_target_modules
|
||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||
training_arguments_kwargs[
|
||||
"loraplus_lr_embedding"
|
||||
@@ -1007,6 +1302,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||
|
||||
if self.cfg.rl == "orpo":
|
||||
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
|
||||
|
||||
if self.cfg.neftune_noise_alpha is not None:
|
||||
training_arguments_kwargs[
|
||||
"neftune_noise_alpha"
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""module for trainer helpers like OptimizerNames"""
|
||||
|
||||
from transformers.utils import ExplicitEnum
|
||||
|
||||
|
||||
class OptimizerNames(ExplicitEnum):
|
||||
"""
|
||||
Stores the acceptable string identifiers for optimizers.
|
||||
"""
|
||||
|
||||
ADAMW_HF = "adamw_hf"
|
||||
ADAMW_TORCH = "adamw_torch"
|
||||
ADAMW_TORCH_FUSED = "adamw_torch_fused"
|
||||
ADAMW_TORCH_XLA = "adamw_torch_xla"
|
||||
ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
|
||||
ADAMW_APEX_FUSED = "adamw_apex_fused"
|
||||
ADAFACTOR = "adafactor"
|
||||
ADAMW_ANYPRECISION = "adamw_anyprecision"
|
||||
SGD = "sgd"
|
||||
ADAGRAD = "adagrad"
|
||||
ADAMW_BNB = "adamw_bnb_8bit"
|
||||
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
|
||||
LION_8BIT = "lion_8bit"
|
||||
LION = "lion_32bit"
|
||||
PAGED_ADAMW = "paged_adamw_32bit"
|
||||
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
|
||||
PAGED_LION = "paged_lion_32bit"
|
||||
PAGED_LION_8BIT = "paged_lion_8bit"
|
||||
RMSPROP = "rmsprop"
|
||||
RMSPROP_BNB = "rmsprop_bnb"
|
||||
RMSPROP_8BIT = "rmsprop_bnb_8bit"
|
||||
RMSPROP_32BIT = "rmsprop_bnb_32bit"
|
||||
GALORE_ADAMW = "galore_adamw"
|
||||
GALORE_ADAMW_8BIT = "galore_adamw_8bit"
|
||||
GALORE_ADAFACTOR = "galore_adafactor"
|
||||
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
|
||||
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
|
||||
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
|
||||
LPMM_ADAMW_4BIT = "lmpp_adamw_4bit"
|
||||
LPMM_ADAMW_4BIT_FUSED = "lmpp_adamw_4bit_fused"
|
||||
|
||||
@@ -30,6 +30,7 @@ class ColorfulFormatter(Formatter):
|
||||
|
||||
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"simple": {
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
"""multipack patching for v2 of sample packing"""
|
||||
import importlib
|
||||
|
||||
import transformers
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
||||
@@ -12,11 +15,12 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"falcon",
|
||||
"phi",
|
||||
"gemma",
|
||||
"gemmoe",
|
||||
"starcoder2",
|
||||
]
|
||||
|
||||
|
||||
def patch_for_multipack(model_type):
|
||||
def patch_for_multipack(model_type, model_name=None):
|
||||
if model_type == "mixtral":
|
||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
@@ -43,3 +47,15 @@ def patch_for_multipack(model_type):
|
||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemmoe":
|
||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
# we need to load the model here in order for modeling_gemmoe to be available
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
module_name = model_config.__class__.__module__.replace(
|
||||
".configuration_gemmoe", ".modeling_gemmoe"
|
||||
)
|
||||
modeling_gemmoe = importlib.import_module(module_name)
|
||||
modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
|
||||
20
src/axolotl/prompt_strategies/base.py
Normal file
20
src/axolotl/prompt_strategies/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
module for base dataset transform strategies
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def load(strategy, cfg, module_base=None, **kwargs):
|
||||
try:
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{strategy}", module_base)
|
||||
func = getattr(mod, load_fn)
|
||||
return func(cfg, **kwargs)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"unable to load strategy {strategy}")
|
||||
return None
|
||||
@@ -1,20 +1,8 @@
|
||||
"""
|
||||
module for DPO style dataset transform strategies
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from ..base import load as load_base
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def load(strategy, cfg, **kwargs):
|
||||
try:
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
|
||||
func = getattr(mod, load_fn)
|
||||
return func(cfg, **kwargs)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
LOG.warning(f"unable to load strategy {strategy}")
|
||||
return None
|
||||
load = partial(load_base, module="axolotl.prompt_strategies.dpo")
|
||||
|
||||
@@ -24,6 +24,25 @@ def argilla(
|
||||
return transform_fn
|
||||
|
||||
|
||||
def argilla_chat(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
for argilla/dpo-mix-7k conversations
|
||||
"""
|
||||
|
||||
def transform_fn(sample):
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
|
||||
|
||||
def icr(
|
||||
cfg,
|
||||
**kwargs,
|
||||
|
||||
9
src/axolotl/prompt_strategies/orpo/__init__.py
Normal file
9
src/axolotl/prompt_strategies/orpo/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
module for ORPO style dataset transform strategies
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
|
||||
from ..base import load as load_base
|
||||
|
||||
load = partial(load_base, module="axolotl.prompt_strategies.orpo")
|
||||
187
src/axolotl/prompt_strategies/orpo/chat_template.py
Normal file
187
src/axolotl/prompt_strategies/orpo/chat_template.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""chatml prompt tokenization strategy for ORPO"""
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
|
||||
from axolotl.prompters import Prompter
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""message/turn"""
|
||||
|
||||
role: str
|
||||
content: str
|
||||
label: Optional[bool] = None
|
||||
|
||||
|
||||
class MessageList(BaseModel):
|
||||
"""conversation"""
|
||||
|
||||
messages: List[Message]
|
||||
|
||||
|
||||
def load(
|
||||
tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, **kwargs
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
chatml transforms for datasets with system, input, chosen, rejected
|
||||
"""
|
||||
|
||||
chat_template = chat_templates("chatml")
|
||||
if ds_cfg and "chat_template" in ds_cfg:
|
||||
chat_template = ds_cfg["chat_template"]
|
||||
try:
|
||||
chat_template = chat_templates(chat_template)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return ORPOTokenizingStrategy(
|
||||
ORPOPrompter(chat_template, tokenizer),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
dataset_parser=ORPODatasetParsingStrategy(),
|
||||
)
|
||||
|
||||
|
||||
class ORPODatasetParsingStrategy:
|
||||
"""Strategy to parse chosen rejected dataset into messagelist"""
|
||||
|
||||
def get_chosen_conversation_thread(self, prompt) -> MessageList:
|
||||
"""Dataset structure mappings"""
|
||||
|
||||
messages: List[Message] = []
|
||||
if system := prompt.get("system", None):
|
||||
messages.append(Message(role="system", content=system, label=False))
|
||||
messages.append(Message(role="user", content=prompt["prompt"], label=False))
|
||||
messages.append(
|
||||
Message(
|
||||
role="assistant", content=prompt["chosen"][1]["content"], label=True
|
||||
)
|
||||
)
|
||||
return MessageList(messages=messages)
|
||||
|
||||
def get_rejected_conversation_thread(self, prompt) -> MessageList:
|
||||
"""Dataset structure mappings"""
|
||||
|
||||
messages: List[Message] = []
|
||||
if system := prompt.get("system", None):
|
||||
messages.append(Message(role="system", content=system, label=False))
|
||||
messages.append(Message(role="user", content=prompt["prompt"], label=False))
|
||||
messages.append(
|
||||
Message(
|
||||
role="assistant", content=prompt["rejected"][1]["content"], label=True
|
||||
)
|
||||
)
|
||||
return MessageList(messages=messages)
|
||||
|
||||
|
||||
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
rejected_input_ids
|
||||
input_ids
|
||||
rejected_attention_mask
|
||||
attention_mask
|
||||
rejected_labels
|
||||
labels
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
dataset_parser=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_parser = dataset_parser
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# pass the rejected prompt/row to the Prompter to get the formatted prompt
|
||||
prompt_len = 0
|
||||
rejected_message_list = self.dataset_parser.get_rejected_conversation_thread(
|
||||
prompt
|
||||
)
|
||||
input_ids = []
|
||||
labels = []
|
||||
for _, (part, label) in enumerate(
|
||||
self.prompter.build_prompt(rejected_message_list)
|
||||
):
|
||||
if not part:
|
||||
continue
|
||||
_input_ids = self.tokenizer.encode(part, add_special_tokens=False)
|
||||
prev_idx = len(input_ids)
|
||||
input_ids += _input_ids[prev_idx:]
|
||||
if label:
|
||||
labels += input_ids[prev_idx:]
|
||||
else:
|
||||
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
|
||||
prompt_len = len(input_ids)
|
||||
# remap the input_ids, attention_mask and labels
|
||||
rejected_input_ids = input_ids
|
||||
rejected_labels = labels
|
||||
# pass the chosen prompt/row to the Prompter to get the formatted prompt
|
||||
chosen_message_list = self.dataset_parser.get_chosen_conversation_thread(prompt)
|
||||
input_ids = []
|
||||
labels = []
|
||||
for _, (part, label) in enumerate(
|
||||
self.prompter.build_prompt(chosen_message_list)
|
||||
):
|
||||
if not part:
|
||||
continue
|
||||
_input_ids = self.tokenizer.encode(part, add_special_tokens=False)
|
||||
prev_idx = len(input_ids)
|
||||
input_ids += _input_ids[prev_idx:]
|
||||
if label:
|
||||
labels += input_ids[prev_idx:]
|
||||
else:
|
||||
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
|
||||
|
||||
return {
|
||||
"rejected_input_ids": rejected_input_ids,
|
||||
"rejected_labels": rejected_labels,
|
||||
"rejected_attention_mask": [1] * len(rejected_labels),
|
||||
"input_ids": input_ids,
|
||||
"labels": labels,
|
||||
"attention_mask": [1] * len(labels),
|
||||
"prompt_attention_mask": [1] * prompt_len
|
||||
+ [0] * (len(labels) - prompt_len),
|
||||
}
|
||||
|
||||
|
||||
class ORPOPrompter(Prompter):
|
||||
"""Single Turn prompter for ORPO"""
|
||||
|
||||
def __init__(self, chat_template, tokenizer):
|
||||
self.chat_template = chat_template
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
message_list: MessageList,
|
||||
) -> Generator[Tuple[str, bool], None, None]:
|
||||
conversation = []
|
||||
for message in message_list.messages:
|
||||
conversation.append(message.model_dump())
|
||||
if message.role == "system":
|
||||
yield self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
add_generation_prompt=False,
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
), False
|
||||
if message.role == "user":
|
||||
yield self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
add_generation_prompt=True,
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
), False
|
||||
if message.role == "assistant":
|
||||
yield self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
add_generation_prompt=False,
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
), True
|
||||
@@ -1,10 +1,18 @@
|
||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
from axolotl.utils.tokenization import (
|
||||
chatml_to_conversation,
|
||||
merge_consecutive_messages,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def register_chatml_template(system_message=None):
|
||||
@@ -19,6 +27,16 @@ def register_chatml_template(system_message=None):
|
||||
sep="<|im_end|>",
|
||||
)
|
||||
)
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="chatml_glaive",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
system_message=system_message,
|
||||
roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
@@ -27,11 +45,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
)
|
||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
role_key_model=field_model,
|
||||
role_key_human=field_human,
|
||||
roles=roles,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
@@ -77,6 +97,20 @@ def load_guanaco(tokenizer, cfg):
|
||||
)
|
||||
|
||||
|
||||
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"]
|
||||
if ds_cfg and "conversation" in ds_cfg
|
||||
else "chatml_glaive"
|
||||
)
|
||||
return GlaiveShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(conversation=conversation),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
@@ -113,7 +147,12 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"system": "system",
|
||||
}
|
||||
turns = [
|
||||
{"from": role_map[t[role_key]], "value": t[value_key]}
|
||||
{
|
||||
"from": (
|
||||
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
|
||||
),
|
||||
"value": t[value_key],
|
||||
}
|
||||
for t in conversations
|
||||
]
|
||||
return turns
|
||||
@@ -158,3 +197,15 @@ class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingSt
|
||||
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
sharegpt strategy that remaps glaive data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversation = chatml_to_conversation(prompt)
|
||||
conversation = merge_consecutive_messages(conversation)
|
||||
|
||||
return conversation
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers import BatchEncoding, PreTrainedTokenizer
|
||||
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
||||
add_get_turns_to_conversation,
|
||||
)
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -37,7 +37,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter,
|
||||
prompter: Prompter,
|
||||
tokenizer,
|
||||
train_on_inputs: bool = False,
|
||||
sequence_len: int = 2048,
|
||||
@@ -340,6 +340,23 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
self.prompter._conversation.copy() # pylint: disable=protected-access
|
||||
)
|
||||
|
||||
input_roles = {conversation.roles[0]}
|
||||
output_roles = {conversation.roles[1]}
|
||||
|
||||
if len(conversation.roles) == 3:
|
||||
tool_role_label = conversation.roles[2]
|
||||
input_roles.add(tool_role_label)
|
||||
|
||||
# Add roles from the config
|
||||
if self.prompter.roles:
|
||||
if "input" in self.prompter.roles and self.prompter.roles["input"]:
|
||||
for role in self.prompter.roles["input"]:
|
||||
input_roles.add(role)
|
||||
|
||||
if "output" in self.prompter.roles and self.prompter.roles["output"]:
|
||||
for role in self.prompter.roles["output"]:
|
||||
output_roles.add(role)
|
||||
|
||||
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
||||
role_remap = []
|
||||
if (
|
||||
@@ -360,11 +377,18 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
LOG.warning(f"expected tuple, got {part}")
|
||||
continue
|
||||
|
||||
user, assistant = conversation.roles
|
||||
role, content = part
|
||||
|
||||
# Uses "in" because role contains extra characters
|
||||
if user in role:
|
||||
input_turn = any(r.lower() in role.lower() for r in input_roles)
|
||||
output_turn = any(r.lower() in role.lower() for r in output_roles)
|
||||
empty_role = role.strip() == ""
|
||||
|
||||
if not any([input_turn, output_turn, empty_role]):
|
||||
LOG.warning(f"unhandled role: {role}")
|
||||
continue
|
||||
|
||||
if input_turn:
|
||||
role = (
|
||||
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||
if role_remap
|
||||
@@ -384,7 +408,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
else:
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif assistant in role:
|
||||
elif output_turn:
|
||||
role = (
|
||||
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||
if role_remap
|
||||
@@ -415,7 +439,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
||||
len_role, len(labels)
|
||||
)
|
||||
elif role == "":
|
||||
elif empty_role:
|
||||
turn = content
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
@@ -426,9 +450,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
else:
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
else:
|
||||
LOG.warning(f"unhandled role: {role}")
|
||||
continue
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
|
||||
@@ -259,6 +259,12 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
"Role did not alternate between turns (gpt and human). Please check your data."
|
||||
)
|
||||
|
||||
CONVERSATION_ROLE_FORMAT = {
|
||||
"chatml": "<|im_start|>{ROLE}",
|
||||
"zephyr": "<|{ROLE}|>",
|
||||
"vicuna_v1.1": "{ROLE}",
|
||||
}
|
||||
|
||||
|
||||
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
@@ -267,6 +273,10 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
|
||||
role_key_human = "human"
|
||||
role_key_model = "gpt"
|
||||
# Optional, only used for tool usage datasets.
|
||||
role_key_tool: Optional[str] = None
|
||||
# Optional, role input/output mapping
|
||||
roles: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -274,6 +284,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
conversation: Optional[Union[str, Conversation]] = None,
|
||||
role_key_human: Optional[str] = None,
|
||||
role_key_model: Optional[str] = None,
|
||||
role_key_tool: Optional[str] = None,
|
||||
roles: Optional[dict] = None,
|
||||
):
|
||||
if conversation:
|
||||
if isinstance(conversation, Conversation):
|
||||
@@ -286,6 +298,10 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
self.role_key_human = role_key_human
|
||||
if role_key_model:
|
||||
self.role_key_model = role_key_model
|
||||
if role_key_tool:
|
||||
self.role_key_tool = role_key_tool
|
||||
if roles:
|
||||
self.roles = roles
|
||||
|
||||
def _build_result(self, source):
|
||||
if len(source) < 2:
|
||||
@@ -303,6 +319,8 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
source.pop(0)
|
||||
|
||||
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
|
||||
if self.role_key_tool:
|
||||
roles[self.role_key_tool] = conv.roles[2]
|
||||
|
||||
try:
|
||||
# Apply prompt templates
|
||||
@@ -315,11 +333,23 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
|
||||
conv.messages = []
|
||||
for _, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
if len(conv.messages) > 0 and (
|
||||
(role == conv.messages[-1][0]) or (role not in conv.roles)
|
||||
):
|
||||
from_role = sentence["from"]
|
||||
if from_role in roles:
|
||||
role = roles[from_role]
|
||||
else:
|
||||
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
|
||||
raise NotImplementedError(
|
||||
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
|
||||
"Please help us by creating an Issue to add support for this conversation type."
|
||||
)
|
||||
|
||||
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
|
||||
ROLE=from_role
|
||||
)
|
||||
|
||||
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
|
||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||
|
||||
conv.append_message(role, sentence["value"])
|
||||
|
||||
return conv.get_turns()
|
||||
@@ -347,11 +377,13 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
conversation: Optional[Union[str, Conversation]] = None,
|
||||
role_key_human: Optional[str] = None,
|
||||
role_key_model: Optional[str] = None,
|
||||
roles: Optional[dict] = None,
|
||||
):
|
||||
super().__init__(
|
||||
conversation=conversation,
|
||||
role_key_human=role_key_human,
|
||||
role_key_model=role_key_model,
|
||||
roles=roles,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import torch
|
||||
import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from peft import PeftModel, PeftModelForCausalLM
|
||||
from peft import PeftModel
|
||||
from pkg_resources import get_distribution # type: ignore
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
@@ -19,7 +19,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.freeze import freeze_parameters_except
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
@@ -85,7 +85,7 @@ def train(
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
model_ref = None
|
||||
if cfg.rl:
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||
# use built-in trl autounwrap
|
||||
LOG.debug("Passing model_ref: None to RL trainer")
|
||||
@@ -99,7 +99,7 @@ def train(
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||
|
||||
trainer = setup_trainer(
|
||||
cfg,
|
||||
@@ -110,9 +110,6 @@ def train(
|
||||
total_num_steps,
|
||||
)
|
||||
|
||||
if hasattr(model, "config"):
|
||||
model.config.use_cache = False
|
||||
|
||||
# go ahead and presave, so we have the adapter config available to inspect
|
||||
if peft_config:
|
||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||
@@ -207,20 +204,6 @@ def train(
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
if cfg.adapter and isinstance(model, (PeftModel, PeftModelForCausalLM)):
|
||||
model.to("cpu")
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
||||
model.save_pretrained(
|
||||
str(Path(cfg.output_dir) / "merged"),
|
||||
safe_serialization=safe_serialization,
|
||||
progressbar=True,
|
||||
)
|
||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||
|
||||
|
||||
if not cfg.hub_model_id:
|
||||
try:
|
||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||
|
||||
@@ -24,9 +24,9 @@ def check_cuda_device(default_value):
|
||||
or not torch.cuda.is_available()
|
||||
or device == "auto"
|
||||
or torch.device(device).type == "cpu"
|
||||
or torch.device(device).type == "meta"
|
||||
):
|
||||
return default_value
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -21,7 +21,7 @@ def chat_templates(user_choice: str):
|
||||
templates = {
|
||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||
}
|
||||
|
||||
|
||||
@@ -191,6 +191,11 @@ def normalize_cfg_datasets(cfg):
|
||||
f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template"
|
||||
)
|
||||
cfg.datasets[idx].conversation = "chatml"
|
||||
if ds_cfg.type == "orpo.chat_template" and not ds_cfg.chat_template:
|
||||
LOG.info(
|
||||
f"updating dataset {ds_cfg.path} with `chat_template: chatml` to match your chat_template"
|
||||
)
|
||||
cfg.datasets[idx].chat_template = "chatml"
|
||||
|
||||
|
||||
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Module for pydantic models for configuration
|
||||
"""
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
import os
|
||||
@@ -95,6 +96,8 @@ class SFTDataset(BaseModel):
|
||||
field_human: Optional[str] = None
|
||||
field_model: Optional[str] = None
|
||||
|
||||
roles: Optional[Dict[str, List[str]]] = None
|
||||
|
||||
|
||||
class UserDefinedDPOType(BaseModel):
|
||||
"""User defined typing for DPO"""
|
||||
@@ -123,13 +126,16 @@ class RLType(str, Enum):
|
||||
dpo = "dpo" # pylint: disable=invalid-name
|
||||
ipo = "ipo" # pylint: disable=invalid-name
|
||||
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
||||
orpo = "orpo" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
"""Chat templates configuration subset"""
|
||||
|
||||
alpaca = "alpaca" # pylint: disable=invalid-name
|
||||
chatml = "chatml" # pylint: disable=invalid-name
|
||||
inst = "inst" # pylint: disable=invalid-name
|
||||
gemma = "gemma" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LoftQConfig(BaseModel):
|
||||
@@ -179,6 +185,7 @@ class LoraConfig(BaseModel):
|
||||
peft_layers_to_transform: Optional[List[int]] = None
|
||||
peft: Optional[PeftConfig] = None
|
||||
peft_use_dora: Optional[bool] = None
|
||||
peft_use_relora: Optional[bool] = None
|
||||
|
||||
lora_on_cpu: Optional[bool] = None
|
||||
gptq: Optional[bool] = None
|
||||
@@ -306,6 +313,15 @@ class HyperparametersConfig(BaseModel):
|
||||
learning_rate: Union[str, float]
|
||||
weight_decay: Optional[float] = None
|
||||
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
|
||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
||||
)
|
||||
optim_target_modules: Optional[Union[List[str], Literal["all_linear"]]] = Field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The target modules to optimize, i.e. the module names that you would like to train."
|
||||
},
|
||||
)
|
||||
torchdistx_path: Optional[str] = None
|
||||
lr_scheduler: Optional[SchedulerType] = None
|
||||
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
||||
@@ -411,6 +427,7 @@ class AxolotlInputConfig(
|
||||
|
||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
||||
shuffle_merged_datasets: Optional[bool] = True
|
||||
dataset_prepared_path: Optional[str] = None
|
||||
dataset_shard_num: Optional[int] = None
|
||||
dataset_shard_idx: Optional[int] = None
|
||||
@@ -427,6 +444,8 @@ class AxolotlInputConfig(
|
||||
dataloader_prefetch_factor: Optional[int] = None
|
||||
dataloader_drop_last: Optional[bool] = None
|
||||
|
||||
remove_unused_columns: Optional[bool] = None
|
||||
|
||||
push_dataset_to_hub: Optional[str] = None
|
||||
hf_use_auth_token: Optional[bool] = None
|
||||
|
||||
@@ -511,10 +530,14 @@ class AxolotlInputConfig(
|
||||
|
||||
neftune_noise_alpha: Optional[float] = None
|
||||
|
||||
max_memory: Optional[Union[int, str]] = None
|
||||
orpo_alpha: Optional[float] = None
|
||||
|
||||
max_memory: Optional[
|
||||
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
||||
] = None
|
||||
gpu_memory_limit: Optional[Union[int, str]] = None
|
||||
|
||||
chat_template: Optional[Union[Literal["chatml", "inst"], ChatTemplate]] = None
|
||||
chat_template: Optional[ChatTemplate] = None
|
||||
default_system_message: Optional[str] = None
|
||||
|
||||
# INTERNALS - document for now, generally not set externally
|
||||
@@ -989,3 +1012,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_deepspeed(cls, data):
|
||||
if data.get("deepspeed") and data.get("fsdp"):
|
||||
raise ValueError("deepspeed and fsdp cannot be used together.")
|
||||
return data
|
||||
|
||||
@@ -114,9 +114,7 @@ def prepare_dataset(cfg, tokenizer):
|
||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||
if total_eval_steps == 0:
|
||||
raise ValueError(
|
||||
"eval dataset split is too small for sample_packing. "
|
||||
"You should set `eval_sample_packing: False` "
|
||||
"or decrease the value of `eval_batch_size`. "
|
||||
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
|
||||
)
|
||||
|
||||
if cfg.max_steps:
|
||||
@@ -417,8 +415,11 @@ def load_tokenized_prepared_datasets(
|
||||
dataset = concatenate_datasets(datasets)
|
||||
|
||||
if len(datasets) > 1:
|
||||
LOG.info("shuffle merged datasets")
|
||||
dataset = dataset.shuffle(seed=seed)
|
||||
if cfg.shuffle_merged_datasets:
|
||||
LOG.debug("shuffle merged datasets")
|
||||
dataset = dataset.shuffle(seed=seed)
|
||||
else:
|
||||
LOG.debug("NOT shuffling merged datasets")
|
||||
|
||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||
|
||||
@@ -821,7 +822,11 @@ def wrap_pretraining_dataset(
|
||||
else:
|
||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||
|
||||
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
||||
if cfg.shuffle_merged_datasets:
|
||||
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
||||
else:
|
||||
LOG.debug("NOT shuffling merged pretraining datasets")
|
||||
|
||||
dataset = dataset.map(
|
||||
encode,
|
||||
batched=True,
|
||||
|
||||
@@ -3,13 +3,14 @@ module to freeze/unfreeze parameters by name
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.freeze")
|
||||
|
||||
|
||||
def freeze_parameters_except(model, regex_patterns):
|
||||
def freeze_layers_except(model, regex_patterns):
|
||||
"""
|
||||
Freezes all layers of the given model except for the layers that match given regex patterns.
|
||||
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
||||
@@ -17,22 +18,211 @@ def freeze_parameters_except(model, regex_patterns):
|
||||
Parameters:
|
||||
- model (nn.Module): The PyTorch model to be modified.
|
||||
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
||||
Note that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names.
|
||||
Also, to match the entire layer name, the pattern should start with "^" and end with "$", otherwise it will match any part of the layer name.
|
||||
The range pattern part is optional and it is not compiled as a regex pattern which means you must put "$" before the range pattern if you want to match the entire layer name.
|
||||
E.g., ["^model.embed_tokens.weight$[:32000]", "layers.2[0-9]+.block_sparse_moe.gate.[a-z]+$"]
|
||||
|
||||
Returns:
|
||||
None; the model is modified in place.
|
||||
"""
|
||||
# Escape periods and compile the regex patterns
|
||||
compiled_patterns = [
|
||||
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
|
||||
]
|
||||
if isinstance(regex_patterns, str):
|
||||
regex_patterns = [regex_patterns]
|
||||
|
||||
# First, freeze all parameters in the model
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
patterns = [LayerNamePattern(pattern) for pattern in regex_patterns]
|
||||
|
||||
# Unfreeze layers that match the regex patterns
|
||||
for name, param in model.named_parameters():
|
||||
if any(pattern.match(name) for pattern in compiled_patterns):
|
||||
if is_main_process():
|
||||
LOG.debug(f"unfreezing {name}")
|
||||
param.requires_grad = False
|
||||
unfrozen_ranges = []
|
||||
for pattern in patterns:
|
||||
if not pattern.match(name):
|
||||
continue
|
||||
|
||||
param.requires_grad = True
|
||||
|
||||
if pattern.range is not None:
|
||||
unfrozen_ranges.append(pattern.range)
|
||||
|
||||
merged_unfrozen_ranges = _merge_ranges(unfrozen_ranges, len(param))
|
||||
|
||||
if param.requires_grad and is_main_process():
|
||||
unfrozen_ranges = (
|
||||
f" with ranges {merged_unfrozen_ranges}"
|
||||
if merged_unfrozen_ranges
|
||||
else ""
|
||||
)
|
||||
LOG.debug(f"Unfrozen {name}{unfrozen_ranges}")
|
||||
|
||||
if not merged_unfrozen_ranges:
|
||||
continue
|
||||
|
||||
# The range list we need is actually the inverted of the merged ranges
|
||||
ranges_to_freeze = _invert_ranges(merged_unfrozen_ranges, len(param))
|
||||
|
||||
param.register_hook(_create_freeze_parameters_hook(ranges_to_freeze))
|
||||
|
||||
if is_main_process() and all(
|
||||
not param.requires_grad for param in model.parameters()
|
||||
):
|
||||
LOG.warning("All parameters are frozen. Model will not be trained.")
|
||||
|
||||
|
||||
def _invert_ranges(
|
||||
given_ranges: List[Tuple[int, int]], layer_size: int
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Inverts a list of ranges to obtain the ranges not covered by the given ranges.
|
||||
|
||||
Parameters:
|
||||
- given_ranges (List[Tuple[int, int]]): List of ranges to invert. Each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
|
||||
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
|
||||
Returns:
|
||||
- List[Tuple[int, int]]: List of inverted ranges, where each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
|
||||
"""
|
||||
if not given_ranges:
|
||||
return [(0, layer_size)]
|
||||
|
||||
inverted_ranges = []
|
||||
current_start = 0
|
||||
|
||||
for start, end in sorted(given_ranges):
|
||||
if start > current_start:
|
||||
inverted_ranges.append((current_start, start))
|
||||
current_start = max(current_start, end)
|
||||
|
||||
# Handle the case where the last given range does not reach the end of the total_size
|
||||
if current_start < layer_size:
|
||||
inverted_ranges.append((current_start, layer_size))
|
||||
|
||||
return inverted_ranges
|
||||
|
||||
|
||||
def _merge_ranges(
|
||||
given_ranges: List[Tuple[int, Union[int, None]]], layer_size: int
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Merges overlapping ranges and sorts the given ranges.
|
||||
|
||||
This function takes a list of ranges and merges any overlapping ranges. The ranges are represented
|
||||
as tuples, where the first element is the start index (inclusive) and the second element is the end
|
||||
index (exclusive). The end index can be None, indicating that the range extends to the end of the
|
||||
sequence.
|
||||
|
||||
Parameters:
|
||||
- given_ranges (List[Tuple[int, int | None]]): List of ranges to merge.
|
||||
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
|
||||
|
||||
Returns:
|
||||
- List[Tuple[int, int]]: List of merged ranges, as start (inclusive) and end (exclusive) indices.
|
||||
"""
|
||||
# End of each range can be determined now since we have the total size
|
||||
processed_ranges = [
|
||||
(start, end if end is not None else layer_size) for start, end in given_ranges
|
||||
]
|
||||
|
||||
# No need to merge if there's only one or no ranges
|
||||
if len(processed_ranges) <= 1:
|
||||
return processed_ranges
|
||||
|
||||
sorted_ranges = sorted(processed_ranges)
|
||||
|
||||
merged_ranges = [sorted_ranges[0]]
|
||||
for start, end in sorted_ranges[1:]:
|
||||
prev_start, prev_end = merged_ranges[-1]
|
||||
if start <= prev_end:
|
||||
merged_ranges[-1] = (prev_start, max(prev_end, end))
|
||||
else:
|
||||
merged_ranges.append((start, end))
|
||||
|
||||
return merged_ranges
|
||||
|
||||
|
||||
def _create_freeze_parameters_hook(ranges_to_freeze: List[Tuple[int, int]]) -> Callable:
|
||||
"""
|
||||
Create a hook to freeze parameters in specified ranges by setting their gradients to zero.
|
||||
|
||||
This function takes a list of tuples representing the ranges of indices to freeze. Each tuple should contain
|
||||
two integers representing the start and end indices of the range.
|
||||
|
||||
Parameters:
|
||||
- ranges_to_freeze (List[Tuple[int, int]]): Ranges of indices to freeze.
|
||||
|
||||
Returns:
|
||||
- Callable: A hook function to be used with `register_hook` on parameters.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
ranges_to_freeze = [(0, 10), (20, 30)]
|
||||
hook = _create_freeze_parameters_hook(ranges_to_freeze)
|
||||
model.register_hook(hook)
|
||||
```
|
||||
"""
|
||||
|
||||
def freeze_parameters_hook(gradients):
|
||||
for start, end in ranges_to_freeze:
|
||||
gradients[start:end].zero_()
|
||||
|
||||
return freeze_parameters_hook
|
||||
|
||||
|
||||
class LayerNamePattern:
|
||||
"""
|
||||
Represents a regex pattern for layer names, potentially including a parameter index range.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern: str):
|
||||
"""
|
||||
Initializes a new instance of the LayerNamePattern class.
|
||||
|
||||
Parameters:
|
||||
- pattern (str): The regex pattern for layer names, potentially including a parameter index range.
|
||||
"""
|
||||
self.raw_pattern = pattern
|
||||
name_pattern, self.range = self._parse_pattern(pattern)
|
||||
self.name_regex = re.compile(name_pattern.replace(".", "\\."))
|
||||
|
||||
def match(self, name: str) -> bool:
|
||||
"""
|
||||
Checks if the given layer name matches the regex pattern.
|
||||
|
||||
Parameters:
|
||||
- name (str): The layer name to check.
|
||||
|
||||
Returns:
|
||||
- bool: True if the layer name matches the pattern, False otherwise.
|
||||
"""
|
||||
return self.name_regex.match(name) is not None
|
||||
|
||||
def _parse_pattern(
|
||||
self, pattern: str
|
||||
) -> Tuple[str, Union[Tuple[int, Union[int, None]], None]]:
|
||||
"""
|
||||
Extracts the range pattern from the given pattern.
|
||||
|
||||
Parameters:
|
||||
- pattern (str): The pattern to extract the range from.
|
||||
|
||||
Returns:
|
||||
- Tuple[str, Tuple[int, int | None] | None]: A tuple containing the regex pattern to match the layer name without the range pattern and the range of layer indices to match, if specified.
|
||||
"""
|
||||
match = re.match(r"^(.+)\[([0-9]*)(?::([0-9]*))?\]$", pattern)
|
||||
if not match:
|
||||
return pattern, None
|
||||
|
||||
base_pattern, start_part, end_part = match.groups()
|
||||
|
||||
if end_part is None and start_part.isdecimal():
|
||||
index = int(start_part)
|
||||
return base_pattern, (index, index + 1)
|
||||
|
||||
# [:end] or [start:] or [start:end]
|
||||
start = int(start_part) if start_part else 0
|
||||
end = int(end_part) if end_part else None
|
||||
|
||||
if end is not None and start >= end:
|
||||
raise ValueError(
|
||||
f"Invalid range in layer name pattern: {pattern}."
|
||||
"End of range must be greater than start."
|
||||
)
|
||||
return base_pattern, (start, end)
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
"""Module for models and model loading"""
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||
import types
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union # noqa: F401
|
||||
|
||||
import addict
|
||||
import bitsandbytes as bnb
|
||||
import safetensors
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate import init_empty_weights
|
||||
from bitsandbytes.nn import Linear4bit, Params4bit
|
||||
from fastcore.parallel import parallel
|
||||
from peft import (
|
||||
LoftQConfig,
|
||||
PeftConfig,
|
||||
@@ -16,6 +23,7 @@ from peft import (
|
||||
prepare_model_for_kbit_training,
|
||||
)
|
||||
from peft.tuners.lora import QuantLinear
|
||||
from torch import Tensor, nn
|
||||
from transformers import ( # noqa: F401
|
||||
AddedToken,
|
||||
AutoConfig,
|
||||
@@ -27,7 +35,9 @@ from transformers import ( # noqa: F401
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
|
||||
|
||||
from axolotl.core.policies.auto_wrap import SUPPORTED_AUTO_WRAP_MODEL_TYPES
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||
@@ -262,6 +272,117 @@ def load_tokenizer(cfg):
|
||||
return tokenizer
|
||||
|
||||
|
||||
def replace_linear(
|
||||
model: nn.Module,
|
||||
linear_replacement: Type[nn.Module],
|
||||
quant_config: Union[dict, None] = None,
|
||||
skip_modules=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Replace linear modules with a new Linear module.
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
linear_replacement (`torch.nn.Module`):
|
||||
The linear module that replaces the old one. Only expects standard arguments.
|
||||
If other arguments need to be passed, use a lambda.
|
||||
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
|
||||
List of modules names not to convert. Defaults to `lm_head`.
|
||||
"""
|
||||
if skip_modules is None:
|
||||
skip_modules = ["lm_head"]
|
||||
for name, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
replace_linear(
|
||||
module, linear_replacement, quant_config, skip_modules, **kwargs
|
||||
)
|
||||
|
||||
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
||||
if issubclass(linear_replacement, Linear4bit):
|
||||
model._modules[ # pylint: disable=protected-access
|
||||
name
|
||||
] = linear_replacement(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported linear replacement: {type(linear_replacement)}"
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def load_and_quantize(
|
||||
module: nn.Module,
|
||||
name: str,
|
||||
value: Tensor,
|
||||
device: torch.device = None,
|
||||
dtype: torch.dtype = None,
|
||||
skip_names: Optional[List[str]] = None,
|
||||
is_meta_rank: bool = False,
|
||||
low_memory: bool = True,
|
||||
verbose: bool = False,
|
||||
quant_method: str = "bnb",
|
||||
):
|
||||
"""
|
||||
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
|
||||
|
||||
Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
|
||||
"""
|
||||
|
||||
if skip_names is None:
|
||||
skip_names = []
|
||||
|
||||
def place_on_device(value):
|
||||
if is_meta_rank:
|
||||
device = "meta"
|
||||
elif low_memory:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cuda"
|
||||
return value.to(device=device, dtype=dtype)
|
||||
|
||||
if any(skip_name in name for skip_name in skip_names):
|
||||
if verbose:
|
||||
print(f"Skipping {name} because it is in skip_names")
|
||||
return
|
||||
|
||||
module_key, _, value_key = name.rpartition(".")
|
||||
try:
|
||||
submodule = module.get_submodule(module_key)
|
||||
except AttributeError as exc:
|
||||
print(f"Module {module_key} not found:\n{exc}")
|
||||
return
|
||||
|
||||
try:
|
||||
if quant_method == "bnb":
|
||||
param = submodule.get_parameter(value_key)
|
||||
if isinstance(param, Params4bit):
|
||||
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
|
||||
# shape as the quantized Params4bit with an initialized quant_state. However,
|
||||
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
|
||||
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
|
||||
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
|
||||
value = type(param)(
|
||||
value.to(device=device, dtype=dtype).data, **param.__dict__
|
||||
).cuda(device)
|
||||
if is_meta_rank:
|
||||
value = type(param)(value.data.to("meta"), **value.__dict__)
|
||||
elif low_memory:
|
||||
value = type(param)(value.data.to("cpu"), **value.__dict__)
|
||||
else:
|
||||
value = type(param)(place_on_device(value).data)
|
||||
|
||||
except AttributeError:
|
||||
# it's a buffer
|
||||
value = place_on_device(value)
|
||||
|
||||
setattr(submodule, value_key, value)
|
||||
|
||||
|
||||
def load_model(
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@@ -308,7 +429,7 @@ def load_model(
|
||||
and cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
patch_for_multipack(cfg.model_config_type)
|
||||
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
|
||||
elif cfg.is_llama_derived_model:
|
||||
# Modify all llama derived models in one block
|
||||
|
||||
@@ -394,7 +515,7 @@ def load_model(
|
||||
|
||||
if max_memory is not None:
|
||||
# Based on https://github.com/togethercomputer/OpenChatKit/blob/main/inference/bot.py
|
||||
from accelerate import infer_auto_device_map, init_empty_weights
|
||||
from accelerate import infer_auto_device_map
|
||||
|
||||
with init_empty_weights():
|
||||
model_canvas = AutoModelForCausalLM.from_config(model_config)
|
||||
@@ -496,8 +617,78 @@ def load_model(
|
||||
model_kwargs["attn_implementation"] = "eager"
|
||||
model_config._attn_implementation = "eager" # pylint: disable=protected-access
|
||||
|
||||
qlora_fsdp = (
|
||||
cfg.fsdp
|
||||
and cfg.adapter == "qlora"
|
||||
and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES
|
||||
)
|
||||
|
||||
try:
|
||||
if (
|
||||
if qlora_fsdp:
|
||||
if cfg.bf16 or cfg.bfloat16:
|
||||
torch_dtype, compute_dtype = torch.float32, torch.bfloat16
|
||||
elif cfg.fp16 or cfg.float16:
|
||||
torch_dtype, compute_dtype = torch.float32, torch.float16
|
||||
else:
|
||||
torch_dtype, compute_dtype = torch.float32, torch.float16
|
||||
|
||||
with init_empty_weights():
|
||||
LOG.info("Loading model with empty weights.")
|
||||
model = AutoModelForCausalLM.from_config(model_config)
|
||||
model.model = replace_linear(
|
||||
model.model,
|
||||
Linear4bit,
|
||||
compute_dtype=compute_dtype,
|
||||
quant_type="nf4",
|
||||
quant_storage=torch_dtype,
|
||||
)
|
||||
|
||||
model.is_loaded_in_4bit = True
|
||||
|
||||
# Grab the safetensors files that hold the weights
|
||||
try:
|
||||
idx = hub.cached_file(base_model, SAFE_WEIGHTS_INDEX_NAME)
|
||||
files, _ = hub.get_checkpoint_shard_files(base_model, idx)
|
||||
except OSError:
|
||||
try:
|
||||
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
|
||||
files = []
|
||||
files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME))
|
||||
except OSError as exc:
|
||||
# This means the model probably doesn't have a safetensors file
|
||||
raise exc
|
||||
|
||||
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
|
||||
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
|
||||
def load_and_quantize_parallel(name_param, model, **kwargs):
|
||||
name, param = name_param
|
||||
load_and_quantize(model, name, param, **kwargs)
|
||||
|
||||
param_count = sum((p.numel() for n, p in model.named_parameters()))
|
||||
for filename in files:
|
||||
weights = safetensors.torch.load_file(filename)
|
||||
quant_method = "bnb"
|
||||
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
left = int(os.cpu_count() / torch.cuda.device_count())
|
||||
right = int(
|
||||
8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9))
|
||||
)
|
||||
n_workers = min(left, right)
|
||||
parallel(
|
||||
load_and_quantize_parallel,
|
||||
weights.items(),
|
||||
n_workers=n_workers,
|
||||
threadpool=True,
|
||||
model=model,
|
||||
dtype=torch_dtype,
|
||||
device=cfg.local_rank,
|
||||
skip_names=[],
|
||||
is_meta_rank=(cfg.local_rank != 0),
|
||||
verbose=False,
|
||||
quant_method=quant_method,
|
||||
)
|
||||
|
||||
elif (
|
||||
model_config.model_type == "llama"
|
||||
and not cfg.trust_remote_code
|
||||
and not cfg.gptq
|
||||
@@ -613,7 +804,7 @@ def load_model(
|
||||
LOG.exception(err)
|
||||
raise err
|
||||
|
||||
if isinstance(model, (PeftModel, PeftModelForCausalLM)):
|
||||
if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp:
|
||||
model = model.merge_and_unload()
|
||||
|
||||
embeddings_len = (
|
||||
@@ -692,9 +883,14 @@ def load_model(
|
||||
if cfg.adapter == "lora" and loftq_bits:
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if qlora_fsdp:
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if cfg.adapter in ["lora", "qlora"]:
|
||||
if cfg.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs
|
||||
)
|
||||
if (
|
||||
cfg.load_in_8bit or cfg.load_in_4bit
|
||||
) and not skip_prepare_model_for_kbit_training:
|
||||
@@ -706,7 +902,7 @@ def load_model(
|
||||
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
if needs_fa2_dtype or cfg.flash_attention:
|
||||
if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp:
|
||||
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
@@ -724,7 +920,12 @@ def load_model(
|
||||
else:
|
||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||
|
||||
if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
|
||||
if (
|
||||
cfg.ddp
|
||||
and not load_in_8bit
|
||||
and not (cfg.rl and cfg.load_in_4bit)
|
||||
and not qlora_fsdp
|
||||
):
|
||||
# TODO revaldate this conditional
|
||||
model.to(f"cuda:{cfg.local_rank}")
|
||||
|
||||
@@ -813,6 +1014,30 @@ def find_all_linear_names(model):
|
||||
return list(lora_module_names)
|
||||
|
||||
|
||||
def setup_quantized_meta_for_peft(model: nn.Module):
|
||||
"""Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device"""
|
||||
|
||||
def temp_to_method(self, *args, **kwargs): # pylint: disable=unused-argument
|
||||
return self
|
||||
|
||||
for param in model.parameters():
|
||||
if isinstance(param, Params4bit):
|
||||
param.quant_state._orig_to = ( # pylint: disable=protected-access
|
||||
param.quant_state.to
|
||||
)
|
||||
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
|
||||
|
||||
|
||||
def setup_quantized_peft_meta_for_training(model: nn.Module):
|
||||
"""Replaces dummy `quant_state.to` method with the original function to allow training to continue"""
|
||||
for param in model.parameters():
|
||||
if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"):
|
||||
param.quant_state.to = (
|
||||
param.quant_state._orig_to # pylint: disable=protected-access
|
||||
)
|
||||
param.quant_state._orig_to = None # pylint: disable=protected-access
|
||||
|
||||
|
||||
def load_lora(model, cfg, inference=False, config_only=False):
|
||||
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
|
||||
|
||||
@@ -832,6 +1057,8 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||
if cfg.peft_use_dora:
|
||||
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
||||
if cfg.peft_use_rslora:
|
||||
lora_config_kwargs["use_rslora"] = cfg.use_rslora
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=cfg.lora_r,
|
||||
@@ -849,6 +1076,11 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
if config_only:
|
||||
return None, lora_config
|
||||
|
||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
|
||||
if cfg.fsdp and cfg.adapter == "qlora" and rank != 0:
|
||||
setup_quantized_meta_for_peft(model)
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
LOG.debug("Loading pretrained PEFT - LoRA")
|
||||
model_kwargs: Any = {}
|
||||
@@ -864,6 +1096,9 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
else:
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
model.print_trainable_parameters()
|
||||
if rank == 0:
|
||||
model.print_trainable_parameters()
|
||||
elif cfg.fsdp and cfg.adapter == "qlora":
|
||||
setup_quantized_peft_meta_for_training(model)
|
||||
|
||||
return model, lora_config
|
||||
|
||||
@@ -5,7 +5,7 @@ Multipack Batch Sampler
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Iterable, List, Union, Optional
|
||||
from typing import Any, Iterable, List, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
@@ -115,14 +115,12 @@ class MultipackBatchSampler(BatchSampler):
|
||||
batch_max_len: int,
|
||||
lengths: np.ndarray,
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
consistent_length: Optional[bool] = False,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
self.batch_size = batch_size
|
||||
self.batch_max_len = batch_max_len
|
||||
self.lengths: np.ndarray = lengths
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
self.consistent_length = consistent_length
|
||||
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
|
||||
@@ -166,18 +164,11 @@ class MultipackBatchSampler(BatchSampler):
|
||||
|
||||
def __iter__(self):
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
if self.consistent_length:
|
||||
length = self._len_est()
|
||||
return iter(batches[:length])
|
||||
else:
|
||||
return iter(batches)
|
||||
return iter(batches)
|
||||
|
||||
def num_batches(self):
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
if self.consistent_length:
|
||||
return self._len_est()
|
||||
else:
|
||||
return len(batches)
|
||||
return len(batches)
|
||||
|
||||
def efficiency(self):
|
||||
return self.eff_total_used / self.eff_total_slots
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
@@ -36,3 +38,65 @@ def check_example_labels(example, tokenizer, text_only=False):
|
||||
LOG.info("\n\n\n")
|
||||
|
||||
return " ".join(colored_tokens)
|
||||
|
||||
|
||||
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
||||
GLAIVE_TO_SHAREGPT_ROLE = {
|
||||
"SYSTEM": "system",
|
||||
"USER": "human",
|
||||
"ASSISTANT": "gpt",
|
||||
"FUNCTION RESPONSE": "tool",
|
||||
}
|
||||
|
||||
GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ")
|
||||
|
||||
|
||||
def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Converts a ChatML formatted row to a list of messages in ShareGPT format.
|
||||
Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb.
|
||||
"""
|
||||
|
||||
system_prompt = row.get("system")
|
||||
if system_prompt:
|
||||
system_prompt = system_prompt.removeprefix("SYSTEM: ")
|
||||
|
||||
chat_str = row["chat"]
|
||||
chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s]
|
||||
|
||||
chat_msg_dicts = [
|
||||
{"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value}
|
||||
for role, value in zip(chat_msgs[::2], chat_msgs[1::2])
|
||||
]
|
||||
|
||||
if system_prompt:
|
||||
chat_msg_dicts = [
|
||||
{"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt}
|
||||
] + chat_msg_dicts
|
||||
|
||||
return chat_msg_dicts
|
||||
|
||||
|
||||
def merge_consecutive_messages(messages):
|
||||
"""
|
||||
Merge consecutive messages from the same sender into a single message.
|
||||
This can be useful with datasets that contain multiple consecutive tool calls.
|
||||
"""
|
||||
|
||||
merged_messages = []
|
||||
current_from = None
|
||||
current_message = ""
|
||||
|
||||
for msg in messages:
|
||||
if current_from == msg["from"]:
|
||||
current_message += msg["value"]
|
||||
else:
|
||||
if current_from is not None:
|
||||
merged_messages.append({"from": current_from, "value": current_message})
|
||||
current_from = msg["from"]
|
||||
current_message = msg["value"]
|
||||
|
||||
if current_from is not None:
|
||||
merged_messages.append({"from": current_from, "value": current_message})
|
||||
|
||||
return merged_messages
|
||||
|
||||
@@ -277,7 +277,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
calc_sample_packing_eff_est,
|
||||
)
|
||||
sample_packing_eff_est = (
|
||||
math.ceil(sample_packing_actual_eff_all * 10000.0) / 10000.0
|
||||
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
||||
)
|
||||
if update:
|
||||
cfg.sample_packing_eff_est = sample_packing_eff_est
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Test module for sharegpt integration w chatml
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from tokenizers import AddedToken
|
||||
@@ -8,6 +9,7 @@ from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies.sharegpt import (
|
||||
GlaiveShareGPTPromptTokenizingStrategy,
|
||||
SimpleShareGPTPromptTokenizingStrategy,
|
||||
register_chatml_template,
|
||||
)
|
||||
@@ -48,6 +50,50 @@ def fixture_sharegpt_dataset():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="glaive_dataset")
|
||||
def fixture_sharegpt_glaive_dataset():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"system": "SYSTEM: This is a system prompt",
|
||||
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="multi_role_dataset")
|
||||
def fixture_multi_role_dataset():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"from": "system",
|
||||
"value": "use get_weather(city) to get the weather for a city",
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "hello, what's the weather in New York?",
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "let me get that for you",
|
||||
},
|
||||
{
|
||||
"from": "tool",
|
||||
"value": "get_weather(New York)",
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "the weather in New York is 70 degrees and sunny",
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="tokenizer")
|
||||
def fixture_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
@@ -156,3 +202,65 @@ class TestSharegpt:
|
||||
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_chatml_glaive(self, glaive_dataset, tokenizer):
|
||||
strategy = GlaiveShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
tokenizer,
|
||||
True, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, glaive_dataset, process_count=1
|
||||
)
|
||||
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
1, # bos
|
||||
32001, 1587, 13, 3260, 349, 264, 1587, 11510, 32000, 28705, 13, # system
|
||||
32001, 2188, 13, 6325, 368, 1820, 264, 9314, 354, 528, 477, 1450, 2726, 298, 4222, 28804, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_multi_role_dataset(self, multi_role_dataset, tokenizer):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}),
|
||||
tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, multi_role_dataset, process_count=1
|
||||
)
|
||||
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
1, # bos
|
||||
32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system
|
||||
32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
||||
32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool
|
||||
32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
-100, # bos
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human
|
||||
-100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool
|
||||
-100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
285
tests/test_freeze.py
Normal file
285
tests/test_freeze.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
This module contains unit tests for the `freeze_layers_except` function.
|
||||
|
||||
The `freeze_layers_except` function is used to freeze layers in a model, except for the specified layers.
|
||||
The unit tests in this module verify the behavior of the `freeze_layers_except` function in different scenarios.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
|
||||
ZERO = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
||||
ONE_TO_TEN = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||
|
||||
|
||||
class TestFreezeLayersExcept(unittest.TestCase):
|
||||
"""
|
||||
A test case class for the `freeze_layers_except` function.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.model = _TestModel()
|
||||
|
||||
def test_freeze_layers_with_dots_in_name(self):
|
||||
freeze_layers_except(self.model, ["features.layer"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
def test_freeze_layers_without_dots_in_name(self):
|
||||
freeze_layers_except(self.model, ["classifier"])
|
||||
self.assertFalse(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertTrue(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
def test_freeze_layers_regex_patterns(self):
|
||||
# The second pattern cannot match because only characters 'a' to 'c' are allowed after the word 'class', whereas it should be matching the character 'i'.
|
||||
freeze_layers_except(self.model, [r"^features.[a-z]+.weight$", r"class[a-c]+"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
def test_all_layers_frozen(self):
|
||||
freeze_layers_except(self.model, [])
|
||||
self.assertFalse(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be frozen.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
def test_all_layers_unfrozen(self):
|
||||
freeze_layers_except(self.model, ["features.layer", "classifier"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertTrue(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be trainable.",
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_start_end(self):
|
||||
freeze_layers_except(self.model, ["features.layer[1:5]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_single_index(self):
|
||||
freeze_layers_except(self.model, ["features.layer[5]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[ZERO, ZERO, ZERO, ZERO, ZERO, ONE_TO_TEN, ZERO, ZERO, ZERO, ZERO]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_start_omitted(self):
|
||||
freeze_layers_except(self.model, ["features.layer[:5]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_end_omitted(self):
|
||||
freeze_layers_except(self.model, ["features.layer[4:]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_merge_included(self):
|
||||
freeze_layers_except(self.model, ["features.layer[4:]", "features.layer[5:6]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_merge_intersect(self):
|
||||
freeze_layers_except(self.model, ["features.layer[4:7]", "features.layer[6:8]"])
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ZERO,
|
||||
]
|
||||
)
|
||||
|
||||
def test_freeze_layers_with_range_pattern_merge_separate(self):
|
||||
freeze_layers_except(
|
||||
self.model,
|
||||
["features.layer[1:2]", "features.layer[3:4]", "features.layer[5:6]"],
|
||||
)
|
||||
self.assertTrue(
|
||||
self.model.features.layer.weight.requires_grad,
|
||||
"model.features.layer should be trainable.",
|
||||
)
|
||||
self.assertFalse(
|
||||
self.model.classifier.weight.requires_grad,
|
||||
"model.classifier should be frozen.",
|
||||
)
|
||||
|
||||
self._assert_gradient_output(
|
||||
[
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ONE_TO_TEN,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
ZERO,
|
||||
]
|
||||
)
|
||||
|
||||
def _assert_gradient_output(self, expected):
|
||||
input_tensor = torch.tensor([ONE_TO_TEN], dtype=torch.float32)
|
||||
|
||||
self.model.features.layer.weight.grad = None # Reset gradients
|
||||
output = self.model.features.layer(input_tensor)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
expected_grads = torch.tensor(expected)
|
||||
torch.testing.assert_close(
|
||||
self.model.features.layer.weight.grad, expected_grads
|
||||
)
|
||||
|
||||
|
||||
class _SubLayerModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer = nn.Linear(10, 10)
|
||||
|
||||
|
||||
class _TestModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.features = _SubLayerModule()
|
||||
self.classifier = nn.Linear(10, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Module for testing prompt tokenizers."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import unittest
|
||||
@@ -7,7 +8,8 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer, LlamaTokenizer
|
||||
from datasets import load_dataset
|
||||
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
||||
from axolotl.prompt_strategies.alpaca_w_system import (
|
||||
@@ -18,11 +20,14 @@ from axolotl.prompt_strategies.llama2_chat import (
|
||||
Llama2ChatPrompter,
|
||||
LLama2ChatTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompt_strategies.orpo.chat_template import load
|
||||
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -266,6 +271,23 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
idx = res["input_ids"].index(20255) # assistant token
|
||||
assert res["labels"][idx] == -100
|
||||
|
||||
def test_glaive_tool_label_ignore(self):
|
||||
conversation = {
|
||||
"system": "SYSTEM: This is a system prompt",
|
||||
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
|
||||
}
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = GlaiveShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
res = strat.tokenize_prompt(conversation)
|
||||
idx = res["input_ids"].index(13566) # assistant token
|
||||
assert res["labels"][idx] == -100
|
||||
|
||||
def test_no_sys_prompt(self):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
@@ -427,5 +449,57 @@ If a question does not make any sense, or is not factually coherent, explain why
|
||||
)
|
||||
|
||||
|
||||
class OrpoTokenizationTest(unittest.TestCase):
|
||||
"""test case for the ORPO tokenization"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"eos_token": AddedToken(
|
||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||
)
|
||||
}
|
||||
)
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken(
|
||||
"<|im_start|>", rstrip=False, lstrip=False, normalized=False
|
||||
),
|
||||
]
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
self.dataset = load_dataset(
|
||||
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
|
||||
).select([0])
|
||||
|
||||
def test_orpo_integration(self):
|
||||
strat = load(
|
||||
self.tokenizer,
|
||||
DictDefault({"train_on_inputs": False}),
|
||||
DictDefault({"chat_template": "chatml"}),
|
||||
)
|
||||
res = strat.tokenize_prompt(self.dataset[0])
|
||||
assert "rejected_input_ids" in res
|
||||
assert "rejected_labels" in res
|
||||
assert "input_ids" in res
|
||||
assert "labels" in res
|
||||
assert "prompt_attention_mask" in res
|
||||
|
||||
assert len(res["rejected_input_ids"]) == len(res["rejected_labels"])
|
||||
assert len(res["input_ids"]) == len(res["labels"])
|
||||
assert len(res["input_ids"]) == len(res["prompt_attention_mask"])
|
||||
|
||||
assert res["rejected_labels"][0] == -100
|
||||
assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1]
|
||||
|
||||
assert res["labels"][0] == -100
|
||||
assert res["input_ids"][-1] == res["labels"][-1]
|
||||
|
||||
assert res["prompt_attention_mask"][0] == 1
|
||||
assert res["prompt_attention_mask"][-1] == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user