Compare commits
11 Commits
4bit-optim
...
scatter_mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
10328b3429 | ||
|
|
5bfc470d57 | ||
|
|
04168801c9 | ||
|
|
d43a79b7bf | ||
|
|
884d81331e | ||
|
|
2ea75b4160 | ||
|
|
035e680631 | ||
|
|
26fc10df01 | ||
|
|
1bc008e901 | ||
|
|
3f7ed6a784 | ||
|
|
feea977923 |
67
README.md
67
README.md
@@ -13,9 +13,6 @@ Features:
|
|||||||
- Log results and optionally checkpoints to wandb or mlflow
|
- Log results and optionally checkpoints to wandb or mlflow
|
||||||
- And more!
|
- And more!
|
||||||
|
|
||||||
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
|
|
||||||
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
|
|
||||||
</a>
|
|
||||||
|
|
||||||
<table>
|
<table>
|
||||||
<tr>
|
<tr>
|
||||||
@@ -31,7 +28,6 @@ Features:
|
|||||||
- [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod
|
- [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod
|
||||||
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
|
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
|
||||||
- [Windows](#windows)
|
- [Windows](#windows)
|
||||||
- [Mac](#mac)
|
|
||||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||||
- [Dataset](#dataset)
|
- [Dataset](#dataset)
|
||||||
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
||||||
@@ -103,14 +99,24 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
|||||||
|
|
||||||
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
|
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
|
||||||
|
|
||||||
|
### For developers
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging
|
pip3 install packaging
|
||||||
|
```
|
||||||
|
|
||||||
|
General case:
|
||||||
|
```
|
||||||
pip3 install -e '.[flash-attn,deepspeed]'
|
pip3 install -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Mac: see https://github.com/OpenAccess-AI-Collective/axolotl/blob/13199f678b9aab39e92961323bdbce3234ee4b2b/docs/mac.md
|
||||||
|
```
|
||||||
|
pip3 install -e '.'
|
||||||
|
```
|
||||||
|
|
||||||
### Usage
|
### Usage
|
||||||
```bash
|
```bash
|
||||||
# preprocess datasets - optional but recommended
|
# preprocess datasets - optional but recommended
|
||||||
@@ -243,31 +249,9 @@ For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud
|
|||||||
```
|
```
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
##### GCP
|
|
||||||
|
|
||||||
<details>
|
|
||||||
|
|
||||||
<summary>Click to Expand</summary>
|
|
||||||
|
|
||||||
Use a Deeplearning linux OS with cuda and pytorch installed. Then follow instructions on quickstart.
|
|
||||||
|
|
||||||
Make sure to run the below to uninstall xla.
|
|
||||||
```bash
|
|
||||||
pip uninstall -y torch_xla[tpu]
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
#### Windows
|
#### Windows
|
||||||
Please use WSL or Docker!
|
Please use WSL or Docker!
|
||||||
|
|
||||||
#### Mac
|
|
||||||
|
|
||||||
Use the below instead of the install method in QuickStart.
|
|
||||||
```
|
|
||||||
pip3 install -e '.'
|
|
||||||
```
|
|
||||||
More info: [mac.md](/docs/mac.md)
|
|
||||||
|
|
||||||
#### Launching on public clouds via SkyPilot
|
#### Launching on public clouds via SkyPilot
|
||||||
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
|
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
|
||||||
@@ -651,13 +635,9 @@ datasets:
|
|||||||
train_on_split: train # Optional[str] name of dataset split to load from
|
train_on_split: train # Optional[str] name of dataset split to load from
|
||||||
|
|
||||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
field_human: # Optional[str]. Human key to use for conversation.
|
field_human: # Optional[str]. Human key to use for conversation.
|
||||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
field_model: # Optional[str]. Assistant key to use for conversation.
|
||||||
# Add additional keys from your dataset as input or output roles
|
|
||||||
roles:
|
|
||||||
input: # Optional[List[str]]. These will be masked based on train_on_input
|
|
||||||
output: # Optional[List[str]].
|
|
||||||
|
|
||||||
# Custom user instruction prompt
|
# Custom user instruction prompt
|
||||||
- path: repo
|
- path: repo
|
||||||
@@ -682,10 +662,6 @@ datasets:
|
|||||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||||
field:
|
field:
|
||||||
|
|
||||||
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
|
|
||||||
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
|
|
||||||
shuffle_merged_datasets: true
|
|
||||||
|
|
||||||
# A list of one or more datasets to eval the model with.
|
# A list of one or more datasets to eval the model with.
|
||||||
# You can use either test_datasets, or val_set_size, but not both.
|
# You can use either test_datasets, or val_set_size, but not both.
|
||||||
test_datasets:
|
test_datasets:
|
||||||
@@ -867,7 +843,7 @@ group_by_length: false
|
|||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||||
# gradient_checkpointing_kwargs:
|
# gradient_checkpointing_kwargs:
|
||||||
# use_reentrant: true
|
# use_reentrant: false
|
||||||
|
|
||||||
# Stop training after this many evaluation losses have increased in a row
|
# Stop training after this many evaluation losses have increased in a row
|
||||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||||
@@ -907,26 +883,7 @@ lr_div_factor: # Learning rate div factor
|
|||||||
# - paged_adamw_8bit
|
# - paged_adamw_8bit
|
||||||
# - paged_lion_32bit
|
# - paged_lion_32bit
|
||||||
# - paged_lion_8bit
|
# - paged_lion_8bit
|
||||||
# - galore_adamw
|
|
||||||
# - galore_adamw_8bit
|
|
||||||
# - galore_adafactor
|
|
||||||
# - galore_adamw_layerwise
|
|
||||||
# - galore_adamw_8bit_layerwise
|
|
||||||
# - galore_adafactor_layerwise
|
|
||||||
optimizer:
|
optimizer:
|
||||||
# Dictionary of arguments to pass to the optimizer
|
|
||||||
optim_args:
|
|
||||||
# For Galore Optimizers the following optim_args are available
|
|
||||||
# rank: # type: int
|
|
||||||
# update_proj_gap # type: int
|
|
||||||
# scale # type: float
|
|
||||||
# proj_type: # type: str, default = std
|
|
||||||
|
|
||||||
# The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
|
|
||||||
optim_target_modules:
|
|
||||||
# - self_attn # for llama
|
|
||||||
# - mlp
|
|
||||||
|
|
||||||
# Specify weight decay
|
# Specify weight decay
|
||||||
weight_decay:
|
weight_decay:
|
||||||
# adamw hyperparams
|
# adamw hyperparams
|
||||||
|
|||||||
@@ -23,9 +23,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ WORKDIR /workspace/axolotl
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
@@ -1,29 +0,0 @@
|
|||||||
# Optimizers
|
|
||||||
|
|
||||||
Optimizers are an important component when training LLMs. Optimizers are responsible for updating the model's weights (parameters) based on the gradients computed during backpropagation.
|
|
||||||
The goal of an optimizer is to minimize the loss function.
|
|
||||||
|
|
||||||
### Adam/AdamW Optimizers
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
adam_beta1: 0.9
|
|
||||||
adam_beta2: 0.999
|
|
||||||
adam_epsilon: 1e-8
|
|
||||||
weight_decay: 0.0
|
|
||||||
```
|
|
||||||
|
|
||||||
### GaLore Optimizer
|
|
||||||
|
|
||||||
https://huggingface.co/papers/2403.03507
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
optimizer: galore_adamw | galore_adamw_8bit | galore_adafactor
|
|
||||||
optim_args:
|
|
||||||
rank: 128
|
|
||||||
update_proj_gap: 200
|
|
||||||
scale: 0.25
|
|
||||||
proj_type: std
|
|
||||||
optim_target_modules:
|
|
||||||
- mlp
|
|
||||||
- attn
|
|
||||||
```
|
|
||||||
15
docs/rlhf.md
15
docs/rlhf.md
@@ -34,21 +34,6 @@ datasets:
|
|||||||
rl: ipo
|
rl: ipo
|
||||||
```
|
```
|
||||||
|
|
||||||
#### ORPO
|
|
||||||
|
|
||||||
Paper: https://arxiv.org/abs/2403.07691
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
rl: orpo
|
|
||||||
orpo_alpha: 0.1
|
|
||||||
remove_unused_columns: false
|
|
||||||
|
|
||||||
chat_template: chatml
|
|
||||||
datasets:
|
|
||||||
- path: argilla/ultrafeedback-binarized-preferences-cleaned
|
|
||||||
type: orpo.chat_template
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Using local dataset files
|
#### Using local dataset files
|
||||||
```yaml
|
```yaml
|
||||||
datasets:
|
datasets:
|
||||||
|
|||||||
75
examples/mistral/mixtral_fused.py
Normal file
75
examples/mistral/mixtral_fused.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import gc
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||||
|
from transformers import AutoTokenizer, TextStreamer
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralForCausalLM, MixtralConfig
|
||||||
|
|
||||||
|
def compute_memory_used_pct(device):
|
||||||
|
memory_used = torch.cuda.max_memory_allocated(device) / (1024**3)
|
||||||
|
memory_pct = (
|
||||||
|
memory_used
|
||||||
|
/ (torch.cuda.get_device_properties(device).total_memory / (1024**3))
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
return memory_pct
|
||||||
|
|
||||||
|
model_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
config = MixtralConfig.from_pretrained(model_path, max_position_embeddings=2048, use_cache=False)
|
||||||
|
model = MixtralForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
config=config,
|
||||||
|
device_map="auto",
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
modules = {k:v for k,v in model.named_modules() if isinstance(v, MixtralSparseMoeBlock)}
|
||||||
|
|
||||||
|
for device_index in range(torch.cuda.device_count()):
|
||||||
|
device_memory_pct = compute_memory_used_pct(device_index)
|
||||||
|
print(device_index, device_memory_pct)
|
||||||
|
|
||||||
|
with tqdm(modules.items(), desc="scatter moe") as pbar:
|
||||||
|
for i, (name, module) in enumerate(pbar):
|
||||||
|
smoe = SparseMoeBlock(
|
||||||
|
experts=module.experts,
|
||||||
|
gate=module.gate,
|
||||||
|
hidden_dim=module.hidden_dim,
|
||||||
|
ffn_dim=module.ffn_dim,
|
||||||
|
num_experts=module.num_experts,
|
||||||
|
top_k=module.top_k,
|
||||||
|
)
|
||||||
|
old_module = model.model.layers[i].block_sparse_moe
|
||||||
|
setattr(model.model.layers[i], "block_sparse_moe", smoe)
|
||||||
|
del old_module
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
for device_index in range(torch.cuda.device_count()):
|
||||||
|
device_memory_pct = compute_memory_used_pct(device_index)
|
||||||
|
print(device_index, device_memory_pct)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# Convert prompt to tokens
|
||||||
|
prompt_template = "[INST] {prompt} [/INST]"
|
||||||
|
|
||||||
|
prompt = "You're standing on the surface of the Earth. "\
|
||||||
|
"You walk one mile south, one mile west and one mile north. "\
|
||||||
|
"You end up exactly where you started. Where are you?"
|
||||||
|
|
||||||
|
tokens = tokenizer(
|
||||||
|
prompt_template.format(prompt=prompt),
|
||||||
|
return_tensors='pt'
|
||||||
|
).input_ids.cuda()
|
||||||
|
|
||||||
|
# Generate output
|
||||||
|
generation_output = model.generate(
|
||||||
|
tokens,
|
||||||
|
streamer=streamer,
|
||||||
|
max_new_tokens=512
|
||||||
|
)
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.9.0
|
peft==0.9.0
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git@f6261d7d81edd036fc53bfede65fe91f01a661aa
|
transformers==4.38.2
|
||||||
tokenizers==0.15.0
|
tokenizers==0.15.0
|
||||||
bitsandbytes>=0.43.0
|
bitsandbytes>=0.43.0
|
||||||
accelerate==0.26.1
|
accelerate==0.26.1
|
||||||
@@ -39,8 +39,5 @@ s3fs
|
|||||||
gcsfs
|
gcsfs
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
|
trl>=0.7.9
|
||||||
fastcore>=1.5.29
|
fastcore>=1.5.29
|
||||||
|
|
||||||
lpmm @ git+https://github.com/thu-ml/low-bit-optimizers.git@main
|
|
||||||
yacs
|
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -89,8 +89,5 @@ setup(
|
|||||||
"lion-pytorch": [
|
"lion-pytorch": [
|
||||||
"lion-pytorch==0.1.2",
|
"lion-pytorch==0.1.2",
|
||||||
],
|
],
|
||||||
"galore": [
|
|
||||||
"galore_torch",
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
LOG.warning(msg)
|
LOG.warning(msg)
|
||||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
|
||||||
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
|
if parsed_cfg.rl:
|
||||||
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
else:
|
else:
|
||||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|||||||
else:
|
else:
|
||||||
register_chatml_template()
|
register_chatml_template()
|
||||||
|
|
||||||
if cfg.rl and cfg.rl != "orpo":
|
if cfg.rl:
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -11,25 +11,21 @@ import math
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections import defaultdict
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
from typing import List, Optional, Type, Union
|
||||||
|
|
||||||
import lpmm
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import FullyShardedDataParallelPlugin
|
from accelerate import FullyShardedDataParallelPlugin
|
||||||
from accelerate.utils import str_to_bool
|
from accelerate.utils import str_to_bool
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from torch import nn
|
|
||||||
from torch.distributed.fsdp import MixedPrecision
|
from torch.distributed.fsdp import MixedPrecision
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import (
|
from transformers import (
|
||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
PreTrainedModel,
|
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
@@ -39,7 +35,6 @@ from transformers.utils import is_sagemaker_mp_enabled
|
|||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
|
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
|
||||||
from axolotl.core.trainers import OptimizerNames
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
from axolotl.loraplus import create_loraplus_optimizer
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
@@ -66,9 +61,6 @@ from axolotl.utils.schedulers import (
|
|||||||
get_cosine_schedule_with_warmup_decay_constant,
|
get_cosine_schedule_with_warmup_decay_constant,
|
||||||
)
|
)
|
||||||
|
|
||||||
# monkeypatch so it accepts our custom optimizers
|
|
||||||
transformers.training_args.OptimizerNames = OptimizerNames
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
import smdistributed.modelparallel.torch as smp
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
@@ -208,9 +200,6 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "whether this is a qlora training"},
|
metadata={"help": "whether this is a qlora training"},
|
||||||
)
|
)
|
||||||
orpo_alpha: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class AxolotlTrainer(Trainer):
|
||||||
@@ -227,115 +216,33 @@ class AxolotlTrainer(Trainer):
|
|||||||
num_epochs=1,
|
num_epochs=1,
|
||||||
bench_data_collator=None,
|
bench_data_collator=None,
|
||||||
eval_data_collator=None,
|
eval_data_collator=None,
|
||||||
**kwargs,
|
**kwargs
|
||||||
):
|
):
|
||||||
self.num_epochs = num_epochs
|
self.num_epochs = num_epochs
|
||||||
self.bench_data_collator = bench_data_collator
|
self.bench_data_collator = bench_data_collator
|
||||||
self.eval_data_collator = eval_data_collator
|
self.eval_data_collator = eval_data_collator
|
||||||
super().__init__(*_args, **kwargs)
|
super().__init__(*_args, **kwargs)
|
||||||
self.train_data_collator = self.data_collator
|
self.train_data_collator = self.data_collator
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
|
||||||
if self.args.orpo_alpha:
|
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_optimizer_cls_and_kwargs(
|
|
||||||
args: TrainingArguments, model: Optional[PreTrainedModel] = None
|
|
||||||
) -> Tuple[Any, Any]:
|
|
||||||
optim_args = {}
|
|
||||||
if args.optim_args:
|
|
||||||
for mapping in args.optim_args.replace(" ", "").split(","):
|
|
||||||
key, value = mapping.split("=")
|
|
||||||
optim_args[key] = value
|
|
||||||
|
|
||||||
optimizer_kwargs = {"lr": args.learning_rate}
|
|
||||||
|
|
||||||
adam_kwargs = {
|
|
||||||
"betas": (args.adam_beta1, args.adam_beta2),
|
|
||||||
"eps": args.adam_epsilon,
|
|
||||||
}
|
|
||||||
|
|
||||||
if args.optim in [
|
|
||||||
OptimizerNames.LPMM_ADAMW_4BIT,
|
|
||||||
OptimizerNames.LPMM_ADAMW_4BIT_FUSED,
|
|
||||||
]:
|
|
||||||
optimizer_cls = lpmm.optim.AdamW
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
|
||||||
if args.optim == OptimizerNames.LPMM_ADAMW_4BIT_FUSED:
|
|
||||||
optimizer_kwargs.update({"fused": True})
|
|
||||||
return optimizer_cls, optimizer_kwargs
|
|
||||||
|
|
||||||
return Trainer.get_optimizer_cls_and_kwargs(
|
|
||||||
args,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
|
if self.args.loraplus_lr_ratio is None:
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
|
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||||
optimizer_grouped_parameters = [
|
self.args,
|
||||||
{
|
)
|
||||||
"params": [
|
|
||||||
p
|
|
||||||
for n, p in opt_model.named_parameters()
|
|
||||||
if (n in decay_parameters and p.requires_grad)
|
|
||||||
],
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p
|
|
||||||
for n, p in opt_model.named_parameters()
|
|
||||||
if (n not in decay_parameters and p.requires_grad)
|
|
||||||
],
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
(
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
|
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||||
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
opt_model,
|
||||||
optimizer_cls,
|
optimizer_cls,
|
||||||
optimizer_kwargs,
|
optimizer_kwargs,
|
||||||
) = AxolotlTrainer.get_optimizer_cls_and_kwargs(self.args)
|
loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding,
|
||||||
if self.args.loraplus_lr_ratio:
|
)
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
|
||||||
loraplus_lr_embedding = getattr(
|
|
||||||
self.args, "loraplus_lr_embedding", None
|
|
||||||
)
|
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
opt_model,
|
|
||||||
optimizer_cls,
|
|
||||||
optimizer_kwargs,
|
|
||||||
loraplus_lr_ratio,
|
|
||||||
loraplus_lr_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
if optimizer_cls.__name__ == "Adam8bit":
|
|
||||||
import bitsandbytes
|
|
||||||
|
|
||||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
|
||||||
|
|
||||||
skipped = 0
|
|
||||||
for module in opt_model.modules():
|
|
||||||
if isinstance(module, nn.Embedding):
|
|
||||||
skipped += sum(
|
|
||||||
{
|
|
||||||
p.data_ptr(): p.numel() for p in module.parameters()
|
|
||||||
}.values()
|
|
||||||
)
|
|
||||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
|
||||||
manager.register_module_override(
|
|
||||||
module, "weight", {"optim_bits": 32}
|
|
||||||
)
|
|
||||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
|
||||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
@@ -558,112 +465,8 @@ class AxolotlTrainer(Trainer):
|
|||||||
# outputs = model(**inputs)
|
# outputs = model(**inputs)
|
||||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||||
# return (loss, outputs) if return_outputs else loss
|
# return (loss, outputs) if return_outputs else loss
|
||||||
if self.args.orpo_alpha:
|
|
||||||
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
|
|
||||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
|
|
||||||
def orpo_compute_custom_loss(self, logits, labels):
|
|
||||||
logits = logits.contiguous()
|
|
||||||
loss = 0.0
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
# move labels to correct device to enable model parallelism
|
|
||||||
labels = labels.to(logits.device)
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
|
|
||||||
# Flatten the tokens
|
|
||||||
loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(
|
|
||||||
dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def orpo_compute_logps(
|
|
||||||
self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits
|
|
||||||
):
|
|
||||||
# Get the shape of chosen_attention_mask[:, :-1]
|
|
||||||
chosen_shape = chosen_attention_mask[:, :-1].shape
|
|
||||||
|
|
||||||
# Calculate the padding size
|
|
||||||
pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1)
|
|
||||||
|
|
||||||
# Pad prompt_attention_mask with zeros to match the desired shape
|
|
||||||
prompt_attention_mask_padded = torch.nn.functional.pad(
|
|
||||||
prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Perform the subtraction operation
|
|
||||||
mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded
|
|
||||||
|
|
||||||
per_token_logps = torch.gather(
|
|
||||||
logits[:, :-1, :].log_softmax(-1),
|
|
||||||
dim=2,
|
|
||||||
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
|
|
||||||
).squeeze(2)
|
|
||||||
return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(
|
|
||||||
dtype=torch.float64
|
|
||||||
) / mask.sum(dim=1).to(dtype=torch.float64)
|
|
||||||
|
|
||||||
def orpo_compute_loss(self, model, inputs, return_outputs=False):
|
|
||||||
outputs_neg = model(
|
|
||||||
**{
|
|
||||||
"input_ids": inputs["rejected_input_ids"],
|
|
||||||
"attention_mask": inputs["rejected_attention_mask"],
|
|
||||||
"labels": inputs["rejected_labels"],
|
|
||||||
},
|
|
||||||
output_hidden_states=True,
|
|
||||||
)
|
|
||||||
outputs_pos = model(
|
|
||||||
**{
|
|
||||||
"input_ids": inputs["input_ids"],
|
|
||||||
"attention_mask": inputs["attention_mask"],
|
|
||||||
"labels": inputs["labels"],
|
|
||||||
},
|
|
||||||
output_hidden_states=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate NLL loss
|
|
||||||
pos_loss = self.orpo_compute_custom_loss(
|
|
||||||
logits=outputs_pos.logits, labels=inputs["input_ids"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate Log Probability
|
|
||||||
pos_prob = self.orpo_compute_logps(
|
|
||||||
prompt_attention_mask=inputs["prompt_attention_mask"],
|
|
||||||
chosen_inputs=inputs["input_ids"],
|
|
||||||
chosen_attention_mask=inputs["attention_mask"],
|
|
||||||
logits=outputs_pos.logits,
|
|
||||||
)
|
|
||||||
neg_prob = self.orpo_compute_logps(
|
|
||||||
prompt_attention_mask=inputs["prompt_attention_mask"],
|
|
||||||
chosen_inputs=inputs["rejected_input_ids"],
|
|
||||||
chosen_attention_mask=inputs["rejected_attention_mask"],
|
|
||||||
logits=outputs_neg.logits,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate log odds
|
|
||||||
log_odds = (pos_prob - neg_prob) - (
|
|
||||||
torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob))
|
|
||||||
)
|
|
||||||
sig_ratio = torch.nn.functional.sigmoid(log_odds)
|
|
||||||
ratio = torch.log(sig_ratio)
|
|
||||||
|
|
||||||
# Calculate the Final Loss
|
|
||||||
loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to(
|
|
||||||
dtype=torch.bfloat16
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics = {}
|
|
||||||
metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item()
|
|
||||||
metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item()
|
|
||||||
metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item()
|
|
||||||
metrics["log_odds"] = torch.mean(log_odds).cpu().item()
|
|
||||||
self.store_metrics(metrics, train_eval="train")
|
|
||||||
|
|
||||||
return (loss, outputs_pos) if return_outputs else loss
|
|
||||||
|
|
||||||
@wraps(Trainer.push_to_hub)
|
@wraps(Trainer.push_to_hub)
|
||||||
def push_to_hub(self, *args, **kwargs) -> str:
|
def push_to_hub(self, *args, **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -724,28 +527,6 @@ class AxolotlTrainer(Trainer):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def log(self, logs: Dict[str, float]) -> None:
|
|
||||||
"""
|
|
||||||
Log `logs` on the various objects watching training, including stored metrics.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logs (`Dict[str, float]`):
|
|
||||||
The values to log.
|
|
||||||
"""
|
|
||||||
# logs either has 'loss' or 'eval_loss'
|
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
|
||||||
# Add averaged stored metrics to logs
|
|
||||||
for key, metrics in self._stored_metrics[train_eval].items():
|
|
||||||
logs[key] = torch.tensor(metrics).mean().item()
|
|
||||||
del self._stored_metrics[train_eval]
|
|
||||||
return super().log(logs)
|
|
||||||
|
|
||||||
def store_metrics(
|
|
||||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
|
||||||
) -> None:
|
|
||||||
for key, value in metrics.items():
|
|
||||||
self._stored_metrics[train_eval][key].append(value)
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1056,6 +837,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"gradient_checkpointing_kwargs"
|
"gradient_checkpointing_kwargs"
|
||||||
] = self.cfg.gradient_checkpointing_kwargs
|
] = self.cfg.gradient_checkpointing_kwargs
|
||||||
|
else:
|
||||||
|
training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
|
||||||
|
"use_reentrant": False
|
||||||
|
}
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||||
if self.cfg.fsdp_config:
|
if self.cfg.fsdp_config:
|
||||||
@@ -1118,11 +903,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
||||||
training_arguments_kwargs["dataloader_drop_last"] = True
|
training_arguments_kwargs["dataloader_drop_last"] = True
|
||||||
|
|
||||||
if self.cfg.remove_unused_columns is not None:
|
|
||||||
training_arguments_kwargs[
|
|
||||||
"remove_unused_columns"
|
|
||||||
] = self.cfg.remove_unused_columns
|
|
||||||
|
|
||||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||||
# no eval set, so don't eval
|
# no eval set, so don't eval
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||||
@@ -1236,18 +1016,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["optim"] = (
|
training_arguments_kwargs["optim"] = (
|
||||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||||
)
|
)
|
||||||
if self.cfg.optim_args:
|
|
||||||
if isinstance(self.cfg.optim_args, dict):
|
|
||||||
optim_args = ",".join(
|
|
||||||
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
optim_args = self.cfg.optim_args
|
|
||||||
training_arguments_kwargs["optim_args"] = optim_args
|
|
||||||
if self.cfg.optim_target_modules:
|
|
||||||
training_arguments_kwargs[
|
|
||||||
"optim_target_modules"
|
|
||||||
] = self.cfg.optim_target_modules
|
|
||||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"loraplus_lr_embedding"
|
"loraplus_lr_embedding"
|
||||||
@@ -1302,9 +1070,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||||
|
|
||||||
if self.cfg.rl == "orpo":
|
|
||||||
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
|
|
||||||
|
|
||||||
if self.cfg.neftune_noise_alpha is not None:
|
if self.cfg.neftune_noise_alpha is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"neftune_noise_alpha"
|
"neftune_noise_alpha"
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
"""module for trainer helpers like OptimizerNames"""
|
|
||||||
|
|
||||||
from transformers.utils import ExplicitEnum
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizerNames(ExplicitEnum):
|
|
||||||
"""
|
|
||||||
Stores the acceptable string identifiers for optimizers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
ADAMW_HF = "adamw_hf"
|
|
||||||
ADAMW_TORCH = "adamw_torch"
|
|
||||||
ADAMW_TORCH_FUSED = "adamw_torch_fused"
|
|
||||||
ADAMW_TORCH_XLA = "adamw_torch_xla"
|
|
||||||
ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
|
|
||||||
ADAMW_APEX_FUSED = "adamw_apex_fused"
|
|
||||||
ADAFACTOR = "adafactor"
|
|
||||||
ADAMW_ANYPRECISION = "adamw_anyprecision"
|
|
||||||
SGD = "sgd"
|
|
||||||
ADAGRAD = "adagrad"
|
|
||||||
ADAMW_BNB = "adamw_bnb_8bit"
|
|
||||||
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
|
|
||||||
LION_8BIT = "lion_8bit"
|
|
||||||
LION = "lion_32bit"
|
|
||||||
PAGED_ADAMW = "paged_adamw_32bit"
|
|
||||||
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
|
|
||||||
PAGED_LION = "paged_lion_32bit"
|
|
||||||
PAGED_LION_8BIT = "paged_lion_8bit"
|
|
||||||
RMSPROP = "rmsprop"
|
|
||||||
RMSPROP_BNB = "rmsprop_bnb"
|
|
||||||
RMSPROP_8BIT = "rmsprop_bnb_8bit"
|
|
||||||
RMSPROP_32BIT = "rmsprop_bnb_32bit"
|
|
||||||
GALORE_ADAMW = "galore_adamw"
|
|
||||||
GALORE_ADAMW_8BIT = "galore_adamw_8bit"
|
|
||||||
GALORE_ADAFACTOR = "galore_adafactor"
|
|
||||||
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
|
|
||||||
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
|
|
||||||
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
|
|
||||||
LPMM_ADAMW_4BIT = "lmpp_adamw_4bit"
|
|
||||||
LPMM_ADAMW_4BIT_FUSED = "lmpp_adamw_4bit_fused"
|
|
||||||
|
|||||||
0
src/axolotl/monkeypatch/moe/__init__.py
Normal file
0
src/axolotl/monkeypatch/moe/__init__.py
Normal file
149
src/axolotl/monkeypatch/moe/linear.py
Normal file
149
src/axolotl/monkeypatch/moe/linear.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""
|
||||||
|
Adapted from:
|
||||||
|
https://github.com/shawntan/scattermoe
|
||||||
|
https://arxiv.org/abs/2403.08245
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from axolotl.monkeypatch.moe import ops
|
||||||
|
|
||||||
|
class ParallelLinear(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx, x, expert_weights, k,
|
||||||
|
sorted_expert_idxs, sorted_scattered_idxs,
|
||||||
|
padded_block_idxs, expert_offsets,
|
||||||
|
gates=None, grouped_in=False, grouped_out=False,
|
||||||
|
):
|
||||||
|
|
||||||
|
output = ops.scatter2scatter(
|
||||||
|
X=x, W=expert_weights,
|
||||||
|
sorted_expert_idxs=sorted_expert_idxs,
|
||||||
|
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||||
|
padded_block_idxs=padded_block_idxs,
|
||||||
|
k=k, x_grouped=grouped_in, y_grouped=grouped_out
|
||||||
|
)
|
||||||
|
if gates is not None:
|
||||||
|
output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1))
|
||||||
|
output = torch.bmm(
|
||||||
|
gates[:, None, :],
|
||||||
|
output_expanded
|
||||||
|
).squeeze(1)
|
||||||
|
else:
|
||||||
|
output_expanded = None
|
||||||
|
|
||||||
|
ctx.save_for_backward(
|
||||||
|
x, expert_weights,
|
||||||
|
sorted_expert_idxs,
|
||||||
|
sorted_scattered_idxs,
|
||||||
|
padded_block_idxs, expert_offsets,
|
||||||
|
gates,
|
||||||
|
output_expanded
|
||||||
|
)
|
||||||
|
ctx.grouped_in = grouped_in
|
||||||
|
ctx.grouped_out = grouped_out
|
||||||
|
ctx.k = k
|
||||||
|
return output
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
(x, expert_weights,
|
||||||
|
sorted_expert_idxs,
|
||||||
|
sorted_scattered_idxs,
|
||||||
|
padded_block_idxs, expert_offsets,
|
||||||
|
gates, output_expanded) = ctx.saved_tensors
|
||||||
|
k = ctx.k
|
||||||
|
grouped_in = ctx.grouped_in
|
||||||
|
grouped_out = ctx.grouped_out
|
||||||
|
# print("backward")
|
||||||
|
if gates is not None:
|
||||||
|
# calculate gates gradient
|
||||||
|
d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
|
||||||
|
gates_flat = gates.flatten()
|
||||||
|
gate_fan = gates.size(1)
|
||||||
|
# print("expanded and grouping")
|
||||||
|
grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later
|
||||||
|
else:
|
||||||
|
d_gates = None
|
||||||
|
gates_flat = None
|
||||||
|
gate_fan = 1
|
||||||
|
grouped_grad_out = None
|
||||||
|
|
||||||
|
if grouped_out:
|
||||||
|
grouped_grad_out = grad_out
|
||||||
|
else:
|
||||||
|
grouped_grad_out = ops.group(grad_out, sorted_scattered_idxs,
|
||||||
|
fan_out=gate_fan, coeff=gates_flat,
|
||||||
|
out=grouped_grad_out)
|
||||||
|
if grouped_in:
|
||||||
|
grouped_x = x
|
||||||
|
d_expanded_input = None
|
||||||
|
else:
|
||||||
|
grouped_x = ops.group(x, sorted_scattered_idxs, fan_out=k)
|
||||||
|
d_expanded_input = grouped_x
|
||||||
|
d_weights = ops.group_bwd_W(
|
||||||
|
DY=grouped_grad_out, X=grouped_x,
|
||||||
|
expert_offsets=expert_offsets,
|
||||||
|
E=expert_weights.size(0)
|
||||||
|
)
|
||||||
|
d_expanded_input = ops.scatter2scatter(
|
||||||
|
X=grouped_grad_out, x_grouped=True,
|
||||||
|
W=expert_weights.permute(0, 2, 1),
|
||||||
|
padded_block_idxs=padded_block_idxs,
|
||||||
|
sorted_expert_idxs=sorted_expert_idxs,
|
||||||
|
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||||
|
k=1,
|
||||||
|
y_grouped=grouped_in,
|
||||||
|
out=d_expanded_input # Reuse grouped_x buffer
|
||||||
|
)
|
||||||
|
|
||||||
|
if k == 1:
|
||||||
|
d_input = d_expanded_input
|
||||||
|
else:
|
||||||
|
d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2)
|
||||||
|
# print("backward end.")
|
||||||
|
return (
|
||||||
|
# x, expert_weights, k,
|
||||||
|
d_input, d_weights, None,
|
||||||
|
# sorted_expert_idxs, sorted_scattered_idxs,
|
||||||
|
None, None,
|
||||||
|
# padded_block_idxs, expert_offsets,
|
||||||
|
None, None,
|
||||||
|
# gates
|
||||||
|
d_gates, None, None
|
||||||
|
)
|
||||||
|
|
||||||
|
def parallel_linear(inputs, expert_weights, k,
|
||||||
|
sorted_expert_idxs, sorted_scattered_idxs,
|
||||||
|
padded_block_idxs, expert_offsets,
|
||||||
|
gates=None):
|
||||||
|
results = ParallelLinear.apply(inputs, expert_weights, k,
|
||||||
|
sorted_expert_idxs, sorted_scattered_idxs,
|
||||||
|
padded_block_idxs, expert_offsets, gates)
|
||||||
|
return results
|
||||||
|
|
||||||
|
class ParallelExperts(nn.Module):
|
||||||
|
def __init__(self, num_experts, input_size, output_size, device) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
torch.empty(num_experts, output_size, input_size, device=device)
|
||||||
|
)
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.input_size = input_size
|
||||||
|
self.output_size = output_size
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return 'num_experts={}, input_size={}, output_size={}'.format(
|
||||||
|
self.num_experts, self.input_size, self.output_size)
|
||||||
|
|
||||||
|
def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs,
|
||||||
|
padded_block_idxs, expert_offsets,
|
||||||
|
gates=None, grouped_in=False, grouped_out=False):
|
||||||
|
|
||||||
|
results = ParallelLinear.apply(
|
||||||
|
inputs, self.weight.permute(0, 2, 1), k,
|
||||||
|
sorted_expert_idxs, sorted_scattered_idxs,
|
||||||
|
padded_block_idxs, expert_offsets,
|
||||||
|
gates, grouped_in, grouped_out
|
||||||
|
)
|
||||||
|
return results
|
||||||
86
src/axolotl/monkeypatch/moe/mlp.py
Normal file
86
src/axolotl/monkeypatch/moe/mlp.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""
|
||||||
|
Adapted from:
|
||||||
|
https://github.com/shawntan/scattermoe
|
||||||
|
https://arxiv.org/abs/2403.08245
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.moe import ops
|
||||||
|
from axolotl.monkeypatch.moe.linear import ParallelExperts
|
||||||
|
|
||||||
|
|
||||||
|
class FusedExperts(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
experts: nn.ModuleList =None,
|
||||||
|
hidden_dim=128,
|
||||||
|
ffn_dim=512,
|
||||||
|
num_experts=8,
|
||||||
|
top_k=2,
|
||||||
|
activation=nn.SiLU(),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This implements fused experts that are compatible with Mixtral.
|
||||||
|
MLP of type Gated-Linear Unit, typically with a SiLU activation function.
|
||||||
|
"""
|
||||||
|
super(FusedExperts, self).__init__()
|
||||||
|
|
||||||
|
device = experts[0].w1.weight.device
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.ffn_dim = ffn_dim
|
||||||
|
self.experts = ParallelExperts(num_experts, hidden_dim, 2 * ffn_dim, device=device)
|
||||||
|
self.output_experts = ParallelExperts(num_experts, ffn_dim, hidden_dim, device=device)
|
||||||
|
self.top_k = min(top_k, self.num_experts)
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(len(experts)):
|
||||||
|
self.experts.weight.data[i].copy_(
|
||||||
|
torch.cat(
|
||||||
|
[experts[i].w1.weight.detach(), experts[i].w3.weight.detach()],
|
||||||
|
dim=0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.output_experts.weight.data[i].copy_(
|
||||||
|
experts[i].w2.weight.detach()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, routing_weights: torch.Tensor, selected_experts: torch.Tensor
|
||||||
|
):
|
||||||
|
x_shape = x.size()
|
||||||
|
x = x.view(-1, x_shape[-1])
|
||||||
|
with torch.no_grad():
|
||||||
|
sorted_expert_idxs, sorted_scattered_idxs = ops.flatten_and_sort(
|
||||||
|
selected_experts
|
||||||
|
)
|
||||||
|
padded_block_idxs, expert_offsets = ops.padded_block_indices(
|
||||||
|
sorted_expert_idxs, self.num_experts
|
||||||
|
)
|
||||||
|
|
||||||
|
h, gates = self.experts(
|
||||||
|
x,
|
||||||
|
self.top_k,
|
||||||
|
sorted_expert_idxs,
|
||||||
|
sorted_scattered_idxs,
|
||||||
|
padded_block_idxs,
|
||||||
|
expert_offsets,
|
||||||
|
grouped_out=True,
|
||||||
|
).chunk(2, dim=-1)
|
||||||
|
h = self.activation(gates) * h
|
||||||
|
y = self.output_experts(
|
||||||
|
h,
|
||||||
|
1,
|
||||||
|
sorted_expert_idxs,
|
||||||
|
sorted_scattered_idxs,
|
||||||
|
padded_block_idxs,
|
||||||
|
expert_offsets,
|
||||||
|
grouped_in=True,
|
||||||
|
gates=routing_weights,
|
||||||
|
)
|
||||||
|
y = y.view(*x_shape[:-1], y.size(-1))
|
||||||
|
return y
|
||||||
50
src/axolotl/monkeypatch/moe/moe.py
Normal file
50
src/axolotl/monkeypatch/moe/moe.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from axolotl.monkeypatch.moe.mlp import FusedExperts
|
||||||
|
|
||||||
|
class SparseMoeBlock(nn.Module):
|
||||||
|
def __init__(self, experts, gate, hidden_dim, ffn_dim, num_experts, top_k):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.ffn_dim = ffn_dim
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.top_k = top_k
|
||||||
|
self.gate = gate
|
||||||
|
self.experts = FusedExperts(
|
||||||
|
experts=experts,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
ffn_dim=ffn_dim,
|
||||||
|
num_experts=num_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
activation=experts[0].act_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
def _post_training(self, model, name):
|
||||||
|
# get original weights back: reverse the concat + stack in the fused experts
|
||||||
|
w1s, w3s = torch.split(torch.unbind(self.experts.experts.weight, dim=0), 2, dim=1)
|
||||||
|
w2s = torch.unbind(self.experts.output_experts.weight, dim=0)
|
||||||
|
|
||||||
|
# TODO: recreate MoE class with original weights
|
||||||
|
experts = []
|
||||||
|
for i in range(self.num_experts):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
|
# router_logits: (batch * sequence_length, n_experts)
|
||||||
|
router_logits = self.gate(hidden_states)
|
||||||
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
|
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||||
|
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# we cast back to the input dtype
|
||||||
|
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||||
|
|
||||||
|
# Fused expert forward
|
||||||
|
final_hidden_states = self.experts(hidden_states, routing_weights, selected_experts)
|
||||||
|
|
||||||
|
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||||
|
return final_hidden_states, router_logits
|
||||||
353
src/axolotl/monkeypatch/moe/ops.py
Normal file
353
src/axolotl/monkeypatch/moe/ops.py
Normal file
@@ -0,0 +1,353 @@
|
|||||||
|
"""
|
||||||
|
Adapted from:
|
||||||
|
https://github.com/shawntan/scattermoe
|
||||||
|
https://arxiv.org/abs/2403.08245
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
BLOCK_M = 128
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def flatten_and_sort(expert_idxs:torch.Tensor):
|
||||||
|
flattened_expert_idxs = expert_idxs.flatten()
|
||||||
|
sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
|
||||||
|
return sorted_expert_idxs, sorted_scattered_idxs
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def padded_block_indices(sorted_experts_idxs: torch.Tensor, k: int, N_BLOCK_SIZE: int=BLOCK_M) :
|
||||||
|
expert_counts = torch.bincount(sorted_experts_idxs, minlength=k)
|
||||||
|
padded_block_counts = ((expert_counts - 1) // N_BLOCK_SIZE) + 1
|
||||||
|
padded_expert_block_end = padded_block_counts.cumsum(-1)
|
||||||
|
expert_boundaries_end = expert_counts.cumsum(-1)
|
||||||
|
expert_boundaries_start = expert_boundaries_end - expert_counts
|
||||||
|
padded_expert_block_start = padded_expert_block_end - padded_block_counts
|
||||||
|
block_idxs = torch.arange(padded_expert_block_end[-1],
|
||||||
|
dtype=sorted_experts_idxs.dtype,
|
||||||
|
device=sorted_experts_idxs.device)
|
||||||
|
block_mask = (
|
||||||
|
(block_idxs[:, None] < padded_expert_block_start) |
|
||||||
|
(block_idxs[:, None] >= padded_expert_block_end)
|
||||||
|
)
|
||||||
|
expanded_block_idxs = (
|
||||||
|
N_BLOCK_SIZE * (block_idxs[:, None] - padded_expert_block_start) +
|
||||||
|
expert_boundaries_start
|
||||||
|
)
|
||||||
|
expanded_block_idxs = expanded_block_idxs.masked_fill(block_mask, 0).sum(-1)
|
||||||
|
return expanded_block_idxs, expert_boundaries_end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _scatter2scatter_configs():
|
||||||
|
return [
|
||||||
|
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
|
||||||
|
]
|
||||||
|
|
||||||
|
@triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], )
|
||||||
|
@triton.heuristics({
|
||||||
|
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
|
||||||
|
"NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
|
||||||
|
})
|
||||||
|
@triton.jit
|
||||||
|
def _scatter2scatter(
|
||||||
|
X_ptr, stride_xm, stride_xk,
|
||||||
|
W_ptr, stride_we, stride_wk, stride_wn,
|
||||||
|
Y_ptr, stride_ym, stride_yn,
|
||||||
|
grouped_idx_ptr, expert_idxs_ptr, block_start_idx_ptr,
|
||||||
|
FAN_OUT: tl.constexpr,
|
||||||
|
M: tl.constexpr, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||||
|
ACC_TYPE: tl.constexpr,
|
||||||
|
OUT_M: tl.constexpr,
|
||||||
|
allow_tf32: tl.constexpr,
|
||||||
|
x_grouped: tl.constexpr, y_grouped: tl.constexpr,
|
||||||
|
NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
|
||||||
|
N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
|
||||||
|
M_block_id = pid // N_BLOCK_COUNT
|
||||||
|
N_block_id = pid % N_BLOCK_COUNT
|
||||||
|
M_range = tl.arange(0, BLOCK_M)
|
||||||
|
block_start_idx = tl.load(block_start_idx_ptr + M_block_id)
|
||||||
|
# M_block = tl.max_contiguous((block_start_idx + M_range) % OUT_M, BLOCK_M)
|
||||||
|
M_block = tl.max_contiguous(block_start_idx + M_range, BLOCK_M)
|
||||||
|
E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_block < (FAN_OUT * M), other=E)
|
||||||
|
E_idx = tl.min(E_idxs)
|
||||||
|
E_mask = E_idxs == E_idx
|
||||||
|
M_idx = tl.load(grouped_idx_ptr + M_block, mask=E_mask, other=0)
|
||||||
|
if x_grouped:
|
||||||
|
M_in_idx = M_block
|
||||||
|
else:
|
||||||
|
M_in_idx = M_idx // FAN_OUT
|
||||||
|
|
||||||
|
if y_grouped:
|
||||||
|
M_out_idx = M_block
|
||||||
|
else:
|
||||||
|
M_out_idx = M_idx
|
||||||
|
|
||||||
|
K_block = tl.arange(0, BLOCK_K)
|
||||||
|
|
||||||
|
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
N_mask = N_block < N
|
||||||
|
# N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
|
||||||
|
# N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
|
||||||
|
X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
|
||||||
|
W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we
|
||||||
|
|
||||||
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||||
|
iters = tl.cdiv(K, BLOCK_K)
|
||||||
|
for K_block_id in range(0, iters):
|
||||||
|
if NO_K_MASK:
|
||||||
|
x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
|
||||||
|
if NO_N_MASK:
|
||||||
|
w = tl.load(W_blk_ptrs)
|
||||||
|
else:
|
||||||
|
w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
|
||||||
|
else:
|
||||||
|
K_mask = (K_block_id * BLOCK_K + K_block) < K
|
||||||
|
x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
|
||||||
|
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
|
||||||
|
X_blk_ptrs += BLOCK_K * stride_xk
|
||||||
|
W_blk_ptrs += BLOCK_K * stride_wk
|
||||||
|
acc += tl.dot(x, w, allow_tf32=allow_tf32, out_dtype=ACC_TYPE)
|
||||||
|
|
||||||
|
Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
|
||||||
|
tl.store(Y_blk_ptrs, acc, mask=E_mask[:, None] & N_mask[None, :])
|
||||||
|
|
||||||
|
def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k,
|
||||||
|
padded_block_idxs, x_grouped=False, y_grouped=False,
|
||||||
|
out=None):
|
||||||
|
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||||
|
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||||
|
# Pre-kernel setup
|
||||||
|
x_dim = X.size(-1)
|
||||||
|
y_dim = W.size(-1)
|
||||||
|
L_scattered = sorted_expert_idxs.size(0)
|
||||||
|
if out is None:
|
||||||
|
O = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
|
||||||
|
else:
|
||||||
|
assert out.size(0) == L_scattered and out.size(1) == y_dim
|
||||||
|
O = out
|
||||||
|
|
||||||
|
def grid(META):
|
||||||
|
grid_num = (
|
||||||
|
padded_block_idxs.size(0) *
|
||||||
|
triton.cdiv(META['N'], META['BLOCK_N']),
|
||||||
|
)
|
||||||
|
return grid_num
|
||||||
|
"""
|
||||||
|
print("X", X.size(), X.stride(),
|
||||||
|
"W", W.size(), W.stride(),
|
||||||
|
"O", O.size(), O.stride(),
|
||||||
|
"sorted_idxs", sorted_scattered_idxs.size(),
|
||||||
|
"FAN_OUT", k,
|
||||||
|
"BLOCK_M", BLOCK_M,
|
||||||
|
"grouped", (x_grouped, y_grouped))
|
||||||
|
"""
|
||||||
|
_scatter2scatter[grid](
|
||||||
|
# X_ptr, stride_xm, stride_xk,
|
||||||
|
X, X.stride(0), X.stride(1),
|
||||||
|
# W_ptr, stride_we, stride_wk, stride_wn,
|
||||||
|
W, W.stride(0), W.stride(1), W.stride(2),
|
||||||
|
# Y_ptr, stride_ym, stride_yn,
|
||||||
|
O, O.stride(0), O.stride(1),
|
||||||
|
grouped_idx_ptr=sorted_scattered_idxs,
|
||||||
|
expert_idxs_ptr=sorted_expert_idxs,
|
||||||
|
block_start_idx_ptr=padded_block_idxs,
|
||||||
|
FAN_OUT=k,
|
||||||
|
M=X.size(0),
|
||||||
|
K=X.size(1),
|
||||||
|
N=O.size(1), E=W.size(0),
|
||||||
|
BLOCK_M=BLOCK_M,
|
||||||
|
ACC_TYPE=tl.float32,
|
||||||
|
OUT_M=O.size(0),
|
||||||
|
allow_tf32=True,
|
||||||
|
x_grouped=x_grouped, y_grouped=y_grouped,
|
||||||
|
)
|
||||||
|
return O
|
||||||
|
|
||||||
|
|
||||||
|
def _config_XtY():
|
||||||
|
return [
|
||||||
|
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4),
|
||||||
|
]
|
||||||
|
|
||||||
|
def group_bwd_W(DY, X, expert_offsets, E):
|
||||||
|
DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
|
||||||
|
DW = DWt.permute(0, 2, 1)
|
||||||
|
def grid(META):
|
||||||
|
grid = (
|
||||||
|
E * triton.cdiv(META['K'], META['BLOCK_K']),
|
||||||
|
triton.cdiv(META['N'], META['BLOCK_N']),
|
||||||
|
)
|
||||||
|
return grid
|
||||||
|
_groupXtY[grid](
|
||||||
|
# DY_ptr, stride_dym, stride_dyk,
|
||||||
|
DY, DY.stride(0), DY.stride(1),
|
||||||
|
# X_ptr, stride_xm, stride_xn,
|
||||||
|
X, X.stride(0), X.stride(1),
|
||||||
|
# DW_ptr, stride_dwe, stride_dwk, stride_dwn,
|
||||||
|
DW, DW.stride(0), DW.stride(1), DW.stride(2),
|
||||||
|
# expert_offsets_ptr,
|
||||||
|
expert_offsets,
|
||||||
|
# K: tl.constexpr, N: tl.constexpr,
|
||||||
|
M=DY.size(0), N=DY.size(-1), K=X.size(-1),
|
||||||
|
# ACC_TYPE: tl.constexpr,
|
||||||
|
ACC_TYPE=tl.float32,
|
||||||
|
allow_tf32=True
|
||||||
|
)
|
||||||
|
return DW
|
||||||
|
|
||||||
|
@triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], )
|
||||||
|
@triton.heuristics({
|
||||||
|
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
|
||||||
|
"NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
|
||||||
|
})
|
||||||
|
@triton.jit
|
||||||
|
def _groupXtY(
|
||||||
|
DY_ptr, stride_dym, stride_dyk,
|
||||||
|
X_ptr, stride_xm, stride_xn,
|
||||||
|
DW_ptr, stride_dwe, stride_dwk, stride_dwn,
|
||||||
|
expert_offsets_ptr,
|
||||||
|
M: tl.constexpr, K: tl.constexpr, N: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||||
|
ACC_TYPE: tl.constexpr,
|
||||||
|
allow_tf32: tl.constexpr,
|
||||||
|
NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
|
||||||
|
):
|
||||||
|
pid0 = tl.program_id(axis=0)
|
||||||
|
pid1 = tl.program_id(axis=1)
|
||||||
|
num0 = tl.num_programs(0)
|
||||||
|
num1 = tl.num_programs(1)
|
||||||
|
pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
|
||||||
|
|
||||||
|
K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
|
||||||
|
E_idx = pid0 // K_BLOCK_COUNT
|
||||||
|
K_block_id = pid0 % K_BLOCK_COUNT
|
||||||
|
N_block_id = pid1
|
||||||
|
|
||||||
|
if E_idx == 0:
|
||||||
|
start_idx = 0
|
||||||
|
else:
|
||||||
|
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
|
||||||
|
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
|
||||||
|
|
||||||
|
if end_idx > start_idx:
|
||||||
|
M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
|
||||||
|
|
||||||
|
K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||||
|
K_mask = K_block < K
|
||||||
|
K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
|
||||||
|
|
||||||
|
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
N_mask = N_block < N
|
||||||
|
N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
|
||||||
|
|
||||||
|
M_idxs = M_block
|
||||||
|
xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
|
||||||
|
dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
|
||||||
|
|
||||||
|
acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
|
||||||
|
iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
|
||||||
|
for i in range(0, iters):
|
||||||
|
M_mask = (i * BLOCK_M + M_block) < end_idx
|
||||||
|
if NO_K_MASK:
|
||||||
|
xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
|
||||||
|
else:
|
||||||
|
xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
|
||||||
|
if NO_N_MASK:
|
||||||
|
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
|
||||||
|
else:
|
||||||
|
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
|
||||||
|
acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
|
||||||
|
xt_blk_ptrs += BLOCK_M * stride_xm
|
||||||
|
dy_blk_ptrs += BLOCK_M * stride_dym
|
||||||
|
|
||||||
|
|
||||||
|
DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn
|
||||||
|
acc = acc.to(DW_blk_ptrs.dtype.element_ty)
|
||||||
|
tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
|
||||||
|
|
||||||
|
|
||||||
|
def _config_grouping():
|
||||||
|
return [
|
||||||
|
triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
|
||||||
|
triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
|
||||||
|
]
|
||||||
|
|
||||||
|
def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
|
||||||
|
N = sorted_expert_idxs.size(0)
|
||||||
|
K = A.size(1)
|
||||||
|
assert A.size(0) * fan_out == N
|
||||||
|
if out is not None:
|
||||||
|
Y = out
|
||||||
|
else:
|
||||||
|
Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
|
||||||
|
# print("grp init:", Y.size())
|
||||||
|
def grid(META):
|
||||||
|
grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),)
|
||||||
|
return grid_num
|
||||||
|
_group[grid](
|
||||||
|
# A_ptr, stride_an, stride_ai,
|
||||||
|
A, A.stride(0), A.stride(1), coeff is not None, coeff, fan_out,
|
||||||
|
# Y_ptr, stride_yn, stride_yk,
|
||||||
|
Y, Y.stride(0), Y.stride(1),
|
||||||
|
# grouped_idx_ptr,
|
||||||
|
sorted_expert_idxs,
|
||||||
|
# N: tl.constexpr, K: tl.constexpr,
|
||||||
|
N, K
|
||||||
|
)
|
||||||
|
return Y
|
||||||
|
|
||||||
|
@triton.autotune(configs=_config_grouping(), key=['K'])
|
||||||
|
@triton.heuristics({
|
||||||
|
"NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0
|
||||||
|
})
|
||||||
|
@triton.jit
|
||||||
|
def _group(
|
||||||
|
src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr,
|
||||||
|
tgt_ptr, stride_tn, stride_ti,
|
||||||
|
grouped_idx_ptr,
|
||||||
|
N: tl.constexpr, K: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||||
|
NO_K_MASK: tl.constexpr
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
|
||||||
|
N_block_id = pid
|
||||||
|
N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
N_mask = N_blk < N
|
||||||
|
N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
|
||||||
|
N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
|
||||||
|
|
||||||
|
K_blk = tl.arange(0, BLOCK_K)
|
||||||
|
src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
|
||||||
|
tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
|
||||||
|
|
||||||
|
if has_coeff:
|
||||||
|
c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
|
||||||
|
|
||||||
|
iters = tl.cdiv(K, BLOCK_K)
|
||||||
|
for i in range(0, iters):
|
||||||
|
if NO_K_MASK:
|
||||||
|
block = tl.load(src_blk_ptrs) # , mask=N_mask[:, None])
|
||||||
|
if has_coeff:
|
||||||
|
block *= c
|
||||||
|
tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
|
||||||
|
|
||||||
|
else:
|
||||||
|
K_mask = (i * BLOCK_K + K_blk) < K
|
||||||
|
mask = N_mask[:, None] & K_mask[None, :]
|
||||||
|
block = tl.load(src_blk_ptrs, mask=mask)
|
||||||
|
if has_coeff:
|
||||||
|
block *= c
|
||||||
|
tl.store(tgt_blk_ptrs, block, mask=mask)
|
||||||
|
|
||||||
|
src_blk_ptrs += BLOCK_K * stride_sk
|
||||||
|
tgt_blk_ptrs += BLOCK_K * stride_ti
|
||||||
66
src/axolotl/monkeypatch/moe/single.py
Normal file
66
src/axolotl/monkeypatch/moe/single.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""
|
||||||
|
Adapted from:
|
||||||
|
https://github.com/shawntan/scattermoe
|
||||||
|
https://arxiv.org/abs/2403.08245
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _single2scatter(
|
||||||
|
X_ptr, stride_xm, stride_xk,
|
||||||
|
W_ptr, stride_we, stride_wk, stride_wn,
|
||||||
|
Y_ptr, stride_ym, stride_yn,
|
||||||
|
expert_idxs_ptr,
|
||||||
|
FAN_OUT: tl.constexpr,
|
||||||
|
K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||||
|
ACC_TYPE: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid0 = tl.program_id(axis=0)
|
||||||
|
pid1 = tl.program_id(axis=1)
|
||||||
|
|
||||||
|
N_block_id = pid0
|
||||||
|
if FAN_OUT == 1:
|
||||||
|
in_idx = pid1
|
||||||
|
else:
|
||||||
|
in_idx = 0
|
||||||
|
out_idx = pid1
|
||||||
|
|
||||||
|
K_block = tl.arange(0, BLOCK_K)
|
||||||
|
N_block = tl.max_contiguous(tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), BLOCK_N)
|
||||||
|
E_idx = tl.load(expert_idxs_ptr + pid1)
|
||||||
|
X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
|
||||||
|
W_blk_ptrs = W_ptr + E_idx * stride_we + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn
|
||||||
|
acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
|
||||||
|
for K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
|
||||||
|
x = tl.load(X_blk_ptrs)
|
||||||
|
w = tl.load(W_blk_ptrs)
|
||||||
|
acc += tl.sum(x * w, axis=0)[None, :]
|
||||||
|
X_blk_ptrs += BLOCK_K * stride_xk
|
||||||
|
W_blk_ptrs += BLOCK_K * stride_wk
|
||||||
|
Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
|
||||||
|
tl.store(Y_blk_ptrs, acc)
|
||||||
|
|
||||||
|
def single2scatter(X, W, expert_idxs):
|
||||||
|
E, xdim, ydim = W.size()
|
||||||
|
k = expert_idxs.size(1)
|
||||||
|
assert X.size(0) == k or X.size(0) == 1
|
||||||
|
Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
|
||||||
|
BLOCK_N = 128
|
||||||
|
BLOCK_K = 128
|
||||||
|
grid = ydim // BLOCK_N, k
|
||||||
|
_single2scatter[grid](
|
||||||
|
X, X.stride(0), X.stride(1),
|
||||||
|
W, W.stride(0), W.stride(1), W.stride(2),
|
||||||
|
Y, Y.stride(0), Y.stride(1),
|
||||||
|
expert_idxs,
|
||||||
|
FAN_OUT=Y.size(0) // X.size(0),
|
||||||
|
K=xdim, N=ydim, E=E,
|
||||||
|
BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
|
||||||
|
ACC_TYPE=tl.float32
|
||||||
|
)
|
||||||
|
return Y
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
"""
|
|
||||||
module for base dataset transform strategies
|
|
||||||
"""
|
|
||||||
|
|
||||||
import importlib
|
|
||||||
import logging
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
|
||||||
|
|
||||||
|
|
||||||
def load(strategy, cfg, module_base=None, **kwargs):
|
|
||||||
try:
|
|
||||||
load_fn = strategy.split(".")[-1]
|
|
||||||
strategy = ".".join(strategy.split(".")[:-1])
|
|
||||||
mod = importlib.import_module(f".{strategy}", module_base)
|
|
||||||
func = getattr(mod, load_fn)
|
|
||||||
return func(cfg, **kwargs)
|
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
|
||||||
LOG.warning(f"unable to load strategy {strategy}")
|
|
||||||
return None
|
|
||||||
@@ -1,8 +1,20 @@
|
|||||||
"""
|
"""
|
||||||
module for DPO style dataset transform strategies
|
module for DPO style dataset transform strategies
|
||||||
"""
|
"""
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from ..base import load as load_base
|
import importlib
|
||||||
|
import logging
|
||||||
|
|
||||||
load = partial(load_base, module="axolotl.prompt_strategies.dpo")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
def load(strategy, cfg, **kwargs):
|
||||||
|
try:
|
||||||
|
load_fn = strategy.split(".")[-1]
|
||||||
|
strategy = ".".join(strategy.split(".")[:-1])
|
||||||
|
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
|
||||||
|
func = getattr(mod, load_fn)
|
||||||
|
return func(cfg, **kwargs)
|
||||||
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
LOG.warning(f"unable to load strategy {strategy}")
|
||||||
|
return None
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
"""
|
|
||||||
module for ORPO style dataset transform strategies
|
|
||||||
"""
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from ..base import load as load_base
|
|
||||||
|
|
||||||
load = partial(load_base, module="axolotl.prompt_strategies.orpo")
|
|
||||||
@@ -1,187 +0,0 @@
|
|||||||
"""chatml prompt tokenization strategy for ORPO"""
|
|
||||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
|
|
||||||
from axolotl.prompters import Prompter
|
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
|
||||||
"""message/turn"""
|
|
||||||
|
|
||||||
role: str
|
|
||||||
content: str
|
|
||||||
label: Optional[bool] = None
|
|
||||||
|
|
||||||
|
|
||||||
class MessageList(BaseModel):
|
|
||||||
"""conversation"""
|
|
||||||
|
|
||||||
messages: List[Message]
|
|
||||||
|
|
||||||
|
|
||||||
def load(
|
|
||||||
tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, **kwargs
|
|
||||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
|
||||||
"""
|
|
||||||
chatml transforms for datasets with system, input, chosen, rejected
|
|
||||||
"""
|
|
||||||
|
|
||||||
chat_template = chat_templates("chatml")
|
|
||||||
if ds_cfg and "chat_template" in ds_cfg:
|
|
||||||
chat_template = ds_cfg["chat_template"]
|
|
||||||
try:
|
|
||||||
chat_template = chat_templates(chat_template)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return ORPOTokenizingStrategy(
|
|
||||||
ORPOPrompter(chat_template, tokenizer),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
dataset_parser=ORPODatasetParsingStrategy(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ORPODatasetParsingStrategy:
|
|
||||||
"""Strategy to parse chosen rejected dataset into messagelist"""
|
|
||||||
|
|
||||||
def get_chosen_conversation_thread(self, prompt) -> MessageList:
|
|
||||||
"""Dataset structure mappings"""
|
|
||||||
|
|
||||||
messages: List[Message] = []
|
|
||||||
if system := prompt.get("system", None):
|
|
||||||
messages.append(Message(role="system", content=system, label=False))
|
|
||||||
messages.append(Message(role="user", content=prompt["prompt"], label=False))
|
|
||||||
messages.append(
|
|
||||||
Message(
|
|
||||||
role="assistant", content=prompt["chosen"][1]["content"], label=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return MessageList(messages=messages)
|
|
||||||
|
|
||||||
def get_rejected_conversation_thread(self, prompt) -> MessageList:
|
|
||||||
"""Dataset structure mappings"""
|
|
||||||
|
|
||||||
messages: List[Message] = []
|
|
||||||
if system := prompt.get("system", None):
|
|
||||||
messages.append(Message(role="system", content=system, label=False))
|
|
||||||
messages.append(Message(role="user", content=prompt["prompt"], label=False))
|
|
||||||
messages.append(
|
|
||||||
Message(
|
|
||||||
role="assistant", content=prompt["rejected"][1]["content"], label=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return MessageList(messages=messages)
|
|
||||||
|
|
||||||
|
|
||||||
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
|
||||||
"""
|
|
||||||
rejected_input_ids
|
|
||||||
input_ids
|
|
||||||
rejected_attention_mask
|
|
||||||
attention_mask
|
|
||||||
rejected_labels
|
|
||||||
labels
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
dataset_parser=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.dataset_parser = dataset_parser
|
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
|
||||||
# pass the rejected prompt/row to the Prompter to get the formatted prompt
|
|
||||||
prompt_len = 0
|
|
||||||
rejected_message_list = self.dataset_parser.get_rejected_conversation_thread(
|
|
||||||
prompt
|
|
||||||
)
|
|
||||||
input_ids = []
|
|
||||||
labels = []
|
|
||||||
for _, (part, label) in enumerate(
|
|
||||||
self.prompter.build_prompt(rejected_message_list)
|
|
||||||
):
|
|
||||||
if not part:
|
|
||||||
continue
|
|
||||||
_input_ids = self.tokenizer.encode(part, add_special_tokens=False)
|
|
||||||
prev_idx = len(input_ids)
|
|
||||||
input_ids += _input_ids[prev_idx:]
|
|
||||||
if label:
|
|
||||||
labels += input_ids[prev_idx:]
|
|
||||||
else:
|
|
||||||
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
|
|
||||||
prompt_len = len(input_ids)
|
|
||||||
# remap the input_ids, attention_mask and labels
|
|
||||||
rejected_input_ids = input_ids
|
|
||||||
rejected_labels = labels
|
|
||||||
# pass the chosen prompt/row to the Prompter to get the formatted prompt
|
|
||||||
chosen_message_list = self.dataset_parser.get_chosen_conversation_thread(prompt)
|
|
||||||
input_ids = []
|
|
||||||
labels = []
|
|
||||||
for _, (part, label) in enumerate(
|
|
||||||
self.prompter.build_prompt(chosen_message_list)
|
|
||||||
):
|
|
||||||
if not part:
|
|
||||||
continue
|
|
||||||
_input_ids = self.tokenizer.encode(part, add_special_tokens=False)
|
|
||||||
prev_idx = len(input_ids)
|
|
||||||
input_ids += _input_ids[prev_idx:]
|
|
||||||
if label:
|
|
||||||
labels += input_ids[prev_idx:]
|
|
||||||
else:
|
|
||||||
labels += [IGNORE_INDEX] * (len(input_ids) - prev_idx)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"rejected_input_ids": rejected_input_ids,
|
|
||||||
"rejected_labels": rejected_labels,
|
|
||||||
"rejected_attention_mask": [1] * len(rejected_labels),
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"labels": labels,
|
|
||||||
"attention_mask": [1] * len(labels),
|
|
||||||
"prompt_attention_mask": [1] * prompt_len
|
|
||||||
+ [0] * (len(labels) - prompt_len),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ORPOPrompter(Prompter):
|
|
||||||
"""Single Turn prompter for ORPO"""
|
|
||||||
|
|
||||||
def __init__(self, chat_template, tokenizer):
|
|
||||||
self.chat_template = chat_template
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
|
|
||||||
def build_prompt(
|
|
||||||
self,
|
|
||||||
message_list: MessageList,
|
|
||||||
) -> Generator[Tuple[str, bool], None, None]:
|
|
||||||
conversation = []
|
|
||||||
for message in message_list.messages:
|
|
||||||
conversation.append(message.model_dump())
|
|
||||||
if message.role == "system":
|
|
||||||
yield self.tokenizer.apply_chat_template(
|
|
||||||
conversation,
|
|
||||||
add_generation_prompt=False,
|
|
||||||
chat_template=self.chat_template,
|
|
||||||
tokenize=False,
|
|
||||||
), False
|
|
||||||
if message.role == "user":
|
|
||||||
yield self.tokenizer.apply_chat_template(
|
|
||||||
conversation,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
chat_template=self.chat_template,
|
|
||||||
tokenize=False,
|
|
||||||
), False
|
|
||||||
if message.role == "assistant":
|
|
||||||
yield self.tokenizer.apply_chat_template(
|
|
||||||
conversation,
|
|
||||||
add_generation_prompt=False,
|
|
||||||
chat_template=self.chat_template,
|
|
||||||
tokenize=False,
|
|
||||||
), True
|
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
||||||
@@ -12,8 +11,6 @@ from axolotl.utils.tokenization import (
|
|||||||
merge_consecutive_messages,
|
merge_consecutive_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
|
||||||
|
|
||||||
|
|
||||||
def register_chatml_template(system_message=None):
|
def register_chatml_template(system_message=None):
|
||||||
system_message = system_message or "You are a helpful assistant."
|
system_message = system_message or "You are a helpful assistant."
|
||||||
@@ -45,13 +42,11 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
)
|
)
|
||||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||||
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
|
||||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||||
ShareGPTPrompterV2(
|
ShareGPTPrompterV2(
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
role_key_model=field_model,
|
role_key_model=field_model,
|
||||||
role_key_human=field_human,
|
role_key_human=field_human,
|
||||||
roles=roles,
|
|
||||||
),
|
),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
@@ -147,12 +142,7 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|||||||
"system": "system",
|
"system": "system",
|
||||||
}
|
}
|
||||||
turns = [
|
turns = [
|
||||||
{
|
{"from": role_map[t[role_key]], "value": t[value_key]}
|
||||||
"from": (
|
|
||||||
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
|
|
||||||
),
|
|
||||||
"value": t[value_key],
|
|
||||||
}
|
|
||||||
for t in conversations
|
for t in conversations
|
||||||
]
|
]
|
||||||
return turns
|
return turns
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from transformers import BatchEncoding, PreTrainedTokenizer
|
|||||||
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
||||||
add_get_turns_to_conversation,
|
add_get_turns_to_conversation,
|
||||||
)
|
)
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -37,7 +37,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
prompter: Prompter,
|
prompter,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
train_on_inputs: bool = False,
|
train_on_inputs: bool = False,
|
||||||
sequence_len: int = 2048,
|
sequence_len: int = 2048,
|
||||||
@@ -340,23 +340,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
self.prompter._conversation.copy() # pylint: disable=protected-access
|
self.prompter._conversation.copy() # pylint: disable=protected-access
|
||||||
)
|
)
|
||||||
|
|
||||||
input_roles = {conversation.roles[0]}
|
|
||||||
output_roles = {conversation.roles[1]}
|
|
||||||
|
|
||||||
if len(conversation.roles) == 3:
|
|
||||||
tool_role_label = conversation.roles[2]
|
|
||||||
input_roles.add(tool_role_label)
|
|
||||||
|
|
||||||
# Add roles from the config
|
|
||||||
if self.prompter.roles:
|
|
||||||
if "input" in self.prompter.roles and self.prompter.roles["input"]:
|
|
||||||
for role in self.prompter.roles["input"]:
|
|
||||||
input_roles.add(role)
|
|
||||||
|
|
||||||
if "output" in self.prompter.roles and self.prompter.roles["output"]:
|
|
||||||
for role in self.prompter.roles["output"]:
|
|
||||||
output_roles.add(role)
|
|
||||||
|
|
||||||
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
||||||
role_remap = []
|
role_remap = []
|
||||||
if (
|
if (
|
||||||
@@ -377,18 +360,19 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
LOG.warning(f"expected tuple, got {part}")
|
LOG.warning(f"expected tuple, got {part}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
tool_role_label = None
|
||||||
|
if len(conversation.roles) == 3:
|
||||||
|
(
|
||||||
|
user_role_label,
|
||||||
|
assistant_role_label,
|
||||||
|
tool_role_label,
|
||||||
|
) = conversation.roles
|
||||||
|
else:
|
||||||
|
user_role_label, assistant_role_label = conversation.roles
|
||||||
role, content = part
|
role, content = part
|
||||||
|
|
||||||
# Uses "in" because role contains extra characters
|
# Uses "in" because role contains extra characters
|
||||||
input_turn = any(r.lower() in role.lower() for r in input_roles)
|
if user_role_label in role:
|
||||||
output_turn = any(r.lower() in role.lower() for r in output_roles)
|
|
||||||
empty_role = role.strip() == ""
|
|
||||||
|
|
||||||
if not any([input_turn, output_turn, empty_role]):
|
|
||||||
LOG.warning(f"unhandled role: {role}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if input_turn:
|
|
||||||
role = (
|
role = (
|
||||||
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||||
if role_remap
|
if role_remap
|
||||||
@@ -408,7 +392,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
else:
|
else:
|
||||||
# everything from this is masked out from the labels
|
# everything from this is masked out from the labels
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
elif output_turn:
|
elif assistant_role_label in role:
|
||||||
role = (
|
role = (
|
||||||
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||||
if role_remap
|
if role_remap
|
||||||
@@ -439,7 +423,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
||||||
len_role, len(labels)
|
len_role, len(labels)
|
||||||
)
|
)
|
||||||
elif empty_role:
|
elif role == "":
|
||||||
turn = content
|
turn = content
|
||||||
# this is only ever the first part, should include the bos token and the user query
|
# this is only ever the first part, should include the bos token and the user query
|
||||||
res = self._tokenize(
|
res = self._tokenize(
|
||||||
@@ -450,6 +434,11 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
else:
|
else:
|
||||||
# everything from this is masked out from the labels
|
# everything from this is masked out from the labels
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
|
elif tool_role_label and tool_role_label in role:
|
||||||
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
|
else:
|
||||||
|
LOG.warning(f"unhandled role: {role}")
|
||||||
|
continue
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
result, current_len = parse_tokenized_to_result(
|
result, current_len = parse_tokenized_to_result(
|
||||||
|
|||||||
@@ -259,12 +259,6 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
|
|||||||
"Role did not alternate between turns (gpt and human). Please check your data."
|
"Role did not alternate between turns (gpt and human). Please check your data."
|
||||||
)
|
)
|
||||||
|
|
||||||
CONVERSATION_ROLE_FORMAT = {
|
|
||||||
"chatml": "<|im_start|>{ROLE}",
|
|
||||||
"zephyr": "<|{ROLE}|>",
|
|
||||||
"vicuna_v1.1": "{ROLE}",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||||
"""
|
"""
|
||||||
@@ -274,9 +268,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|||||||
role_key_human = "human"
|
role_key_human = "human"
|
||||||
role_key_model = "gpt"
|
role_key_model = "gpt"
|
||||||
# Optional, only used for tool usage datasets.
|
# Optional, only used for tool usage datasets.
|
||||||
role_key_tool: Optional[str] = None
|
role_key_tool = None
|
||||||
# Optional, role input/output mapping
|
|
||||||
roles: Optional[dict] = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -285,7 +277,6 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|||||||
role_key_human: Optional[str] = None,
|
role_key_human: Optional[str] = None,
|
||||||
role_key_model: Optional[str] = None,
|
role_key_model: Optional[str] = None,
|
||||||
role_key_tool: Optional[str] = None,
|
role_key_tool: Optional[str] = None,
|
||||||
roles: Optional[dict] = None,
|
|
||||||
):
|
):
|
||||||
if conversation:
|
if conversation:
|
||||||
if isinstance(conversation, Conversation):
|
if isinstance(conversation, Conversation):
|
||||||
@@ -300,8 +291,6 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|||||||
self.role_key_model = role_key_model
|
self.role_key_model = role_key_model
|
||||||
if role_key_tool:
|
if role_key_tool:
|
||||||
self.role_key_tool = role_key_tool
|
self.role_key_tool = role_key_tool
|
||||||
if roles:
|
|
||||||
self.roles = roles
|
|
||||||
|
|
||||||
def _build_result(self, source):
|
def _build_result(self, source):
|
||||||
if len(source) < 2:
|
if len(source) < 2:
|
||||||
@@ -333,23 +322,11 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|||||||
|
|
||||||
conv.messages = []
|
conv.messages = []
|
||||||
for _, sentence in enumerate(source):
|
for _, sentence in enumerate(source):
|
||||||
from_role = sentence["from"]
|
role = roles[sentence["from"]]
|
||||||
if from_role in roles:
|
if len(conv.messages) > 0 and (
|
||||||
role = roles[from_role]
|
(role == conv.messages[-1][0]) or (role not in conv.roles)
|
||||||
else:
|
):
|
||||||
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
|
|
||||||
"Please help us by creating an Issue to add support for this conversation type."
|
|
||||||
)
|
|
||||||
|
|
||||||
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
|
|
||||||
ROLE=from_role
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
|
|
||||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||||
|
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
|
|
||||||
return conv.get_turns()
|
return conv.get_turns()
|
||||||
@@ -377,13 +354,11 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
|||||||
conversation: Optional[Union[str, Conversation]] = None,
|
conversation: Optional[Union[str, Conversation]] = None,
|
||||||
role_key_human: Optional[str] = None,
|
role_key_human: Optional[str] = None,
|
||||||
role_key_model: Optional[str] = None,
|
role_key_model: Optional[str] = None,
|
||||||
roles: Optional[dict] = None,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
role_key_human=role_key_human,
|
role_key_human=role_key_human,
|
||||||
role_key_model=role_key_model,
|
role_key_model=role_key_model,
|
||||||
roles=roles,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ def train(
|
|||||||
model.generation_config.do_sample = True
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
model_ref = None
|
model_ref = None
|
||||||
if cfg.rl and cfg.rl != "orpo":
|
if cfg.rl:
|
||||||
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
if cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||||
# use built-in trl autounwrap
|
# use built-in trl autounwrap
|
||||||
LOG.debug("Passing model_ref: None to RL trainer")
|
LOG.debug("Passing model_ref: None to RL trainer")
|
||||||
@@ -110,6 +110,9 @@ def train(
|
|||||||
total_num_steps,
|
total_num_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if hasattr(model, "config"):
|
||||||
|
model.config.use_cache = False
|
||||||
|
|
||||||
# go ahead and presave, so we have the adapter config available to inspect
|
# go ahead and presave, so we have the adapter config available to inspect
|
||||||
if peft_config:
|
if peft_config:
|
||||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ def chat_templates(user_choice: str):
|
|||||||
templates = {
|
templates = {
|
||||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -191,11 +191,6 @@ def normalize_cfg_datasets(cfg):
|
|||||||
f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template"
|
f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template"
|
||||||
)
|
)
|
||||||
cfg.datasets[idx].conversation = "chatml"
|
cfg.datasets[idx].conversation = "chatml"
|
||||||
if ds_cfg.type == "orpo.chat_template" and not ds_cfg.chat_template:
|
|
||||||
LOG.info(
|
|
||||||
f"updating dataset {ds_cfg.path} with `chat_template: chatml` to match your chat_template"
|
|
||||||
)
|
|
||||||
cfg.datasets[idx].chat_template = "chatml"
|
|
||||||
|
|
||||||
|
|
||||||
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||||
|
|||||||
@@ -96,8 +96,6 @@ class SFTDataset(BaseModel):
|
|||||||
field_human: Optional[str] = None
|
field_human: Optional[str] = None
|
||||||
field_model: Optional[str] = None
|
field_model: Optional[str] = None
|
||||||
|
|
||||||
roles: Optional[Dict[str, List[str]]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedDPOType(BaseModel):
|
class UserDefinedDPOType(BaseModel):
|
||||||
"""User defined typing for DPO"""
|
"""User defined typing for DPO"""
|
||||||
@@ -126,7 +124,6 @@ class RLType(str, Enum):
|
|||||||
dpo = "dpo" # pylint: disable=invalid-name
|
dpo = "dpo" # pylint: disable=invalid-name
|
||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplate(str, Enum):
|
class ChatTemplate(str, Enum):
|
||||||
@@ -313,15 +310,6 @@ class HyperparametersConfig(BaseModel):
|
|||||||
learning_rate: Union[str, float]
|
learning_rate: Union[str, float]
|
||||||
weight_decay: Optional[float] = None
|
weight_decay: Optional[float] = None
|
||||||
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
|
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
|
||||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
|
||||||
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
|
||||||
)
|
|
||||||
optim_target_modules: Optional[Union[List[str], Literal["all_linear"]]] = Field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "The target modules to optimize, i.e. the module names that you would like to train."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
torchdistx_path: Optional[str] = None
|
torchdistx_path: Optional[str] = None
|
||||||
lr_scheduler: Optional[SchedulerType] = None
|
lr_scheduler: Optional[SchedulerType] = None
|
||||||
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
||||||
@@ -427,7 +415,6 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
||||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
||||||
shuffle_merged_datasets: Optional[bool] = True
|
|
||||||
dataset_prepared_path: Optional[str] = None
|
dataset_prepared_path: Optional[str] = None
|
||||||
dataset_shard_num: Optional[int] = None
|
dataset_shard_num: Optional[int] = None
|
||||||
dataset_shard_idx: Optional[int] = None
|
dataset_shard_idx: Optional[int] = None
|
||||||
@@ -444,8 +431,6 @@ class AxolotlInputConfig(
|
|||||||
dataloader_prefetch_factor: Optional[int] = None
|
dataloader_prefetch_factor: Optional[int] = None
|
||||||
dataloader_drop_last: Optional[bool] = None
|
dataloader_drop_last: Optional[bool] = None
|
||||||
|
|
||||||
remove_unused_columns: Optional[bool] = None
|
|
||||||
|
|
||||||
push_dataset_to_hub: Optional[str] = None
|
push_dataset_to_hub: Optional[str] = None
|
||||||
hf_use_auth_token: Optional[bool] = None
|
hf_use_auth_token: Optional[bool] = None
|
||||||
|
|
||||||
@@ -530,8 +515,6 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
neftune_noise_alpha: Optional[float] = None
|
neftune_noise_alpha: Optional[float] = None
|
||||||
|
|
||||||
orpo_alpha: Optional[float] = None
|
|
||||||
|
|
||||||
max_memory: Optional[
|
max_memory: Optional[
|
||||||
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
||||||
] = None
|
] = None
|
||||||
|
|||||||
@@ -415,11 +415,8 @@ def load_tokenized_prepared_datasets(
|
|||||||
dataset = concatenate_datasets(datasets)
|
dataset = concatenate_datasets(datasets)
|
||||||
|
|
||||||
if len(datasets) > 1:
|
if len(datasets) > 1:
|
||||||
if cfg.shuffle_merged_datasets:
|
LOG.info("shuffle merged datasets")
|
||||||
LOG.debug("shuffle merged datasets")
|
dataset = dataset.shuffle(seed=seed)
|
||||||
dataset = dataset.shuffle(seed=seed)
|
|
||||||
else:
|
|
||||||
LOG.debug("NOT shuffling merged datasets")
|
|
||||||
|
|
||||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||||
|
|
||||||
@@ -822,11 +819,7 @@ def wrap_pretraining_dataset(
|
|||||||
else:
|
else:
|
||||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||||
|
|
||||||
if cfg.shuffle_merged_datasets:
|
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
||||||
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
|
||||||
else:
|
|
||||||
LOG.debug("NOT shuffling merged pretraining datasets")
|
|
||||||
|
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
encode,
|
encode,
|
||||||
batched=True,
|
batched=True,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ module to freeze/unfreeze parameters by name
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Callable, List, Tuple, Union
|
from typing import Callable, List, Tuple
|
||||||
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
@@ -99,7 +99,7 @@ def _invert_ranges(
|
|||||||
|
|
||||||
|
|
||||||
def _merge_ranges(
|
def _merge_ranges(
|
||||||
given_ranges: List[Tuple[int, Union[int, None]]], layer_size: int
|
given_ranges: List[Tuple[int, int | None]], layer_size: int
|
||||||
) -> List[Tuple[int, int]]:
|
) -> List[Tuple[int, int]]:
|
||||||
"""
|
"""
|
||||||
Merges overlapping ranges and sorts the given ranges.
|
Merges overlapping ranges and sorts the given ranges.
|
||||||
@@ -194,9 +194,7 @@ class LayerNamePattern:
|
|||||||
"""
|
"""
|
||||||
return self.name_regex.match(name) is not None
|
return self.name_regex.match(name) is not None
|
||||||
|
|
||||||
def _parse_pattern(
|
def _parse_pattern(self, pattern: str) -> Tuple[str, Tuple[int, int | None] | None]:
|
||||||
self, pattern: str
|
|
||||||
) -> Tuple[str, Union[Tuple[int, Union[int, None]], None]]:
|
|
||||||
"""
|
"""
|
||||||
Extracts the range pattern from the given pattern.
|
Extracts the range pattern from the given pattern.
|
||||||
|
|
||||||
|
|||||||
@@ -715,32 +715,27 @@ def load_model(
|
|||||||
if cfg.flash_attn_fuse_qkv:
|
if cfg.flash_attn_fuse_qkv:
|
||||||
LOG.info("patching with fused QKV")
|
LOG.info("patching with fused QKV")
|
||||||
replace_llama_qkv_with_fused(model)
|
replace_llama_qkv_with_fused(model)
|
||||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
elif (
|
||||||
# This is a WIP, still an issue with the backward pass
|
model_config.model_type == "mixtral"
|
||||||
# RuntimeError: grad can be implicitly created only for scalar outputs
|
and not cfg.adapter
|
||||||
# TODO: try config.sequence_parallel = False
|
and cfg.fuse_moe
|
||||||
# # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12
|
):
|
||||||
# # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components
|
from axolotl.monkeypatch.utils import set_module_name
|
||||||
# # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442
|
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||||
# from flash_attn.utils.pretrained import state_dict_from_pretrained
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||||
# from flash_attn.models.gpt import GPTLMHeadModel
|
|
||||||
# from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
|
for name, module in model.named_modules():
|
||||||
# from transformers import GPTNeoXConfig
|
if isinstance(module, MixtralSparseMoeBlock):
|
||||||
# config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model))
|
smoe = SparseMoeBlock(
|
||||||
# config.use_flash_attn = True
|
experts=module.experts,
|
||||||
# config.fused_bias_fc = True
|
gate=module.gate,
|
||||||
# config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
|
hidden_dim=module.hidden_dim,
|
||||||
# config.activation_function = "gelu_fast"
|
ffn_dim=module.ffn_dim,
|
||||||
# config.fused_dropout_add_ln = True
|
num_experts=module.num_experts,
|
||||||
# # config.residual_in_fp32 = True
|
top_k=module.top_k,
|
||||||
#
|
)
|
||||||
# model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained(
|
set_module_name(model, name, smoe)
|
||||||
# base_model,
|
|
||||||
# config,
|
|
||||||
# dtype=torch_dtype,
|
|
||||||
# device=cfg.device,
|
|
||||||
# )
|
|
||||||
# model.train() # sets to train instead of eval mode
|
|
||||||
elif model_type == "MambaLMHeadModel":
|
elif model_type == "MambaLMHeadModel":
|
||||||
# FIXME this is janky at best and hacked together to make it work
|
# FIXME this is janky at best and hacked together to make it work
|
||||||
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
||||||
@@ -888,9 +883,7 @@ def load_model(
|
|||||||
|
|
||||||
if cfg.adapter in ["lora", "qlora"]:
|
if cfg.adapter in ["lora", "qlora"]:
|
||||||
if cfg.gradient_checkpointing:
|
if cfg.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable(
|
model.gradient_checkpointing_enable()
|
||||||
gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
cfg.load_in_8bit or cfg.load_in_4bit
|
cfg.load_in_8bit or cfg.load_in_4bit
|
||||||
) and not skip_prepare_model_for_kbit_training:
|
) and not skip_prepare_model_for_kbit_training:
|
||||||
|
|||||||
60
tests/monkeypatch/test_moe.py
Normal file
60
tests/monkeypatch/test_moe.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from axolotl.monkeypatch.moe.mlp import FusedExperts
|
||||||
|
from axolotl.monkeypatch.moe.moe import SparseMoeBlock
|
||||||
|
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock, MixtralConfig
|
||||||
|
|
||||||
|
def test_fused_mixtral_moe():
|
||||||
|
# NOTE: Requires torch 2.2.0
|
||||||
|
# Set random seeds for reproducibility
|
||||||
|
torch.set_default_dtype(torch.float16)
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
# Define the configuration for the MixtralSparseMoeBlock
|
||||||
|
config = MixtralConfig(
|
||||||
|
hidden_size=128,
|
||||||
|
intermediate_size=512,
|
||||||
|
num_local_experts=8,
|
||||||
|
num_experts_per_tok=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the MixtralSparseMoeBlock and SparseMoeBlock with the same configuration
|
||||||
|
mixtral_moe = MixtralSparseMoeBlock(config)
|
||||||
|
sparse_moe = SparseMoeBlock(
|
||||||
|
experts=mixtral_moe.experts,
|
||||||
|
gate=mixtral_moe.gate,
|
||||||
|
hidden_dim=config.hidden_size,
|
||||||
|
ffn_dim=config.intermediate_size,
|
||||||
|
num_experts=config.num_local_experts,
|
||||||
|
top_k=config.num_experts_per_tok
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.cat([
|
||||||
|
mixtral_moe.experts[0].w1.weight.data,
|
||||||
|
mixtral_moe.experts[0].w3.weight.data], dim=0
|
||||||
|
).equal(sparse_moe.experts.experts.weight[0])
|
||||||
|
|
||||||
|
# Generate random input data
|
||||||
|
batch_size = 16
|
||||||
|
sequence_length = 32
|
||||||
|
input_data = torch.randn(batch_size, sequence_length, config.hidden_size)
|
||||||
|
|
||||||
|
# Run the forward pass with gradients for both models
|
||||||
|
with torch.no_grad():
|
||||||
|
mixtral_output, mixtral_router_logits = mixtral_moe(input_data)
|
||||||
|
sparse_output, sparse_router_logits = sparse_moe(input_data)
|
||||||
|
|
||||||
|
# Compute the difference between the outputs
|
||||||
|
output_diff = torch.abs(mixtral_output - sparse_output).mean().item()
|
||||||
|
router_diff = torch.abs(mixtral_router_logits - sparse_router_logits).mean().item()
|
||||||
|
|
||||||
|
# Define the tolerance for the difference
|
||||||
|
tolerance = 0.05
|
||||||
|
|
||||||
|
# # Check if the difference is within the tolerance
|
||||||
|
assert output_diff < 0.05, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||||
|
assert router_diff == 0, f"Output difference is {output_diff}, which is greater than the tolerance of {tolerance}"
|
||||||
@@ -62,38 +62,6 @@ def fixture_sharegpt_glaive_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="multi_role_dataset")
|
|
||||||
def fixture_multi_role_dataset():
|
|
||||||
return Dataset.from_list(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{
|
|
||||||
"from": "system",
|
|
||||||
"value": "use get_weather(city) to get the weather for a city",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "hello, what's the weather in New York?",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "let me get that for you",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "tool",
|
|
||||||
"value": "get_weather(New York)",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "the weather in New York is 70 degrees and sunny",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="tokenizer")
|
@pytest.fixture(name="tokenizer")
|
||||||
def fixture_tokenizer():
|
def fixture_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
||||||
@@ -228,39 +196,3 @@ class TestSharegpt:
|
|||||||
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
|
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def test_multi_role_dataset(self, multi_role_dataset, tokenizer):
|
|
||||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}),
|
|
||||||
tokenizer,
|
|
||||||
False, # train_on_inputs
|
|
||||||
2048, # sequence_len
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
|
||||||
strategy, multi_role_dataset, process_count=1
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = dataset_wrapper[0]["input_ids"]
|
|
||||||
# fmt: off
|
|
||||||
assert input_ids == [
|
|
||||||
1, # bos
|
|
||||||
32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system
|
|
||||||
32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human
|
|
||||||
32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
|
||||||
32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool
|
|
||||||
32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
labels = dataset_wrapper[0]["labels"]
|
|
||||||
# fmt: off
|
|
||||||
assert labels == [
|
|
||||||
-100, # bos
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human
|
|
||||||
-100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool
|
|
||||||
-100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
|
|||||||
@@ -8,8 +8,7 @@ from pathlib import Path
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from datasets import load_dataset
|
from transformers import AutoTokenizer, LlamaTokenizer
|
||||||
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
|
|
||||||
|
|
||||||
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
||||||
from axolotl.prompt_strategies.alpaca_w_system import (
|
from axolotl.prompt_strategies.alpaca_w_system import (
|
||||||
@@ -20,14 +19,12 @@ from axolotl.prompt_strategies.llama2_chat import (
|
|||||||
Llama2ChatPrompter,
|
Llama2ChatPrompter,
|
||||||
LLama2ChatTokenizingStrategy,
|
LLama2ChatTokenizingStrategy,
|
||||||
)
|
)
|
||||||
from axolotl.prompt_strategies.orpo.chat_template import load
|
|
||||||
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
|
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
|
||||||
from axolotl.prompt_tokenizers import (
|
from axolotl.prompt_tokenizers import (
|
||||||
AlpacaPromptTokenizingStrategy,
|
AlpacaPromptTokenizingStrategy,
|
||||||
ShareGPTPromptTokenizingStrategy,
|
ShareGPTPromptTokenizingStrategy,
|
||||||
)
|
)
|
||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -449,57 +446,5 @@ If a question does not make any sense, or is not factually coherent, explain why
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OrpoTokenizationTest(unittest.TestCase):
|
|
||||||
"""test case for the ORPO tokenization"""
|
|
||||||
|
|
||||||
def setUp(self) -> None:
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
|
|
||||||
tokenizer.add_special_tokens(
|
|
||||||
{
|
|
||||||
"eos_token": AddedToken(
|
|
||||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tokenizer.add_tokens(
|
|
||||||
[
|
|
||||||
AddedToken(
|
|
||||||
"<|im_start|>", rstrip=False, lstrip=False, normalized=False
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.dataset = load_dataset(
|
|
||||||
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
|
|
||||||
).select([0])
|
|
||||||
|
|
||||||
def test_orpo_integration(self):
|
|
||||||
strat = load(
|
|
||||||
self.tokenizer,
|
|
||||||
DictDefault({"train_on_inputs": False}),
|
|
||||||
DictDefault({"chat_template": "chatml"}),
|
|
||||||
)
|
|
||||||
res = strat.tokenize_prompt(self.dataset[0])
|
|
||||||
assert "rejected_input_ids" in res
|
|
||||||
assert "rejected_labels" in res
|
|
||||||
assert "input_ids" in res
|
|
||||||
assert "labels" in res
|
|
||||||
assert "prompt_attention_mask" in res
|
|
||||||
|
|
||||||
assert len(res["rejected_input_ids"]) == len(res["rejected_labels"])
|
|
||||||
assert len(res["input_ids"]) == len(res["labels"])
|
|
||||||
assert len(res["input_ids"]) == len(res["prompt_attention_mask"])
|
|
||||||
|
|
||||||
assert res["rejected_labels"][0] == -100
|
|
||||||
assert res["rejected_input_ids"][-1] == res["rejected_labels"][-1]
|
|
||||||
|
|
||||||
assert res["labels"][0] == -100
|
|
||||||
assert res["input_ids"][-1] == res["labels"][-1]
|
|
||||||
|
|
||||||
assert res["prompt_attention_mask"][0] == 1
|
|
||||||
assert res["prompt_attention_mask"][-1] == 0
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user