Compare commits
46 Commits
grpo-path-
...
topk-logpr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68e97d032a | ||
|
|
23f029a89c | ||
|
|
afbb44f08b | ||
|
|
d753ead033 | ||
|
|
c011405117 | ||
|
|
a2e52a29e9 | ||
|
|
e82268e580 | ||
|
|
75e1480c10 | ||
|
|
45e1548d59 | ||
|
|
165088e7c1 | ||
|
|
2d5826f544 | ||
|
|
a4170030ab | ||
|
|
bf842730a5 | ||
|
|
1db6ad60a7 | ||
|
|
29b366b2e1 | ||
|
|
b53a41372f | ||
|
|
02f45e94be | ||
|
|
954e192f38 | ||
|
|
8dfadc2b3c | ||
|
|
23a9fcb0a7 | ||
|
|
c3d4f6e295 | ||
|
|
7fa690fac8 | ||
|
|
3c743c4bfb | ||
|
|
91bb95685a | ||
|
|
b194e17c28 | ||
|
|
3aac3b1da9 | ||
|
|
3d8425fa91 | ||
|
|
97a2fa2781 | ||
|
|
a98526ef78 | ||
|
|
2e57391bf8 | ||
|
|
aa45fed451 | ||
|
|
a09a5cfd1c | ||
|
|
40362d60e0 | ||
|
|
ffae8d6a95 | ||
|
|
fdbb1a207c | ||
|
|
30046315d9 | ||
|
|
e37a4a536a | ||
|
|
44f64ab627 | ||
|
|
826f1b1494 | ||
|
|
526e5ee8b8 | ||
|
|
fd8cb32547 | ||
|
|
e48e2df4dd | ||
|
|
b7616022ab | ||
|
|
1faf1a5c5a | ||
|
|
5bbad5ef93 | ||
|
|
a971eb4ce6 |
12
.github/workflows/base.yml
vendored
12
.github/workflows/base.yml
vendored
@@ -22,12 +22,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.10"
|
||||
pytorch: 2.4.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
@@ -40,6 +34,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
2
.github/workflows/docs.yml
vendored
2
.github/workflows/docs.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.11'
|
||||
- name: install dependencies
|
||||
run: |
|
||||
python3 -m pip install jupyter
|
||||
|
||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -19,6 +19,6 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
|
||||
7
.github/workflows/main.yml
vendored
7
.github/workflows/main.yml
vendored
@@ -24,8 +24,13 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
16
.github/workflows/multi-gpu-e2e.yml
vendored
16
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -4,6 +4,10 @@ on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'tests/e2e/multigpu/*.py'
|
||||
- 'requirements.txt'
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/multi-gpu-e2e.yml'
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||
@@ -24,13 +28,21 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
axolotl_extras: # no vllm support for 2.4.1
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras: vllm
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
# awaiting vllm#12721
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
@@ -42,7 +54,7 @@ jobs:
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
5
.github/workflows/nightlies.yml
vendored
5
.github/workflows/nightlies.yml
vendored
@@ -22,6 +22,11 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
2
.github/workflows/pypi.yml
vendored
2
.github/workflows/pypi.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
|
||||
20
.github/workflows/tests-nightly.yml
vendored
20
.github/workflows/tests-nightly.yml
vendored
@@ -12,7 +12,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
env:
|
||||
@@ -25,13 +25,8 @@ jobs:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11"]
|
||||
pytorch_version: ["2.4.1", "2.5.1"]
|
||||
exclude:
|
||||
- python_version: "3.10"
|
||||
pytorch_version: "2.4.1"
|
||||
- python_version: "3.10"
|
||||
pytorch_version: "2.5.1"
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -112,13 +107,20 @@ jobs:
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
25
.github/workflows/tests.yml
vendored
25
.github/workflows/tests.yml
vendored
@@ -35,7 +35,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
env:
|
||||
@@ -48,13 +48,8 @@ jobs:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11"]
|
||||
pytorch_version: ["2.4.1", "2.5.1"]
|
||||
exclude:
|
||||
- python_version: "3.10"
|
||||
pytorch_version: "2.4.1"
|
||||
- python_version: "3.10"
|
||||
pytorch_version: "2.5.1"
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -127,7 +122,7 @@ jobs:
|
||||
max-parallel: 1
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.4.1", "2.5.1"]
|
||||
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -209,14 +204,14 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
@@ -251,13 +246,19 @@ jobs:
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
@@ -51,7 +51,7 @@ Features:
|
||||
|
||||
**Requirements**:
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python ≥3.10
|
||||
- Python 3.11
|
||||
- PyTorch ≥2.4.1
|
||||
|
||||
### Installation
|
||||
|
||||
@@ -4,8 +4,8 @@ set -e
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -37,15 +37,11 @@ temp_dir = tempfile.mkdtemp()
|
||||
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||
f.write(dockerfile_contents)
|
||||
|
||||
cicd_image = (
|
||||
Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
)
|
||||
.env(df_args)
|
||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
)
|
||||
cicd_image = Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
).env(df_args)
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
modal application to run axolotl gpu tests in Modal
|
||||
"""
|
||||
"""Modal app to run axolotl GPU tests"""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
|
||||
@@ -46,6 +46,10 @@ overrides_of_model_config:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
# optional overrides the base model loading from_pretrained
|
||||
overrides_of_model_kwargs:
|
||||
# use_cache: False
|
||||
|
||||
# optional overrides to the bnb 4bit quantization configuration
|
||||
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
||||
bnb_config_kwargs:
|
||||
@@ -87,7 +91,12 @@ datasets:
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
|
||||
data_files: # Optional[str] path to source data files
|
||||
shards: # Optional[int] number of shards to split data into
|
||||
|
||||
shards: # Optional[int] split dataset into N pieces (use with shards_idx)
|
||||
shards_idx: # Optional[int] = 0 the index of sharded dataset to use
|
||||
|
||||
preprocess_shards: # Optional[int] process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)
|
||||
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
train_on_split: train # Optional[str] name of dataset split to load from
|
||||
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||
@@ -133,10 +142,19 @@ datasets:
|
||||
|
||||
# Key containing the messages (default: "messages")
|
||||
field_messages: messages
|
||||
# Key for role in each message (default: "role")
|
||||
message_field_role: role
|
||||
# Key for content in each message (default: "content")
|
||||
message_field_content: content
|
||||
|
||||
# Mapping of properties from the input dataset to the chat template.
|
||||
# (default: message_property_mappings={'role':'role', 'content':'content'})
|
||||
# If a property exists in the template but not in this mapping, the system will attempt
|
||||
# to load it directly from the message using the property name as the key.
|
||||
# Example: In the mapping below, 'from' is loaded from input dataset and used as 'role',
|
||||
# while 'value' is loaded and used as 'content' in the chat template.
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
# ...
|
||||
|
||||
message_property_mappings:
|
||||
|
||||
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
||||
roles:
|
||||
@@ -296,6 +314,13 @@ lora_modules_to_save:
|
||||
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# Apply custom LoRA autograd functions and activation function Triton kernels for
|
||||
# speed and memory savings
|
||||
# See: https://axolotl-ai-cloud.github.io/axolotl/docs/lora_optims.html
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
|
||||
# LoRA+ hyperparameters
|
||||
# For more details about the following options, see:
|
||||
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
|
||||
@@ -344,6 +369,9 @@ comet_mode: # Create a new experiment ("create") or log to an existing one ("get
|
||||
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
|
||||
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
|
||||
|
||||
# Tensorboard
|
||||
use_tensorboard: # Optional[bool]
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
output_dir: ./completed-model
|
||||
|
||||
@@ -378,6 +406,12 @@ save_total_limit: # Checkpoints saved at a time
|
||||
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
||||
max_steps:
|
||||
|
||||
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
|
||||
include_tokens_per_second: # Optional[bool]
|
||||
|
||||
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
|
||||
auto_find_batch_size: # Optional[bool]
|
||||
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
||||
|
||||
@@ -6,7 +6,7 @@ order: 3
|
||||
|
||||
## sharegpt
|
||||
|
||||
IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
|
||||
IMPORTANT: ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
|
||||
|
||||
## pygmalion
|
||||
|
||||
@@ -22,7 +22,7 @@ Chat Template strategy uses a jinja2 template that converts a list of messages i
|
||||
{"conversations": [{"role": "...", "content": "..."}]}
|
||||
```
|
||||
|
||||
See `config.qmd` for full configs and supported templates.
|
||||
See [configs](../config.qmd) for full configs and supported templates.
|
||||
|
||||
### Migrating from sharegpt
|
||||
|
||||
@@ -42,8 +42,9 @@ datasets:
|
||||
type: chat_template
|
||||
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
# new (if setting a new chat_template like chatml, gemma, etc)
|
||||
chat_template: chatml
|
||||
@@ -52,8 +53,9 @@ datasets:
|
||||
type: chat_template
|
||||
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
```
|
||||
|
||||
We recommend checking the below examples for other usecases.
|
||||
@@ -138,8 +140,9 @@ datasets:
|
||||
type: chat_template
|
||||
chat_template: tokenizer_default
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
roles_to_train: []
|
||||
train_on_eos: turn
|
||||
message_field_training: train
|
||||
|
||||
@@ -1,14 +1,458 @@
|
||||
---
|
||||
title: Dataset Formats
|
||||
description: Supported dataset formats.
|
||||
listing:
|
||||
fields: [title, description]
|
||||
type: table
|
||||
sort-ui: false
|
||||
filter-ui: false
|
||||
max-description-length: 250
|
||||
description: Guide to Dataset Formats in Axolotl
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-depth: 5
|
||||
---
|
||||
|
||||
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL format. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
||||
|
||||
Below are these various formats organized by task:
|
||||
Axolotl is a training framework that aims to make the process convenient yet flexible to users by simply passing a config yaml file.
|
||||
|
||||
As there are a lot of available options in Axolotl, this guide aims to provide an simplify the user experience to choosing the proper choice.
|
||||
|
||||
Axolotl supports 3 kinds of training methods: pre-training, supervised fine-tuning, and preference-based post-training (e.g. DPO, ORPO, PRMs). Each method has their own dataset format which are described below.
|
||||
|
||||
## [Pre-training](pretraining.qmd)
|
||||
|
||||
When aiming to train on large corpora of text datasets, pre-training is your go-to choice. Due to the size of these datasets, downloading the entire-datasets before beginning training would be prohibitively time-consuming. Axolotl supports [streaming](https://huggingface.co/docs/datasets/en/stream) to only load batches into memory at a time.
|
||||
|
||||
A sample format for a pre-training dataset is as follows:
|
||||
|
||||
```json
|
||||
{"text": "first row"}
|
||||
{"text": "second row"}
|
||||
...
|
||||
```
|
||||
|
||||
It is typically recommended to save your dataset as `.jsonl` due to its flexibility and simplicity.
|
||||
|
||||
Axolotl supports loading from a Hugging Face hub repo or from local files.
|
||||
|
||||
::: {.callout-important}
|
||||
For pre-training only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts.
|
||||
:::
|
||||
|
||||
### Pre-training from Hugging Face hub datasets
|
||||
|
||||
As an example, to train using a Hugging Face dataset `hf_org/name`, you can pass the following config:
|
||||
|
||||
```yaml
|
||||
pretraining_dataset: hf_org/name
|
||||
```
|
||||
|
||||
### Pre-training from local dataset files
|
||||
|
||||
Given a few corpus files: `A.jsonl`, `B.jsonl`, and `C.jsonl`, your config will look like the below:
|
||||
|
||||
```yaml
|
||||
pretraining_dataset:
|
||||
- path: json
|
||||
data_files:
|
||||
- A.jsonl
|
||||
- B.jsonl
|
||||
- C.jsonl
|
||||
```
|
||||
|
||||
While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet`, `arrow`, `SQL`, `Webdataset`) that are supported by [`Dataset.load_dataset`](https://huggingface.co/docs/datasets/loading#local-and-remote-files)
|
||||
|
||||
### Pre-training without streaming
|
||||
|
||||
On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
|
||||
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.
|
||||
|
||||
From Hugging Face:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: hf_org/name
|
||||
type: completion
|
||||
```
|
||||
|
||||
From local files (either example works):
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type: completion
|
||||
|
||||
- path: json
|
||||
data_files: ["A.jsonl", "B.jsonl", "C.jsonl"]
|
||||
type: completion
|
||||
```
|
||||
|
||||
### Pre-training dataset configuration tips
|
||||
|
||||
#### Setting max_steps
|
||||
|
||||
When using streaming for large datasets, Axolotl does not know in advance how large the dataset is and does not know when to stop.
|
||||
|
||||
Therefore, it is necessary to set `max_steps: int` in your config for pre-training to run, so that Axolotl knows when to stop training.
|
||||
|
||||
One step is equal to `sequence_len * micro_batch_size * gradient_accumulation_steps * total_num_gpus` tokens.
|
||||
|
||||
#### Group_by_length
|
||||
|
||||
It is recommended to leave this off if downloading from Hugging Face hub as it would download the entire dataset which can be very large.
|
||||
|
||||
## Supervised fine-tuning (SFT)
|
||||
|
||||
Supervised fine-tuning is the process of training models to respond to an instruction or chat input.
|
||||
|
||||
As there are a wide variety of dataset formats, Axolotl tries to support a majority of the formats available in public datasets.
|
||||
|
||||
Axolotl provides four approaches for loading datasets, however, it's easier to work backwards from the dataset you have available to figure out which approach to use.
|
||||
|
||||
A flow chart is as follows:
|
||||
|
||||
1. Do you already have the dataset tokenized? If yes, check [Pre-Tokenized Dataset](#pre-tokenized-dataset).
|
||||
|
||||
2. Do you want to format the dataset yourself and manually choose each section to mask? If yes, check [Template Free Dataset](#template-free-dataset)
|
||||
|
||||
3. Is your dataset in a "conversation" format, containing a `list[messages]`? If yes, check [Conversation Dataset](#conversation-dataset)
|
||||
|
||||
4. Is your dataset in an "instruct" format, containing `{ instruction, response }`? If yes, check [Instruction Dataset](#instruction-dataset)
|
||||
|
||||
If you went through the flow chart and did not find one that matches, it is recommended to preprocess your dataset into one of the above or create a thread on Github Discussion.
|
||||
|
||||
::: {.callout-tip}
|
||||
You can mix and match within each approach or across approaches to train a model on a variety of datasets.
|
||||
:::
|
||||
|
||||
### [Pre-Tokenized Dataset](tokenized.qmd)
|
||||
|
||||
We suggest this approach when you want to bring your own tokenized dataset.
|
||||
|
||||
Axolotl expects the dataset to have three keys:
|
||||
- `input_ids`: from tokenizing formatted prompt
|
||||
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
|
||||
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.
|
||||
|
||||
::: {.callout-tip}
|
||||
Make sure to add BOS/EOS tokens to your prompt and mask it appropriately.
|
||||
:::
|
||||
|
||||
A config for this would look like:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type:
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
`type: ` is empty!
|
||||
:::
|
||||
|
||||
### [Template Free Dataset](template_free.qmd)
|
||||
|
||||
We reccomend this approach when you want granular control over the prompt formatting, special tokens, and masking, whilst letting Axolotl handle the tokenization. This is very useful if your dataset has unique prompts that differ across samples and where one single general template wouldn't suffice.
|
||||
|
||||
In the example below, you could see that there is no proper structure. At the same time, it's very flexible as there are no constraints on how your prompt can look.
|
||||
|
||||
```json
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"label": true,
|
||||
"text": "<s>Hello\n"
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "hi there!. "
|
||||
},
|
||||
{
|
||||
"label": false,
|
||||
"text": "goodbye "
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "farewell</s>"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Each prompt must be have a key called `segments` which is a list of `{ text, label }`.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type: input_output
|
||||
```
|
||||
|
||||
### [Conversation Dataset](conversation.qmd)
|
||||
|
||||
`conversation` messages are a list of messages which usually contain a `role` and `content` key.
|
||||
|
||||
::: {.callout-tip}
|
||||
Fun fact: Axolotl synonymously refers to "chat" messages as `conversation` messages due to how FastChat initially used this term to build a widely used [fastchat conversation](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) method for formatting chat messages prior to the creation of `chat_templates`.
|
||||
:::
|
||||
|
||||
#### What are `chat_templates`?
|
||||
|
||||
The current most popular and convenient method for inference is to use `chat_templates` for formatting prompts. Axolotl supports using `chat_templates` for training to ensure that the model performs in the same environment as in inference.
|
||||
|
||||
Here's a quick rundown on `chat_template`: A `chat_template` is a Jinja2 template which formats a list of messages into a prompt.
|
||||
|
||||
An example of a prompt formatted into a popular template called ChatML can be seen below:
|
||||
|
||||
Single prompt (pretty-printed):
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "How can I help you?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Can you add 3+5?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The answer is 8."
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The ChatML template is as follows:
|
||||
```jinja2
|
||||
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
|
||||
```
|
||||
|
||||
The above prompt formatted into this template will result in:
|
||||
|
||||
```
|
||||
<|im_start|>user
|
||||
Hi<|im_end|>
|
||||
<|im_start|>assistant
|
||||
How can I help you?<|im_end|>
|
||||
<|im_start|>user
|
||||
Can you add 3+5?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
The answer is 8.<|im_end|>
|
||||
```
|
||||
|
||||
By using delimiters (`<|im_start|>` and `<|im_end|>`), a prompt separates different speakers which helps the model identify which portion belongs to whom.
|
||||
|
||||
#### Common Conversation Dataset formats
|
||||
|
||||
Older conversation datasets with the following format are colloquially called `sharegpt` datasets.
|
||||
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
Newer conversation datasets usually follow the OpenAI format.
|
||||
|
||||
```json
|
||||
{"messages": [{"role": "...", "content": "..."}]}
|
||||
```
|
||||
|
||||
Axolotl supports both as well as allowing customization of any kind of key.
|
||||
|
||||
#### [Chat Template Usage](conversation.qmd#chat_template)
|
||||
|
||||
To properly use this method, it is important to identify three things:
|
||||
|
||||
1. Which `chat_template` would you use?
|
||||
|
||||
2. What are the keys in your dataset, and what are the possible roles? For example, in OpenAI format, the keys would be `messages`, `role`, and `content`, respectively, whereas the possible roles are `system`, `user`, and `assistant`.
|
||||
|
||||
3. What do you want to mask? For instance, only assistant messages, only last message, or nothing.
|
||||
|
||||
##### Choosing a `chat_template`
|
||||
|
||||
There are a lot of `chat_templates` out there. Axolotl supports the common ones: [supported chat templates](https://github.com/axolotl-ai-cloud/axolotl/blob/860609392184cf62a7e0ca676658b170e059ce6c/src/axolotl/utils/chat_templates.py#L17). For example, to use ChatML, it would be `chat_template: chatml`.
|
||||
|
||||
However, it is also possible to use the already configured template within the tokenizer by specifying `chat_template: tokenizer_default`. If you want a fallback (in case some tokenizer does not have it pre-configured), you can do `chat_template: tokenizer_default_fallback_chatml` to fallback to the ChatML template if a tokenizer template was not found.
|
||||
|
||||
One last but powerful approach is to bring your own template. This can be set via:
|
||||
|
||||
```yaml
|
||||
chat_template_jinja: # your template
|
||||
```
|
||||
|
||||
##### Setting `chat_template` dataset keys
|
||||
|
||||
We currently default to OpenAI format for dataset keys, so if that's your current dataset format, there's nothing to do here.
|
||||
|
||||
If your dataset format is different, here are the keys you should check (with their defaults):
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
...
|
||||
field_messages: messages # this should point to the key containing the list of conversations
|
||||
message_property_mappings: # this is a mapping from keys in your dataset to keys in chat_template
|
||||
role: role
|
||||
content: content
|
||||
```
|
||||
|
||||
In some `chat_templates` (e.g. [Gemma](https://huggingface.co/google/gemma-2b-it/blob/main/tokenizer_config.json#L1507)), the roles are hardcoded to `user` and `assistant`. Consequently, you may find it necessary to map the roles in your dataset to these above. We currently have some defaults that should work for common datasets, but if you get a `KeyError`, it would be necessary to add mapping for your roles. Here is an example of how it would look like:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
...
|
||||
roles:
|
||||
assistant:
|
||||
- gpt
|
||||
- model
|
||||
user:
|
||||
- human
|
||||
```
|
||||
|
||||
In the example above, all `gpt` and `model` values are converted to `assistant`. All `human` values are converted to `user.`
|
||||
|
||||
##### Handling masking
|
||||
|
||||
The common use case for `chat_template` is for chat messages, therefore, it is common to mask all non-assistant messages. Assistant messages refer to the bot messages that you want the model to learn on.
|
||||
|
||||
To train on all `assistant` messages, you would set the following configs.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
...
|
||||
roles_to_train: ["assistant"]
|
||||
train_on_eos: "turn"
|
||||
```
|
||||
|
||||
The `train_on_eos` config means that it would mask all EOS tokens for turns that aren't assistant-turns. The other options are: `all` and `last` to choose which EOS to train on.
|
||||
|
||||
Perhaps, you want to train on `assistant` and `narrator` roles, you can simply add `narrator` to the list of `roles_to_train`. You would also need to add it to the mapping of `roles` above.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
...
|
||||
roles_to_train: ["assistant", "narrator"]
|
||||
roles:
|
||||
assistant:
|
||||
- gpt
|
||||
- model
|
||||
user:
|
||||
- human
|
||||
narrator: ["narrator"]
|
||||
```
|
||||
|
||||
#### Applying `chat_template`
|
||||
|
||||
Once all the above steps are completed, you could combine all these configs together to form a bespoke configuration for your custom dataset. The final step would be to correctly set the EOS token in your config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type: chat_template
|
||||
|
||||
# step 1
|
||||
chat_template: chatml
|
||||
|
||||
# step 2
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
roles:
|
||||
assistant:
|
||||
- gpt
|
||||
- model
|
||||
- assistant
|
||||
user:
|
||||
- human
|
||||
- user
|
||||
|
||||
# step 3
|
||||
roles_to_train: ["assistant"]
|
||||
train_on_eos: "turn"
|
||||
|
||||
special_tokens:
|
||||
eos_token: <|im_end|>
|
||||
```
|
||||
|
||||
If this config were to be applied to the sample dataset above, the output would look as such (which can be retrieved via `axolotl preprocess config.yaml --debug`):
|
||||
|
||||
```
|
||||
<|im_start|>(-100, 128256) user(-100, 882)
|
||||
(-100, 198) Hi(-100, 13347) <|im_end|>(-100, 128257)
|
||||
(-100, 198) <|im_start|>(-100, 128256) assistant(-100, 78191)
|
||||
(-100, 198) How(4438, 4438) can(649, 649) I(358, 358) help(1520, 1520) you(499, 499) ?(30, 30) <|im_end|>(128257, 128257)
|
||||
(-100, 198) <|im_start|>(-100, 128256) user(-100, 882)
|
||||
(-100, 198) Can(-100, 6854) you(-100, 499) add(-100, 923) (-100, 220) 3(-100, 18) +(-100, 10) 5(-100, 20) ?(-100, 30) <|im_end|>(-100, 128257)
|
||||
(-100, 198) <|im_start|>(-100, 128256) assistant(-100, 78191)
|
||||
(-100, 198) The(791, 791) answer(4320, 4320) is(374, 374) (220, 220) 8(23, 23) .(13, 13) <|im_end|>(128257, 128257)
|
||||
(-100, 198)
|
||||
```
|
||||
|
||||
The first number refers to the label, the second refers to the `token_id`. For example, `-100` labels appear on non-assistant portions, meaning that they are masked during. For assistant portions, the label is the same as the `token_id`.
|
||||
|
||||
### [Instruction Dataset](inst_tune.qmd)
|
||||
|
||||
Instruction datasets are used to train instruction-following models and comprise a prompt, containing an instruction, and a single response. In contrast to chat datasets which may be multi-turn, instruct datasets are typically single-turn.
|
||||
|
||||
An example is of a common format called Alpaca:
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "output": "..."}
|
||||
```
|
||||
|
||||
Using those keys, a prompt can be built based on it.
|
||||
```
|
||||
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction:
|
||||
{instruction}
|
||||
|
||||
### Input:
|
||||
{input}
|
||||
|
||||
### Response:
|
||||
{output}
|
||||
```
|
||||
|
||||
This can be configured as such:
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type: alpaca
|
||||
```
|
||||
|
||||
Axolotl supports many kinds of instruction dataset. All of them can be found here (https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/inst_tune.html) with their respective type and sample row format.
|
||||
|
||||
#### Custom Instruct Prompt Format
|
||||
|
||||
Due to the myriad possibilities of instruction formats, Axolotl allows customizing your own instruction format without having to dive into the code directly.
|
||||
|
||||
In the example below, a sample row is used to output in `mistral_v1` format.
|
||||
```json
|
||||
{"input": "...", "output": "..."}
|
||||
```
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: repo
|
||||
type:
|
||||
system_prompt: ""
|
||||
|
||||
field_system:
|
||||
field_instruction: input
|
||||
field_input:
|
||||
field_output: output
|
||||
|
||||
# multi-line example with input
|
||||
format: |-
|
||||
[INST] {instruction} {input} [/INST]
|
||||
|
||||
# single-line example without input
|
||||
no_input_format: "[INST] {instruction} [/INST]"
|
||||
```
|
||||
|
||||
The config sets that the `field_instruction` is actually named `input`, and the `field_input` is empty as we don't have an `input` in this sample. Generally, `instruction` can be thought as the question to the model, and `input` as the additional information with `output` being the response. It is not necessary to have an `input` nor `system`. In the end, the most important part is to understand what format you want it to look like and how you can customize this to your use case.
|
||||
|
||||
## Reinforcement Learning from Human Feedback (RLHF)
|
||||
|
||||
As there are multiple RLHF methods with their own dataset requirements. Please see [RLHF datasets](../rlhf.qmd) documentation for more detail.
|
||||
|
||||
@@ -19,3 +19,11 @@ description: Frequently asked questions
|
||||
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
|
||||
|
||||
> A: You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
|
||||
|
||||
**Q: The codes is stuck on saving preprocessed datasets.**
|
||||
|
||||
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
|
||||
> A: This means that the property mapping for the stated attribute does not exist when building `chat_template` prompt. For example, if `no attribute 'content'`, please check you have added the correct mapping for `content` under `message_property_mappings`.
|
||||
|
||||
128
docs/lora_optims.qmd
Normal file
128
docs/lora_optims.qmd
Normal file
@@ -0,0 +1,128 @@
|
||||
---
|
||||
title: "LoRA Optimizations"
|
||||
description: "Custom autograd functions and Triton kernels in Axolotl for optimized
|
||||
LoRA fine-tuning"
|
||||
---
|
||||
|
||||
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
|
||||
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
|
||||
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
|
||||
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
|
||||
to leverage operator fusion and tensor re-use in order to improve speed and reduce
|
||||
memory usage during the forward and backward passes of these calculations.
|
||||
|
||||
We currently support several common model architectures, including (but not limited to):
|
||||
|
||||
- `llama`
|
||||
- `mistral`
|
||||
- `qwen2`
|
||||
- `gemma`
|
||||
- `gemma2`
|
||||
|
||||
<details>
|
||||
|
||||
The set of models we support is currently limited by our attention patching strategy,
|
||||
which assumes (and replaces) specific code blocks for query / key / value and output
|
||||
projections:
|
||||
|
||||
```python
|
||||
ORIGINAL_QKV_CODE = """
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
ORIGINAL_O_CODE = """
|
||||
attn_output = self.o_proj(attn_output)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
```
|
||||
|
||||
Is replaced with:
|
||||
|
||||
```python
|
||||
PATCHED_QKV_CODE = """
|
||||
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
||||
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
||||
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
PATCHED_O_CODE = """
|
||||
attn_output = self.apply_o(attn_output)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
```
|
||||
|
||||
Where `apply_qkv` and `apply_o` are defined in the `axolotl.kernels.lora` module.
|
||||
|
||||
We welcome testing of other model architectures and / or PRs to expand our patching
|
||||
logic to be compatible with more of them.
|
||||
|
||||
</details>
|
||||
|
||||
## Usage
|
||||
|
||||
These optimizations can be enabled in your Axolotl config YAML file. The
|
||||
`lora_mlp_kernel` option enables the optimized MLP path, while `lora_qkv_kernel` and
|
||||
`lora_o_kernel` enable the fused query-key-value projection and optimized output
|
||||
projection, respectively.
|
||||
|
||||
```yaml
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
||||
- Note: Set `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` to enable [memory-efficient attention on AMD GPUs](https://github.com/ROCm/aotriton/issues/16#issuecomment-2346675491)
|
||||
- Targeted LoRA adapters cannot use Dropout
|
||||
- This may limit model expressivity / cause overfitting
|
||||
- Targeted LoRA adapters cannot have bias terms
|
||||
- This may limit model expressivity
|
||||
|
||||
Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to
|
||||
be re-finetuned without these features in order to be useful.
|
||||
|
||||
## Implementation details
|
||||
|
||||
### Custom autograd functions
|
||||
|
||||
The LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the
|
||||
LoRA and base weight computations together and provides a single, efficient backward
|
||||
pass for the entire MLP block.
|
||||
|
||||
For attention components, similar optimizations are provided through a function that
|
||||
handles the query, key, and value projections, and a function that handles the output
|
||||
projection. They are designed to work with the existing `transformers` attention
|
||||
implementation via some monkey-patching logic.
|
||||
|
||||
### Triton kernels
|
||||
|
||||
Two activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for
|
||||
improved speed and memory performance. These kernels handle both the forward and
|
||||
backward passes.
|
||||
|
||||
### Integration
|
||||
|
||||
The custom autograd functions and Triton kernels are designed to work together. The
|
||||
autograd function manages the high-level computation flow and gradient tracking, while
|
||||
calling the Triton kernels for the activation function computation. During the backward
|
||||
pass, the kernel computes both the activation output and the required gradients, which
|
||||
the autograd function then uses to compute the final gradients for the entire
|
||||
computation path.
|
||||
|
||||
## Future Work
|
||||
|
||||
- Support for additional model architectures
|
||||
- Support for the FSDP setting
|
||||
- Support for dropout and bias
|
||||
- Additional operator fusions
|
||||
@@ -3,6 +3,18 @@ title: Multi Node
|
||||
description: How to use Axolotl on multiple machines
|
||||
---
|
||||
|
||||
The below are three ways to train multi-node in Axolotl.
|
||||
|
||||
::: {.callout-important}
|
||||
Each machine needs a copy of Axolotl, we suggest using the same commit to ensure compatibility.
|
||||
|
||||
You will also need to have the same configuration file for your model on each machine.
|
||||
|
||||
Make sure the main machine is reachable by other machines.
|
||||
:::
|
||||
|
||||
# Accelerate
|
||||
|
||||
You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below:
|
||||
|
||||
~/.cache/huggingface/accelerate/default_config.yaml
|
||||
@@ -26,7 +38,7 @@ tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
Configure your model to use FSDP with for example:
|
||||
Configure your model to use FSDP in the Axolotl yaml. For example:
|
||||
```yaml
|
||||
fsdp:
|
||||
- full_shard
|
||||
@@ -37,12 +49,40 @@ fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
## Machine configuration
|
||||
|
||||
On each machine you need a copy of Axolotl, we suggest using the same commit to ensure compatibility.
|
||||
|
||||
You will also need to have the same configuration file for your model on each machine.
|
||||
|
||||
On the main machine only, make sure the port you set as `main_process_port` is open in TCP and reachable by other machines.
|
||||
|
||||
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.
|
||||
|
||||
# Raytrain
|
||||
|
||||
Please see ray train doc [here](ray-integration.qmd).
|
||||
|
||||
# Torchrun
|
||||
|
||||
If you are using Infiniband, we recommend torchrun to utilize the full bandwidth.
|
||||
|
||||
Set the following env (change buffersize/socketname depending on your system):
|
||||
|
||||
```yaml
|
||||
export NCCL_IB_DISABLE=0
|
||||
export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond"
|
||||
export NCCL_BUFFSIZE=2097152
|
||||
```
|
||||
|
||||
Run the following on each node:
|
||||
|
||||
```bash
|
||||
torchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" -m axolotl.cli.train config.yaml
|
||||
```
|
||||
|
||||
Please make sure to substitute the placeholder variables.
|
||||
|
||||
- `num_nodes`: Number of nodes (containing GPUs)
|
||||
- `gpu_per_node`: Number of gpus per node
|
||||
- `head_node_ip`: IP of the head node (make sure other machines can connect to this)
|
||||
- `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400)
|
||||
- `rdzv_id`: A unique job ID that is used by the job across nodes.
|
||||
|
||||
::: {.callout-note}
|
||||
You need to call `axolotl.cli.train` instead of `axolotl train` as the latter calls accelerate under the hood
|
||||
:::
|
||||
|
||||
More info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html)
|
||||
|
||||
451
docs/rlhf.qmd
451
docs/rlhf.qmd
@@ -1,26 +1,39 @@
|
||||
---
|
||||
title: "RLHF (Beta)"
|
||||
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
---
|
||||
|
||||
### Overview
|
||||
# Overview
|
||||
|
||||
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
|
||||
feedback. Various methods include, but not limited to:
|
||||
|
||||
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
||||
- Direct Preference Optimization (DPO)
|
||||
- Identity Preference Optimization (IPO)
|
||||
- [Direct Preference Optimization (DPO)](#dpo)
|
||||
- [Identity Preference Optimization (IPO)](#ipo)
|
||||
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
||||
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
||||
|
||||
|
||||
### RLHF using Axolotl
|
||||
# RLHF using Axolotl
|
||||
|
||||
>[!IMPORTANT]
|
||||
>This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
||||
::: {.callout-important}
|
||||
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
||||
:::
|
||||
|
||||
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
|
||||
We rely on the [TRL](https://github.com/huggingface/trl) library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.
|
||||
|
||||
::: {.callout-tip}
|
||||
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
|
||||
:::
|
||||
|
||||
## DPO
|
||||
|
||||
Example config:
|
||||
|
||||
#### DPO
|
||||
```yaml
|
||||
rl: dpo
|
||||
datasets:
|
||||
@@ -32,12 +45,265 @@ datasets:
|
||||
type: chatml
|
||||
```
|
||||
|
||||
#### IPO
|
||||
DPO supports the following types with the following dataset format:
|
||||
|
||||
### chatml.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"instruction": "...",
|
||||
"chosen_response": "...",
|
||||
"rejected_response": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.icr
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"input": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.intel
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"question": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"instruction": "...",
|
||||
"chosen_response": "...",
|
||||
"rejected_response": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.icr
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"input": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.intel
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"question": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### zephyr.nectar
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "...",
|
||||
"answers": [
|
||||
{
|
||||
"answer": "...",
|
||||
"rank": 1
|
||||
},
|
||||
{
|
||||
"answer": "...",
|
||||
"rank": 2
|
||||
}
|
||||
// ... more answers with ranks
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### chat_template.default
|
||||
|
||||
```yaml
|
||||
rl: dpo
|
||||
datasets:
|
||||
- path: ...
|
||||
split: train
|
||||
type: chat_template.default
|
||||
field_messages: "messages"
|
||||
field_chosen: "chosen"
|
||||
field_rejected: "rejected"
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user: ["user"]
|
||||
assistant: ["assistant"]
|
||||
system: ["system"]
|
||||
```
|
||||
|
||||
Sample input format:
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "..."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "..."
|
||||
},
|
||||
// ... more messages
|
||||
],
|
||||
"chosen": {
|
||||
"role": "assistant",
|
||||
"content": "..."
|
||||
},
|
||||
"rejected": {
|
||||
"role": "assistant",
|
||||
"content": "..."
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### user_defined.default
|
||||
|
||||
For custom behaviors,
|
||||
|
||||
```yaml
|
||||
rl: dpo
|
||||
datasets:
|
||||
- path: ...
|
||||
split: train
|
||||
type: user_defined.default
|
||||
|
||||
field_prompt: "prompt"
|
||||
field_system: "system"
|
||||
field_chosen: "chosen"
|
||||
field_rejected: "rejected"
|
||||
prompt_format: "{prompt}"
|
||||
chosen_format: "{chosen}"
|
||||
rejected_format: "{rejected}"
|
||||
```
|
||||
|
||||
The input format is a simple JSON input with customizable fields based on the above config.
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
## IPO
|
||||
|
||||
As IPO is just DPO with a different loss function, all supported options for DPO works here.
|
||||
|
||||
```yaml
|
||||
rl: ipo
|
||||
```
|
||||
|
||||
#### ORPO
|
||||
## ORPO
|
||||
|
||||
Paper: https://arxiv.org/abs/2403.07691
|
||||
|
||||
@@ -52,8 +318,28 @@ datasets:
|
||||
type: chat_template.argilla
|
||||
```
|
||||
|
||||
ORPO supports the following types with the following dataset format:
|
||||
|
||||
#### KTO
|
||||
### chat_template.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...", // if available, will be taken as user message for single-turn instead of from list below
|
||||
|
||||
// chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## KTO
|
||||
|
||||
```yaml
|
||||
rl: kto
|
||||
@@ -72,7 +358,144 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: true
|
||||
```
|
||||
|
||||
#### Using local dataset files
|
||||
KTO supports the following types with the following dataset format:
|
||||
|
||||
### chatml.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"instruction": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."}
|
||||
],
|
||||
"completion": [
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.intel
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"question": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"instruction": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
"completion": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.intel
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"question": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### user_defined.default
|
||||
|
||||
For custom behaviors,
|
||||
|
||||
```yaml
|
||||
rl: kto
|
||||
datasets:
|
||||
- path: ...
|
||||
split: train
|
||||
type: user_defined.default
|
||||
|
||||
field_prompt: "prompt"
|
||||
field_system: "system"
|
||||
field_completion: "completion"
|
||||
field_label: "label"
|
||||
prompt_format: "{prompt}"
|
||||
completion_format: "{completion}"
|
||||
```
|
||||
|
||||
The input format is a simple JSON input with customizable fields based on the above config.
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"completion": "...",
|
||||
"label": "..."
|
||||
}
|
||||
```
|
||||
|
||||
## Using local dataset files
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- ds_type: json
|
||||
@@ -82,9 +505,9 @@ datasets:
|
||||
type: chatml.intel
|
||||
```
|
||||
|
||||
#### Trl autounwrap for peft
|
||||
## TRL auto-unwrapping for PEFT
|
||||
|
||||
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
|
||||
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
|
||||
|
||||
```yaml
|
||||
# load ref model when adapter training.
|
||||
|
||||
@@ -21,8 +21,9 @@ datasets:
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -16,8 +16,9 @@ datasets:
|
||||
type: chat_template
|
||||
drop_system_message: true
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
@@ -13,8 +13,9 @@ datasets:
|
||||
type: chat_template
|
||||
drop_system_message: true
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -17,8 +17,9 @@ datasets:
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
|
||||
@@ -17,8 +17,9 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
|
||||
@@ -14,8 +14,9 @@ datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
|
||||
@@ -17,8 +17,9 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
@@ -31,8 +32,9 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
|
||||
82
examples/llama-3/lora-1b-kernels.yml
Normal file
82
examples/llama-3/lora-1b-kernels.yml
Normal file
@@ -0,0 +1,82 @@
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
# Currently, we don't support dropout with our custom Triton kernels
|
||||
# lora_dropout: 0.05
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# These options enable our custom Triton kernels / autograd
|
||||
# functions for MLP and attention calculations
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
@@ -22,8 +22,9 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -14,8 +14,9 @@ datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
|
||||
@@ -12,8 +12,9 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.45.1
|
||||
bitsandbytes==0.45.2
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
flash-attn==2.7.0.post2
|
||||
flash-attn==2.7.4.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.5.2
|
||||
@@ -13,12 +13,12 @@ liger-kernel==0.5.2
|
||||
packaging==23.2
|
||||
|
||||
peft==0.14.0
|
||||
transformers==4.48.1
|
||||
transformers==4.49.0
|
||||
tokenizers>=0.21.0
|
||||
accelerate==1.3.0
|
||||
datasets==3.2.0
|
||||
deepspeed==0.16.1
|
||||
trl==0.13.0
|
||||
trl==0.15.1
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
@@ -26,7 +26,7 @@ sentencepiece
|
||||
gradio==3.50.2
|
||||
|
||||
modal==0.70.5
|
||||
pydantic==2.6.3
|
||||
pydantic==2.10.6
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
|
||||
@@ -31,27 +31,26 @@ def parse_dataset(dataset=None, split="train"):
|
||||
ds_cfg["field_messages"] = field_messages
|
||||
|
||||
message_fields = features[field_messages][0].keys()
|
||||
message_field_role = None
|
||||
|
||||
message_property_mappings = {"role": None, "content": None}
|
||||
for key in ["from", "role"]:
|
||||
if key in message_fields:
|
||||
message_field_role = key
|
||||
message_property_mappings["role"] = key
|
||||
break
|
||||
if not message_field_role:
|
||||
if not message_property_mappings["role"]:
|
||||
raise ValueError(
|
||||
f'No role field found in messages: {", ".join(message_fields)}'
|
||||
)
|
||||
ds_cfg["message_field_role"] = message_field_role
|
||||
|
||||
message_field_content = None
|
||||
for key in ["content", "text", "value"]:
|
||||
if key in message_fields:
|
||||
message_field_content = key
|
||||
message_property_mappings["content"] = key
|
||||
break
|
||||
if not message_field_content:
|
||||
if not message_property_mappings["content"]:
|
||||
raise ValueError(
|
||||
f'No content field found in messages: {", ".join(message_fields)}'
|
||||
)
|
||||
ds_cfg["message_field_content"] = message_field_content
|
||||
ds_cfg["message_property_mappings"] = message_property_mappings
|
||||
|
||||
print(yaml.dump({"datasets": [ds_cfg]}))
|
||||
|
||||
|
||||
12
setup.py
12
setup.py
@@ -71,12 +71,15 @@ def parse_requirements():
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 5):
|
||||
if (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.29.post2")
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.28.post3")
|
||||
_install_requires.append("xformers>=0.0.28.post3")
|
||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||
elif (major, minor) >= (2, 4):
|
||||
if patch == 0:
|
||||
@@ -122,7 +125,7 @@ setup(
|
||||
},
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn==2.7.0.post2",
|
||||
"flash-attn==2.7.4.post1",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.16.1",
|
||||
@@ -153,5 +156,8 @@ setup(
|
||||
"ray": [
|
||||
"ray[train]",
|
||||
],
|
||||
"vllm": [
|
||||
"vllm==0.7.2",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -4,4 +4,4 @@ import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.6.0"
|
||||
__version__ = "0.8.0.dev0"
|
||||
|
||||
@@ -35,13 +35,18 @@ def do_cli_train(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
accelerate: bool = True,
|
||||
cwd=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
config_yaml = file.read()
|
||||
cloud.train(config_yaml, accelerate=accelerate)
|
||||
local_dirs = {}
|
||||
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
|
||||
local_dirs = {"/workspace/mounts": cwd}
|
||||
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
|
||||
|
||||
|
||||
def do_cli_lm_eval(
|
||||
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
import subprocess # nosec B404
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
from typing import Optional
|
||||
|
||||
import modal
|
||||
|
||||
@@ -22,8 +23,18 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
|
||||
|
||||
# modal workaround so it doesn't use the automounted axolotl
|
||||
new_env = copy.deepcopy(os.environ)
|
||||
|
||||
if "PYTHONPATH" in new_env:
|
||||
del new_env["PYTHONPATH"]
|
||||
paths = ["/workspace/mounts"]
|
||||
for sub_python_path_str in new_env["PYTHONPATH"].split(":"):
|
||||
sub_python_path = Path(sub_python_path_str)
|
||||
if not sub_python_path.joinpath("src", "axolotl").exists():
|
||||
# we don't want to use the automounted axolotl or unexpected behavior happens
|
||||
paths.append(str(sub_python_path))
|
||||
if paths:
|
||||
new_env["PYTHONPATH"] = ":".join(paths)
|
||||
else:
|
||||
del new_env["PYTHONPATH"]
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call( # nosec B603
|
||||
@@ -112,8 +123,6 @@ class ModalCloud(Cloud):
|
||||
if env := self.get_env():
|
||||
image = image.env(env)
|
||||
|
||||
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
|
||||
return image
|
||||
|
||||
def get_secrets(self):
|
||||
@@ -203,9 +212,12 @@ class ModalCloud(Cloud):
|
||||
memory = int(self.config.memory)
|
||||
return 1024 * memory
|
||||
|
||||
def get_train_env(self):
|
||||
def get_train_env(self, local_dirs=None):
|
||||
image = self.get_image()
|
||||
for mount, local_dir in (local_dirs or {}).items():
|
||||
image = image.add_local_dir(local_dir, mount)
|
||||
return self.app.function(
|
||||
image=self.get_image(),
|
||||
image=image,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
cpu=16.0,
|
||||
gpu=self.get_train_gpu(),
|
||||
@@ -214,14 +226,21 @@ class ModalCloud(Cloud):
|
||||
secrets=self.get_secrets(),
|
||||
)
|
||||
|
||||
def train(self, config_yaml: str, accelerate: bool = True):
|
||||
modal_fn = self.get_train_env()(_train)
|
||||
def train(
|
||||
self,
|
||||
config_yaml: str,
|
||||
accelerate: bool = True,
|
||||
local_dirs: Optional[dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
modal_fn = self.get_train_env(local_dirs)(_train)
|
||||
with modal.enable_output():
|
||||
with self.app.run(detach=True):
|
||||
modal_fn.remote(
|
||||
config_yaml,
|
||||
accelerate=accelerate,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def lm_eval(self, config_yaml: str):
|
||||
@@ -252,7 +271,7 @@ def _preprocess(config_yaml: str, volumes=None):
|
||||
)
|
||||
|
||||
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
@@ -262,8 +281,11 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None):
|
||||
accelerate_args = "--accelerate"
|
||||
else:
|
||||
accelerate_args = "--no-accelerate"
|
||||
num_processes_args = ""
|
||||
if num_processes := kwargs.pop("num_processes", None):
|
||||
num_processes_args = f"--num-processes {num_processes}"
|
||||
run_cmd(
|
||||
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
|
||||
f"axolotl train {accelerate_args} {num_processes_args} /workspace/artifacts/axolotl/config.yaml",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
@@ -2,19 +2,19 @@
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import logging
|
||||
import random
|
||||
import os
|
||||
import subprocess # nosec B404
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import axolotl
|
||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.cli.sweeps import generate_sweep_configs
|
||||
from axolotl.cli.utils import (
|
||||
add_options_from_config,
|
||||
add_options_from_dataclass,
|
||||
@@ -27,76 +27,6 @@ from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
|
||||
def generate_sweep_configs(base_config, sweeps_config):
|
||||
"""
|
||||
Recursively generates all possible configurations by applying sweeps to the base config.
|
||||
|
||||
Args:
|
||||
base_config (dict): The original configuration dictionary
|
||||
sweeps_config (dict): Dictionary where keys are parameters and values are either:
|
||||
- lists of values to sweep independently
|
||||
- or for paired values, a list of dicts under the '_' key
|
||||
|
||||
Returns:
|
||||
list: List of all possible configuration dictionaries
|
||||
|
||||
Example:
|
||||
sweeps_config = {
|
||||
'learning_rate': [0.1, 0.01],
|
||||
'_': [
|
||||
{'load_in_8bit': True, 'adapter': 'lora'},
|
||||
{'load_in_4bit': True, 'adapter': 'qlora'}
|
||||
]
|
||||
}
|
||||
"""
|
||||
# Separate paired values from regular sweeps
|
||||
paired_values = sweeps_config.get("_", [])
|
||||
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
|
||||
|
||||
# Process regular sweeps
|
||||
param_names = list(regular_sweeps.keys())
|
||||
param_values = list(regular_sweeps.values())
|
||||
|
||||
# Generate combinations for regular sweeps
|
||||
regular_combinations = list(product(*param_values)) if param_values else [()]
|
||||
|
||||
# Combine regular sweeps with paired values
|
||||
all_combinations = []
|
||||
for reg_combo in regular_combinations:
|
||||
if paired_values:
|
||||
for paired_set in paired_values:
|
||||
new_config = {}
|
||||
# new_config = deepcopy(base_config)
|
||||
# Combine regular parameters with paired parameters
|
||||
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
|
||||
for param_name, param_value in full_combo.items():
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
else:
|
||||
# If no paired values, just use regular combinations
|
||||
# new_config = deepcopy(base_config)
|
||||
new_config = {}
|
||||
for param_name, param_value in zip(param_names, reg_combo):
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
|
||||
# randomize the order of trials
|
||||
random.seed(42)
|
||||
random.shuffle(all_combinations)
|
||||
|
||||
# Generate a new config for each combination
|
||||
result_configs = []
|
||||
for combination in all_combinations:
|
||||
new_config = deepcopy(base_config)
|
||||
for param_name, param_value in combination.items():
|
||||
new_config[param_name] = param_value
|
||||
result_configs.append(new_config)
|
||||
|
||||
return result_configs
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||
def cli():
|
||||
@@ -165,7 +95,6 @@ def train(
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
@@ -199,7 +128,16 @@ def train(
|
||||
try:
|
||||
if accelerate:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
cwd = os.getcwd()
|
||||
do_cli_train(
|
||||
cloud_config=cloud,
|
||||
config=config,
|
||||
accelerate=True,
|
||||
cwd=cwd,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
@@ -208,7 +146,7 @@ def train(
|
||||
accelerate_args.append(str(main_process_port))
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes", None)
|
||||
accelerate_args.append("--num-processes")
|
||||
accelerate_args.append("--num_processes")
|
||||
accelerate_args.append(str(num_processes))
|
||||
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
@@ -220,7 +158,11 @@ def train(
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
do_cli_train(
|
||||
cloud_config=cloud, config=config, accelerate=False, **kwargs
|
||||
)
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
@@ -381,4 +323,5 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
main()
|
||||
|
||||
77
src/axolotl/cli/sweeps.py
Normal file
77
src/axolotl/cli/sweeps.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Utilities for handling sweeps over configs for axolotl train CLI command"""
|
||||
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
|
||||
|
||||
def generate_sweep_configs(
|
||||
base_config: dict[str, list], sweeps_config: dict[str, list]
|
||||
) -> list[dict[str, list]]:
|
||||
"""
|
||||
Recursively generates all possible configurations by applying sweeps to the base config.
|
||||
|
||||
Args:
|
||||
base_config (dict): The original configuration dictionary
|
||||
sweeps_config (dict): Dictionary where keys are parameters and values are either:
|
||||
- lists of values to sweep independently
|
||||
- or for paired values, a list of dicts under the '_' key
|
||||
|
||||
Returns:
|
||||
list: List of all possible configuration dictionaries
|
||||
|
||||
Example:
|
||||
sweeps_config = {
|
||||
'learning_rate': [0.1, 0.01],
|
||||
'_': [
|
||||
{'load_in_8bit': True, 'adapter': 'lora'},
|
||||
{'load_in_4bit': True, 'adapter': 'qlora'}
|
||||
]
|
||||
}
|
||||
"""
|
||||
# Separate paired values from regular sweeps
|
||||
paired_values = sweeps_config.get("_", [])
|
||||
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
|
||||
|
||||
# Process regular sweeps
|
||||
param_names = list(regular_sweeps.keys())
|
||||
param_values = list(regular_sweeps.values())
|
||||
|
||||
# Generate combinations for regular sweeps
|
||||
regular_combinations = list(product(*param_values)) if param_values else [()]
|
||||
|
||||
# Combine regular sweeps with paired values
|
||||
all_combinations = []
|
||||
for reg_combo in regular_combinations:
|
||||
if paired_values:
|
||||
for paired_set in paired_values:
|
||||
new_config = {}
|
||||
# new_config = deepcopy(base_config)
|
||||
# Combine regular parameters with paired parameters
|
||||
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
|
||||
for param_name, param_value in full_combo.items():
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
else:
|
||||
# If no paired values, just use regular combinations
|
||||
# new_config = deepcopy(base_config)
|
||||
new_config = {}
|
||||
for param_name, param_value in zip(param_names, reg_combo):
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
|
||||
# randomize the order of trials
|
||||
random.seed(42)
|
||||
random.shuffle(all_combinations)
|
||||
|
||||
# Generate a new config for each combination
|
||||
result_configs = []
|
||||
for combination in all_combinations:
|
||||
new_config = deepcopy(base_config)
|
||||
for param_name, param_value in combination.items():
|
||||
new_config[param_name] = param_value
|
||||
result_configs.append(new_config)
|
||||
|
||||
return result_configs
|
||||
@@ -122,9 +122,11 @@ def load_preference_datasets(
|
||||
`total_num_steps`.
|
||||
"""
|
||||
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||
total_num_steps = int(
|
||||
total_num_steps: Optional[int] = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
if cfg.rl == "grpo":
|
||||
total_num_steps = None
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
|
||||
@@ -39,7 +39,6 @@ from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlDPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlMambaTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
@@ -48,9 +47,11 @@ from axolotl.core.trainers.base import (
|
||||
AxolotlTrainer,
|
||||
ReLoRATrainer,
|
||||
)
|
||||
from axolotl.core.trainers.dpo import DPOStrategy
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlCPOConfig,
|
||||
AxolotlDPOConfig,
|
||||
AxolotlKTOConfig,
|
||||
AxolotlORPOConfig,
|
||||
AxolotlPRMConfig,
|
||||
@@ -329,6 +330,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
|
||||
training_arguments_kwargs = {}
|
||||
|
||||
if self.cfg.include_tokens_per_second is not None:
|
||||
training_arguments_kwargs[
|
||||
"include_tokens_per_second"
|
||||
] = self.cfg.include_tokens_per_second
|
||||
|
||||
if self.cfg.bf16 == "full":
|
||||
training_arguments_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
@@ -641,9 +648,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
|
||||
if self.cfg.rl == "orpo":
|
||||
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
|
||||
|
||||
if self.cfg.neftune_noise_alpha is not None:
|
||||
training_arguments_kwargs[
|
||||
"neftune_noise_alpha"
|
||||
@@ -652,7 +656,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_kwargs = {}
|
||||
|
||||
if self.cfg.reward_model:
|
||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
if self.cfg.optimizer in [
|
||||
@@ -965,10 +969,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
if self.cfg.rl_beta:
|
||||
training_args_kwargs["beta"] = self.cfg.rl_beta
|
||||
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
|
||||
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
|
||||
if self.cfg.orpo_alpha:
|
||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||
@@ -977,6 +982,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl == "simpo":
|
||||
training_args_cls = AxolotlCPOConfig
|
||||
training_args_kwargs["loss_type"] = "simpo"
|
||||
@@ -1001,11 +1007,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kto_undesirable_weight or 1.0
|
||||
)
|
||||
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl == "grpo":
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||
|
||||
else:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
if self.cfg.rl == "ipo":
|
||||
@@ -1016,11 +1026,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
if self.cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs[
|
||||
"use_logits_to_keep"
|
||||
] = self.cfg.dpo_use_logits_to_keep
|
||||
|
||||
for blocklist_key in blocklist_args_kwargs:
|
||||
if blocklist_key in training_args_kwargs:
|
||||
del training_args_kwargs[blocklist_key]
|
||||
|
||||
max_steps = self.cfg.max_steps or total_num_steps or -1
|
||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
output_dir=self.cfg.output_dir,
|
||||
self.cfg.output_dir,
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
max_steps=self.cfg.max_steps or total_num_steps,
|
||||
max_steps=max_steps,
|
||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||
learning_rate=self.cfg.learning_rate,
|
||||
warmup_steps=self.cfg.warmup_steps,
|
||||
@@ -1047,8 +1067,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs[
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
if self.cfg.rl in ["dpo", "ipo"]:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
if self.cfg.rl == "grpo":
|
||||
trainer_cls = GRPOStrategy.get_trainer_class()
|
||||
trainer_cls_args = [self.model]
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
elif self.cfg.rl in ["dpo", "ipo"]:
|
||||
trainer_cls = DPOStrategy.get_trainer_class()
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
@@ -1063,12 +1088,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "processing_class" in sig.parameters.keys():
|
||||
dpo_trainer_kwargs["processing_class"] = self.tokenizer
|
||||
else:
|
||||
if "tokenizer" in sig.parameters.keys():
|
||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
else:
|
||||
dpo_trainer_kwargs["processing_class"] = self.tokenizer
|
||||
|
||||
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
||||
if self.cfg.datasets is not None and (
|
||||
trainer_cls is DPOStrategy.get_trainer_class()
|
||||
):
|
||||
dpo_trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
|
||||
@@ -5,30 +5,21 @@ module for customized trainers
|
||||
from __future__ import annotations
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from typing import Dict, Literal, Optional
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import (
|
||||
CPOTrainer,
|
||||
DPOTrainer,
|
||||
KTOTrainer,
|
||||
ORPOTrainer,
|
||||
PRMTrainer,
|
||||
RewardTrainer,
|
||||
)
|
||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||
from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
@@ -847,107 +838,6 @@ class ReLoRATrainer(AxolotlTrainer):
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "dpo"]
|
||||
|
||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
def create_optimizer(self):
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
opt_model,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
if loraplus_lr_ratio:
|
||||
print("Using lora+")
|
||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
@wraps(DPOTrainer.push_to_hub)
|
||||
def push_to_hub(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> Dict:
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
)
|
||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||
for key in res.keys():
|
||||
res[key] = res[key][1:]
|
||||
|
||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||
|
||||
return res
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
num_items_in_batch=None,
|
||||
) -> torch.Tensor:
|
||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
|
||||
33
src/axolotl/core/trainers/dpo/__init__.py
Normal file
33
src/axolotl/core/trainers/dpo/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
DPO Specific Strategy for training
|
||||
"""
|
||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||
|
||||
|
||||
class DPOStrategy:
|
||||
"""
|
||||
Strategy for DPO training
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls):
|
||||
return AxolotlDPOTrainer
|
||||
|
||||
@classmethod
|
||||
def get_training_args_class(cls):
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
|
||||
return AxolotlDPOConfig
|
||||
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg):
|
||||
training_args_kwargs = {}
|
||||
if cfg.rl == "ipo":
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
||||
if cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||
return training_args_kwargs
|
||||
15
src/axolotl/core/trainers/dpo/args.py
Normal file
15
src/axolotl/core/trainers/dpo/args.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Axolotl specific DPO args
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trl import DPOConfig
|
||||
|
||||
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
125
src/axolotl/core/trainers/dpo/trainer.py
Normal file
125
src/axolotl/core/trainers/dpo/trainer.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
DPO trainer for axolotl
|
||||
"""
|
||||
import gc
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from transformers import Trainer
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
SchedulerMixin,
|
||||
_sanitize_kwargs_for_ds_tagging,
|
||||
_sanitize_kwargs_for_tagging,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "dpo"]
|
||||
|
||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
def create_optimizer(self):
|
||||
# pylint: disable=duplicate-code
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
opt_model,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
if loraplus_lr_ratio:
|
||||
print("Using lora+")
|
||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||
# pylint: disable=duplicate-code
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
@wraps(DPOTrainer.push_to_hub)
|
||||
def push_to_hub(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> Dict:
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
)
|
||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||
for key in res.keys():
|
||||
res[key] = res[key][1:]
|
||||
|
||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||
|
||||
return res
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
num_items_in_batch=None,
|
||||
) -> torch.Tensor:
|
||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
119
src/axolotl/core/trainers/grpo/__init__.py
Normal file
119
src/axolotl/core/trainers/grpo/__init__.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
GRPO Specific Strategy for training
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
class GRPOStrategy:
|
||||
"""
|
||||
Strategy for GRPO training
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls):
|
||||
return AxolotlGRPOTrainer
|
||||
|
||||
@classmethod
|
||||
def get_training_args_class(cls):
|
||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||
|
||||
return AxolotlGRPOConfig
|
||||
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg):
|
||||
grpo_args_kwargs = {}
|
||||
if cfg.trl and cfg.trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
|
||||
if cfg.trl and cfg.trl.vllm_device:
|
||||
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
|
||||
else:
|
||||
grpo_args_kwargs["vllm_device"] = "auto"
|
||||
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
|
||||
grpo_args_kwargs[
|
||||
"vllm_gpu_memory_utilization"
|
||||
] = cfg.trl.vllm_gpu_memory_utilization
|
||||
if cfg.trl and cfg.trl.vllm_max_model_len:
|
||||
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
|
||||
if cfg.trl and cfg.trl.num_generations:
|
||||
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
|
||||
if cfg.trl and cfg.trl.sync_ref_model:
|
||||
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
|
||||
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
|
||||
grpo_args_kwargs[
|
||||
"ref_model_mixup_alpha"
|
||||
] = cfg.trl.ref_model_mixup_alpha
|
||||
if cfg.trl and cfg.trl.ref_model_sync_steps:
|
||||
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
|
||||
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
|
||||
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
def set_trainer_args(cls, cfg):
|
||||
trainer_args = []
|
||||
if cfg.trl and cfg.trl.reward_funcs:
|
||||
reward_funcs = []
|
||||
for reward_func_fqn in cfg.trl.reward_funcs:
|
||||
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
|
||||
trainer_args.append(reward_funcs)
|
||||
return trainer_args
|
||||
|
||||
@classmethod
|
||||
def set_trainer_kwargs(cls, cfg):
|
||||
trainer_kwargs = {}
|
||||
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||
trainer_kwargs[
|
||||
"reward_processing_classes"
|
||||
] = cfg.trl.reward_processing_classes
|
||||
return trainer_kwargs
|
||||
|
||||
@classmethod
|
||||
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||
# No data collation is needed in GRPO, handled by trl's trainer __init__
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_blocklist_args_kwargs(cls):
|
||||
return ["dataset_num_proc"]
|
||||
|
||||
@classmethod
|
||||
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
|
||||
"""
|
||||
Returns the reward function from the given fully qualified name, or the path to the reward function model.
|
||||
|
||||
Args:
|
||||
reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform),
|
||||
or a HF hub path to the reward model.
|
||||
Raises:
|
||||
ValueError: If the reward function does not accept at least two arguments.
|
||||
|
||||
Returns:
|
||||
RewardFunc: A callable that accepts prompts and completions and returns rewards,
|
||||
or a path to a reward model.
|
||||
|
||||
"""
|
||||
try:
|
||||
# use importlib to dynamically load the reward function from the module
|
||||
reward_func_module_name = reward_func_fqn.split(".")[-1]
|
||||
reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
|
||||
reward_func = getattr(reward_func_module, reward_func_module_name)
|
||||
if not len(inspect.signature(reward_func).parameters) >= 2:
|
||||
raise ValueError(
|
||||
"Reward function must accept at least two arguments: prompts: list and completions: list"
|
||||
)
|
||||
return reward_func
|
||||
except ModuleNotFoundError:
|
||||
# the user has passed a string (ideally indicating the path of a reward model)
|
||||
LOG.info(
|
||||
f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
|
||||
)
|
||||
return reward_func
|
||||
15
src/axolotl/core/trainers/grpo/args.py
Normal file
15
src/axolotl/core/trainers/grpo/args.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Axolotl Specific Training Args
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trl import GRPOConfig
|
||||
|
||||
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""
|
||||
Axolotl GRPO Config for GRPO training
|
||||
"""
|
||||
108
src/axolotl/core/trainers/grpo/trainer.py
Normal file
108
src/axolotl/core/trainers/grpo/trainer.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Axolotl GRPO trainer
|
||||
"""
|
||||
from accelerate.utils import is_peft_model
|
||||
from accelerate.utils.other import is_compiled_module
|
||||
from transformers import PreTrainedModel
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
from trl.models import unwrap_model_for_generation
|
||||
|
||||
from axolotl.core.trainers.base import SchedulerMixin
|
||||
|
||||
|
||||
# mypy: ignore-errors
|
||||
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
"""
|
||||
Extend the base GRPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# pylint: disable=access-member-before-definition
|
||||
# Enable gradient checkpointing if requested
|
||||
if kwargs["args"].gradient_checkpointing:
|
||||
# Ensure use_cache is disabled
|
||||
if hasattr(self.model, "config"):
|
||||
self.model.config.use_cache = False
|
||||
|
||||
# Enable gradient checkpointing on the base model for PEFT
|
||||
if is_peft_model(self.model) and hasattr(
|
||||
self.model.base_model, "gradient_checkpointing_enable"
|
||||
):
|
||||
self.model.base_model.gradient_checkpointing_enable()
|
||||
# Enable gradient checkpointing for non-PEFT models
|
||||
elif hasattr(self.model, "gradient_checkpointing_enable"):
|
||||
self.model.gradient_checkpointing_enable()
|
||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
||||
# pylint: enable=access-member-before-definition
|
||||
|
||||
def _enable_gradient_checkpointing(
|
||||
self, model: PreTrainedModel, args: GRPOConfig
|
||||
) -> PreTrainedModel:
|
||||
"""Enables gradient checkpointing for the model."""
|
||||
# pylint: disable=unused-argument,redefined-builtin
|
||||
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
||||
use_reentrant = (
|
||||
"use_reentrant" not in gradient_checkpointing_kwargs
|
||||
or gradient_checkpointing_kwargs["use_reentrant"]
|
||||
)
|
||||
|
||||
if use_reentrant:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(
|
||||
make_inputs_require_grad
|
||||
)
|
||||
|
||||
return model
|
||||
# pylint: enable=unused-argument,redefined-builtin
|
||||
|
||||
def _move_model_to_vllm(self):
|
||||
with unwrap_model_for_generation(
|
||||
self.model,
|
||||
self.accelerator,
|
||||
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
|
||||
) as unwrapped_model:
|
||||
if is_compiled_module(unwrapped_model):
|
||||
unwrapped_model = (
|
||||
unwrapped_model._orig_mod # pylint: disable=protected-access
|
||||
)
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.merge_adapter()
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
# Remove base_model and base_layer prefixes
|
||||
state_dict = {
|
||||
k.removeprefix("base_model.model.")
|
||||
.removeprefix("base_model.model.")
|
||||
.replace(".base_layer", ""): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
# Remove values with adapter prefix (example: "_lora")
|
||||
state_dict = {
|
||||
k: v
|
||||
for k, v in state_dict.items()
|
||||
if unwrapped_model.prefix not in k
|
||||
}
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
state_dict = {
|
||||
k.replace("modules_to_save.default.", ""): v
|
||||
for k, v in state_dict.items()
|
||||
if "original_module" not in k
|
||||
}
|
||||
else:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = (
|
||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
)
|
||||
llm_model.load_weights(state_dict.items())
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.unmerge_adapter()
|
||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -217,13 +217,6 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||
"""
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
### AXOLOTL COMMUNITY LICENSE AGREEMENT
|
||||
|
||||
This Axolotl Community License Agreement (“Agreement”) is entered into by and between Axolotl AI Corp. (“Axolotl”) and
|
||||
any individual or entity (“Licensee”) who wishes to use the Software (as defined below) in accordance with the terms
|
||||
and conditions set forth in this Agreement.
|
||||
|
||||
1. Definitions
|
||||
1.1 “Licensee” refers to any individual or entity who has obtained a copy of the Software under this Agreement.
|
||||
1.2 “Plugin Integration” means independent integration software modules which may or may not be offered by Axolotl,
|
||||
which may be licensed separately by their respective authors and/or licensors.
|
||||
1.3 “Software” refers to the specific sub-directory of the Axolotl, Inc. software located at
|
||||
https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations and its subdirectories which
|
||||
permits Plugin Integrations to integrate with the Axolotl service.
|
||||
2. Grant of License
|
||||
2.1 Axolotl hereby grants Licensee a worldwide, non-exclusive, royalty-free, license to use, copy, modify, merge,
|
||||
publish, distribute, sublicense, and/or otherwise exploit the Software, subject to the following conditions:
|
||||
- Licensee must comply with all the terms and conditions of this Agreement.
|
||||
- Licensee must include the original copyright notice and disclaimer of warranty in all copies or substantial
|
||||
portions of the Software.
|
||||
2.2 Licensee may use the Software for any lawful purpose, except as restricted in Section 3.
|
||||
3. Restrictions
|
||||
3.1 Licensee shall not use the Software for any activity that constitutes a commercial activity of offering for
|
||||
free or for sale any services, platform, or equivalent to third parties for the purposes of allowing such
|
||||
third parties to fine-tune artificial intelligence models.
|
||||
3.2 Licensee shall not:
|
||||
- Use the Software for any illegal or unauthorized purpose.
|
||||
- Reverse engineer, decompile, or disassemble the Software.
|
||||
- Remove or modify any copyright, trademark, or other proprietary notices contained in the Software.
|
||||
- Use the Software in a way that could damage, disable, overburden, or impair the functionality of the
|
||||
Software or interfere with any third-party use of the Software.
|
||||
3.3 Axolotl reserves the right to restrict certain Plugin Integrations for use with the Software. To the extent Licensee integrates a permitted, applicable Plugin Integration with the Software, Licensee shall comply with any additional terms and conditions imposed by the licensors of such Plugin Integration for use of such Plugin Integrations. Licensee shall contact Axolotl if it has questions about whether its use of the Software falls beyond the scope of this Agreement.
|
||||
4. Intellectual Property Rights
|
||||
4.1 Axolotl and its contributors retain all intellectual property rights in and to the Software. Licensee
|
||||
acknowledges that this Agreement does not transfer any ownership rights or intellectual property rights to
|
||||
Licensee.
|
||||
5. Disclaimer of Warranty
|
||||
5.1 THE SOFTWARE IS PROVIDED “AS IS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
|
||||
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. IN NO EVENT SHALL
|
||||
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF
|
||||
CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
6. Termination
|
||||
6.1 Axolotl may terminate this Agreement at any time if Licensee fails to comply with any of the terms and
|
||||
conditions set forth herein. Upon termination, Licensee shall cease all use of the Software and destroy any
|
||||
copies in its possession.
|
||||
7. Governing Law
|
||||
7.1 This Agreement shall be governed by and construed in accordance with the laws of the State of California,
|
||||
without regards to conflicts of laws provisions thereof.
|
||||
8. Entire Agreement
|
||||
8.1 This Agreement constitutes the entire agreement between Axolotl and Licensee with respect to the subject matter
|
||||
hereof and supersedes all prior or contemporaneous understandings or agreements between the parties concerning
|
||||
the Software, whether written or oral. Axolotl may update the terms of this Agreement from time to time, and
|
||||
Licensee’s continued use of the Software after any such updates shall constitute acceptance of updated terms
|
||||
on a go-forward basis. Axolotl will use commercially reasonable efforts to provide Licensee notice of any
|
||||
material updates. By using the Software, Licensee acknowledges that it has read, understood, and agrees to be
|
||||
bound by the terms and conditions of this Agreement.
|
||||
|
||||
This Agreement was last updated on August 23, 2024.
|
||||
391
src/axolotl/integrations/kd/topk_logprob/bench_kl.py
Normal file
391
src/axolotl/integrations/kd/topk_logprob/bench_kl.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
benchmark utility helper for benchmarking the KL divergence triton kernel
|
||||
"""
|
||||
import gc
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch.utils.benchmark import Timer
|
||||
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss
|
||||
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
def benchmark_kl_div_loss_with_backward():
|
||||
# Test configurations
|
||||
batch_sizes = [1, 4]
|
||||
seq_lens = [64, 512, 2048, 4096, 8192]
|
||||
vocab_size = 32000
|
||||
top_k = 64
|
||||
|
||||
# Store results
|
||||
results = []
|
||||
|
||||
# Run benchmarks
|
||||
for batch_size in batch_sizes:
|
||||
for seq_len in seq_lens:
|
||||
# Generate random test data
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create tensors with gradients
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Clone student_logits for the two implementations
|
||||
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
|
||||
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
# Define functions for timing that include both forward and backward passes
|
||||
def run_reference():
|
||||
# Forward pass
|
||||
loss_ref = eager_loss(
|
||||
student_logits_ref, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
# Backward pass
|
||||
loss_ref.backward()
|
||||
|
||||
def run_triton():
|
||||
# Forward pass
|
||||
# pylint: disable=duplicate-code
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
# Backward pass
|
||||
loss_triton.backward()
|
||||
|
||||
# Benchmark reference implementation (forward + backward)
|
||||
t0 = Timer(
|
||||
stmt="run_reference()",
|
||||
globals={
|
||||
"run_reference": run_reference,
|
||||
},
|
||||
)
|
||||
# Reset gradients before timing
|
||||
student_logits_ref.grad = None
|
||||
ref_time = t0.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Benchmark Triton implementation (forward + backward)
|
||||
t1 = Timer(
|
||||
stmt="run_triton()",
|
||||
globals={
|
||||
"run_triton": run_triton,
|
||||
},
|
||||
)
|
||||
# Reset gradients before timing
|
||||
student_logits_triton.grad = None
|
||||
triton_time = t1.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Compute speedup
|
||||
speedup = ref_time / triton_time if triton_time > 0 else float("inf")
|
||||
|
||||
# Store results
|
||||
results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"seq_len": seq_len,
|
||||
"reference_time_ms": ref_time,
|
||||
"triton_time_ms": triton_time,
|
||||
"speedup": speedup,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Batch size: {batch_size}, Seq len: {seq_len}")
|
||||
print(f" Reference time (fwd+bwd): {ref_time:.2f} ms")
|
||||
print(f" Triton time (fwd+bwd): {triton_time:.2f} ms")
|
||||
print(f" Speedup: {speedup:.2f}x")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def benchmark_forward_backward_separately():
|
||||
"""
|
||||
Benchmark forward and backward passes separately to identify where the speedup comes from.
|
||||
"""
|
||||
# Test configurations
|
||||
batch_sizes = [1, 4, 8]
|
||||
seq_lens = [64, 512, 2048]
|
||||
vocab_size = 32000
|
||||
top_k = 64
|
||||
|
||||
# Store results
|
||||
detailed_results = []
|
||||
|
||||
# Run benchmarks
|
||||
for batch_size in batch_sizes:
|
||||
for seq_len in seq_lens:
|
||||
# Generate random test data
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create tensors with gradients
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Clone student_logits for the two implementations
|
||||
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
|
||||
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
# Forward-only reference
|
||||
def run_reference_forward():
|
||||
with torch.no_grad():
|
||||
return eager_loss(
|
||||
student_logits_ref,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
|
||||
# Forward-only triton
|
||||
def run_triton_forward():
|
||||
with torch.no_grad():
|
||||
return triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
|
||||
# Benchmark forward pass only
|
||||
|
||||
t0_fwd = Timer(
|
||||
stmt="run_reference_forward()",
|
||||
globals={
|
||||
"run_reference_forward": run_reference_forward,
|
||||
},
|
||||
)
|
||||
ref_fwd_time = t0_fwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
t1_fwd = Timer(
|
||||
stmt="run_triton_forward()",
|
||||
globals={
|
||||
"run_triton_forward": run_triton_forward,
|
||||
},
|
||||
)
|
||||
triton_fwd_time = t1_fwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Pre-compute losses for backward pass benchmarking
|
||||
loss_ref = eager_loss(
|
||||
student_logits_ref, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
|
||||
# Backward-only reference
|
||||
def run_reference_backward():
|
||||
student_logits_ref.grad = None
|
||||
loss_ref.backward()
|
||||
|
||||
# Backward-only triton
|
||||
def run_triton_backward():
|
||||
student_logits_triton.grad = None
|
||||
loss_triton.backward()
|
||||
|
||||
# Benchmark backward pass only
|
||||
t0_bwd = Timer(
|
||||
stmt="run_reference_backward()",
|
||||
globals={
|
||||
"run_reference_backward": run_reference_backward,
|
||||
},
|
||||
)
|
||||
ref_bwd_time = t0_bwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
t1_bwd = Timer(
|
||||
stmt="run_triton_backward()",
|
||||
globals={
|
||||
"run_triton_backward": run_triton_backward,
|
||||
},
|
||||
)
|
||||
triton_bwd_time = t1_bwd.timeit(10).median * 1000 # Convert to ms
|
||||
|
||||
# Compute speedups
|
||||
fwd_speedup = (
|
||||
ref_fwd_time / triton_fwd_time if triton_fwd_time > 0 else float("inf")
|
||||
)
|
||||
bwd_speedup = (
|
||||
ref_bwd_time / triton_bwd_time if triton_bwd_time > 0 else float("inf")
|
||||
)
|
||||
total_ref_time = ref_fwd_time + ref_bwd_time
|
||||
total_triton_time = triton_fwd_time + triton_bwd_time
|
||||
total_speedup = (
|
||||
total_ref_time / total_triton_time
|
||||
if total_triton_time > 0
|
||||
else float("inf")
|
||||
)
|
||||
|
||||
# Store results
|
||||
detailed_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"seq_len": seq_len,
|
||||
"ref_forward_ms": ref_fwd_time,
|
||||
"triton_forward_ms": triton_fwd_time,
|
||||
"forward_speedup": fwd_speedup,
|
||||
"ref_backward_ms": ref_bwd_time,
|
||||
"triton_backward_ms": triton_bwd_time,
|
||||
"backward_speedup": bwd_speedup,
|
||||
"total_ref_ms": total_ref_time,
|
||||
"total_triton_ms": total_triton_time,
|
||||
"total_speedup": total_speedup,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Batch size: {batch_size}, Seq len: {seq_len}")
|
||||
print(
|
||||
f" Forward: Reference={ref_fwd_time:.2f}ms, Triton={triton_fwd_time:.2f}ms, Speedup={fwd_speedup:.2f}x"
|
||||
)
|
||||
print(
|
||||
f" Backward: Reference={ref_bwd_time:.2f}ms, Triton={triton_bwd_time:.2f}ms, Speedup={bwd_speedup:.2f}x"
|
||||
)
|
||||
print(
|
||||
f" Total: Reference={total_ref_time:.2f}ms, Triton={total_triton_time:.2f}ms, Speedup={total_speedup:.2f}x"
|
||||
)
|
||||
|
||||
return detailed_results
|
||||
|
||||
|
||||
def benchmark_memory_usage_with_backward():
|
||||
# Test configurations
|
||||
batch_sizes = [1, 2]
|
||||
seq_len = 8192
|
||||
vocab_size = 128000
|
||||
top_k = 64
|
||||
|
||||
# Store results
|
||||
mem_results = []
|
||||
|
||||
# Run benchmarks
|
||||
for batch_size in batch_sizes:
|
||||
# Generate random test data
|
||||
torch.manual_seed(42)
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||
)
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Clone student_logits for the implementations
|
||||
student_logits_ref = student_logits.clone().detach().requires_grad_(True)
|
||||
student_logits_triton = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
# Measure PyTorch memory usage (forward + backward)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
loss_ref = eager_loss(
|
||||
student_logits_ref, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
loss_ref.backward()
|
||||
torch.cuda.synchronize()
|
||||
pytorch_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
|
||||
|
||||
# Measure Triton memory usage (forward + backward)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton, target_token_ids, target_logprobs, target_mask
|
||||
)
|
||||
loss_triton.backward()
|
||||
torch.cuda.synchronize()
|
||||
triton_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
|
||||
|
||||
# Measure Triton memory usage with different chunk sizes (forward + backward)
|
||||
for n_chunks in [1, 2, 4, 8]:
|
||||
student_logits_chunk = student_logits.clone().detach().requires_grad_(True)
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
loss_chunk = triton_loss(
|
||||
student_logits_chunk,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
)
|
||||
loss_chunk.backward()
|
||||
torch.cuda.synchronize()
|
||||
chunk_mem = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
|
||||
|
||||
mem_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"implementation": f"Triton (chunks={n_chunks})",
|
||||
"memory_mb": chunk_mem,
|
||||
}
|
||||
)
|
||||
|
||||
# Store results
|
||||
mem_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"implementation": "PyTorch",
|
||||
"memory_mb": pytorch_mem,
|
||||
}
|
||||
)
|
||||
|
||||
mem_results.append(
|
||||
{
|
||||
"batch_size": batch_size,
|
||||
"implementation": "Triton (default)",
|
||||
"memory_mb": triton_mem,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Batch size: {batch_size} (with backward pass)")
|
||||
print(f" PyTorch memory: {pytorch_mem:.2f} MB")
|
||||
print(f" Triton memory: {triton_mem:.2f} MB")
|
||||
print(f" Memory reduction: {(1 - triton_mem/pytorch_mem)*100:.2f}%")
|
||||
|
||||
return mem_results
|
||||
|
||||
|
||||
def main():
|
||||
print("Running benchmarks with forward and backward passes...")
|
||||
benchmark_kl_div_loss_with_backward()
|
||||
clean()
|
||||
|
||||
print("\nRunning detailed forward/backward benchmarks...")
|
||||
# benchmark_forward_backward_separately()
|
||||
# clean()
|
||||
|
||||
print("\nRunning memory usage benchmarks with backward passes...")
|
||||
benchmark_memory_usage_with_backward()
|
||||
clean()
|
||||
|
||||
|
||||
def clean():
|
||||
for _ in range(5):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,14 +1,16 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# This software may be used and distributed according to
|
||||
# the terms of the Axolotl Community License Agreement (the "License");
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
loss for top_k KL divergence
|
||||
|
||||
750
src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py
Normal file
750
src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py
Normal file
@@ -0,0 +1,750 @@
|
||||
"""
|
||||
Optimized Triton kernel for KL divergence loss between teacher and student models.
|
||||
"""
|
||||
# pylint: disable=invalid-name,unused-argument
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_logsumexp_logprobs_kernel(
|
||||
student_logits_ptr, # Input logits in original dtype
|
||||
student_logprobs_ptr, # Output logprobs (float32)
|
||||
token_ids_ptr, # Token IDs for top-k
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
K, # batch size, seq len, vocab size, top-k
|
||||
temperature,
|
||||
stride_l_b,
|
||||
stride_l_s,
|
||||
stride_l_v,
|
||||
stride_lp_b,
|
||||
stride_lp_s,
|
||||
stride_lp_k,
|
||||
stride_t_b,
|
||||
stride_t_s,
|
||||
stride_t_k,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Fused kernel that computes logsumexp and logprobs for topk tokens.
|
||||
All computations are done in float32 for numerical stability.
|
||||
"""
|
||||
# Program ID
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Compute logsumexp over the vocabulary
|
||||
max_val = -float("inf")
|
||||
|
||||
# Phase 1: Find max value across vocabulary
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
# Create block indices and mask
|
||||
block_size = min(BLOCK_SIZE, V - v_offset)
|
||||
block_idx = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block_idx < block_size
|
||||
|
||||
# Load logits block and convert to float32 in-place
|
||||
ptrs = (
|
||||
student_logits_ptr
|
||||
+ batch_idx * stride_l_b
|
||||
+ seq_idx * stride_l_s
|
||||
+ (v_offset + block_idx) * stride_l_v
|
||||
)
|
||||
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
|
||||
|
||||
# Apply temperature scaling if needed
|
||||
if temperature != 1.0:
|
||||
block_logits = block_logits / temperature
|
||||
|
||||
# Update max value
|
||||
block_max = tl.max(block_logits, axis=0)
|
||||
max_val = tl.maximum(max_val, block_max)
|
||||
|
||||
# Phase 2: Compute sum of exp(logits - max_val)
|
||||
sum_exp = 0.0
|
||||
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
# Create block indices and mask
|
||||
block_size = min(BLOCK_SIZE, V - v_offset)
|
||||
block_idx = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block_idx < block_size
|
||||
|
||||
# Load logits block and convert to float32 in-place
|
||||
ptrs = (
|
||||
student_logits_ptr
|
||||
+ batch_idx * stride_l_b
|
||||
+ seq_idx * stride_l_s
|
||||
+ (v_offset + block_idx) * stride_l_v
|
||||
)
|
||||
block_logits = tl.load(ptrs, mask=mask, other=-float("inf")).to(tl.float32)
|
||||
|
||||
# Apply temperature scaling if needed
|
||||
if temperature != 1.0:
|
||||
block_logits = block_logits / temperature
|
||||
|
||||
# Compute exp(logits - max_val) and add to sum
|
||||
block_exp = tl.exp(block_logits - max_val)
|
||||
sum_exp += tl.sum(block_exp * mask, axis=0)
|
||||
|
||||
# Compute final logsumexp
|
||||
logsumexp = max_val + tl.log(sum_exp)
|
||||
|
||||
# Phase 3: Compute and store logprobs for the top-k tokens
|
||||
token_ids_base = token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
|
||||
logprobs_base = (
|
||||
student_logprobs_ptr + batch_idx * stride_lp_b + seq_idx * stride_lp_s
|
||||
)
|
||||
|
||||
for k in range(K):
|
||||
# Load token ID for position k
|
||||
token_id = tl.load(token_ids_base + k * stride_t_k)
|
||||
|
||||
# Load the corresponding logit and convert to float32
|
||||
token_logit_ptr = (
|
||||
student_logits_ptr
|
||||
+ batch_idx * stride_l_b
|
||||
+ seq_idx * stride_l_s
|
||||
+ token_id * stride_l_v
|
||||
)
|
||||
token_logit = tl.load(token_logit_ptr).to(tl.float32)
|
||||
|
||||
# Apply temperature scaling if needed
|
||||
if temperature != 1.0:
|
||||
token_logit = token_logit / temperature
|
||||
|
||||
# Compute logprob directly: logit - logsumexp
|
||||
token_logprob = token_logit - logsumexp
|
||||
|
||||
# Store the result
|
||||
tl.store(logprobs_base + k * stride_lp_k, token_logprob)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def grad_softmax_kernel(
|
||||
grad_student_logits_ptr,
|
||||
target_token_ids_ptr,
|
||||
teacher_probs_ptr,
|
||||
student_probs_ptr,
|
||||
mask_ptr,
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
K, # batch size, seq len, vocab size, top-k
|
||||
scale,
|
||||
stride_gl_b,
|
||||
stride_gl_s,
|
||||
stride_gl_v,
|
||||
stride_t_b,
|
||||
stride_t_s,
|
||||
stride_t_k,
|
||||
stride_p_b,
|
||||
stride_p_s,
|
||||
stride_p_k,
|
||||
stride_sp_b,
|
||||
stride_sp_s,
|
||||
stride_sp_k,
|
||||
stride_m_b,
|
||||
stride_m_s,
|
||||
stride_m_k,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Program ID
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Base pointers for this (batch, seq) pair
|
||||
grad_logits_base = (
|
||||
grad_student_logits_ptr + batch_idx * stride_gl_b + seq_idx * stride_gl_s
|
||||
)
|
||||
token_ids_base = (
|
||||
target_token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
|
||||
)
|
||||
teacher_probs_base = (
|
||||
teacher_probs_ptr + batch_idx * stride_p_b + seq_idx * stride_p_s
|
||||
)
|
||||
student_probs_base = (
|
||||
student_probs_ptr + batch_idx * stride_sp_b + seq_idx * stride_sp_s
|
||||
)
|
||||
mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s
|
||||
|
||||
# Process each teacher probability one at a time, computing all gradients for it
|
||||
for k in range(0, K):
|
||||
# Load data for current position k
|
||||
teacher_prob = tl.load(teacher_probs_base + k * stride_p_k)
|
||||
student_prob_k = tl.load(student_probs_base + k * stride_sp_k)
|
||||
mask_val = tl.load(mask_base + k * stride_m_k)
|
||||
|
||||
# Precompute the self-influence term (multiplied by scale)
|
||||
self_term = teacher_prob * (1.0 - student_prob_k) * scale
|
||||
|
||||
# Calculate gradient contributions for all positions j
|
||||
for j in range(0, K):
|
||||
token_id_j = tl.load(token_ids_base + j * stride_t_k)
|
||||
student_prob_j = tl.load(student_probs_base + j * stride_sp_k)
|
||||
mask_j = tl.load(mask_base + j * stride_m_k)
|
||||
|
||||
# Calculate the masking factor
|
||||
combined_mask = mask_val * mask_j
|
||||
|
||||
# Determine if this is a diagonal or off-diagonal term
|
||||
is_k_equals_j = tl.where(k == j, 1.0, 0.0)
|
||||
|
||||
# Compute the gradient contribution
|
||||
# For diagonal (k==j): -teacher_prob * (1-student_prob_k) * scale * mask
|
||||
# For off-diagonal: -(-teacher_prob * student_prob_j) * scale * mask
|
||||
grad_contribution = (
|
||||
-(
|
||||
self_term * is_k_equals_j
|
||||
- teacher_prob * student_prob_j * scale * (1.0 - is_k_equals_j)
|
||||
)
|
||||
* combined_mask
|
||||
)
|
||||
|
||||
# Atomically update the gradient for this token
|
||||
tl.atomic_add(
|
||||
grad_logits_base + token_id_j * stride_gl_v, grad_contribution
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def grad_topk_softmax_kernel(
|
||||
grad_student_logits_ptr,
|
||||
student_logits_ptr,
|
||||
target_token_ids_ptr,
|
||||
teacher_probs_ptr,
|
||||
student_probs_ptr,
|
||||
mask_ptr,
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
K, # batch size, seq len, vocab size, top-k
|
||||
scale,
|
||||
stride_gl_b,
|
||||
stride_gl_s,
|
||||
stride_gl_v,
|
||||
stride_l_b,
|
||||
stride_l_s,
|
||||
stride_l_v,
|
||||
stride_t_b,
|
||||
stride_t_s,
|
||||
stride_t_k,
|
||||
stride_p_b,
|
||||
stride_p_s,
|
||||
stride_p_k,
|
||||
stride_sp_b,
|
||||
stride_sp_s,
|
||||
stride_sp_k,
|
||||
stride_m_b,
|
||||
stride_m_s,
|
||||
stride_m_k,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Program ID
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Base pointers for this (batch, seq) pair
|
||||
grad_logits_base = (
|
||||
grad_student_logits_ptr + batch_idx * stride_gl_b + seq_idx * stride_gl_s
|
||||
)
|
||||
# logits_base = student_logits_ptr + batch_idx * stride_l_b + seq_idx * stride_l_s
|
||||
token_ids_base = (
|
||||
target_token_ids_ptr + batch_idx * stride_t_b + seq_idx * stride_t_s
|
||||
)
|
||||
teacher_probs_base = (
|
||||
teacher_probs_ptr + batch_idx * stride_p_b + seq_idx * stride_p_s
|
||||
)
|
||||
student_probs_base = (
|
||||
student_probs_ptr + batch_idx * stride_sp_b + seq_idx * stride_sp_s
|
||||
)
|
||||
mask_base = mask_ptr + batch_idx * stride_m_b + seq_idx * stride_m_s
|
||||
|
||||
# Load all token IDs, probs and masks for this position
|
||||
token_ids = tl.zeros([K], dtype=tl.int32)
|
||||
teacher_probs = tl.zeros([K], dtype=tl.float32)
|
||||
student_probs = tl.zeros([K], dtype=tl.float32)
|
||||
masks = tl.zeros([K], dtype=tl.float32)
|
||||
|
||||
for k in range(K):
|
||||
token_ids[k] = tl.load(token_ids_base + k * stride_t_k)
|
||||
teacher_probs[k] = tl.load(teacher_probs_base + k * stride_p_k)
|
||||
student_probs[k] = tl.load(student_probs_base + k * stride_sp_k)
|
||||
masks[k] = tl.load(mask_base + k * stride_m_k)
|
||||
|
||||
# Process gradients for all tokens in this position
|
||||
for k in range(K):
|
||||
# token_id = token_ids[k]
|
||||
mask_k = masks[k]
|
||||
|
||||
# Skip computation if mask is zero by multiplying gradient by mask
|
||||
for j in range(K):
|
||||
other_token_id = token_ids[j]
|
||||
mask_j = masks[j]
|
||||
combined_mask = mask_k * mask_j
|
||||
|
||||
# Compute gradient differently for diagonal vs off-diagonal entries
|
||||
# Using * 1.0 to convert boolean to float
|
||||
is_diagonal = tl.where(j == k, 1.0, 0.0)
|
||||
|
||||
# Self influence: gradient = teacher_prob * (1 - student_prob)
|
||||
self_grad = teacher_probs[k] * (1.0 - student_probs[k]) * is_diagonal
|
||||
|
||||
# Cross influence: gradient = -teacher_prob[k] * student_prob[j]
|
||||
cross_grad = -teacher_probs[k] * student_probs[j] * (1.0 - is_diagonal)
|
||||
|
||||
# Combined gradient scaled by mask
|
||||
grad_val = (self_grad + cross_grad) * scale * combined_mask
|
||||
|
||||
tl.atomic_add(grad_logits_base + other_token_id * stride_gl_v, grad_val)
|
||||
|
||||
|
||||
# Triton-accelerated implementation of KL divergence loss for top-k tokens
|
||||
# Chunking helper functions for handling long sequences
|
||||
def chunk_tensor(
|
||||
tensor: torch.Tensor, max_seq_len: int
|
||||
) -> Tuple[torch.Tensor, Optional[int]]:
|
||||
"""Split a tensor along sequence dimension if needed."""
|
||||
_, seq_len, *__ = tensor.shape
|
||||
|
||||
if seq_len <= max_seq_len:
|
||||
return tensor, None
|
||||
|
||||
num_chunks = (seq_len + max_seq_len - 1) // max_seq_len
|
||||
chunks = []
|
||||
|
||||
for i in range(num_chunks):
|
||||
start_idx = i * max_seq_len
|
||||
end_idx = min((i + 1) * max_seq_len, seq_len)
|
||||
chunks.append(tensor[:, start_idx:end_idx, ...])
|
||||
|
||||
return chunks, num_chunks
|
||||
|
||||
|
||||
def merge_chunks(chunks: list, original_shape: torch.Size):
|
||||
"""Merge chunks back into a single tensor with original shape."""
|
||||
return torch.cat(chunks, dim=1)
|
||||
|
||||
|
||||
# Triton-accelerated implementation of KL divergence loss for top-k tokens
|
||||
class TopKKLDivergence(torch.autograd.Function):
|
||||
"""
|
||||
Autograd function for KL divergence loss between top-k logprobs
|
||||
with support for chunking to handle very long sequences.
|
||||
"""
|
||||
|
||||
# Max sequence length to process in a single kernel launch
|
||||
# This is a tunable parameter that might need adjustment based on GPU memory
|
||||
MAX_SEQ_LEN = 8192
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch=-1,
|
||||
kd_temperature=1.0,
|
||||
top_k_before_softmax=0,
|
||||
):
|
||||
"""
|
||||
Forward pass for KL divergence loss between top-k logprobs with chunking.
|
||||
"""
|
||||
# Only convert target_logprobs to float, leave student_logits as is
|
||||
target_logprobs = target_logprobs.float()
|
||||
|
||||
# Get dimensions
|
||||
batch_size, _, vocab_size = student_logits.shape
|
||||
_, teacher_seq_len, top_k = target_token_ids.shape
|
||||
|
||||
# Slice student logits to match teacher sequence length
|
||||
student_logits_for_kd = student_logits[:, :teacher_seq_len, :]
|
||||
|
||||
# Store original values for backward pass
|
||||
ctx.original_seq_len = teacher_seq_len
|
||||
ctx.original_dtype = student_logits.dtype
|
||||
|
||||
# Apply chunking for long sequences
|
||||
if teacher_seq_len > TopKKLDivergence.MAX_SEQ_LEN:
|
||||
# Chunk the inputs
|
||||
student_logits_chunks, num_chunks = chunk_tensor(
|
||||
student_logits_for_kd, TopKKLDivergence.MAX_SEQ_LEN
|
||||
)
|
||||
target_token_ids_chunks, _ = chunk_tensor(
|
||||
target_token_ids, TopKKLDivergence.MAX_SEQ_LEN
|
||||
)
|
||||
# target_logprobs_chunks, _ = chunk_tensor(
|
||||
# target_logprobs, TopKKLDivergence.MAX_SEQ_LEN
|
||||
# )
|
||||
# target_mask_chunks, _ = chunk_tensor(
|
||||
# target_mask, TopKKLDivergence.MAX_SEQ_LEN
|
||||
# )
|
||||
|
||||
# Process each chunk
|
||||
student_logprobs_chunks = []
|
||||
student_probs_chunks = []
|
||||
|
||||
for i in range(num_chunks):
|
||||
chunk_logits = student_logits_chunks[i]
|
||||
chunk_token_ids = target_token_ids_chunks[i]
|
||||
chunk_seq_len = chunk_logits.shape[1]
|
||||
|
||||
if top_k_before_softmax:
|
||||
# Apply temperature to student logits
|
||||
if kd_temperature != 1.0:
|
||||
chunk_logits = chunk_logits / kd_temperature
|
||||
|
||||
# Gather student logits for top-k tokens
|
||||
chunk_logits_topk = torch.gather(
|
||||
chunk_logits, dim=-1, index=chunk_token_ids
|
||||
)
|
||||
|
||||
# Compute softmax over gathered logits
|
||||
chunk_logprobs_topk = torch.log_softmax(chunk_logits_topk, dim=-1)
|
||||
chunk_probs_topk = torch.exp(chunk_logprobs_topk)
|
||||
else:
|
||||
# Allocate output tensor for logprobs directly (always in float32)
|
||||
chunk_logprobs_topk = torch.empty(
|
||||
(batch_size, chunk_seq_len, top_k),
|
||||
dtype=torch.float32,
|
||||
device=chunk_logits.device,
|
||||
)
|
||||
|
||||
# Launch fused kernel directly
|
||||
grid = (batch_size * chunk_seq_len,)
|
||||
fused_logsumexp_logprobs_kernel[grid](
|
||||
chunk_logits.contiguous(),
|
||||
chunk_logprobs_topk,
|
||||
chunk_token_ids.contiguous(),
|
||||
batch_size,
|
||||
chunk_seq_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
kd_temperature,
|
||||
chunk_logits.stride(0),
|
||||
chunk_logits.stride(1),
|
||||
chunk_logits.stride(2),
|
||||
chunk_logprobs_topk.stride(0),
|
||||
chunk_logprobs_topk.stride(1),
|
||||
chunk_logprobs_topk.stride(2),
|
||||
chunk_token_ids.stride(0),
|
||||
chunk_token_ids.stride(1),
|
||||
chunk_token_ids.stride(2),
|
||||
min(1024, triton.next_power_of_2(vocab_size)),
|
||||
)
|
||||
|
||||
# Calculate probs from logprobs
|
||||
chunk_probs_topk = torch.exp(chunk_logprobs_topk)
|
||||
|
||||
# Store results
|
||||
student_logprobs_chunks.append(chunk_logprobs_topk)
|
||||
student_probs_chunks.append(chunk_probs_topk)
|
||||
|
||||
# Merge results
|
||||
student_logprobs_topk = torch.cat(student_logprobs_chunks, dim=1)
|
||||
student_probs_topk = torch.cat(student_probs_chunks, dim=1)
|
||||
|
||||
# Save chunking info for backward pass
|
||||
ctx.used_chunking = True
|
||||
ctx.num_chunks = num_chunks
|
||||
|
||||
else:
|
||||
# Original code path for shorter sequences
|
||||
if top_k_before_softmax:
|
||||
# Apply temperature to student logits
|
||||
if kd_temperature != 1.0:
|
||||
student_logits_for_kd = student_logits_for_kd / kd_temperature
|
||||
|
||||
# Gather student logits for top-k tokens
|
||||
student_logits_topk = torch.gather(
|
||||
student_logits_for_kd, dim=-1, index=target_token_ids
|
||||
)
|
||||
|
||||
# Compute softmax over gathered logits
|
||||
student_logprobs_topk = torch.log_softmax(student_logits_topk, dim=-1)
|
||||
student_probs_topk = torch.exp(student_logprobs_topk)
|
||||
else:
|
||||
# Allocate output tensor for logprobs directly (always in float32)
|
||||
student_logprobs_topk = torch.empty(
|
||||
(batch_size, teacher_seq_len, top_k),
|
||||
dtype=torch.float32,
|
||||
device=student_logits.device,
|
||||
)
|
||||
|
||||
# Launch fused kernel directly
|
||||
grid = (batch_size * teacher_seq_len,)
|
||||
fused_logsumexp_logprobs_kernel[grid](
|
||||
student_logits_for_kd.contiguous(),
|
||||
student_logprobs_topk,
|
||||
target_token_ids.contiguous(),
|
||||
batch_size,
|
||||
teacher_seq_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
kd_temperature,
|
||||
student_logits_for_kd.stride(0),
|
||||
student_logits_for_kd.stride(1),
|
||||
student_logits_for_kd.stride(2),
|
||||
student_logprobs_topk.stride(0),
|
||||
student_logprobs_topk.stride(1),
|
||||
student_logprobs_topk.stride(2),
|
||||
target_token_ids.stride(0),
|
||||
target_token_ids.stride(1),
|
||||
target_token_ids.stride(2),
|
||||
min(1024, triton.next_power_of_2(vocab_size)),
|
||||
)
|
||||
|
||||
# Calculate probs from logprobs
|
||||
student_probs_topk = torch.exp(student_logprobs_topk)
|
||||
|
||||
# No chunking used
|
||||
ctx.used_chunking = False
|
||||
|
||||
# Save tensors for backward pass
|
||||
ctx.save_for_backward(
|
||||
student_logits_for_kd,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
student_probs_topk,
|
||||
)
|
||||
ctx.kd_temperature = kd_temperature
|
||||
ctx.top_k_before_softmax = top_k_before_softmax
|
||||
ctx.num_items_in_batch = num_items_in_batch
|
||||
|
||||
# Convert mask to boolean
|
||||
valid_mask = target_mask.bool()
|
||||
|
||||
# Extract valid tokens only - this is where the error was happening
|
||||
# Use cloned contiguous tensors and explicit indexing for safety
|
||||
student_logprobs_flat = student_logprobs_topk.view(-1, top_k)
|
||||
target_logprobs_flat = target_logprobs.view(-1, top_k)
|
||||
valid_mask_flat = valid_mask.view(-1, top_k)
|
||||
|
||||
# Gather valid indices explicitly to avoid illegal memory access
|
||||
valid_indices = torch.nonzero(valid_mask_flat.view(-1)).squeeze(-1)
|
||||
student_logprobs_valid = torch.index_select(
|
||||
student_logprobs_flat.view(-1), 0, valid_indices
|
||||
)
|
||||
target_logprobs_valid = torch.index_select(
|
||||
target_logprobs_flat.view(-1), 0, valid_indices
|
||||
)
|
||||
|
||||
# Convert teacher logprobs to probabilities
|
||||
teacher_probs_valid = torch.exp(target_logprobs_valid)
|
||||
|
||||
# Compute KL divergence loss
|
||||
token_losses = teacher_probs_valid * (
|
||||
target_logprobs_valid - student_logprobs_valid
|
||||
)
|
||||
kd_loss = token_losses.sum()
|
||||
|
||||
# Apply temperature scaling
|
||||
# pylint: disable=duplicate-code
|
||||
if kd_temperature != 1.0:
|
||||
kd_loss = kd_loss * (kd_temperature**2)
|
||||
|
||||
# Normalize by number of items or valid tokens
|
||||
if num_items_in_batch > 0:
|
||||
kd_loss = kd_loss / float(num_items_in_batch)
|
||||
else:
|
||||
num_valid_tokens = valid_indices.numel()
|
||||
kd_loss = kd_loss / float(num_valid_tokens if num_valid_tokens > 0 else 1)
|
||||
|
||||
return kd_loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
"""
|
||||
Optimized backward pass for KL divergence loss with proper dtype handling and chunking.
|
||||
"""
|
||||
(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
student_probs,
|
||||
) = ctx.saved_tensors
|
||||
kd_temperature = ctx.kd_temperature
|
||||
num_items_in_batch = ctx.num_items_in_batch
|
||||
original_dtype = ctx.original_dtype
|
||||
|
||||
# Get dimensions
|
||||
batch_size, _, vocab_size = student_logits.shape
|
||||
_, teacher_seq_len, top_k = target_token_ids.shape
|
||||
|
||||
# Initialize gradient tensor in float32 to support atomic operations
|
||||
grad_student_logits = torch.zeros_like(student_logits, dtype=torch.float32)
|
||||
|
||||
# Compute scaling factor
|
||||
scale = grad_output.item()
|
||||
|
||||
# Apply temperature scaling from forward pass
|
||||
if kd_temperature != 1.0:
|
||||
scale = scale * (kd_temperature**2)
|
||||
|
||||
# Normalize by number of items or valid tokens
|
||||
if num_items_in_batch > 0:
|
||||
scale = scale / float(num_items_in_batch)
|
||||
else:
|
||||
scale = scale / float(target_mask.sum().item())
|
||||
|
||||
# Apply chain rule for temperature scaling (1/temperature)
|
||||
if kd_temperature != 1.0:
|
||||
scale = scale / kd_temperature
|
||||
|
||||
# Convert teacher logprobs to probabilities
|
||||
teacher_probs = torch.exp(target_logprobs)
|
||||
|
||||
# Use chunking for the backward pass if used in forward
|
||||
if getattr(ctx, "used_chunking", False):
|
||||
num_chunks = ctx.num_chunks
|
||||
max_seq = TopKKLDivergence.MAX_SEQ_LEN
|
||||
|
||||
# Process each chunk
|
||||
for i in range(num_chunks):
|
||||
start_idx = i * max_seq
|
||||
end_idx = min((i + 1) * max_seq, teacher_seq_len)
|
||||
chunk_len = end_idx - start_idx
|
||||
|
||||
# Get chunk slices
|
||||
# student_logits_chunk = student_logits[:, start_idx:end_idx, :]
|
||||
target_token_ids_chunk = target_token_ids[:, start_idx:end_idx, :]
|
||||
teacher_probs_chunk = teacher_probs[:, start_idx:end_idx, :]
|
||||
student_probs_chunk = student_probs[:, start_idx:end_idx, :]
|
||||
target_mask_chunk = target_mask[:, start_idx:end_idx, :]
|
||||
grad_student_logits_chunk = grad_student_logits[:, start_idx:end_idx, :]
|
||||
|
||||
# Launch gradient computation kernel for this chunk
|
||||
grid = (batch_size * chunk_len,)
|
||||
grad_softmax_kernel[grid](
|
||||
grad_student_logits_chunk.contiguous(),
|
||||
target_token_ids_chunk.contiguous(),
|
||||
teacher_probs_chunk.contiguous(),
|
||||
student_probs_chunk.contiguous(),
|
||||
target_mask_chunk.contiguous(),
|
||||
batch_size,
|
||||
chunk_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
scale,
|
||||
grad_student_logits_chunk.stride(0),
|
||||
grad_student_logits_chunk.stride(1),
|
||||
grad_student_logits_chunk.stride(2),
|
||||
target_token_ids_chunk.stride(0),
|
||||
target_token_ids_chunk.stride(1),
|
||||
target_token_ids_chunk.stride(2),
|
||||
teacher_probs_chunk.stride(0),
|
||||
teacher_probs_chunk.stride(1),
|
||||
teacher_probs_chunk.stride(2),
|
||||
student_probs_chunk.stride(0),
|
||||
student_probs_chunk.stride(1),
|
||||
student_probs_chunk.stride(2),
|
||||
target_mask_chunk.stride(0),
|
||||
target_mask_chunk.stride(1),
|
||||
target_mask_chunk.stride(2),
|
||||
min(1024, triton.next_power_of_2(top_k)),
|
||||
)
|
||||
|
||||
# Update the gradient tensor (already in-place)
|
||||
else:
|
||||
# Original code path for shorter sequences
|
||||
# Launch gradient computation kernel
|
||||
grid = (batch_size * teacher_seq_len,)
|
||||
grad_softmax_kernel[grid](
|
||||
grad_student_logits.contiguous(),
|
||||
target_token_ids.contiguous(),
|
||||
teacher_probs.contiguous(),
|
||||
student_probs.contiguous(),
|
||||
target_mask.contiguous(),
|
||||
batch_size,
|
||||
teacher_seq_len,
|
||||
vocab_size,
|
||||
top_k,
|
||||
scale,
|
||||
grad_student_logits.stride(0),
|
||||
grad_student_logits.stride(1),
|
||||
grad_student_logits.stride(2),
|
||||
target_token_ids.stride(0),
|
||||
target_token_ids.stride(1),
|
||||
target_token_ids.stride(2),
|
||||
teacher_probs.stride(0),
|
||||
teacher_probs.stride(1),
|
||||
teacher_probs.stride(2),
|
||||
student_probs.stride(0),
|
||||
student_probs.stride(1),
|
||||
student_probs.stride(2),
|
||||
target_mask.stride(0),
|
||||
target_mask.stride(1),
|
||||
target_mask.stride(2),
|
||||
min(1024, triton.next_power_of_2(top_k)),
|
||||
)
|
||||
|
||||
# Convert gradient back to original dtype if needed
|
||||
if original_dtype != torch.float32:
|
||||
grad_student_logits = grad_student_logits.to(original_dtype)
|
||||
|
||||
# Return gradients for student_logits and None for other inputs
|
||||
return grad_student_logits, None, None, None, None, None, None
|
||||
|
||||
|
||||
# Wrapper function for chunked computation
|
||||
def loss(
|
||||
student_logits: torch.Tensor,
|
||||
target_token_ids: torch.Tensor,
|
||||
target_logprobs: torch.Tensor,
|
||||
target_mask: torch.Tensor,
|
||||
num_items_in_batch: int = -1,
|
||||
kd_temperature: float = 1.0,
|
||||
top_k_before_softmax: int = 0,
|
||||
max_seq_len: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Triton-accelerated Memory-efficient KL divergence loss computation for knowledge distillation
|
||||
with support for very long sequences.
|
||||
|
||||
Args:
|
||||
student_logits: Student logits [B, seq_len, vocab_size]
|
||||
target_token_ids: Teacher token IDs [B, seq_len, top_k]
|
||||
target_logprobs: Teacher logprobs [B, seq_len, top_k]
|
||||
target_mask: Token mask [B, seq_len, top_k]
|
||||
num_items_in_batch: Number of items for normalization (-1 for auto)
|
||||
kd_temperature: Temperature for KD
|
||||
top_k_before_softmax: Flag for softmax application order
|
||||
max_seq_len: Override default MAX_SEQ_LEN value for chunking
|
||||
"""
|
||||
# Allow overriding the max sequence length
|
||||
if max_seq_len is not None and max_seq_len > 0:
|
||||
TopKKLDivergence.MAX_SEQ_LEN = max_seq_len
|
||||
|
||||
total_loss = TopKKLDivergence.apply(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
-1 if num_items_in_batch <= 0 else num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
|
||||
return total_loss
|
||||
67
src/axolotl/integrations/kd/topk_logprob/logsumexp.py
Normal file
67
src/axolotl/integrations/kd/topk_logprob/logsumexp.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Optimized Triton kernels for logsumexp
|
||||
"""
|
||||
# pylint: disable=invalid-name,unused-argument
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# Helper function for computing logsumexp
|
||||
@triton.jit
|
||||
def logsumexp_kernel(
|
||||
logits_ptr,
|
||||
output_ptr,
|
||||
B,
|
||||
S,
|
||||
V, # batch size, seq len, vocab size
|
||||
stride_b,
|
||||
stride_s,
|
||||
stride_v,
|
||||
out_stride_b,
|
||||
out_stride_s,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Program ID
|
||||
# pylint: disable=duplicate-code
|
||||
pid = tl.program_id(0)
|
||||
batch_idx = pid // S
|
||||
seq_idx = pid % S
|
||||
|
||||
# Bounds check
|
||||
if batch_idx >= B or seq_idx >= S:
|
||||
return
|
||||
|
||||
# Pointers
|
||||
logits_base = logits_ptr + batch_idx * stride_b + seq_idx * stride_s
|
||||
|
||||
# Find maximum for numerical stability
|
||||
max_val = -float("inf")
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
v_size = min(BLOCK_SIZE, V - v_offset)
|
||||
mask = tl.arange(0, BLOCK_SIZE) < v_size
|
||||
|
||||
logits_block = tl.load(
|
||||
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
|
||||
mask=mask,
|
||||
other=-float("inf"),
|
||||
)
|
||||
max_val = tl.maximum(max_val, tl.max(logits_block, axis=0))
|
||||
|
||||
# Compute sum of exp(logit - max_val)
|
||||
sum_exp = 0.0
|
||||
for v_offset in range(0, V, BLOCK_SIZE):
|
||||
v_size = min(BLOCK_SIZE, V - v_offset)
|
||||
mask = tl.arange(0, BLOCK_SIZE) < v_size
|
||||
|
||||
logits_block = tl.load(
|
||||
logits_base + (v_offset + tl.arange(0, BLOCK_SIZE)) * stride_v,
|
||||
mask=mask,
|
||||
other=-float("inf"),
|
||||
)
|
||||
sum_exp += tl.sum(tl.exp(logits_block - max_val), axis=0)
|
||||
|
||||
# Compute logsumexp
|
||||
result = max_val + tl.log(sum_exp)
|
||||
|
||||
# Store result
|
||||
tl.store(output_ptr + batch_idx * out_stride_b + seq_idx * out_stride_s, result)
|
||||
@@ -20,6 +20,7 @@ from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
||||
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
||||
from .topk_logprob.forward_kl_triton import loss as topk_kd_loss_triton
|
||||
|
||||
|
||||
class AxolotlKDTrainer(AxolotlTrainer):
|
||||
@@ -85,7 +86,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
else:
|
||||
loss_kd = topk_kd_loss(
|
||||
loss_fn = (
|
||||
topk_kd_loss
|
||||
if self.args.kd_top_k_before_softmax
|
||||
else topk_kd_loss_triton
|
||||
)
|
||||
loss_kd = loss_fn(
|
||||
shift_logits,
|
||||
target_token_ids_for_loss,
|
||||
target_logprobs_for_loss,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,590 @@
|
||||
{
|
||||
"model.layers.0.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.1.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.2.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.3.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.4.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.5.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.6.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.7.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.8.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.9.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.10.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.11.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.12.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.13.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.14.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.15.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"lm_head": {
|
||||
"snr": Infinity,
|
||||
"type": "lm_head"
|
||||
},
|
||||
"model.layers.0.mlp.down_proj": {
|
||||
"snr": 70.0594253540039,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.1.mlp.down_proj": {
|
||||
"snr": 11.135851860046387,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.2.mlp.down_proj": {
|
||||
"snr": 7.035482883453369,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.3.mlp.down_proj": {
|
||||
"snr": 6.422532081604004,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.4.mlp.down_proj": {
|
||||
"snr": 5.748020172119141,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.5.mlp.down_proj": {
|
||||
"snr": 3.885556697845459,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.6.mlp.down_proj": {
|
||||
"snr": 3.4336745738983154,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.7.mlp.down_proj": {
|
||||
"snr": 2.791595935821533,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.8.mlp.down_proj": {
|
||||
"snr": 5.36277961730957,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.9.mlp.down_proj": {
|
||||
"snr": 4.459208011627197,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.10.mlp.down_proj": {
|
||||
"snr": 6.272170066833496,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.11.mlp.down_proj": {
|
||||
"snr": 5.264761447906494,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.12.mlp.down_proj": {
|
||||
"snr": 4.324735641479492,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.13.mlp.down_proj": {
|
||||
"snr": 3.878648042678833,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.14.mlp.down_proj": {
|
||||
"snr": 2.9773054122924805,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.15.mlp.down_proj": {
|
||||
"snr": 4.471445560455322,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.0.mlp.gate_proj": {
|
||||
"snr": 25.227100372314453,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.1.mlp.gate_proj": {
|
||||
"snr": 6.58299446105957,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.2.mlp.gate_proj": {
|
||||
"snr": 3.4688243865966797,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.3.mlp.gate_proj": {
|
||||
"snr": 1.555246114730835,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.4.mlp.gate_proj": {
|
||||
"snr": 0.7770601511001587,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.5.mlp.gate_proj": {
|
||||
"snr": 0.6239906549453735,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.6.mlp.gate_proj": {
|
||||
"snr": 0.6440379023551941,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.7.mlp.gate_proj": {
|
||||
"snr": 0.5120116472244263,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.8.mlp.gate_proj": {
|
||||
"snr": 0.6544050574302673,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.9.mlp.gate_proj": {
|
||||
"snr": 0.5381016731262207,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.10.mlp.gate_proj": {
|
||||
"snr": 0.622873842716217,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.11.mlp.gate_proj": {
|
||||
"snr": 0.9361700415611267,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.12.mlp.gate_proj": {
|
||||
"snr": 1.475605845451355,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.13.mlp.gate_proj": {
|
||||
"snr": 1.608325719833374,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.14.mlp.gate_proj": {
|
||||
"snr": 1.0720024108886719,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.15.mlp.gate_proj": {
|
||||
"snr": 0.7111338973045349,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.0.mlp.up_proj": {
|
||||
"snr": 28.431896209716797,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.1.mlp.up_proj": {
|
||||
"snr": 15.546019554138184,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.2.mlp.up_proj": {
|
||||
"snr": 23.048023223876953,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.3.mlp.up_proj": {
|
||||
"snr": 25.790977478027344,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.4.mlp.up_proj": {
|
||||
"snr": 18.552549362182617,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.5.mlp.up_proj": {
|
||||
"snr": 8.85106372833252,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.6.mlp.up_proj": {
|
||||
"snr": 10.653799057006836,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.7.mlp.up_proj": {
|
||||
"snr": 7.365357875823975,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.8.mlp.up_proj": {
|
||||
"snr": 11.98373794555664,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.9.mlp.up_proj": {
|
||||
"snr": 8.04493236541748,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.10.mlp.up_proj": {
|
||||
"snr": 8.523039817810059,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.11.mlp.up_proj": {
|
||||
"snr": 5.381742477416992,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.12.mlp.up_proj": {
|
||||
"snr": 3.9845118522644043,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.13.mlp.up_proj": {
|
||||
"snr": 3.4893221855163574,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.14.mlp.up_proj": {
|
||||
"snr": 1.764201045036316,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.15.mlp.up_proj": {
|
||||
"snr": 0.9730708599090576,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.embed_tokens": {
|
||||
"snr": Infinity,
|
||||
"type": "model.embed_tokens"
|
||||
},
|
||||
"model.norm": {
|
||||
"snr": Infinity,
|
||||
"type": "model.norm"
|
||||
},
|
||||
"model.layers.0.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.1.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.2.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.3.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.4.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.5.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.6.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.7.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.8.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.9.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.10.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.11.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.12.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.13.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.14.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.15.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.0.self_attn.k_proj": {
|
||||
"snr": 0.11727584153413773,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.1.self_attn.k_proj": {
|
||||
"snr": 0.24786807596683502,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.2.self_attn.k_proj": {
|
||||
"snr": 0.36378130316734314,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.3.self_attn.k_proj": {
|
||||
"snr": 0.2983120381832123,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.4.self_attn.k_proj": {
|
||||
"snr": 0.33789733052253723,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.5.self_attn.k_proj": {
|
||||
"snr": 0.29155924916267395,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.6.self_attn.k_proj": {
|
||||
"snr": 0.2537297010421753,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.7.self_attn.k_proj": {
|
||||
"snr": 0.28204113245010376,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.8.self_attn.k_proj": {
|
||||
"snr": 0.2776711583137512,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.9.self_attn.k_proj": {
|
||||
"snr": 0.2927376627922058,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.10.self_attn.k_proj": {
|
||||
"snr": 0.31486213207244873,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.11.self_attn.k_proj": {
|
||||
"snr": 0.32363659143447876,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.12.self_attn.k_proj": {
|
||||
"snr": 0.31382912397384644,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.13.self_attn.k_proj": {
|
||||
"snr": 0.4635234773159027,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.14.self_attn.k_proj": {
|
||||
"snr": 0.25379249453544617,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.15.self_attn.k_proj": {
|
||||
"snr": 0.2628238797187805,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.0.self_attn.o_proj": {
|
||||
"snr": 0.27602291107177734,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.1.self_attn.o_proj": {
|
||||
"snr": 0.2149604707956314,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.2.self_attn.o_proj": {
|
||||
"snr": 0.2540294826030731,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.3.self_attn.o_proj": {
|
||||
"snr": 0.27978822588920593,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.4.self_attn.o_proj": {
|
||||
"snr": 0.3121289908885956,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.5.self_attn.o_proj": {
|
||||
"snr": 0.35037684440612793,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.6.self_attn.o_proj": {
|
||||
"snr": 0.366205096244812,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.7.self_attn.o_proj": {
|
||||
"snr": 0.3692712187767029,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.8.self_attn.o_proj": {
|
||||
"snr": 0.3301038146018982,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.9.self_attn.o_proj": {
|
||||
"snr": 0.3003396987915039,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.10.self_attn.o_proj": {
|
||||
"snr": 0.30804169178009033,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.11.self_attn.o_proj": {
|
||||
"snr": 0.28501132130622864,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.12.self_attn.o_proj": {
|
||||
"snr": 0.2171541005373001,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.13.self_attn.o_proj": {
|
||||
"snr": 0.19183959066867828,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.14.self_attn.o_proj": {
|
||||
"snr": 0.19215913116931915,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.15.self_attn.o_proj": {
|
||||
"snr": 0.25486502051353455,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.0.self_attn.q_proj": {
|
||||
"snr": 0.03850084915757179,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.1.self_attn.q_proj": {
|
||||
"snr": 0.0713055431842804,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.2.self_attn.q_proj": {
|
||||
"snr": 0.07948919385671616,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.3.self_attn.q_proj": {
|
||||
"snr": 0.08047746121883392,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.4.self_attn.q_proj": {
|
||||
"snr": 0.0852593332529068,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.5.self_attn.q_proj": {
|
||||
"snr": 0.09794823825359344,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.6.self_attn.q_proj": {
|
||||
"snr": 0.09627152234315872,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.7.self_attn.q_proj": {
|
||||
"snr": 0.11065381020307541,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.8.self_attn.q_proj": {
|
||||
"snr": 0.12031875550746918,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.9.self_attn.q_proj": {
|
||||
"snr": 0.09804573655128479,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.10.self_attn.q_proj": {
|
||||
"snr": 0.10897502303123474,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.11.self_attn.q_proj": {
|
||||
"snr": 0.09267337620258331,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.12.self_attn.q_proj": {
|
||||
"snr": 0.08803492039442062,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.13.self_attn.q_proj": {
|
||||
"snr": 0.0902542844414711,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.14.self_attn.q_proj": {
|
||||
"snr": 0.10154066979885101,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.15.self_attn.q_proj": {
|
||||
"snr": 0.09083802253007889,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.0.self_attn.v_proj": {
|
||||
"snr": 2.842210054397583,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.1.self_attn.v_proj": {
|
||||
"snr": 10.59461498260498,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.2.self_attn.v_proj": {
|
||||
"snr": 8.993025779724121,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.3.self_attn.v_proj": {
|
||||
"snr": 62.567787170410156,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.4.self_attn.v_proj": {
|
||||
"snr": 23.80082893371582,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.5.self_attn.v_proj": {
|
||||
"snr": 7.957369804382324,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.6.self_attn.v_proj": {
|
||||
"snr": 12.01815414428711,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.7.self_attn.v_proj": {
|
||||
"snr": 5.095500469207764,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.8.self_attn.v_proj": {
|
||||
"snr": 11.719332695007324,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.9.self_attn.v_proj": {
|
||||
"snr": 555.0869750976562,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.10.self_attn.v_proj": {
|
||||
"snr": 22.95538330078125,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.11.self_attn.v_proj": {
|
||||
"snr": 30.042158126831055,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.12.self_attn.v_proj": {
|
||||
"snr": 9.577271461486816,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.13.self_attn.v_proj": {
|
||||
"snr": 18.176361083984375,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.14.self_attn.v_proj": {
|
||||
"snr": 1.5695856809616089,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.15.self_attn.v_proj": {
|
||||
"snr": 2.7235565185546875,
|
||||
"type": "self_attn.v_proj"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,590 @@
|
||||
{
|
||||
"model.layers.0.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.1.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.2.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.3.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.4.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.5.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.6.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.7.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.8.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.9.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.10.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.11.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.12.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.13.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.14.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.15.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"lm_head": {
|
||||
"snr": Infinity,
|
||||
"type": "lm_head"
|
||||
},
|
||||
"model.layers.0.mlp.down_proj": {
|
||||
"snr": 57.09797286987305,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.1.mlp.down_proj": {
|
||||
"snr": 9.538983345031738,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.2.mlp.down_proj": {
|
||||
"snr": 6.227016925811768,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.3.mlp.down_proj": {
|
||||
"snr": 5.660686492919922,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.4.mlp.down_proj": {
|
||||
"snr": 5.178432464599609,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.5.mlp.down_proj": {
|
||||
"snr": 3.5638349056243896,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.6.mlp.down_proj": {
|
||||
"snr": 3.0918056964874268,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.7.mlp.down_proj": {
|
||||
"snr": 2.456392288208008,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.8.mlp.down_proj": {
|
||||
"snr": 4.525328636169434,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.9.mlp.down_proj": {
|
||||
"snr": 3.9409055709838867,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.10.mlp.down_proj": {
|
||||
"snr": 5.447249412536621,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.11.mlp.down_proj": {
|
||||
"snr": 4.807600975036621,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.12.mlp.down_proj": {
|
||||
"snr": 3.915374517440796,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.13.mlp.down_proj": {
|
||||
"snr": 3.4820363521575928,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.14.mlp.down_proj": {
|
||||
"snr": 2.6045074462890625,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.15.mlp.down_proj": {
|
||||
"snr": 3.7237701416015625,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.0.mlp.gate_proj": {
|
||||
"snr": 22.160131454467773,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.1.mlp.gate_proj": {
|
||||
"snr": 6.072206020355225,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.2.mlp.gate_proj": {
|
||||
"snr": 3.2467362880706787,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.3.mlp.gate_proj": {
|
||||
"snr": 1.4111896753311157,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.4.mlp.gate_proj": {
|
||||
"snr": 0.7405938506126404,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.5.mlp.gate_proj": {
|
||||
"snr": 0.5916463136672974,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.6.mlp.gate_proj": {
|
||||
"snr": 0.6149423718452454,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.7.mlp.gate_proj": {
|
||||
"snr": 0.48369669914245605,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.8.mlp.gate_proj": {
|
||||
"snr": 0.6047574877738953,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.9.mlp.gate_proj": {
|
||||
"snr": 0.5092479586601257,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.10.mlp.gate_proj": {
|
||||
"snr": 0.5999670624732971,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.11.mlp.gate_proj": {
|
||||
"snr": 0.8980127573013306,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.12.mlp.gate_proj": {
|
||||
"snr": 1.4252448081970215,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.13.mlp.gate_proj": {
|
||||
"snr": 1.509937047958374,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.14.mlp.gate_proj": {
|
||||
"snr": 1.0066585540771484,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.15.mlp.gate_proj": {
|
||||
"snr": 0.6413647532463074,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.0.mlp.up_proj": {
|
||||
"snr": 26.08852195739746,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.1.mlp.up_proj": {
|
||||
"snr": 13.382951736450195,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.2.mlp.up_proj": {
|
||||
"snr": 20.088768005371094,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.3.mlp.up_proj": {
|
||||
"snr": 23.0632381439209,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.4.mlp.up_proj": {
|
||||
"snr": 16.07433319091797,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.5.mlp.up_proj": {
|
||||
"snr": 8.00507640838623,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.6.mlp.up_proj": {
|
||||
"snr": 9.538354873657227,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.7.mlp.up_proj": {
|
||||
"snr": 6.286602973937988,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.8.mlp.up_proj": {
|
||||
"snr": 10.092820167541504,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.9.mlp.up_proj": {
|
||||
"snr": 7.193963527679443,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.10.mlp.up_proj": {
|
||||
"snr": 7.320116996765137,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.11.mlp.up_proj": {
|
||||
"snr": 4.8728532791137695,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.12.mlp.up_proj": {
|
||||
"snr": 3.596583366394043,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.13.mlp.up_proj": {
|
||||
"snr": 3.166161298751831,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.14.mlp.up_proj": {
|
||||
"snr": 1.5600818395614624,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.15.mlp.up_proj": {
|
||||
"snr": 0.8726214170455933,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.embed_tokens": {
|
||||
"snr": Infinity,
|
||||
"type": "model.embed_tokens"
|
||||
},
|
||||
"model.norm": {
|
||||
"snr": Infinity,
|
||||
"type": "model.norm"
|
||||
},
|
||||
"model.layers.0.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.1.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.2.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.3.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.4.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.5.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.6.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.7.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.8.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.9.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.10.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.11.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.12.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.13.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.14.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.15.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.0.self_attn.k_proj": {
|
||||
"snr": 0.1154392883181572,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.1.self_attn.k_proj": {
|
||||
"snr": 0.24299409985542297,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.2.self_attn.k_proj": {
|
||||
"snr": 0.3624322712421417,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.3.self_attn.k_proj": {
|
||||
"snr": 0.29509487748146057,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.4.self_attn.k_proj": {
|
||||
"snr": 0.32953736186027527,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.5.self_attn.k_proj": {
|
||||
"snr": 0.2908833622932434,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.6.self_attn.k_proj": {
|
||||
"snr": 0.2488437294960022,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.7.self_attn.k_proj": {
|
||||
"snr": 0.27847856283187866,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.8.self_attn.k_proj": {
|
||||
"snr": 0.27143892645835876,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.9.self_attn.k_proj": {
|
||||
"snr": 0.28804272413253784,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.10.self_attn.k_proj": {
|
||||
"snr": 0.31197959184646606,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.11.self_attn.k_proj": {
|
||||
"snr": 0.3203586935997009,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.12.self_attn.k_proj": {
|
||||
"snr": 0.30905747413635254,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.13.self_attn.k_proj": {
|
||||
"snr": 0.46828722953796387,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.14.self_attn.k_proj": {
|
||||
"snr": 0.24205778539180756,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.15.self_attn.k_proj": {
|
||||
"snr": 0.2559327781200409,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.0.self_attn.o_proj": {
|
||||
"snr": 0.2638678550720215,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.1.self_attn.o_proj": {
|
||||
"snr": 0.21109595894813538,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.2.self_attn.o_proj": {
|
||||
"snr": 0.24751724302768707,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.3.self_attn.o_proj": {
|
||||
"snr": 0.2728094160556793,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.4.self_attn.o_proj": {
|
||||
"snr": 0.3001374304294586,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.5.self_attn.o_proj": {
|
||||
"snr": 0.33903488516807556,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.6.self_attn.o_proj": {
|
||||
"snr": 0.3530929982662201,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.7.self_attn.o_proj": {
|
||||
"snr": 0.36753255128860474,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.8.self_attn.o_proj": {
|
||||
"snr": 0.3373180329799652,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.9.self_attn.o_proj": {
|
||||
"snr": 0.2970578670501709,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.10.self_attn.o_proj": {
|
||||
"snr": 0.3076324760913849,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.11.self_attn.o_proj": {
|
||||
"snr": 0.2766900658607483,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.12.self_attn.o_proj": {
|
||||
"snr": 0.20973259210586548,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.13.self_attn.o_proj": {
|
||||
"snr": 0.18185566365718842,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.14.self_attn.o_proj": {
|
||||
"snr": 0.18329747021198273,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.15.self_attn.o_proj": {
|
||||
"snr": 0.2437991499900818,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.0.self_attn.q_proj": {
|
||||
"snr": 0.038040731102228165,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.1.self_attn.q_proj": {
|
||||
"snr": 0.0707998052239418,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.2.self_attn.q_proj": {
|
||||
"snr": 0.0787411704659462,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.3.self_attn.q_proj": {
|
||||
"snr": 0.08089710026979446,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.4.self_attn.q_proj": {
|
||||
"snr": 0.08591937273740768,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.5.self_attn.q_proj": {
|
||||
"snr": 0.09852176159620285,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.6.self_attn.q_proj": {
|
||||
"snr": 0.09690654277801514,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.7.self_attn.q_proj": {
|
||||
"snr": 0.11181341856718063,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.8.self_attn.q_proj": {
|
||||
"snr": 0.12042108923196793,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.9.self_attn.q_proj": {
|
||||
"snr": 0.09799323976039886,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.10.self_attn.q_proj": {
|
||||
"snr": 0.10901063680648804,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.11.self_attn.q_proj": {
|
||||
"snr": 0.09307146072387695,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.12.self_attn.q_proj": {
|
||||
"snr": 0.0880950540304184,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.13.self_attn.q_proj": {
|
||||
"snr": 0.08886399120092392,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.14.self_attn.q_proj": {
|
||||
"snr": 0.09955056011676788,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.15.self_attn.q_proj": {
|
||||
"snr": 0.08929339051246643,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.0.self_attn.v_proj": {
|
||||
"snr": 2.5501928329467773,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.1.self_attn.v_proj": {
|
||||
"snr": 9.449499130249023,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.2.self_attn.v_proj": {
|
||||
"snr": 7.9920830726623535,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.3.self_attn.v_proj": {
|
||||
"snr": 50.69462585449219,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.4.self_attn.v_proj": {
|
||||
"snr": 19.083511352539062,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.5.self_attn.v_proj": {
|
||||
"snr": 7.21597146987915,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.6.self_attn.v_proj": {
|
||||
"snr": 11.27744197845459,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.7.self_attn.v_proj": {
|
||||
"snr": 4.579711437225342,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.8.self_attn.v_proj": {
|
||||
"snr": 10.940719604492188,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.9.self_attn.v_proj": {
|
||||
"snr": 553.4417724609375,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.10.self_attn.v_proj": {
|
||||
"snr": 20.59434700012207,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.11.self_attn.v_proj": {
|
||||
"snr": 26.636865615844727,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.12.self_attn.v_proj": {
|
||||
"snr": 8.614749908447266,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.13.self_attn.v_proj": {
|
||||
"snr": 17.722007751464844,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.14.self_attn.v_proj": {
|
||||
"snr": 1.48500657081604,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.15.self_attn.v_proj": {
|
||||
"snr": 2.5776851177215576,
|
||||
"type": "self_attn.v_proj"
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
0
src/axolotl/kernels/__init__.py
Normal file
0
src/axolotl/kernels/__init__.py
Normal file
159
src/axolotl/kernels/geglu.py
Normal file
159
src/axolotl/kernels/geglu.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Module for definition of GEGLU Triton kernels.
|
||||
|
||||
See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
|
||||
|
||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||
"""
|
||||
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
SQRT_2_PI: tl.constexpr = 0.7978845608028654 # sqrt(2/π)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _geglu_fwd_kernel(
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""GEGLU forward kernel.
|
||||
|
||||
Args:
|
||||
gate_ptr: Pointer to gate tensor [*, hidden_dim].
|
||||
up_ptr: Pointer to up-projection tensor [*, hidden_dim].
|
||||
out_ptr: Pointer to output tensor [*, hidden_dim].
|
||||
n_elements: Total number of elements in the input tensors.
|
||||
BLOCK_SIZE: Size of thread blocks for parallel computation.
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
|
||||
up = tl.load(up_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
# Compute activation in fp32 then convert back
|
||||
gelu_gate = 0.5 * gate * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0)
|
||||
gelu_gate = gelu_gate.to(up.dtype)
|
||||
result = gelu_gate * up
|
||||
|
||||
tl.store(out_ptr + offsets, result, mask=mask)
|
||||
|
||||
|
||||
def geglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
|
||||
"""GEGLU forward pass.
|
||||
|
||||
Args:
|
||||
gate: Input gate tensor of shape [batch, seq_len, hidden_dim].
|
||||
up: Up-projection tensor of shape [batch, seq_len, hidden_dim].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape [batch, seq_len, hidden_dim].
|
||||
"""
|
||||
batch, seq_len, hidden_dim = gate.shape
|
||||
n_elements = gate.numel()
|
||||
out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device="cuda")
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
|
||||
_geglu_fwd_kernel[grid](
|
||||
gate_ptr=gate,
|
||||
up_ptr=up,
|
||||
out_ptr=out,
|
||||
n_elements=n_elements,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _geglu_bwd_kernel(
|
||||
grad_out_ptr,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""GEGLU backward kernel. Stores gradient results in-place.
|
||||
|
||||
Args:
|
||||
grad_out_ptr: Pointer to gradient output tensor [*, hidden_dim].
|
||||
gate_ptr: Pointer to gate tensor [*, hidden_dim].
|
||||
up_ptr: Pointer to up-projection tensor [*, hidden_dim].
|
||||
n_elements: Total number of elements in the input tensors.
|
||||
BLOCK_SIZE: Size of thread blocks for parallel computation.
|
||||
|
||||
Note:
|
||||
After kernel execution, tensors are modified in-place:
|
||||
- `grad_out_ptr` contains GEGLU activation output (`h`)
|
||||
- `gate_ptr` contains gradient w.r.t gate (`grad_gate`)
|
||||
- `up_ptr` contains gradient w.r.t up (`grad_up`)
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0)
|
||||
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
|
||||
up = tl.load(up_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
# Forward pass
|
||||
gelu_partial = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0)
|
||||
gelu_gate = gelu_partial * gate
|
||||
gelu_gate = gelu_gate.to(grad_out.dtype)
|
||||
|
||||
# Forward output
|
||||
h = gelu_gate * up
|
||||
|
||||
# Compute gradients
|
||||
grad_up = grad_out * gelu_gate
|
||||
|
||||
# Compute gate gradient using GELU derivative
|
||||
temp = grad_out * up
|
||||
t = 0.3989422804014327 # 1/sqrt(2*pi)
|
||||
dgelu_dgate = gelu_partial + t * gate * tl.exp(-0.5 * gate * gate)
|
||||
grad_gate = temp.to(tl.float32) * dgelu_dgate
|
||||
grad_gate = grad_gate.to(grad_out.dtype)
|
||||
|
||||
# Store results
|
||||
tl.store(grad_out_ptr + offsets, h, mask=mask)
|
||||
tl.store(gate_ptr + offsets, grad_gate, mask=mask)
|
||||
tl.store(up_ptr + offsets, grad_up, mask=mask)
|
||||
|
||||
|
||||
def geglu_backward(
|
||||
grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""GEGLU backward pass using in-place operations.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`.
|
||||
gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
|
||||
up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- GEGLU activation output (`h`)
|
||||
- Gradient with respect to gate (`grad_gate`)
|
||||
- Gradient with respect to up (`grad_up`)
|
||||
|
||||
Note:
|
||||
This function modifies its input tensors in-place to store results.
|
||||
"""
|
||||
n_elements = grad_output.numel()
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
|
||||
_geglu_bwd_kernel[grid](
|
||||
grad_out_ptr=grad_output,
|
||||
gate_ptr=gate,
|
||||
up_ptr=up,
|
||||
n_elements=n_elements,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
return grad_output, gate, up
|
||||
779
src/axolotl/kernels/lora.py
Normal file
779
src/axolotl/kernels/lora.py
Normal file
@@ -0,0 +1,779 @@
|
||||
"""
|
||||
Module for definition of Low-Rank Adaptation (LoRA) Triton kernels.
|
||||
|
||||
See "LoRA: Low-Rank Adaptation of Large Language Models"
|
||||
(https://arxiv.org/abs/2106.09685).
|
||||
|
||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||
"""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
from torch import nn
|
||||
|
||||
from .geglu import geglu_backward, geglu_forward
|
||||
from .quantize import dequantize
|
||||
from .swiglu import swiglu_backward, swiglu_forward
|
||||
from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd
|
||||
|
||||
|
||||
def get_lora_parameters(
|
||||
proj: nn.Module,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
QuantState | None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
float | None,
|
||||
]:
|
||||
"""
|
||||
Gets LoRA parameters from a projection module.
|
||||
|
||||
Args:
|
||||
proj: The projection module to extract parameters from.
|
||||
|
||||
Returns:
|
||||
A tuple containing the base weight matrix, quantization state, LoRA A matrix,
|
||||
LoRA B matrix, and scaling factor. States and matrices may be None if not
|
||||
available.
|
||||
"""
|
||||
# For DPO or disabled adapters
|
||||
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
|
||||
W = base_layer.weight
|
||||
|
||||
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
return W, quant_state, None, None, None
|
||||
|
||||
active_adapter = (
|
||||
proj.active_adapters[0]
|
||||
if hasattr(proj, "active_adapters")
|
||||
else proj.active_adapter
|
||||
)
|
||||
A = proj.lora_A[active_adapter].weight
|
||||
B = proj.lora_B[active_adapter].weight
|
||||
s = proj.scaling[active_adapter]
|
||||
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
|
||||
return W, quant_state, A, B, s
|
||||
|
||||
|
||||
def matmul_lora(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
W_quant: QuantState,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
s: float,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Efficient fused matmul + LoRA computation.
|
||||
|
||||
Args:
|
||||
X: Input tensor [*, in_features]
|
||||
W: Base weight matrix [out_features, in_features]
|
||||
W_quant: Quantization state for W
|
||||
A: LoRA A matrix [rank, in_features]
|
||||
B: LoRA B matrix [out_features, rank]
|
||||
s: LoRA scaling factor
|
||||
out: Optional output tensor for inplace operations
|
||||
|
||||
Returns:
|
||||
Result of X @ W + X @ A @ B
|
||||
"""
|
||||
dtype = X.dtype
|
||||
W = dequantize(W.t(), W_quant)
|
||||
|
||||
if X.dim() == 3:
|
||||
batch, seq_len, _ = X.shape
|
||||
X = X.view(-1, X.shape[-1])
|
||||
reshape = True
|
||||
else:
|
||||
reshape = False
|
||||
|
||||
out = torch.matmul(X, W, out=out)
|
||||
if W_quant is not None:
|
||||
del W
|
||||
|
||||
if A is not None:
|
||||
A, B = A.t(), B.t()
|
||||
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
||||
|
||||
return out.view(batch, seq_len, -1) if reshape else out
|
||||
|
||||
|
||||
class LoRA_MLP(torch.autograd.Function):
|
||||
"""Optimized LoRA MLP implementation."""
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def forward(
|
||||
ctx,
|
||||
X: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
gate_quant: object | None,
|
||||
gate_A: torch.Tensor | None,
|
||||
gate_B: torch.Tensor | None,
|
||||
gate_scale: float,
|
||||
up_weight: torch.Tensor,
|
||||
up_quant: object | None,
|
||||
up_A: torch.Tensor | None,
|
||||
up_B: torch.Tensor | None,
|
||||
up_scale: float,
|
||||
down_weight: torch.Tensor,
|
||||
down_quant: object | None,
|
||||
down_A: torch.Tensor | None,
|
||||
down_B: torch.Tensor | None,
|
||||
down_scale: float,
|
||||
activation_fn: Callable,
|
||||
activation_fn_backward: Callable,
|
||||
inplace: bool | None = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for LoRA MLP.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
X: Input features
|
||||
gate_weight: Gate projection weight
|
||||
gate_quant: Gate quantization state
|
||||
gate_A: Gate LoRA A matrix
|
||||
gate_B: Gate LoRA B matrix
|
||||
gate_scale: Gate LoRA scale
|
||||
up_weight: Up-projection weight
|
||||
up_quant: Up-projection quantization state
|
||||
up_A: Up-projection LoRA A matrix
|
||||
up_B: Up-projection LoRA B matrix
|
||||
up_scale: Up-projection LoRA scale
|
||||
down_weight: Down-projection weight
|
||||
down_quant: Down-projection quantization state
|
||||
down_A: Down-projection LoRA A matrix
|
||||
down_B: Down-projection LoRA B matrix
|
||||
down_scale: Down-projection LoRA scale
|
||||
activation_fn: Forward activation function
|
||||
activation_fn_backward: Backward activation function
|
||||
inplace: Whether to perform operations in-place
|
||||
|
||||
Returns:
|
||||
Output transformed by multi-layer perceptron and activation function
|
||||
"""
|
||||
# Compute projections
|
||||
gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale)
|
||||
up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale)
|
||||
|
||||
# Activation
|
||||
hidden = activation_fn(gate, up)
|
||||
|
||||
# Down projection
|
||||
output = matmul_lora(
|
||||
hidden, down_weight, down_quant, down_A, down_B, down_scale
|
||||
)
|
||||
|
||||
# Save for backward
|
||||
ctx.save_for_backward(X, gate, up, gate_A, gate_B, up_A, up_B, down_A, down_B)
|
||||
ctx.scales = (gate_scale, up_scale, down_scale)
|
||||
ctx.quants = (gate_quant, up_quant, down_quant)
|
||||
ctx.weights = (gate_weight, up_weight, down_weight)
|
||||
ctx.activation_fn = activation_fn
|
||||
ctx.activation_fn_backward = activation_fn_backward
|
||||
ctx.inplace = inplace
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_bwd
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_output: torch.Tensor,
|
||||
) -> tuple[
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Performs backward pass computation for LoRA MLP.
|
||||
|
||||
Args:
|
||||
ctx: Context object storing tensors saved during forward pass
|
||||
grad_output: Gradient of loss with respect to layer output
|
||||
|
||||
Returns:
|
||||
Tuple containing gradients for all inputs from forward pass:
|
||||
- Input gradient tensor (or `None`)
|
||||
- `None` for weights/quantization states
|
||||
- LoRA A/B matrix gradients (or `None`)
|
||||
- `None` for scaling factors
|
||||
- `None` for activation functions and flags
|
||||
"""
|
||||
(
|
||||
X,
|
||||
gate,
|
||||
up,
|
||||
gate_A,
|
||||
gate_B,
|
||||
up_A,
|
||||
up_B,
|
||||
down_A,
|
||||
down_B,
|
||||
) = ctx.saved_tensors
|
||||
gate_scale, up_scale, down_scale = ctx.scales
|
||||
gate_quant, up_quant, down_quant = ctx.quants
|
||||
gate_weight, up_weight, down_weight = ctx.weights
|
||||
|
||||
# Transpose all LoRA matrices
|
||||
gate_A, gate_B = (
|
||||
gate_A.t() if gate_A is not None else None,
|
||||
gate_B.t() if gate_B is not None else None,
|
||||
)
|
||||
up_A, up_B = (
|
||||
up_A.t() if up_A is not None else None,
|
||||
up_B.t() if up_B is not None else None,
|
||||
)
|
||||
down_A, down_B = (
|
||||
down_A.t() if down_A is not None else None,
|
||||
down_B.t() if down_B is not None else None,
|
||||
)
|
||||
|
||||
# Reshape inputs
|
||||
batch, seq_len, hd = X.shape
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
X = X.view(-1, X.shape[-1])
|
||||
gate = gate.view(-1, gate.shape[-1])
|
||||
up = up.view(-1, up.shape[-1])
|
||||
dtype = X.dtype
|
||||
|
||||
# Down projection
|
||||
DW = matmul_lora(
|
||||
grad_output,
|
||||
down_weight.t(),
|
||||
down_quant,
|
||||
down_B,
|
||||
down_A,
|
||||
down_scale,
|
||||
)
|
||||
|
||||
# Activation backward
|
||||
h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up)
|
||||
|
||||
# Initialize and compute LoRA gradients
|
||||
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
|
||||
|
||||
if down_A is not None:
|
||||
d_down_A = h.t() @ (grad_output @ down_B.t())
|
||||
d_down_B = (down_A.t() @ h.t()) @ grad_output
|
||||
d_down_A *= down_scale
|
||||
d_down_B *= down_scale
|
||||
|
||||
if up_A is not None:
|
||||
d_up_A = X.t() @ (grad_up @ up_B.t())
|
||||
d_up_B = (up_A.t() @ X.t()) @ grad_up
|
||||
d_up_A *= up_scale
|
||||
d_up_B *= up_scale
|
||||
|
||||
if gate_A is not None:
|
||||
d_gate_A = X.t() @ (grad_gate @ gate_B.t())
|
||||
d_gate_B = (gate_A.t() @ X.t()) @ grad_gate
|
||||
d_gate_A *= gate_scale
|
||||
d_gate_B *= gate_scale
|
||||
|
||||
# Compute input gradients
|
||||
dX = torch.zeros_like(X) if ctx.needs_input_grad[0] else None
|
||||
|
||||
if dX is not None:
|
||||
# Up projection gradients
|
||||
up_weight = dequantize(up_weight.t(), up_quant)
|
||||
if ctx.inplace:
|
||||
dX = torch.matmul(grad_up, up_weight.t(), out=X)
|
||||
else:
|
||||
dX = torch.matmul(grad_up, up_weight.t())
|
||||
del up_weight
|
||||
|
||||
# Note the .to(dtype) only where mixing LoRA with base weights
|
||||
if up_A is not None:
|
||||
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
|
||||
|
||||
# Gate projection gradients
|
||||
gate_weight = dequantize(gate_weight.t(), gate_quant)
|
||||
dX += grad_gate @ gate_weight.t()
|
||||
del gate_weight
|
||||
|
||||
if gate_A is not None:
|
||||
dX += (
|
||||
grad_gate
|
||||
@ gate_B.to(dtype).t()
|
||||
@ (gate_scale * gate_A.to(dtype).t())
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
dX = dX.view(batch, seq_len, hd)
|
||||
|
||||
# Return gradients in correct order matching forward inputs
|
||||
return (
|
||||
dX,
|
||||
None,
|
||||
None,
|
||||
d_gate_A.t() if d_gate_A is not None else None,
|
||||
d_gate_B.t() if d_gate_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_up_A.t() if d_up_A is not None else None,
|
||||
d_up_B.t() if d_up_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_down_A.t() if d_down_A is not None else None,
|
||||
d_down_B.t() if d_down_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Applies LoRA to MLP layer with SwiGLU activation.
|
||||
|
||||
Args:
|
||||
X: Input tensor for the MLP layer
|
||||
inplace: Whether to perform operations in-place to save memory
|
||||
|
||||
Returns:
|
||||
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
|
||||
"""
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
|
||||
out = LoRA_MLP.apply(
|
||||
X,
|
||||
gateW,
|
||||
gateW_quant,
|
||||
gateA,
|
||||
gateB,
|
||||
gateS,
|
||||
upW,
|
||||
upW_quant,
|
||||
upA,
|
||||
upB,
|
||||
upS,
|
||||
downW,
|
||||
downW_quant,
|
||||
downA,
|
||||
downB,
|
||||
downS,
|
||||
swiglu_forward,
|
||||
swiglu_backward,
|
||||
inplace,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Applies LoRA to MLP layer with GEGLU activation.
|
||||
|
||||
Args:
|
||||
X: Input tensor for the MLP layer
|
||||
inplace: Whether to perform operations in-place to save memory
|
||||
|
||||
Returns:
|
||||
Output tensor after applying LoRA-adapted MLP with GEGLU activation
|
||||
"""
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
out = LoRA_MLP.apply(
|
||||
X,
|
||||
gateW,
|
||||
gateW_quant,
|
||||
gateA,
|
||||
gateB,
|
||||
gateS,
|
||||
upW,
|
||||
upW_quant,
|
||||
upA,
|
||||
upB,
|
||||
upS,
|
||||
downW,
|
||||
downW_quant,
|
||||
downA,
|
||||
downB,
|
||||
downS,
|
||||
geglu_forward,
|
||||
geglu_backward,
|
||||
inplace,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class LoRA_QKV(torch.autograd.Function):
|
||||
"""
|
||||
Optimized LoRA QKV implementation with quantization support.
|
||||
|
||||
Implements efficient computation of query, key, value projections with LoRA,
|
||||
supporting quantization and memory optimization.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
X: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
q_quant: QuantState | None,
|
||||
q_A: torch.Tensor | None,
|
||||
q_B: torch.Tensor | None,
|
||||
q_scale: float,
|
||||
k_weight: torch.Tensor,
|
||||
k_quant: QuantState | None,
|
||||
k_A: torch.Tensor | None,
|
||||
k_B: torch.Tensor | None,
|
||||
k_scale: float,
|
||||
v_weight: torch.Tensor,
|
||||
v_quant: QuantState | None,
|
||||
v_A: torch.Tensor | None,
|
||||
v_B: torch.Tensor | None,
|
||||
v_scale: float,
|
||||
inplace: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass computing Q, K, V projections with LoRA.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
X: Input tensor
|
||||
q_weight: Query projection weight
|
||||
q_quant: Query quantization state
|
||||
q_A: Query LoRA A matrix
|
||||
q_B: Query LoRA B matrix
|
||||
q_scale: Query LoRA scale
|
||||
k_weight: Key projection weight
|
||||
k_quant: Key quantization state
|
||||
k_A: Key LoRA A matrix
|
||||
k_B: Key LoRA B matrix
|
||||
k_scale: Key LoRA scale
|
||||
v_weight: Value projection weight
|
||||
v_quant: Value quantization state
|
||||
v_A: Value LoRA A matrix
|
||||
v_B: Value LoRA B matrix
|
||||
v_scale: Value LoRA scale
|
||||
inplace: Whether to perform operations in-place
|
||||
|
||||
Returns:
|
||||
Tuple of (Query, Key, Value) projection tensors
|
||||
"""
|
||||
Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale)
|
||||
K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale)
|
||||
V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale)
|
||||
|
||||
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
|
||||
ctx.scales = (q_scale, k_scale, v_scale)
|
||||
ctx.quants = (q_quant, k_quant, v_quant)
|
||||
ctx.weights = (q_weight, k_weight, v_weight)
|
||||
ctx.inplace = inplace
|
||||
|
||||
return Q, K, V
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
q_grad: torch.Tensor,
|
||||
k_grad: torch.Tensor,
|
||||
v_grad: torch.Tensor,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Backward pass computing gradients for LoRA QKV.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
q_grad: Gradient for query projection
|
||||
k_grad: Gradient for key projection
|
||||
v_grad: Gradient for value projection
|
||||
|
||||
Returns:
|
||||
Tuple containing gradients for all forward inputs
|
||||
"""
|
||||
X, A_q, B_q, A_k, B_k, A_v, B_v = ctx.saved_tensors
|
||||
q_weight, k_weight, v_weight = ctx.weights
|
||||
q_quant, k_quant, v_quant = ctx.quants
|
||||
q_scale, k_scale, v_scale = ctx.scales
|
||||
dtype = X.dtype
|
||||
|
||||
# Reshape gradients
|
||||
batch, seq_len = X.shape[:2]
|
||||
q_grad = q_grad.view(-1, q_grad.shape[-1])
|
||||
k_grad = k_grad.reshape(-1, k_grad.shape[-1])
|
||||
v_grad = v_grad.view(-1, v_grad.shape[-1])
|
||||
X = X.view(-1, X.shape[-1])
|
||||
|
||||
# Pre-transpose X once
|
||||
X_t = X.t()
|
||||
|
||||
# Initialize LoRA gradients as None
|
||||
d_A_q = d_B_q = d_A_k = d_B_k = d_A_v = d_B_v = None
|
||||
|
||||
# Compute q path LoRA gradients if adapters exist
|
||||
if A_q is not None and B_q is not None:
|
||||
A_q_scaled = (q_scale * A_q).to(dtype)
|
||||
B_q_scaled = B_q.to(dtype)
|
||||
d_A_q = torch.mm(X_t, torch.mm(q_grad, B_q_scaled))
|
||||
d_B_q = torch.mm(torch.mm(A_q_scaled, X_t), q_grad)
|
||||
|
||||
# Compute k path LoRA gradients if adapters exist
|
||||
if A_k is not None and B_k is not None:
|
||||
A_k_scaled = (k_scale * A_k).to(dtype)
|
||||
B_k_scaled = B_k.to(dtype)
|
||||
d_A_k = torch.mm(X_t, torch.mm(k_grad, B_k_scaled))
|
||||
d_B_k = torch.mm(torch.mm(A_k_scaled, X_t), k_grad)
|
||||
|
||||
# Compute v path LoRA gradients if adapters exist
|
||||
if A_v is not None and B_v is not None:
|
||||
A_v_scaled = (v_scale * A_v).to(dtype)
|
||||
B_v_scaled = B_v.to(dtype)
|
||||
d_A_v = torch.mm(X_t, torch.mm(v_grad, B_v_scaled))
|
||||
d_B_v = torch.mm(torch.mm(A_v_scaled, X_t), v_grad)
|
||||
|
||||
# Compute input gradient, reusing X memory if possible
|
||||
out_buffer = X if ctx.inplace else None
|
||||
|
||||
# Q path
|
||||
q_weight_t = dequantize(q_weight, q_quant)
|
||||
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
|
||||
del q_weight
|
||||
del q_weight_t
|
||||
if A_q is not None and B_q is not None:
|
||||
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
|
||||
|
||||
# K path
|
||||
k_weight_t = dequantize(k_weight, k_quant)
|
||||
grad_X.addmm_(k_grad, k_weight_t)
|
||||
del k_weight
|
||||
del k_weight_t
|
||||
if A_k is not None and B_k is not None:
|
||||
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
|
||||
|
||||
# V path
|
||||
v_weight_t = dequantize(v_weight, v_quant)
|
||||
grad_X.addmm_(v_grad, v_weight_t)
|
||||
del v_weight
|
||||
del v_weight_t
|
||||
if A_v is not None and B_v is not None:
|
||||
grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled))
|
||||
|
||||
# Transpose gradients if needed
|
||||
if d_A_q is not None:
|
||||
d_A_q = d_A_q.t()
|
||||
if d_B_q is not None:
|
||||
d_B_q = d_B_q.t()
|
||||
if d_A_k is not None:
|
||||
d_A_k = d_A_k.t()
|
||||
if d_B_k is not None:
|
||||
d_B_k = d_B_k.t()
|
||||
if d_A_v is not None:
|
||||
d_A_v = d_A_v.t()
|
||||
if d_B_v is not None:
|
||||
d_B_v = d_B_v.t()
|
||||
|
||||
return (
|
||||
grad_X.view(batch, seq_len, -1),
|
||||
None,
|
||||
None,
|
||||
d_A_q,
|
||||
d_B_q,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_k,
|
||||
d_B_k,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_v,
|
||||
d_B_v,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_qkv(
|
||||
self, X: torch.Tensor, inplace: bool = True
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Applies LoRA to compute Query, Key, Value projections.
|
||||
|
||||
Args:
|
||||
X: Input tensor
|
||||
inplace: Whether to perform operations in-place
|
||||
|
||||
Returns:
|
||||
Tuple of (Query, Key, Value) projection tensors
|
||||
"""
|
||||
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
||||
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
||||
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
||||
Q, K, V = LoRA_QKV.apply(
|
||||
X,
|
||||
QW,
|
||||
QW_quant,
|
||||
QA,
|
||||
QB,
|
||||
QS,
|
||||
KW,
|
||||
KW_quant,
|
||||
KA,
|
||||
KB,
|
||||
KS,
|
||||
VW,
|
||||
VW_quant,
|
||||
VA,
|
||||
VB,
|
||||
VS,
|
||||
inplace,
|
||||
)
|
||||
|
||||
return Q, K, V
|
||||
|
||||
|
||||
class LoRA_O(torch.autograd.Function):
|
||||
"""Optimized LoRA implementation for output projection."""
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
W_quant: QuantState | None,
|
||||
A: torch.Tensor | None,
|
||||
B: torch.Tensor | None,
|
||||
S: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for output projection with LoRA.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
X: Input tensor
|
||||
W: Output projection weight
|
||||
W_quant: Weight quantization state
|
||||
A: LoRA A matrix
|
||||
B: LoRA B matrix
|
||||
S: LoRA scaling factor
|
||||
|
||||
Returns:
|
||||
Output projection tensor
|
||||
"""
|
||||
XW = matmul_lora(X, W, W_quant, A, B, S)
|
||||
ctx.custom_saved_tensors = (
|
||||
W,
|
||||
W_quant,
|
||||
S,
|
||||
)
|
||||
ctx.save_for_backward(A, B, X)
|
||||
|
||||
return XW
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_bwd
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
dY: torch.Tensor,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Backward pass computing gradients for LoRA output projection.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
dY: Gradient of loss with respect to output
|
||||
|
||||
Returns:
|
||||
Tuple containing gradients for all forward inputs
|
||||
"""
|
||||
W, W_quant, S = ctx.custom_saved_tensors
|
||||
A, B, X = ctx.saved_tensors
|
||||
|
||||
batch, seq_len, hd = X.shape
|
||||
dY = dY.reshape(-1, dY.shape[-1])
|
||||
X = X.reshape(-1, X.shape[-1])
|
||||
dtype = X.dtype
|
||||
|
||||
# Weight projection
|
||||
dY_X = X.t() @ dY
|
||||
d_A = S * dY_X @ B
|
||||
d_B = S * A @ dY_X
|
||||
|
||||
# Get derivative for dX
|
||||
W = dequantize(W.t(), W_quant)
|
||||
dX = dY @ W.t()
|
||||
del W
|
||||
dX += dY @ B.to(dtype) @ (S * A.to(dtype))
|
||||
|
||||
# W, W_quant, A, B, S
|
||||
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
|
||||
|
||||
|
||||
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies LoRA to output projection layer.
|
||||
|
||||
Args:
|
||||
X: Input tensor
|
||||
|
||||
Returns:
|
||||
Transformed output tensor
|
||||
"""
|
||||
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
||||
output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS)
|
||||
|
||||
return output
|
||||
149
src/axolotl/kernels/quantize.py
Normal file
149
src/axolotl/kernels/quantize.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Dequantization utilities for `bitsandbytes` integration."""
|
||||
# pylint: disable=invalid-name,global-statement
|
||||
|
||||
import ctypes
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState, get_ptr
|
||||
from packaging.version import Version
|
||||
|
||||
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
||||
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
|
||||
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
|
||||
|
||||
CUDA_STREAM: torch.cuda.Stream | None = None
|
||||
HAS_CUDA_STREAM: bool = Version(bnb.__version__) > Version("0.43.3")
|
||||
|
||||
|
||||
def dequantize(
|
||||
W: torch.Tensor,
|
||||
quant_state: QuantState | list | None = None,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fast NF4 dequantization using `bitsandbytes` CUDA kernels.
|
||||
|
||||
Performs efficient dequantization of weights from NF4 format using `bitsandbytes`'
|
||||
optimized CUDA implementations. Supports both legacy list and new `QuantState`
|
||||
formats.
|
||||
|
||||
Args:
|
||||
W: Quantized weight tensor to dequantize
|
||||
quant_state: Quantization state containing metadata needed for
|
||||
dequantization. Can be either a `QuantState` object or legacy list format.
|
||||
If None, returns `W` unchanged.
|
||||
out: Optional output tensor for storing dequantized results. Must match
|
||||
expected shape and dtype if provided.
|
||||
|
||||
Returns:
|
||||
Dequantized tensor in the specified dtype (fp16 or bf16). Will be transposed if
|
||||
input `W` was transposed.
|
||||
|
||||
Raises:
|
||||
AssertionError: If provided output tensor doesn't match expected shape / dtype.
|
||||
|
||||
Note:
|
||||
Uses CUDA streams for better performance when available in newer `bitsandbytes`
|
||||
versions (>0.43.3).
|
||||
"""
|
||||
if quant_state is None:
|
||||
return W
|
||||
|
||||
# Get the target device from input tensor W
|
||||
target_device = W.device
|
||||
|
||||
# Extract quantization state
|
||||
if not isinstance(quant_state, list):
|
||||
# New style quant_state class
|
||||
absmax = quant_state.absmax.to(target_device)
|
||||
shape = quant_state.shape
|
||||
dtype = quant_state.dtype
|
||||
blocksize = quant_state.blocksize
|
||||
offset = quant_state.offset.to(target_device)
|
||||
state2 = quant_state.state2
|
||||
absmax2 = state2.absmax.to(target_device)
|
||||
code2 = state2.code.to(target_device)
|
||||
blocksize2 = state2.blocksize
|
||||
else:
|
||||
# Legacy list format
|
||||
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
|
||||
absmax = absmax.to(target_device)
|
||||
offset, state2 = compressed_stats
|
||||
offset = offset.to(target_device)
|
||||
absmax2, code2, blocksize2, _, _, _, _ = state2
|
||||
absmax2 = absmax2.to(target_device)
|
||||
code2 = code2.to(target_device)
|
||||
|
||||
# Setup output tensor on the same device as input
|
||||
if out is None:
|
||||
out = torch.empty(shape, dtype=dtype, device=target_device)
|
||||
else:
|
||||
assert out.shape == shape and out.dtype == dtype
|
||||
out = out.to(target_device)
|
||||
|
||||
# Dequantize statistics on the target device
|
||||
n_elements_absmax: int = absmax.numel()
|
||||
out_absmax: torch.Tensor = torch.empty(
|
||||
n_elements_absmax, dtype=torch.float32, device=target_device
|
||||
)
|
||||
ptr_out_absmax: int = get_ptr(out_absmax)
|
||||
|
||||
# Use CUDA stream if available
|
||||
if HAS_CUDA_STREAM:
|
||||
global CUDA_STREAM
|
||||
if CUDA_STREAM is None:
|
||||
CUDA_STREAM = torch.cuda.current_stream(target_device)
|
||||
|
||||
cdequantize_blockwise_fp32(
|
||||
get_ptr(code2),
|
||||
get_ptr(absmax),
|
||||
get_ptr(absmax2),
|
||||
ptr_out_absmax,
|
||||
ctypes.c_int(blocksize2),
|
||||
ctypes.c_int(n_elements_absmax),
|
||||
CUDA_STREAM,
|
||||
)
|
||||
else:
|
||||
cdequantize_blockwise_fp32(
|
||||
get_ptr(code2),
|
||||
get_ptr(absmax),
|
||||
get_ptr(absmax2),
|
||||
ptr_out_absmax,
|
||||
ctypes.c_int(blocksize2),
|
||||
ctypes.c_int(n_elements_absmax),
|
||||
)
|
||||
|
||||
out_absmax += offset
|
||||
|
||||
# Choose appropriate dequantization function
|
||||
fx = (
|
||||
cdequantize_blockwise_fp16_nf4
|
||||
if dtype == torch.float16
|
||||
else cdequantize_blockwise_bf16_nf4
|
||||
)
|
||||
|
||||
# Dequantize weights
|
||||
if HAS_CUDA_STREAM:
|
||||
fx(
|
||||
get_ptr(None),
|
||||
get_ptr(W),
|
||||
ptr_out_absmax,
|
||||
get_ptr(out),
|
||||
ctypes.c_int(blocksize),
|
||||
ctypes.c_int(out.numel()),
|
||||
CUDA_STREAM,
|
||||
)
|
||||
else:
|
||||
fx(
|
||||
get_ptr(None),
|
||||
get_ptr(W),
|
||||
ptr_out_absmax,
|
||||
get_ptr(out),
|
||||
ctypes.c_int(blocksize),
|
||||
ctypes.c_int(out.numel()),
|
||||
)
|
||||
|
||||
# Handle transposed data
|
||||
is_transposed: bool = W.shape[0] == 1
|
||||
return out.t() if is_transposed else out
|
||||
163
src/axolotl/kernels/swiglu.py
Normal file
163
src/axolotl/kernels/swiglu.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Module for definition of SwiGLU Triton kernels.
|
||||
|
||||
See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
|
||||
|
||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||
"""
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _swiglu_fwd_kernel(
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
SwiGLU forward kernel. The kernel computes activation in fp32 precision for better
|
||||
numerical stability, then converts back to original dtype for the final result.
|
||||
|
||||
Args:
|
||||
gate_ptr: Pointer to gate tensor `[*, hidden_dim]`.
|
||||
up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`.
|
||||
out_ptr: Pointer to output tensor `[*, hidden_dim]`.
|
||||
n_elements: Total number of elements in the input tensors.
|
||||
block_size: Size of thread blocks for parallel computation.
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * block_size + tl.arange(0, block_size)
|
||||
mask = offsets < n_elements
|
||||
|
||||
# Load gate in fp32, keep up in original dtype
|
||||
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
|
||||
up = tl.load(up_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
# Compute activation in fp32 then convert back
|
||||
f = gate * tl.sigmoid(gate)
|
||||
f = f.to(up.dtype)
|
||||
result = f * up
|
||||
|
||||
tl.store(out_ptr + offsets, result, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _swiglu_bwd_kernel(
|
||||
grad_out_ptr,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
n_elements,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
SwiGLU backward kernel. Stores gradient results in-place.
|
||||
|
||||
Args:
|
||||
grad_out_ptr: Pointer to gradient output tensor `[*, hidden_dim]`.
|
||||
gate_ptr: Pointer to gate tensor `[*, hidden_dim]`.
|
||||
up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`.
|
||||
n_elements: Total number of elements in the input tensors.
|
||||
block_size: Size of thread blocks for parallel computation.
|
||||
|
||||
Note:
|
||||
After kernel execution, tensors are modified in-place:
|
||||
- `grad_out_ptr` contains forward output (`h`)
|
||||
- `gate_ptr` contains gradient w.r.t gate (`grad_gate`)
|
||||
- `up_ptr` contains gradient w.r.t up (`grad_up`)
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * block_size + tl.arange(0, block_size)
|
||||
mask = offsets < n_elements
|
||||
|
||||
# Load values - only convert gate to fp32
|
||||
grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0)
|
||||
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
|
||||
up = tl.load(up_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
# Compute SiLU and forward output
|
||||
sigmoid_gate = tl.sigmoid(gate)
|
||||
silu_gate = sigmoid_gate * gate
|
||||
silu_gate = silu_gate.to(grad_out.dtype)
|
||||
h = silu_gate * up
|
||||
|
||||
# Compute gradients
|
||||
grad_up = grad_out * silu_gate # gradient for up is grad_out * SiLU(gate)
|
||||
|
||||
# Compute gate gradient
|
||||
temp = grad_out * up
|
||||
grad_gate = temp.to(tl.float32) * sigmoid_gate * (1.0 + gate * (1.0 - sigmoid_gate))
|
||||
grad_gate = grad_gate.to(grad_out.dtype)
|
||||
|
||||
# Store results with correct gradient ordering
|
||||
tl.store(grad_out_ptr + offsets, h, mask=mask)
|
||||
tl.store(gate_ptr + offsets, grad_gate, mask=mask) # grad wrt gate
|
||||
tl.store(up_ptr + offsets, grad_up, mask=mask) # grad wrt up
|
||||
|
||||
|
||||
# pylint: disable=unnecessary-lambda-assignment
|
||||
def swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
SwiGLU forward pass. Computes SwiGLU activation: `x * sigmoid(x) * up`, where
|
||||
`x` is the gate tensor.
|
||||
|
||||
Args:
|
||||
gate: Input gate tensor of shape `[batch, seq_len, hidden_dim]`.
|
||||
up: Up-projection tensor of shape `[batch, seq_len, hidden_dim]`.
|
||||
|
||||
Returns:
|
||||
Output tensor of shape `[batch, seq_len, hidden_dim]`.
|
||||
"""
|
||||
batch, seq_len, hidden_dim = gate.shape
|
||||
n_elements = gate.numel()
|
||||
out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device="cuda")
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["block_size"]),) # noqa: E731
|
||||
_swiglu_fwd_kernel[grid](
|
||||
gate_ptr=gate,
|
||||
up_ptr=up,
|
||||
out_ptr=out,
|
||||
n_elements=n_elements,
|
||||
block_size=1024,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# pylint: disable=unnecessary-lambda-assignment
|
||||
def swiglu_backward(
|
||||
grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
SwiGLU backward pass using in-place operations.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`.
|
||||
gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
|
||||
up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Forward pass output (`h`)
|
||||
- Gradient with respect to gate (`df`)
|
||||
- Gradient with respect to up-projection (`de`)
|
||||
"""
|
||||
n_elements = grad_output.numel()
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["block_size"]),) # noqa: E731
|
||||
_swiglu_bwd_kernel[grid](
|
||||
grad_out_ptr=grad_output,
|
||||
gate_ptr=gate,
|
||||
up_ptr=up,
|
||||
n_elements=n_elements,
|
||||
block_size=1024,
|
||||
)
|
||||
|
||||
# After kernel execution, tensors contain:
|
||||
# grad_output: h (forward output)
|
||||
# gate: grad_gate (grad wrt gate)
|
||||
# up: grad_up (grad wrt up)
|
||||
return grad_output, gate, up
|
||||
11
src/axolotl/kernels/utils.py
Normal file
11
src/axolotl/kernels/utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Utilities for `axolotl.kernels` submodules."""
|
||||
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
|
||||
if Version(torch.__version__) < Version("2.4.0"):
|
||||
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
||||
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
||||
else:
|
||||
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
333
src/axolotl/monkeypatch/lora_kernels.py
Normal file
333
src/axolotl/monkeypatch/lora_kernels.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Module for patching custom LoRA Triton kernels and `torch.autograd` functions."""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import types
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from peft import PeftModelForCausalLM
|
||||
from torch import nn
|
||||
from transformers import AutoConfig
|
||||
|
||||
from axolotl.kernels.lora import (
|
||||
apply_lora_mlp_geglu,
|
||||
apply_lora_mlp_swiglu,
|
||||
apply_lora_o,
|
||||
apply_lora_qkv,
|
||||
)
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_QKV_CODE = """
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
PATCHED_QKV_CODE = """
|
||||
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
||||
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
||||
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
ORIGINAL_O_CODE = """
|
||||
attn_output = self.o_proj(attn_output)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
PATCHED_O_CODE = """
|
||||
attn_output = self.apply_o(attn_output)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
SUPPORTED_ACTIVATIONS = ["silu", "gelu"]
|
||||
APPLY_FN_MAPPING = {
|
||||
"silu": apply_lora_mlp_swiglu,
|
||||
"gelu": apply_lora_mlp_geglu,
|
||||
}
|
||||
|
||||
|
||||
def original_apply_qkv(
|
||||
self: nn.Module, hidden_states: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Original implementation of QKV projection without optimizations.
|
||||
|
||||
Args:
|
||||
self: The attention module instance.
|
||||
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim].
|
||||
|
||||
Returns:
|
||||
A tuple `(query_states, key_states, value_states)` containing the projected
|
||||
states for query, key, and value.
|
||||
"""
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
return query_states, key_states, value_states
|
||||
|
||||
|
||||
def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Original implementation of output projection without optimizations.
|
||||
|
||||
Args:
|
||||
self: The attention module instance.
|
||||
hidden_states: Input tensor of shape `[`batch_size, seq_len, hidden_dim]`.
|
||||
|
||||
Returns:
|
||||
The output projection result.
|
||||
"""
|
||||
attn_output = self.o_proj(hidden_states)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
"""
|
||||
Get the appropriate attention class by inspecting the model config.
|
||||
Uses dynamic import to support any model architecture that follows
|
||||
the standard transformers naming convention.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
The appropriate attention class for the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If `base_model` not specified or attention class cannot be imported
|
||||
ImportError: If the model module or attention class doesn't exist
|
||||
"""
|
||||
if "base_model" not in cfg:
|
||||
raise ValueError("base_model must be specified in config")
|
||||
|
||||
# Get model config without loading the model
|
||||
model_config = AutoConfig.from_pretrained(cfg["base_model"])
|
||||
model_type = model_config.model_type
|
||||
|
||||
# Special case for model_type = "qwen2"
|
||||
if model_type == "qwen2":
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
|
||||
|
||||
return Qwen2Attention
|
||||
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
module = __import__(
|
||||
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
|
||||
)
|
||||
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
|
||||
|
||||
return attention_cls
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(
|
||||
f"Could not import attention class for model_type: {model_type}. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def patch_self_attn_lora(cfg: DictDefault):
|
||||
"""
|
||||
Given an `axolotl` config, this method patches the inferred attention class forward
|
||||
pass with optimized LoRA implementations.
|
||||
|
||||
It modifies the attention class to use optimized QKV and output projections. The
|
||||
original implementation is preserved and can be restored if needed.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the required code blocks are not found in the attention
|
||||
implementation.
|
||||
"""
|
||||
attention_cls = get_attention_cls_from_config(cfg)
|
||||
|
||||
# Check if already patched
|
||||
if hasattr(attention_cls, "_original_forward"):
|
||||
LOG.info(f"{attention_cls.__name__} already patched")
|
||||
return
|
||||
|
||||
self_attn_forward = inspect.getsource(attention_cls.forward)
|
||||
attention_cls._original_forward = self_attn_forward
|
||||
self_attn_forward, _ = detab_code(self_attn_forward)
|
||||
|
||||
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
|
||||
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
|
||||
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(
|
||||
"def forward(",
|
||||
"def axolotl_attn_forward(",
|
||||
1,
|
||||
)
|
||||
|
||||
# Load necessary imports
|
||||
module_name = attention_cls.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in self_attn_forward:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
|
||||
attention_cls.forward = (
|
||||
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_kernel_patches(
|
||||
model: PeftModelForCausalLM, cfg: DictDefault
|
||||
) -> PeftModelForCausalLM:
|
||||
"""
|
||||
Applies optimized Triton kernel patches to a PEFT model.
|
||||
|
||||
Patches a PEFT model with optimized implementations for MLP and attention
|
||||
computations. The optimizations include custom Triton kernels for activation
|
||||
functions and specialized autograd functions for LoRA computations.
|
||||
|
||||
Args:
|
||||
model: A PEFT model to be patched with optimized kernels.
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
PeftModelForCausalLM: The patched model with optimized kernels.
|
||||
|
||||
Raises:
|
||||
TypeError: If the provided model is not a `PeftModelForCausalLM`.
|
||||
NotImplementedError: If the model type is not supported.
|
||||
AssertionError: If multiple adapters are active (currently unsupported).
|
||||
|
||||
Note:
|
||||
The optimizations require LoRA adapters with no dropout and no bias terms. The
|
||||
function will skip patching if these conditions aren't met.
|
||||
"""
|
||||
if not isinstance(model, PeftModelForCausalLM):
|
||||
raise TypeError("Model must be a PeftModelForCausalLM")
|
||||
|
||||
# Get active LoRA adapter config
|
||||
if hasattr(model, "active_adapters"):
|
||||
assert (
|
||||
len(model.active_adapters) == 1
|
||||
), "Axolotl currently does not support LoRA Triton kernels for multiple adapters"
|
||||
active_adapter = model.active_adapters[0]
|
||||
else:
|
||||
active_adapter = model.active_adapter
|
||||
lora_config = model.model.peft_config[active_adapter]
|
||||
|
||||
# Only patch if conditions are met
|
||||
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
|
||||
|
||||
if not can_patch:
|
||||
LOG.warning("Cannot patch layers - requires no dropout and no bias")
|
||||
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
|
||||
return model
|
||||
|
||||
# This needs to be reset after patching
|
||||
original_level = LOG.getEffectiveLevel()
|
||||
LOG.setLevel(logging.INFO)
|
||||
|
||||
# Choose activation based on model type
|
||||
activation = model.config.hidden_act
|
||||
if activation not in SUPPORTED_ACTIVATIONS:
|
||||
raise NotImplementedError(f"Activation {activation} is not supported")
|
||||
|
||||
# Patch each layer
|
||||
for layer in model.model.model.layers:
|
||||
# Add QKV, O fallback implementations to start
|
||||
# These will be overwritten later (if some conditions apply)
|
||||
layer.self_attn.apply_qkv = types.MethodType(
|
||||
original_apply_qkv, layer.self_attn
|
||||
)
|
||||
layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn)
|
||||
|
||||
if cfg.lora_mlp_kernel:
|
||||
# MLP patching
|
||||
gate_proj = layer.mlp.gate_proj
|
||||
up_proj = layer.mlp.up_proj
|
||||
down_proj = layer.mlp.down_proj
|
||||
|
||||
can_patch_mlp = all(
|
||||
hasattr(proj, "lora_A")
|
||||
and getattr(proj, "base_layer", proj).bias is None
|
||||
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
||||
for proj in (gate_proj, up_proj, down_proj)
|
||||
)
|
||||
|
||||
if can_patch_mlp:
|
||||
apply_fn = APPLY_FN_MAPPING[activation]
|
||||
layer.mlp.forward = types.MethodType(apply_fn, layer.mlp)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
|
||||
)
|
||||
if cfg.lora_qkv_kernel:
|
||||
# Query, key, value patching
|
||||
layer_modules = [
|
||||
getattr(layer.self_attn, linear_proj)
|
||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||
]
|
||||
can_patch_qkv = all(
|
||||
hasattr(module, "lora_A")
|
||||
and getattr(module, "base_layer", module).bias is None
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if can_patch_qkv:
|
||||
# Add optimized implementation
|
||||
layer.self_attn.apply_qkv = types.MethodType(
|
||||
apply_lora_qkv, layer.self_attn
|
||||
)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
|
||||
)
|
||||
if cfg.lora_o_kernel:
|
||||
# Output patching
|
||||
layer_modules = [
|
||||
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||
]
|
||||
can_patch_o = all(
|
||||
hasattr(module, "lora_A")
|
||||
and getattr(module, "base_layer", module).bias is None
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if can_patch_o:
|
||||
layer.self_attn.apply_o = types.MethodType(
|
||||
apply_lora_o, layer.self_attn
|
||||
)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
|
||||
)
|
||||
|
||||
LOG.setLevel(original_level)
|
||||
|
||||
return model
|
||||
@@ -127,6 +127,8 @@ class ReLoRACallback(TrainerCallback):
|
||||
optimizer: torch.optim.Optimizer,
|
||||
**_kwargs,
|
||||
):
|
||||
if not optimizer:
|
||||
optimizer = state.optimizer
|
||||
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
|
||||
@@ -41,10 +41,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
||||
load_kwargs["ds_cfg"] = ds_cfg
|
||||
if "processor" in sig.parameters:
|
||||
load_kwargs["processor"] = processor
|
||||
|
||||
return func(tokenizer, cfg, **load_kwargs)
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
||||
raise exc
|
||||
return None
|
||||
|
||||
@@ -13,8 +13,19 @@ def load(strategy, cfg, module_base=None, **kwargs):
|
||||
if len(strategy.split(".")) == 1:
|
||||
strategy = strategy + ".default"
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{strategy}", module_base)
|
||||
if len(strategy.split(".")) > 1:
|
||||
try:
|
||||
importlib.import_module(
|
||||
strategy.split(".")[-2],
|
||||
".".join(strategy.split(".")[:-2]),
|
||||
)
|
||||
module_base = ".".join(strategy.split(".")[:-2])
|
||||
strategy = strategy.split(".")[-2]
|
||||
except ModuleNotFoundError:
|
||||
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||
else:
|
||||
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(strategy, module_base)
|
||||
func = getattr(mod, load_fn)
|
||||
return func(cfg, **kwargs)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
|
||||
@@ -34,15 +34,12 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
|
||||
max_length = self.prompter.max_length
|
||||
|
||||
self.messages = "chosen_messages"
|
||||
# pylint: disable=duplicate-code
|
||||
prompt[self.messages] = []
|
||||
prompt["messages"] = []
|
||||
if prompt["system"]:
|
||||
prompt[self.messages].append(
|
||||
{"role": "system", "content": prompt["system"]}
|
||||
)
|
||||
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
||||
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
|
||||
prompt["messages"].append({"role": "system", "content": prompt["system"]})
|
||||
prompt["messages"].append({"role": "user", "content": prompt["input"]})
|
||||
prompt["messages"].append({"role": "assistant", "content": prompt["chosen"]})
|
||||
chosen_tokenized = super()._tokenize_single_prompt(prompt)
|
||||
|
||||
if len(chosen_tokenized["input_ids"]) > max_length:
|
||||
@@ -55,17 +52,12 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
:max_length
|
||||
]
|
||||
|
||||
self.messages = "rejected_messages"
|
||||
# pylint: disable=duplicate-code
|
||||
prompt[self.messages] = []
|
||||
prompt["messages"] = []
|
||||
if prompt["system"]:
|
||||
prompt[self.messages].append(
|
||||
{"role": "system", "content": prompt["system"]}
|
||||
)
|
||||
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
||||
prompt[self.messages].append(
|
||||
{"role": "assistant", "content": prompt["rejected"]}
|
||||
)
|
||||
prompt["messages"].append({"role": "system", "content": prompt["system"]})
|
||||
prompt["messages"].append({"role": "user", "content": prompt["input"]})
|
||||
prompt["messages"].append({"role": "assistant", "content": prompt["rejected"]})
|
||||
rejected_tokenized = super()._tokenize_single_prompt(prompt)
|
||||
|
||||
if len(rejected_tokenized["input_ids"]) > max_length:
|
||||
@@ -99,8 +91,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_template_string,
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_property_mappings": ds_cfg.get(
|
||||
"message_property_mappings",
|
||||
{
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_field_training_detail", None
|
||||
@@ -124,7 +121,4 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
||||
)
|
||||
|
||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
|
||||
return strategy
|
||||
|
||||
@@ -4,13 +4,16 @@ HF Chat Templates prompt strategy
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import ProcessorMixin
|
||||
|
||||
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
|
||||
|
||||
# Configure the logger
|
||||
LOG = logging.getLogger("axolotl")
|
||||
@@ -23,16 +26,23 @@ class ChatTemplatePrompter(Prompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template: str,
|
||||
processor=None,
|
||||
chat_template=None,
|
||||
max_length=2048,
|
||||
message_field_role: str = "role",
|
||||
message_field_content: str = "content",
|
||||
message_property_mappings: Optional[Dict[str, str]] = None,
|
||||
message_field_training: Optional[str] = None,
|
||||
message_field_training_detail: Optional[str] = None,
|
||||
field_messages: str = "messages",
|
||||
roles: Optional[Dict[str, List[str]]] = None,
|
||||
drop_system_message: bool = False,
|
||||
):
|
||||
# check if message_property_mappings is None or empty dict
|
||||
if message_property_mappings is None or (not message_property_mappings):
|
||||
message_property_mappings = {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
}
|
||||
|
||||
if roles:
|
||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||
else:
|
||||
@@ -45,18 +55,28 @@ class ChatTemplatePrompter(Prompter):
|
||||
"tool": "tool",
|
||||
}
|
||||
|
||||
self.message_field_role = message_field_role
|
||||
self.message_field_content = message_field_content
|
||||
self._chat_template_msg_variables = self.get_chat_template_msg_variables(
|
||||
chat_template, field_messages
|
||||
)
|
||||
self.message_property_mappings = message_property_mappings
|
||||
self.message_field_training = message_field_training
|
||||
self.message_field_training_detail = message_field_training_detail
|
||||
self.field_messages = field_messages
|
||||
self.tokenizer = tokenizer
|
||||
self.processor: ProcessorMixin = processor
|
||||
self.processor: Optional[ProcessorMixin] = processor
|
||||
self.chat_template = chat_template
|
||||
self.max_length = max_length
|
||||
self.drop_system_message = drop_system_message
|
||||
|
||||
@property
|
||||
def chat_template_msg_variables(self) -> Set[str]:
|
||||
return self._chat_template_msg_variables
|
||||
|
||||
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
||||
if self.processor:
|
||||
if not callable(self.processor):
|
||||
raise TypeError("Processor must be callable")
|
||||
|
||||
text = self.processor.apply_chat_template(
|
||||
conversation,
|
||||
chat_template=self.chat_template,
|
||||
@@ -184,17 +204,21 @@ class ChatTemplatePrompter(Prompter):
|
||||
|
||||
return adjusted_details
|
||||
|
||||
def get_chat_template_msg_variables(
|
||||
self, chat_template: str, field_messages: str
|
||||
) -> Set[str]:
|
||||
template_analyzer = JinjaTemplateAnalyzer(chat_template)
|
||||
return template_analyzer.get_message_vars(field_messages)
|
||||
|
||||
|
||||
class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for instruction-based prompts.
|
||||
"""
|
||||
|
||||
_messages = "messages"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter: ChatTemplatePrompter,
|
||||
prompter: "ChatTemplatePrompter",
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
@@ -202,6 +226,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
train_on_eos=None,
|
||||
):
|
||||
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||
self.prompter: ChatTemplatePrompter = prompter
|
||||
|
||||
self.roles_to_train = []
|
||||
if roles_to_train:
|
||||
@@ -213,13 +238,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
self.train_on_eos = train_on_eos
|
||||
self.images = "images"
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
return self._messages
|
||||
|
||||
@messages.setter
|
||||
def messages(self, messages):
|
||||
self._messages = messages
|
||||
LOG.debug(
|
||||
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_batched(self) -> bool:
|
||||
@@ -229,7 +250,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
||||
try:
|
||||
return all(isinstance(v, list) for v in prompt.values()) and all(
|
||||
isinstance(v, list) for v in prompt[self.messages]
|
||||
isinstance(v, list) for v in prompt[self.prompter.field_messages]
|
||||
)
|
||||
except KeyError:
|
||||
return False
|
||||
@@ -251,8 +272,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
dict(zip(feature_names, row))
|
||||
)
|
||||
for key, val in tokenized_prompt.items():
|
||||
for i in range(0, len(val), self.sequence_len):
|
||||
res[key].append(val[i : i + self.sequence_len])
|
||||
res[key].append(val)
|
||||
|
||||
# If there are no examples left, return an empty dictionary
|
||||
if not res:
|
||||
@@ -464,30 +484,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
turns = []
|
||||
optional_keys = [
|
||||
"tool_calls", # tool that 'assistant' calls
|
||||
"name", # name of tool given by 'tool'
|
||||
"tool_call_id", # mistral/mixtral requires this
|
||||
]
|
||||
for message in prompt[self.messages]:
|
||||
for message in prompt[self.prompter.field_messages]:
|
||||
transformed_message = self.transform_message(message)
|
||||
|
||||
turn = {
|
||||
"role": self.prompter.roles[message[self.prompter.message_field_role]],
|
||||
**transformed_message,
|
||||
"training": message.get(self.prompter.message_field_training),
|
||||
"training_detail": message.get(
|
||||
self.prompter.message_field_training_detail
|
||||
),
|
||||
}
|
||||
|
||||
# do not add content if None as it may conflict with some templates due to tools
|
||||
content = message.get(self.prompter.message_field_content, None)
|
||||
if content is not None:
|
||||
turn["content"] = content
|
||||
|
||||
for key in optional_keys:
|
||||
value = message.get(key, None)
|
||||
if value is not None:
|
||||
turn[key] = value
|
||||
|
||||
turns.append(turn)
|
||||
|
||||
if self.prompter.drop_system_message and turns[0]["role"] == "system":
|
||||
@@ -495,6 +502,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
return turns
|
||||
|
||||
def transform_message(self, message):
|
||||
# Build the initial transformed message from the mappings
|
||||
transformed_message = {}
|
||||
for key, value in self.prompter.message_property_mappings.items():
|
||||
if message.get(value) is not None:
|
||||
transformed_message[key] = message[value]
|
||||
else:
|
||||
LOG.debug(
|
||||
f"Could not find value for property {value} in message: {message}"
|
||||
)
|
||||
|
||||
# Map the role if necessary
|
||||
if "role" in transformed_message:
|
||||
transformed_message["role"] = self.prompter.roles.get(
|
||||
transformed_message["role"], transformed_message["role"]
|
||||
)
|
||||
|
||||
# Determine which keys in the original message were not mapped
|
||||
mapped_values = set(self.prompter.message_property_mappings.values())
|
||||
remaining_keys = set(message) - mapped_values
|
||||
|
||||
# Keep only the properties defined in the chat template
|
||||
# and not already mapped
|
||||
for key in self.prompter.chat_template_msg_variables:
|
||||
if key in remaining_keys:
|
||||
val = message.get(key)
|
||||
if val is not None:
|
||||
transformed_message[key] = val
|
||||
|
||||
return transformed_message
|
||||
|
||||
def get_images(self, prompt):
|
||||
return prompt.get(self.images, None)
|
||||
|
||||
@@ -516,33 +554,46 @@ class StrategyLoader:
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self, tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
|
||||
self,
|
||||
tokenizer,
|
||||
cfg,
|
||||
ds_cfg: Optional[Union[Dict[str, Any], DatasetConfig]] = None,
|
||||
processor=None,
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
ds_cfg = ds_cfg or {}
|
||||
if ds_cfg is None:
|
||||
dataset_config = {}
|
||||
elif isinstance(ds_cfg, BaseModel):
|
||||
dataset_config = ds_cfg.model_dump()
|
||||
else:
|
||||
dataset_config = ds_cfg
|
||||
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
|
||||
)
|
||||
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
||||
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_template_string,
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_property_mappings": dataset_config.get(
|
||||
"message_property_mappings", {}
|
||||
),
|
||||
"message_field_training": dataset_config.get(
|
||||
"message_field_training", None
|
||||
),
|
||||
"message_field_training_detail": dataset_config.get(
|
||||
"message_field_training_detail",
|
||||
None,
|
||||
),
|
||||
"roles": ds_cfg.get("roles"),
|
||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||
"field_messages": dataset_config.get("field_messages", "messages"),
|
||||
"roles": dataset_config.get("roles"),
|
||||
"drop_system_message": dataset_config.get("drop_system_message", False),
|
||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||
"max_length": cfg.sequence_len + 1,
|
||||
"processor": processor,
|
||||
}
|
||||
|
||||
strategy_params = self._get_strategy_params(cfg, ds_cfg)
|
||||
strategy_params = self._get_strategy_params(cfg, dataset_config)
|
||||
strategy_cls = self._get_strategy_cls()
|
||||
|
||||
strategy = strategy_cls(
|
||||
@@ -551,9 +602,6 @@ class StrategyLoader:
|
||||
**strategy_params,
|
||||
)
|
||||
|
||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
|
||||
@@ -3,20 +3,28 @@ DPO prompt strategies for using tokenizer chat templates.
|
||||
"""
|
||||
|
||||
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
|
||||
|
||||
|
||||
def default(
|
||||
cfg, dataset_idx=0, **kwargs
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
ds_cfg = cfg["datasets"][dataset_idx]
|
||||
ds_cfg = handle_legacy_message_fields_logic(ds_cfg)
|
||||
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg=cfg, ds_cfg=ds_cfg
|
||||
)
|
||||
field_messages = ds_cfg.get("field_messages", "messages")
|
||||
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||
field_message_role = ds_cfg.get("message_field_role", "role")
|
||||
field_message_content = ds_cfg.get("message_field_content", "content")
|
||||
message_property_mappings = ds_cfg.get(
|
||||
"message_property_mappings",
|
||||
{
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
)
|
||||
role_map_inv = ds_cfg.get(
|
||||
"roles",
|
||||
{
|
||||
@@ -40,18 +48,18 @@ def default(
|
||||
messages = sample[field_messages]
|
||||
messages = [
|
||||
{
|
||||
"role": role_map[m[field_message_role]],
|
||||
"content": m[field_message_content],
|
||||
"role": role_map[m[message_property_mappings["role"]]],
|
||||
"content": m[message_property_mappings["content"]],
|
||||
}
|
||||
for m in messages
|
||||
]
|
||||
chosen = {
|
||||
"role": role_map[sample[field_chosen][field_message_role]],
|
||||
"content": sample[field_chosen][field_message_content],
|
||||
"role": role_map[sample[field_chosen][message_property_mappings["role"]]],
|
||||
"content": sample[field_chosen][message_property_mappings["content"]],
|
||||
}
|
||||
rejected = {
|
||||
"role": role_map[sample[field_rejected][field_message_role]],
|
||||
"content": sample[field_rejected][field_message_content],
|
||||
"role": role_map[sample[field_rejected][message_property_mappings["role"]]],
|
||||
"content": sample[field_rejected][message_property_mappings["content"]],
|
||||
}
|
||||
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
|
||||
|
||||
|
||||
14
src/axolotl/prompt_strategies/dpo/passthrough.py
Normal file
14
src/axolotl/prompt_strategies/dpo/passthrough.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
DPO prompt strategies passthrough/zero-processing strategy
|
||||
"""
|
||||
|
||||
|
||||
def default(
|
||||
cfg, dataset_idx=0, **kwargs
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(
|
||||
sample, tokenizer=None
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
318
src/axolotl/prompt_strategies/jinja_template_analyzer.py
Normal file
318
src/axolotl/prompt_strategies/jinja_template_analyzer.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""Module for inspect jinja templates for the variables they use"""
|
||||
from typing import Dict, Optional, Set, TypedDict, Union
|
||||
|
||||
from jinja2 import Environment, meta, nodes
|
||||
|
||||
|
||||
class JinjaTemplateAnalysis(TypedDict):
|
||||
"""
|
||||
Represents the detailed analysis of a Jinja template variable.
|
||||
|
||||
Attributes:
|
||||
accessed_properties (Set[str]): A set of properties accessed from the variable
|
||||
(e.g., `foo.bar` results in 'bar' being accessed for 'foo').
|
||||
accessed_indices (Set[Union[int, float]]): A set of indices accessed from the variable.
|
||||
is_iterated (bool): Indicates if the variable is used as an iteration source in a `for` loop.
|
||||
is_conditional (bool): Indicates if the variable is referenced within a conditional statement (e.g., an `if` block).
|
||||
iteration_source (Optional[str]): The name of the variable being iterated over, if applicable.
|
||||
iteration_target (Optional[Union[str, list[str]]]): The loop target(s) assigned in the iteration.
|
||||
"""
|
||||
|
||||
accessed_properties: Set[str]
|
||||
accessed_indices: Set[Union[int, float]]
|
||||
is_iterated: bool
|
||||
is_conditional: bool
|
||||
iteration_source: Optional[str]
|
||||
iteration_target: Optional[Union[str, list[str]]]
|
||||
|
||||
|
||||
class JinjaTemplateAnalyzer:
|
||||
"""
|
||||
Analyzes Jinja templates to extract information about variable usage,
|
||||
including accessed properties, iteration, and conditional references.
|
||||
|
||||
Attributes:
|
||||
env (jinja2.Environment): The Jinja2 environment used for parsing templates.
|
||||
property_access (Dict[str, Set[str]]): Tracks accessed properties for variables.
|
||||
iteration_targets (Dict[str, str]): Maps iteration target variables to their sources.
|
||||
|
||||
Methods:
|
||||
get_template_variables(template: str) -> Dict[str, Set[str]]:
|
||||
Parse a Jinja template and return a mapping of variables to their accessed properties.
|
||||
|
||||
analyze_template(template: str) -> Dict[str, JinjaTemplateAnalysis]:
|
||||
Perform a detailed analysis of the template, including variable usage,
|
||||
iteration, and conditional references.
|
||||
|
||||
Private Methods:
|
||||
_visit_node(node) -> None:
|
||||
Recursively visit AST nodes to detect attribute access and iteration targets.
|
||||
|
||||
_get_base_name(node) -> Optional[str]:
|
||||
Extract the base variable name from a node.
|
||||
|
||||
_get_target_name(node) -> Optional[Union[str, list[str]]]:
|
||||
Extract the target name(s) from a `For` node.
|
||||
"""
|
||||
|
||||
def __init__(self, template: str):
|
||||
self.env: Environment = Environment(autoescape=True)
|
||||
self.property_access: Dict[str, Set[str]] = {}
|
||||
self.iteration_targets: Dict[str, Union[str, list[str]]] = {}
|
||||
self.index_access: Dict[str, Set[Union[int, float]]] = {}
|
||||
self.ast: nodes.Node = self.env.parse(template)
|
||||
self.template: str = template
|
||||
self.variable_assignments: Dict[str, str] = {}
|
||||
|
||||
def _visit_node(self, node) -> None:
|
||||
"""Recursively visit AST nodes to find attribute access."""
|
||||
# Handle attribute access (dot notation)
|
||||
if isinstance(node, nodes.Getattr):
|
||||
base_name = self._get_base_name(node.node)
|
||||
if base_name:
|
||||
self.property_access.setdefault(base_name, set()).add(node.attr)
|
||||
|
||||
# Handle dictionary access (subscript notation)
|
||||
elif isinstance(node, nodes.Getitem):
|
||||
base_name = self._get_base_name(node.node)
|
||||
if base_name and isinstance(node.arg, nodes.Const):
|
||||
value = node.arg.value
|
||||
if isinstance(value, (int, float)):
|
||||
self.index_access.setdefault(base_name, set()).add(value)
|
||||
else:
|
||||
self.property_access.setdefault(base_name, set()).add(value)
|
||||
|
||||
elif isinstance(node, nodes.Test) and node.name == "defined":
|
||||
base_name = self._get_base_name(node.node)
|
||||
if base_name:
|
||||
if isinstance(node.node, nodes.Getattr):
|
||||
self.property_access.setdefault(base_name, set()).add(
|
||||
node.node.attr
|
||||
)
|
||||
|
||||
# Handle loop variables
|
||||
elif isinstance(node, nodes.For):
|
||||
iter_name = self._get_base_name(node.iter)
|
||||
target_name = self._get_target_name(node.target)
|
||||
if iter_name and target_name:
|
||||
self.iteration_targets[target_name] = iter_name
|
||||
self.property_access.setdefault(iter_name, set())
|
||||
|
||||
elif isinstance(node, nodes.Assign):
|
||||
target_name = self._get_target_name(node.target)
|
||||
source_name = self._get_base_name(node.node)
|
||||
if target_name and source_name:
|
||||
self.variable_assignments[target_name] = source_name
|
||||
|
||||
elif isinstance(node, nodes.Filter):
|
||||
if node.name == "selectattr":
|
||||
target = self._get_base_name(node.node)
|
||||
if target:
|
||||
self.variable_assignments[f"filtered_{target}"] = target
|
||||
|
||||
for child in node.iter_child_nodes():
|
||||
self._visit_node(child)
|
||||
|
||||
def _get_target_name(self, node) -> Optional[str]:
|
||||
"""Get the target variable name from a For node.
|
||||
|
||||
Args:
|
||||
node: A Jinja AST node representing either a Name or Tuple node
|
||||
|
||||
Returns:
|
||||
- str: For simple variable targets (e.g., "item" in "for item in items")
|
||||
- None: If the node type is not recognized or is a tuple
|
||||
"""
|
||||
if isinstance(node, nodes.Name):
|
||||
return node.name
|
||||
return None
|
||||
|
||||
def _get_target_names(self, node) -> list[str]:
|
||||
"""Get all target variable names from a For node, including tuple unpacking.
|
||||
|
||||
Args:
|
||||
node: A Jinja AST node representing either a Name or Tuple node
|
||||
|
||||
Returns:
|
||||
List of target variable names
|
||||
"""
|
||||
if isinstance(node, nodes.Name):
|
||||
return [node.name]
|
||||
|
||||
if isinstance(node, nodes.Tuple):
|
||||
names = []
|
||||
for n in node.items:
|
||||
if isinstance(n, nodes.Name):
|
||||
names.append(n.name)
|
||||
return names
|
||||
|
||||
return []
|
||||
|
||||
def _get_base_name(self, node) -> Optional[str]:
|
||||
"""Get the base variable name from a node."""
|
||||
if isinstance(node, nodes.Name):
|
||||
return node.name
|
||||
|
||||
if isinstance(node, nodes.Getattr):
|
||||
return self._get_base_name(node.node)
|
||||
|
||||
if isinstance(node, nodes.Getitem):
|
||||
return self._get_base_name(node.node)
|
||||
|
||||
return None
|
||||
|
||||
def get_template_variables(self) -> Dict[str, Set[str]]:
|
||||
"""
|
||||
Parse a Jinja template and return both variables and their accessed properties.
|
||||
|
||||
Args:
|
||||
template (str): The Jinja template string
|
||||
|
||||
Returns:
|
||||
Dict[str, Set[str]]: Dictionary mapping variable names to sets of accessed properties
|
||||
"""
|
||||
# Parse the template
|
||||
ast = self.env.parse(self.template)
|
||||
|
||||
# Get all undeclared variables
|
||||
variables = meta.find_undeclared_variables(ast)
|
||||
|
||||
# Reset property access tracking
|
||||
self.property_access = {}
|
||||
|
||||
# Visit all nodes to find property access
|
||||
self._visit_node(ast)
|
||||
|
||||
# Create result dictionary
|
||||
result: Dict[str, Set[str]] = {var: set() for var in variables}
|
||||
# Merge in any discovered sub-properties
|
||||
for var, props in self.property_access.items():
|
||||
if var not in result:
|
||||
result[var] = set()
|
||||
result[var].update(props)
|
||||
|
||||
return result
|
||||
|
||||
def analyze_template(self) -> Dict[str, JinjaTemplateAnalysis]:
|
||||
"""
|
||||
Provide a detailed analysis of template variables and their usage.
|
||||
"""
|
||||
variables = self.get_template_variables()
|
||||
self.iteration_targets = {}
|
||||
|
||||
analysis: Dict[str, JinjaTemplateAnalysis] = {
|
||||
var: JinjaTemplateAnalysis(
|
||||
accessed_properties=props,
|
||||
accessed_indices=set(),
|
||||
is_iterated=False,
|
||||
is_conditional=False,
|
||||
iteration_source=None,
|
||||
iteration_target=None,
|
||||
)
|
||||
for var, props in variables.items()
|
||||
}
|
||||
|
||||
for var, indices in self.index_access.items():
|
||||
if var in analysis:
|
||||
analysis[var]["accessed_indices"] = indices
|
||||
|
||||
def visit_node(node):
|
||||
if isinstance(node, nodes.If):
|
||||
|
||||
def find_test_vars(test_node):
|
||||
if isinstance(test_node, nodes.Name):
|
||||
if test_node.name in analysis:
|
||||
analysis[test_node.name]["is_conditional"] = True
|
||||
for child in test_node.iter_child_nodes():
|
||||
find_test_vars(child)
|
||||
|
||||
find_test_vars(node.test)
|
||||
|
||||
if isinstance(node, nodes.For):
|
||||
iter_target = self._get_base_name(node.iter)
|
||||
target_name = self._get_target_name(node.target)
|
||||
if iter_target in analysis:
|
||||
analysis[iter_target]["is_iterated"] = True
|
||||
if target_name:
|
||||
analysis[iter_target]["iteration_target"] = target_name
|
||||
if isinstance(target_name, str) and target_name not in analysis:
|
||||
analysis[target_name] = {
|
||||
"accessed_properties": set(),
|
||||
"is_iterated": False,
|
||||
"is_conditional": False,
|
||||
"iteration_source": iter_target,
|
||||
"iteration_target": None,
|
||||
}
|
||||
|
||||
for child in node.iter_child_nodes():
|
||||
visit_node(child)
|
||||
|
||||
visit_node(self.ast)
|
||||
return analysis
|
||||
|
||||
def get_downstream_properties(self, start_var: str) -> Dict[str, Set[str]]:
|
||||
"""
|
||||
Get all properties accessed on a variable and its downstream assignments.
|
||||
|
||||
Args:
|
||||
start_var: The starting variable to trace
|
||||
|
||||
Returns:
|
||||
Dict mapping variable names to their accessed properties
|
||||
"""
|
||||
visited = set()
|
||||
properties = {}
|
||||
|
||||
def trace_variable(var_name: str):
|
||||
if var_name in visited:
|
||||
return
|
||||
visited.add(var_name)
|
||||
|
||||
# Get direct properties
|
||||
if var_name in self.property_access:
|
||||
properties[var_name] = self.property_access[var_name]
|
||||
|
||||
# Get properties from iteration targets
|
||||
if var_name in self.iteration_targets:
|
||||
target = self.iteration_targets[var_name]
|
||||
if isinstance(target, str):
|
||||
trace_variable(target)
|
||||
elif isinstance(target, list):
|
||||
for t in target:
|
||||
trace_variable(t)
|
||||
|
||||
# Follow assignments
|
||||
for target, source in self.variable_assignments.items():
|
||||
if source == var_name:
|
||||
trace_variable(target)
|
||||
|
||||
# Check for array slicing
|
||||
analysis = self.analyze_template()
|
||||
if var_name in analysis:
|
||||
var_info = analysis[var_name]
|
||||
if var_info["accessed_indices"]:
|
||||
# If this variable is sliced, follow the resulting assignment
|
||||
slice_result = f"{var_name}_slice"
|
||||
if slice_result in self.property_access:
|
||||
trace_variable(slice_result)
|
||||
|
||||
trace_variable(start_var)
|
||||
return properties
|
||||
|
||||
def get_message_vars(self, field_messages: str = "messages") -> Set[str]:
|
||||
"""
|
||||
Get all properties accessed on messages and derived variables.
|
||||
"""
|
||||
all_properties = self.get_downstream_properties(field_messages)
|
||||
|
||||
# Combine all properties from all related variables
|
||||
combined_properties = set()
|
||||
for properties in all_properties.values():
|
||||
combined_properties.update(properties)
|
||||
|
||||
# Also include properties from the message iteration variable
|
||||
analysis = self.analyze_template()
|
||||
if "message" in analysis:
|
||||
combined_properties.update(analysis["message"]["accessed_properties"])
|
||||
|
||||
return combined_properties
|
||||
@@ -51,8 +51,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ds_cfg = ds_cfg or {}
|
||||
|
||||
field_messages = ds_cfg.get("field_messages")
|
||||
message_field_role = ds_cfg.get("message_field_role")
|
||||
message_field_content = ds_cfg.get("message_field_content")
|
||||
message_property_mappings = ds_cfg.get("message_property_mappings")
|
||||
message_field_role = (
|
||||
message_property_mappings.get("role") if message_property_mappings else None
|
||||
)
|
||||
message_field_content = (
|
||||
message_property_mappings.get("content") if message_property_mappings else None
|
||||
)
|
||||
message_field_training = ds_cfg.get("message_field_training")
|
||||
|
||||
builder_kwargs = {}
|
||||
|
||||
@@ -175,6 +175,7 @@ def train(
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
|
||||
pretrain_hooks(cfg, trainer)
|
||||
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
@@ -185,6 +186,7 @@ def train(
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
post_train_hooks(cfg, trainer)
|
||||
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
@@ -15,7 +15,7 @@ _DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
|
||||
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"
|
||||
|
||||
_CHAT_TEMPLATES = {
|
||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||
"alpaca": "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'system' and loop.first %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '### Instruction:\n' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '### Response:\n' + message['content'] + eos_token }}{% endif %}{% if not loop.last %}{{ '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\n\n### Response:\n' }}{% endif %}",
|
||||
"mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1...
|
||||
"mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large...
|
||||
"mistral_v3_tekken": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3-Tekken: Nemo, Pixtral...
|
||||
@@ -38,7 +38,7 @@ def get_chat_template(
|
||||
user_choice: str,
|
||||
jinja_template: Optional[str] = None,
|
||||
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||
):
|
||||
) -> str:
|
||||
"""
|
||||
Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.
|
||||
|
||||
@@ -70,7 +70,7 @@ def get_chat_template(
|
||||
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
|
||||
f"Please add a chat_template in tokenizer config"
|
||||
)
|
||||
return tokenizer.chat_template
|
||||
return tokenizer.chat_template # type: ignore
|
||||
|
||||
if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
|
||||
if not tokenizer:
|
||||
@@ -78,7 +78,7 @@ def get_chat_template(
|
||||
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
|
||||
)
|
||||
if tokenizer.chat_template:
|
||||
return tokenizer.chat_template
|
||||
return tokenizer.chat_template # type: ignore
|
||||
|
||||
user_choice = user_choice[
|
||||
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
|
||||
|
||||
@@ -18,6 +18,7 @@ from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||
)
|
||||
from axolotl.utils.config.models.input.v0_4_1 import DPODataset, KTODataset, SFTDataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model_config
|
||||
|
||||
@@ -258,7 +259,7 @@ def validate_config(
|
||||
cfg: DictDefault,
|
||||
capabilities: Optional[dict] = None,
|
||||
env_capabilities: Optional[dict] = None,
|
||||
):
|
||||
) -> DictDefault:
|
||||
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
|
||||
AxolotlInputConfig = AxolotlInputConfigBase
|
||||
|
||||
@@ -268,6 +269,16 @@ def validate_config(
|
||||
AxolotlInputConfig, # pylint: disable=invalid-name
|
||||
) = merge_input_args()
|
||||
|
||||
# Convert datasets to proper format if needed
|
||||
if cfg.get("datasets"):
|
||||
for idx, ds_cfg in enumerate(cfg["datasets"]):
|
||||
if cfg.get("rl") == "dpo" and not isinstance(ds_cfg, DPODataset):
|
||||
cfg["datasets"][idx] = DPODataset(**ds_cfg)
|
||||
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
|
||||
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
|
||||
elif not isinstance(ds_cfg, SFTDataset):
|
||||
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))
|
||||
|
||||
if capabilities or env_capabilities:
|
||||
if (capabilities and env_capabilities is None) or (
|
||||
env_capabilities and capabilities is None
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
"""
|
||||
Module for pydantic models for configuration
|
||||
"""
|
||||
|
||||
"""Module with Pydantic models for configuration."""
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
@@ -9,12 +6,13 @@ import os
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from annotated_types import MinLen
|
||||
from packaging import version
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
StringConstraints,
|
||||
conlist,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
@@ -24,6 +22,8 @@ from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
||||
|
||||
from .trl import TRLConfig
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||
|
||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||
@@ -33,6 +33,7 @@ class RLType(str, Enum):
|
||||
"""RL trainer type configuration subset"""
|
||||
|
||||
dpo = "dpo" # pylint: disable=invalid-name
|
||||
grpo = "grpo" # pylint: disable=invalid-name
|
||||
ipo = "ipo" # pylint: disable=invalid-name
|
||||
orpo = "orpo" # pylint: disable=invalid-name
|
||||
kto = "kto" # pylint: disable=invalid-name
|
||||
@@ -115,6 +116,9 @@ class RemappedParameters(BaseModel):
|
||||
overrides_of_model_config: Optional[Dict[str, Any]] = Field(
|
||||
default=None, alias="model_config"
|
||||
)
|
||||
overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field(
|
||||
default=None, alias="model_kwargs"
|
||||
)
|
||||
type_of_model: Optional[str] = Field(default=None, alias="model_type")
|
||||
revision_of_model: Optional[str] = Field(default=None, alias="model_revision")
|
||||
|
||||
@@ -163,6 +167,7 @@ class SFTDataset(BaseModel):
|
||||
type: Optional[Union[str, UserDefinedPrompterType]] = None
|
||||
input_transform: Optional[str] = None
|
||||
shards: Optional[int] = None
|
||||
shards_idx: Optional[int] = None
|
||||
preprocess_shards: Optional[int] = None
|
||||
conversation: Optional[str] = None
|
||||
# Do not make this too strict or it will break the validator to choose different dataset class
|
||||
@@ -182,8 +187,13 @@ class SFTDataset(BaseModel):
|
||||
field_human: Optional[str] = None
|
||||
field_model: Optional[str] = None
|
||||
field_messages: Optional[str] = None
|
||||
message_field_role: Optional[str] = None
|
||||
message_field_content: Optional[str] = None
|
||||
message_field_role: Optional[
|
||||
str
|
||||
] = None # deprecated, use message_property_mappings
|
||||
message_field_content: Optional[
|
||||
str
|
||||
] = None # deprecated, use message_property_mappings
|
||||
message_property_mappings: Optional[Dict[str, str]] = None
|
||||
message_field_training: Optional[str] = None
|
||||
message_field_training_detail: Optional[str] = None
|
||||
logprobs_field: Optional[str] = None
|
||||
@@ -195,9 +205,18 @@ class SFTDataset(BaseModel):
|
||||
trust_remote_code: Optional[bool] = False
|
||||
revision: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def handle_legacy_message_fields(cls, data):
|
||||
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
|
||||
return handle_legacy_message_fields_logic(data)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_chat_template_config(cls, data):
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump()
|
||||
|
||||
# Set chat_template to tokenizer_default if not set
|
||||
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
||||
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||
@@ -237,6 +256,7 @@ class DPODataset(BaseModel):
|
||||
type: Optional[Union[UserDefinedDPOType, str]] = None
|
||||
data_files: Optional[List[str]] = None
|
||||
revision: Optional[str] = None
|
||||
field_messages: Optional[str] = None
|
||||
|
||||
|
||||
class StepwiseSupervisedDataset(BaseModel):
|
||||
@@ -273,6 +293,9 @@ class KTODataset(BaseModel):
|
||||
revision: Optional[str] = None
|
||||
|
||||
|
||||
DatasetConfig = Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset]
|
||||
|
||||
|
||||
class LoftQConfig(BaseModel):
|
||||
"""LoftQ configuration subset"""
|
||||
|
||||
@@ -319,6 +342,7 @@ class LoraConfig(BaseModel):
|
||||
peft_use_dora: Optional[bool] = None
|
||||
peft_use_rslora: Optional[bool] = None
|
||||
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
||||
peft_init_lora_weights: Optional[Union[bool, str]] = None
|
||||
|
||||
qlora_sharded_model_loading: Optional[bool] = Field(
|
||||
default=False,
|
||||
@@ -412,6 +436,8 @@ class ReLoRAConfig(BaseModel):
|
||||
class ModelInputConfig(BaseModel):
|
||||
"""model to train on configuration subset"""
|
||||
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
base_model: str
|
||||
base_model_config: Optional[str] = None
|
||||
cls_model_config: Optional[str] = None
|
||||
@@ -426,8 +452,6 @@ class ModelInputConfig(BaseModel):
|
||||
)
|
||||
trust_remote_code: Optional[bool] = None
|
||||
|
||||
model_kwargs: Optional[Dict[str, Any]] = None
|
||||
|
||||
@field_validator("trust_remote_code")
|
||||
@classmethod
|
||||
def hint_trust_remote_code(cls, trust_remote_code):
|
||||
@@ -480,7 +504,7 @@ class HyperparametersConfig(BaseModel):
|
||||
"adopt_adamw",
|
||||
],
|
||||
]
|
||||
] = OptimizerNames.ADAMW_HF.value
|
||||
] = OptimizerNames.ADAMW_HF
|
||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
||||
@@ -492,7 +516,9 @@ class HyperparametersConfig(BaseModel):
|
||||
},
|
||||
)
|
||||
torchdistx_path: Optional[str] = None
|
||||
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
|
||||
lr_scheduler: Optional[
|
||||
Union[SchedulerType, Literal["one_cycle"]]
|
||||
] = SchedulerType.COSINE
|
||||
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
||||
lr_quadratic_warmup: Optional[bool] = None
|
||||
cosine_min_lr_ratio: Optional[float] = None
|
||||
@@ -616,19 +642,19 @@ class RayConfig(BaseModel):
|
||||
use_ray: bool = Field(default=False)
|
||||
ray_run_name: Optional[str] = Field(
|
||||
default=None,
|
||||
metadata={
|
||||
json_schema_extra={
|
||||
"help": "The training results will be saved at `saves/ray_run_name`."
|
||||
},
|
||||
)
|
||||
ray_num_workers: int = Field(
|
||||
default=1,
|
||||
metadata={
|
||||
json_schema_extra={
|
||||
"help": "The number of workers for Ray training. Default is 1 worker."
|
||||
},
|
||||
)
|
||||
resources_per_worker: dict = Field(
|
||||
default_factory=lambda: {"GPU": 1},
|
||||
metadata={
|
||||
json_schema_extra={
|
||||
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
||||
},
|
||||
)
|
||||
@@ -653,35 +679,49 @@ class AxolotlInputConfig(
|
||||
):
|
||||
"""wrapper of all config options"""
|
||||
|
||||
class Config:
|
||||
"""Config for alias"""
|
||||
|
||||
populate_by_name = True
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
strict: Optional[bool] = Field(default=False)
|
||||
resume_from_checkpoint: Optional[str] = None
|
||||
auto_resume_from_checkpoints: Optional[bool] = None
|
||||
resize_token_embeddings_to_32x: Optional[bool] = None
|
||||
mean_resizing_embeddings: Optional[bool] = False
|
||||
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
||||
shrink_embeddings: Optional[bool] = None
|
||||
|
||||
rl: Optional[RLType] = None
|
||||
trl: Optional[TRLConfig] = Field(
|
||||
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
|
||||
)
|
||||
reward_model: Optional[bool] = None
|
||||
process_reward_model: Optional[bool] = None
|
||||
num_labels: Optional[int] = None
|
||||
dpo_use_weighting: Optional[
|
||||
bool
|
||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||
dpo_use_logits_to_keep: Optional[bool] = None
|
||||
|
||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
||||
datasets: Optional[
|
||||
Annotated[
|
||||
list[Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset]],
|
||||
MinLen(1),
|
||||
]
|
||||
] = None
|
||||
|
||||
test_datasets: Optional[
|
||||
Annotated[
|
||||
list[Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset]],
|
||||
MinLen(1),
|
||||
]
|
||||
] = None
|
||||
shuffle_merged_datasets: Optional[bool] = True
|
||||
dataset_prepared_path: Optional[str] = None
|
||||
dataset_shard_num: Optional[int] = None
|
||||
dataset_shard_idx: Optional[int] = None
|
||||
skip_prepare_dataset: Optional[bool] = False
|
||||
|
||||
pretraining_dataset: Optional[ # type: ignore
|
||||
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
|
||||
pretraining_dataset: Optional[
|
||||
Annotated[list[Union[PretrainingDataset, SFTDataset]], MinLen(1)]
|
||||
] = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||
@@ -799,6 +839,10 @@ class AxolotlInputConfig(
|
||||
unsloth_rms_norm: Optional[bool] = None
|
||||
unsloth_rope: Optional[bool] = None
|
||||
|
||||
lora_mlp_kernel: Optional[bool] = None
|
||||
lora_qkv_kernel: Optional[bool] = None
|
||||
lora_o_kernel: Optional[bool] = None
|
||||
|
||||
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
||||
fsdp: Optional[List[str]] = None
|
||||
fsdp_config: Optional[Dict[str, Any]] = None
|
||||
@@ -821,7 +865,7 @@ class AxolotlInputConfig(
|
||||
warmup_steps: Optional[int] = None
|
||||
warmup_ratio: Optional[float] = None
|
||||
eval_steps: Optional[Union[int, float]] = None
|
||||
evals_per_epoch: Optional[Union[int]] = None
|
||||
evals_per_epoch: Optional[int] = None
|
||||
eval_strategy: Optional[str] = None
|
||||
save_steps: Optional[Union[int, float]] = None
|
||||
saves_per_epoch: Optional[int] = None
|
||||
@@ -833,6 +877,7 @@ class AxolotlInputConfig(
|
||||
save_only_model: Optional[bool] = False
|
||||
use_tensorboard: Optional[bool] = None
|
||||
profiler_steps: Optional[int] = None
|
||||
include_tokens_per_second: Optional[bool] = None
|
||||
|
||||
neftune_noise_alpha: Optional[float] = None
|
||||
|
||||
@@ -882,10 +927,15 @@ class AxolotlInputConfig(
|
||||
@classmethod
|
||||
def deprecate_sharegpt_datasets(cls, datasets):
|
||||
for _, ds_cfg in enumerate(datasets):
|
||||
if not ds_cfg.get("type"):
|
||||
# Handle both dict and pydantic model cases
|
||||
ds_type = (
|
||||
ds_cfg.get("type")
|
||||
if isinstance(ds_cfg, dict)
|
||||
else getattr(ds_cfg, "type", None)
|
||||
)
|
||||
if not ds_type:
|
||||
continue
|
||||
|
||||
ds_type = ds_cfg["type"]
|
||||
# skip if it's a dict (for custom user instruction prompt)
|
||||
if isinstance(ds_type, dict):
|
||||
continue
|
||||
@@ -897,6 +947,14 @@ class AxolotlInputConfig(
|
||||
|
||||
return datasets
|
||||
|
||||
@field_serializer("datasets")
|
||||
def datasets_serializer(
|
||||
self, ds_configs: Optional[List[DatasetConfig]]
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
if ds_configs:
|
||||
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
|
||||
return None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_batch_size_fields(cls, data):
|
||||
@@ -1522,12 +1580,42 @@ class AxolotlInputConfig(
|
||||
or data.get("unsloth_lora_qkv")
|
||||
or data.get("unsloth_lora_o")
|
||||
):
|
||||
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
|
||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||
raise ValueError(
|
||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_8bit(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
):
|
||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_axolotl_unsloth(cls, data):
|
||||
is_lora_kernel = any(
|
||||
data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
|
||||
)
|
||||
is_unsloth_lora = any(
|
||||
data.get(k)
|
||||
for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
||||
)
|
||||
if is_lora_kernel and is_unsloth_lora:
|
||||
raise ValueError(
|
||||
"both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_torch_compile_deepspeed(cls, data):
|
||||
@@ -1660,6 +1748,29 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_multigpu_lora_kernels(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
):
|
||||
capabilities = data.get("capabilities")
|
||||
is_fsdp = data.get("fsdp") is not None
|
||||
is_deepspeed = data.get("deepspeed") is not None
|
||||
|
||||
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
||||
if is_fsdp:
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
|
||||
)
|
||||
if is_deepspeed:
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with DeepSpeed."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_adopt_torch_version(cls, data):
|
||||
@@ -1696,3 +1807,77 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
else:
|
||||
data["torch_compile"] = False
|
||||
return data
|
||||
|
||||
|
||||
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
||||
"""
|
||||
Handle backwards compatibility between legacy message field mapping and new property mapping system.
|
||||
|
||||
Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options:
|
||||
- message_field_role: Mapped to the role field
|
||||
- message_field_content: Mapped to the content field
|
||||
|
||||
The new system uses message_property_mappings to support arbitrary field mappings:
|
||||
message_property_mappings:
|
||||
role: source_role_field
|
||||
content: source_content_field
|
||||
additional_field: source_field
|
||||
|
||||
Args:
|
||||
data: Dictionary containing configuration data
|
||||
|
||||
Returns:
|
||||
Updated dictionary with message field mappings consolidated
|
||||
|
||||
Raises:
|
||||
ValueError: If there are conflicts between legacy and new mappings
|
||||
"""
|
||||
data = data.copy() # Create a copy to avoid modifying the original
|
||||
|
||||
if data.get("message_property_mappings") is None:
|
||||
data["message_property_mappings"] = {}
|
||||
|
||||
# Check for conflicts and handle role
|
||||
if "message_field_role" in data:
|
||||
LOG.warning(
|
||||
"message_field_role is deprecated, use message_property_mappings instead. "
|
||||
f"Example: message_property_mappings: {{role: {data['message_field_role']}}}"
|
||||
)
|
||||
if (
|
||||
"role" in data["message_property_mappings"]
|
||||
and data["message_property_mappings"]["role"] != data["message_field_role"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Conflicting message role fields: message_field_role='{data['message_field_role']}' "
|
||||
f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'"
|
||||
)
|
||||
data["message_property_mappings"]["role"] = data["message_field_role"] or "role"
|
||||
|
||||
del data["message_field_role"]
|
||||
elif "role" not in data["message_property_mappings"]:
|
||||
data["message_property_mappings"]["role"] = "role"
|
||||
|
||||
# Check for conflicts and handle content
|
||||
if "message_field_content" in data:
|
||||
LOG.warning(
|
||||
"message_field_content is deprecated, use message_property_mappings instead. "
|
||||
f"Example: message_property_mappings: {{content: {data['message_field_content']}}}"
|
||||
)
|
||||
if (
|
||||
"content" in data["message_property_mappings"]
|
||||
and data["message_property_mappings"]["content"]
|
||||
!= data["message_field_content"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
|
||||
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
|
||||
)
|
||||
data["message_property_mappings"]["content"] = (
|
||||
data["message_field_content"] or "content"
|
||||
)
|
||||
|
||||
del data["message_field_content"]
|
||||
elif "content" not in data["message_property_mappings"]:
|
||||
data["message_property_mappings"]["content"] = "content"
|
||||
|
||||
return data
|
||||
|
||||
35
src/axolotl/utils/config/models/input/v0_4_1/trl.py
Normal file
35
src/axolotl/utils/config/models/input/v0_4_1/trl.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
GRPO specific configuration args
|
||||
"""
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TRLConfig(BaseModel):
|
||||
"""
|
||||
Input args for TRL.
|
||||
"""
|
||||
|
||||
beta: Optional[float] = None
|
||||
max_completion_length: Optional[int] = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Maximum length of the completion for RL training"
|
||||
},
|
||||
)
|
||||
|
||||
# GRPO specific args
|
||||
use_vllm: Optional[bool] = False
|
||||
vllm_device: Optional[str] = "auto"
|
||||
vllm_gpu_memory_utilization: Optional[float] = 0.9
|
||||
vllm_max_model_len: Optional[int] = None
|
||||
vllm_dtype: Optional[str] = "auto"
|
||||
|
||||
reward_funcs: Optional[List[str]] = None
|
||||
num_generations: Optional[int] = None
|
||||
log_completions: Optional[bool] = False
|
||||
|
||||
sync_ref_model: Optional[bool] = False
|
||||
ref_model_mixup_alpha: Optional[float] = 0.9
|
||||
ref_model_sync_steps: Optional[int] = 64
|
||||
@@ -4,15 +4,16 @@ import inspect
|
||||
import logging
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Union
|
||||
|
||||
import yaml
|
||||
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
|
||||
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.prompt_strategies.kto import load as load_kto
|
||||
from axolotl.prompt_strategies.orpo import load as load_orpo
|
||||
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
@@ -57,7 +58,7 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
||||
dataset.save_to_disk(str(prepared_ds_path))
|
||||
|
||||
|
||||
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
||||
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
||||
sig = inspect.signature(ds_transform_fn)
|
||||
if "tokenizer" in sig.parameters:
|
||||
if not tokenizer:
|
||||
@@ -70,6 +71,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
||||
data_set = data_set.map(
|
||||
ds_transform_fn,
|
||||
desc="Mapping RL Dataset",
|
||||
**map_kwargs,
|
||||
)
|
||||
|
||||
return data_set
|
||||
@@ -112,29 +114,21 @@ def drop_long_rl_seq(
|
||||
|
||||
return (len_prompt + len_completion) <= sequence_len
|
||||
|
||||
if rl == "grpo":
|
||||
return True
|
||||
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
|
||||
def load_prepare_preference_datasets(cfg):
|
||||
def load_split(dataset_cfgs, _cfg):
|
||||
split_datasets: List[Any] = []
|
||||
for i, ds_cfg in enumerate(dataset_cfgs):
|
||||
if ds_cfg["ds_type"] == "json":
|
||||
for data_file in ds_cfg["data_files"]:
|
||||
data_files = {ds_cfg["split"]: data_file}
|
||||
ds = load_dataset( # pylint: disable=invalid-name
|
||||
"json",
|
||||
data_files=data_files,
|
||||
split=ds_cfg["split"],
|
||||
)
|
||||
split_datasets.insert(i, ds)
|
||||
else:
|
||||
ds = load_dataset( # pylint: disable=invalid-name
|
||||
ds_cfg["path"],
|
||||
split=ds_cfg["split"],
|
||||
revision=ds_cfg.get("revision", None),
|
||||
)
|
||||
split_datasets.insert(i, ds)
|
||||
use_auth_token = _cfg.hf_use_auth_token
|
||||
for config_dataset in datasets_w_name_generator(dataset_cfgs):
|
||||
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
||||
config_dataset, use_auth_token, streaming=False
|
||||
)
|
||||
split_datasets.append(ds)
|
||||
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
@@ -150,36 +144,45 @@ def load_prepare_preference_datasets(cfg):
|
||||
else:
|
||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||
|
||||
map_kwargs = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
split_datasets[i] = map_dataset(
|
||||
cfg, data_set, ds_transform_fn, tokenizer
|
||||
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||
)
|
||||
elif _cfg.rl == "kto":
|
||||
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
||||
map_kwargs = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
split_datasets[i] = map_dataset(
|
||||
cfg, data_set, ds_transform_fn, tokenizer
|
||||
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||
)
|
||||
else:
|
||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||
# "prompt", "chosen" and "rejected" already preprocessed
|
||||
split_datasets[i] = data_set
|
||||
|
||||
drop_long = partial(
|
||||
drop_long_rl_seq,
|
||||
rl=_cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
if not cfg.skip_prepare_dataset:
|
||||
drop_long = partial(
|
||||
drop_long_rl_seq,
|
||||
rl=_cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} long samples from dataset index {i}"
|
||||
)
|
||||
|
||||
combined_datasets = concatenate_datasets(split_datasets)
|
||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
||||
|
||||
@@ -43,7 +43,7 @@ from axolotl.prompters import (
|
||||
UnsupportedPrompter,
|
||||
)
|
||||
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||
from axolotl.utils.data.shared import load_dataset_w_config
|
||||
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
|
||||
from axolotl.utils.data.utils import (
|
||||
deduplicate_and_log_datasets,
|
||||
drop_long_seq_in_dataset,
|
||||
@@ -180,6 +180,7 @@ def load_tokenized_prepared_datasets(
|
||||
) -> Tuple[DatasetDict, List[Prompter]]:
|
||||
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
||||
tokenizer_name = cfg.tokenizer_config
|
||||
|
||||
ds_hash = str(
|
||||
md5(
|
||||
(
|
||||
@@ -263,30 +264,11 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
datasets = []
|
||||
|
||||
def for_d_in_datasets(dataset_configs):
|
||||
for dataset in dataset_configs:
|
||||
if dataset.name and isinstance(dataset.name, list):
|
||||
# load_dataset doesn't properly handle multiple named configurations
|
||||
# at the same time for a given dataset
|
||||
for name in dataset.name:
|
||||
yield DictDefault({**dataset, "name": name})
|
||||
elif dataset.preprocess_shards and not dataset.shards:
|
||||
for shard in range(dataset.preprocess_shards):
|
||||
yield DictDefault(
|
||||
{
|
||||
**dataset,
|
||||
"shards": dataset.preprocess_shards,
|
||||
"shards_idx": shard,
|
||||
}
|
||||
)
|
||||
else:
|
||||
yield dataset
|
||||
|
||||
streaming_ds = False
|
||||
if preprocess_iterable:
|
||||
streaming_ds = True
|
||||
# pylint: disable=invalid-name
|
||||
for config_dataset in for_d_in_datasets(cfg_datasets):
|
||||
for config_dataset in datasets_w_name_generator(cfg_datasets):
|
||||
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
||||
config_dataset, use_auth_token, streaming=streaming_ds
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
dataset loading shared utils
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -29,9 +30,43 @@ def get_ds_type(config_dataset: DictDefault):
|
||||
return ds_type
|
||||
|
||||
|
||||
def datasets_w_name_generator(dataset_configs: list[DictDefault]):
|
||||
"""
|
||||
Yields dataset configs handling multiple names or preprocess_shards
|
||||
|
||||
Args:
|
||||
dataset_configs: list of dataset configs (equivalent to cfg.datasets)
|
||||
"""
|
||||
for dataset in dataset_configs:
|
||||
if dataset.name and isinstance(dataset.name, list):
|
||||
# load_dataset doesn't properly handle multiple named configurations
|
||||
# at the same time for a given dataset
|
||||
for name in dataset.name:
|
||||
yield DictDefault({**dataset, "name": name})
|
||||
elif dataset.preprocess_shards and not dataset.shards:
|
||||
for shard in range(dataset.preprocess_shards):
|
||||
yield DictDefault(
|
||||
{
|
||||
**dataset,
|
||||
"shards": dataset.preprocess_shards,
|
||||
"shards_idx": shard,
|
||||
}
|
||||
)
|
||||
else:
|
||||
yield dataset
|
||||
|
||||
|
||||
def load_dataset_w_config(
|
||||
config_dataset, auth_token, streaming=False
|
||||
config_dataset: DictDefault, use_auth_token: bool, streaming=False
|
||||
) -> Union[Dataset, DatasetDict]:
|
||||
"""
|
||||
Load a dataset from a config
|
||||
|
||||
Args:
|
||||
config_dataset: single dataset config
|
||||
use_auth_token: whether to use HF auth token
|
||||
streaming: whether to stream the dataset
|
||||
"""
|
||||
# pylint: disable=invalid-name
|
||||
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
||||
ds_from_hub = False
|
||||
@@ -43,7 +78,7 @@ def load_dataset_w_config(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=True,
|
||||
token=auth_token,
|
||||
token=use_auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=ds_trust_remote_code,
|
||||
)
|
||||
@@ -161,7 +196,7 @@ def load_dataset_w_config(
|
||||
name=config_dataset.name,
|
||||
streaming=streaming,
|
||||
data_files=config_dataset.data_files,
|
||||
token=auth_token,
|
||||
token=use_auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
|
||||
@@ -172,10 +172,11 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
||||
)
|
||||
|
||||
try:
|
||||
min_input_len = np.min(get_dataset_lengths(dataset))
|
||||
LOG.debug(f"min_input_len: {min_input_len}")
|
||||
max_input_len = np.max(get_dataset_lengths(dataset))
|
||||
LOG.debug(f"max_input_len: {max_input_len}")
|
||||
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
||||
min_input_len = np.min(ds_lengths)
|
||||
LOG.info(f"min_input_len: {min_input_len}")
|
||||
max_input_len = np.max(ds_lengths)
|
||||
LOG.info(f"max_input_len: {max_input_len}")
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -13,3 +13,26 @@ class DictDefault(Dict):
|
||||
|
||||
def __or__(self, other):
|
||||
return DictDefault(super().__ror__(other))
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
# workaround for pickle/unpickle issues and __frozen not being available
|
||||
try:
|
||||
isFrozen = hasattr( # pylint: disable=invalid-name
|
||||
self, "__frozen"
|
||||
) and object.__getattribute__(self, "__frozen")
|
||||
except AttributeError:
|
||||
isFrozen = False # pylint: disable=invalid-name
|
||||
|
||||
if isFrozen and name not in super().keys():
|
||||
raise KeyError(name)
|
||||
super(Dict, self).__setitem__(name, value) # pylint: disable=bad-super-call
|
||||
try:
|
||||
p = object.__getattribute__(self, "__parent")
|
||||
key = object.__getattribute__(self, "__key")
|
||||
except AttributeError:
|
||||
p = None
|
||||
key = None
|
||||
if p is not None:
|
||||
p[key] = self
|
||||
object.__delattr__(self, "__parent")
|
||||
object.__delattr__(self, "__key")
|
||||
|
||||
75
src/axolotl/utils/lora.py
Normal file
75
src/axolotl/utils/lora.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# Copyright 2025 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
module to get the state dict of a merged lora model
|
||||
"""
|
||||
import torch
|
||||
from peft.tuners.tuners_utils import onload_layer
|
||||
from peft.utils import ModulesToSaveWrapper, _get_submodules
|
||||
|
||||
|
||||
def get_lora_merged_state_dict(
|
||||
model: torch.nn.Module,
|
||||
) -> dict:
|
||||
r"""
|
||||
Create and return a state_dict that has the LoRA deltas
|
||||
merged into the base model’s weights, without modifying `model` in place.
|
||||
|
||||
Arguments:
|
||||
model (torch.nn.Module): A model that has LoRA/PEFT adapters attached.
|
||||
|
||||
Returns:
|
||||
dict: A state_dict of the merged parameters.
|
||||
"""
|
||||
|
||||
base_model_prefix = "base_model.model."
|
||||
state_dict = {}
|
||||
key_list = [key for key, _ in model.named_modules() if model.prefix not in key]
|
||||
for key in key_list:
|
||||
try:
|
||||
_, target, _ = _get_submodules(model, key)
|
||||
except AttributeError:
|
||||
continue
|
||||
with onload_layer(target):
|
||||
weight_key = key.replace(base_model_prefix, "") + ".weight"
|
||||
bias_key = key.replace(base_model_prefix, "") + ".bias"
|
||||
if hasattr(target, "base_layer"):
|
||||
target.merge(safe_merge=True, adapter_names=None)
|
||||
# get the state_dict of target.base_layer
|
||||
layer_state_dict = target.base_layer.state_dict()
|
||||
state_dict[weight_key] = layer_state_dict["weight"]
|
||||
elif isinstance(target, ModulesToSaveWrapper):
|
||||
# save any additional trainable modules part of `modules_to_save`
|
||||
new_module = target.modules_to_save[target.active_adapter]
|
||||
if hasattr(new_module, "base_layer"):
|
||||
# check if the module is itself a tuner layer
|
||||
new_module.merge(safe_merge=True, adapter_names=None)
|
||||
layer_state_dict = new_module.state_dict()
|
||||
state_dict[weight_key] = layer_state_dict["weight"]
|
||||
elif hasattr(target, "weight"):
|
||||
if any(
|
||||
skip in key
|
||||
for skip in [
|
||||
".original_module",
|
||||
".modules_to_save",
|
||||
".base_layer",
|
||||
]
|
||||
):
|
||||
continue
|
||||
layer_state_dict = target.state_dict()
|
||||
state_dict[weight_key] = layer_state_dict["weight"]
|
||||
if hasattr(target, "bias") and "bias" in layer_state_dict.keys():
|
||||
state_dict[bias_key] = layer_state_dict["bias"]
|
||||
return state_dict
|
||||
@@ -357,8 +357,8 @@ class ModelLoader:
|
||||
|
||||
# init model kwargs
|
||||
self.model_kwargs: Dict[str, Any] = {}
|
||||
if cfg.model_kwargs:
|
||||
for key, val in cfg.model_kwargs.items():
|
||||
if cfg.overrides_of_model_kwargs:
|
||||
for key, val in cfg.overrides_of_model_kwargs.items():
|
||||
self.model_kwargs[key] = val
|
||||
|
||||
# init model
|
||||
@@ -414,6 +414,7 @@ class ModelLoader:
|
||||
has_remote_code = "AutoModelForCausalLM" in auto_map_config
|
||||
else:
|
||||
has_remote_code = False
|
||||
|
||||
if has_remote_code and self.cfg.trust_remote_code is False:
|
||||
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
|
||||
has_remote_code = self.cfg.trust_remote_code
|
||||
@@ -425,10 +426,6 @@ class ModelLoader:
|
||||
|
||||
if self.cfg.is_llama_derived_model:
|
||||
self.patch_loss_llama()
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora()
|
||||
elif self.cfg.is_llama_derived_model:
|
||||
self.patch_llama_derived_model()
|
||||
|
||||
@@ -442,6 +439,11 @@ class ModelLoader:
|
||||
|
||||
patch_mistral_cross_entropy()
|
||||
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora(self.cfg)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
if hasattr(self.model_config, "model_type"):
|
||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||
@@ -472,9 +474,7 @@ class ModelLoader:
|
||||
return importlib.util.find_spec("flash_attn") is not None
|
||||
|
||||
def patch_loss_llama(self) -> None:
|
||||
"""
|
||||
Patch loss functions
|
||||
"""
|
||||
"""Patch loss functions and other optimizations"""
|
||||
if self.has_flash_attn:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
patch_fa_llama_cross_entropy,
|
||||
@@ -494,15 +494,14 @@ class ModelLoader:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
|
||||
|
||||
patch_unsloth_layernorm()
|
||||
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora()
|
||||
|
||||
def patch_llama_derived_model(self) -> None:
|
||||
"""
|
||||
Modify all llama derived models in one block
|
||||
"""
|
||||
"""Modify all llama derived models in one block"""
|
||||
self.patch_loss_llama()
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
@@ -1013,7 +1012,8 @@ class ModelLoader:
|
||||
if hasattr(module, "weight"):
|
||||
module.to(dist_dtype)
|
||||
|
||||
def apply_lora_patch(self) -> None:
|
||||
# TODO: Deprecate this.
|
||||
def apply_unsloth_lora_patch(self) -> None:
|
||||
if self.cfg.unsloth_lora_mlp:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
|
||||
|
||||
@@ -1027,6 +1027,16 @@ class ModelLoader:
|
||||
|
||||
integrate_rope_embeddings()
|
||||
|
||||
def apply_lora_patch(self) -> None:
|
||||
if (
|
||||
self.cfg.lora_mlp_kernel
|
||||
or self.cfg.lora_qkv_kernel
|
||||
or self.cfg.lora_o_kernel
|
||||
):
|
||||
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
|
||||
|
||||
apply_lora_kernel_patches(self.model, self.cfg)
|
||||
|
||||
def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||
self.apply_patches()
|
||||
self.set_auto_model_loader()
|
||||
@@ -1053,9 +1063,12 @@ class ModelLoader:
|
||||
if self.cfg.resize_token_embeddings_to_32x
|
||||
else len(self.tokenizer)
|
||||
)
|
||||
if (
|
||||
hasattr(self.model, "get_input_embeddings")
|
||||
and self.model.get_input_embeddings().num_embeddings != embeddings_len
|
||||
if hasattr(self.model, "get_input_embeddings") and (
|
||||
self.model.get_input_embeddings().num_embeddings < embeddings_len
|
||||
or (
|
||||
self.model.get_input_embeddings().num_embeddings > embeddings_len
|
||||
and self.cfg.shrink_embeddings
|
||||
)
|
||||
):
|
||||
resize_kwargs = {}
|
||||
if self.cfg.mean_resizing_embeddings is not None:
|
||||
@@ -1168,6 +1181,7 @@ class ModelLoader:
|
||||
if self.cfg.adapter is not None:
|
||||
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
||||
|
||||
self.apply_unsloth_lora_patch()
|
||||
self.apply_lora_patch()
|
||||
|
||||
for _ in range(3):
|
||||
@@ -1307,8 +1321,11 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
if loftq_bits:
|
||||
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||
if cfg.peft_init_lora_weights:
|
||||
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
|
||||
if cfg.peft_use_dora:
|
||||
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
||||
LOG.info("Initializing LoRA weights using dora. This might take longer.")
|
||||
if cfg.peft_use_rslora:
|
||||
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
|
||||
if cfg.peft_layer_replication:
|
||||
|
||||
@@ -4,13 +4,17 @@ helper util to calculate dataset lengths
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_dataset_lengths(dataset):
|
||||
if "length" in dataset.data.column_names:
|
||||
lengths = np.array(dataset.data.column("length"))
|
||||
elif "position_ids" in dataset.data.column_names:
|
||||
position_ids = dataset.data.column("position_ids")
|
||||
def get_dataset_lengths(dataset, from_arrow=False):
|
||||
if "length" in dataset.column_names:
|
||||
lengths = np.array(dataset["length"])
|
||||
elif "position_ids" in dataset.column_names:
|
||||
position_ids = dataset["position_ids"]
|
||||
lengths = np.array([x[-1] + 1 for x in position_ids])
|
||||
else:
|
||||
input_ids = dataset.data.column("input_ids")
|
||||
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
||||
if from_arrow:
|
||||
input_ids = dataset.data.column("input_ids")
|
||||
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
|
||||
else:
|
||||
input_ids = dataset["input_ids"]
|
||||
lengths = np.array([len(seq) for seq in input_ids])
|
||||
return lengths
|
||||
|
||||
@@ -396,8 +396,8 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
):
|
||||
total_num_tokens = np.sum(
|
||||
train_dataset.select_columns("input_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
||||
.to_pandas()["input_ids"]
|
||||
.apply(len)
|
||||
.values
|
||||
)
|
||||
LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True)
|
||||
@@ -576,7 +576,7 @@ def prepare_opinionated_env(cfg):
|
||||
def setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||
):
|
||||
if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"):
|
||||
if cfg.rl:
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||
trainer_builder.model_ref = model[1]
|
||||
trainer_builder.peft_config = model[2]
|
||||
|
||||
@@ -9,7 +9,7 @@ from e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -79,6 +79,7 @@ class TestKnowledgeDistillation:
|
||||
def test_llama_kd(self, temp_dir, kd_min_cfg):
|
||||
cfg = DictDefault(kd_min_cfg)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
@@ -89,6 +90,12 @@ class TestKnowledgeDistillation:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"load_in_8bit",
|
||||
@@ -109,6 +116,7 @@ class TestKnowledgeDistillation:
|
||||
| kd_min_cfg
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
@@ -119,3 +127,9 @@ class TestKnowledgeDistillation:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
|
||||
)
|
||||
|
||||
163
tests/e2e/integrations/test_kl_loss.py
Normal file
163
tests/e2e/integrations/test_kl_loss.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
sanity checks on kl loss and gradients
|
||||
"""
|
||||
import torch
|
||||
|
||||
# Import both implementations
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl import loss as eager_loss
|
||||
from axolotl.integrations.kd.topk_logprob.forward_kl_triton import loss as triton_loss
|
||||
|
||||
|
||||
def test_kl_loss_gradient():
|
||||
"""Test that the gradient of the Triton implementation matches the eager implementation."""
|
||||
|
||||
# Set the random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create random inputs
|
||||
batch_size = 2
|
||||
seq_len = 3
|
||||
vocab_size = 100
|
||||
top_k = 5
|
||||
|
||||
# Generate random student logits
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, requires_grad=True, device="cuda"
|
||||
)
|
||||
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
|
||||
|
||||
# Generate random target token IDs, ensuring they're valid indices
|
||||
# pylint: disable=duplicate-code
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
|
||||
# Generate random target logprobs (before normalization)
|
||||
target_logprobs_raw = torch.randn(batch_size, seq_len, top_k, device="cuda")
|
||||
|
||||
# Normalize the target logprobs to ensure they form a valid distribution
|
||||
target_logprobs = torch.log_softmax(target_logprobs_raw, dim=-1)
|
||||
|
||||
# Create a random mask with some tokens masked out
|
||||
target_mask = torch.randint(
|
||||
0, 2, (batch_size, seq_len, top_k), device="cuda"
|
||||
).float()
|
||||
|
||||
# Additional parameters
|
||||
num_items_in_batch = batch_size * seq_len
|
||||
kd_temperature = 1.0
|
||||
top_k_before_softmax = 0 # Test both modes
|
||||
|
||||
# Compute the loss and gradients with eager implementation
|
||||
loss_eager = eager_loss(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
loss_eager.backward()
|
||||
grad_eager = student_logits.grad.clone()
|
||||
|
||||
# Reset gradients
|
||||
student_logits.grad.zero_()
|
||||
|
||||
# Compute the loss and gradients with Triton implementation
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
loss_triton.backward()
|
||||
grad_triton = student_logits_triton.grad.clone()
|
||||
|
||||
# Compare loss values
|
||||
print(f"Eager loss: {loss_eager.item()}")
|
||||
print(f"Triton loss: {loss_triton.item()}")
|
||||
loss_diff = abs(loss_eager.item() - loss_triton.item())
|
||||
print(f"Loss difference: {loss_diff}")
|
||||
assert loss_diff < 1e-5, "Loss values differ significantly!"
|
||||
|
||||
# Compare gradients
|
||||
grad_diff = (grad_eager - grad_triton).abs().max().item()
|
||||
print(f"Max gradient difference: {grad_diff}")
|
||||
|
||||
# Print some sample gradients
|
||||
sample_idx = (0, 0, 0) # (batch, seq, vocab)
|
||||
print(f"Sample eager gradient: {grad_eager[sample_idx].item()}")
|
||||
print(f"Sample triton gradient: {grad_triton[sample_idx].item()}")
|
||||
|
||||
# Compute relative difference for non-zero gradients
|
||||
mask = grad_eager.abs() > 1e-10
|
||||
if mask.sum() > 0:
|
||||
rel_diff = (
|
||||
(
|
||||
(grad_eager[mask] - grad_triton[mask]).abs()
|
||||
/ (grad_eager[mask].abs() + 1e-10)
|
||||
)
|
||||
.max()
|
||||
.item()
|
||||
)
|
||||
print(f"Max relative gradient difference: {rel_diff}")
|
||||
assert rel_diff < 1e-3, "Gradients differ significantly!"
|
||||
|
||||
# Also test top_k_before_softmax = 1 mode
|
||||
top_k_before_softmax = 1
|
||||
|
||||
# Reset the gradients
|
||||
student_logits = torch.randn(
|
||||
batch_size, seq_len, vocab_size, requires_grad=True, device="cuda"
|
||||
)
|
||||
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
|
||||
|
||||
# Compute the loss and gradients with eager implementation
|
||||
loss_eager = eager_loss(
|
||||
student_logits,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
loss_eager.backward()
|
||||
grad_eager = student_logits.grad.clone()
|
||||
|
||||
# Compute the loss and gradients with Triton implementation
|
||||
loss_triton = triton_loss(
|
||||
student_logits_triton,
|
||||
target_token_ids,
|
||||
target_logprobs,
|
||||
target_mask,
|
||||
num_items_in_batch,
|
||||
kd_temperature,
|
||||
top_k_before_softmax,
|
||||
)
|
||||
loss_triton.backward()
|
||||
grad_triton = student_logits_triton.grad.clone()
|
||||
|
||||
# Compare gradients for top_k_before_softmax = 1
|
||||
grad_diff = (grad_eager - grad_triton).abs().max().item()
|
||||
print("\nWith top_k_before_softmax=1:")
|
||||
print(f"Max gradient difference: {grad_diff}")
|
||||
|
||||
# Compute relative difference for non-zero gradients
|
||||
mask = grad_eager.abs() > 1e-10
|
||||
if mask.sum() > 0:
|
||||
rel_diff = (
|
||||
(
|
||||
(grad_eager[mask] - grad_triton[mask]).abs()
|
||||
/ (grad_eager[mask].abs() + 1e-10)
|
||||
)
|
||||
.max()
|
||||
.item()
|
||||
)
|
||||
assert (
|
||||
rel_diff < 1e-3
|
||||
), f"Gradients differ significantly, Max relative gradient difference: {rel_diff}"
|
||||
204
tests/e2e/integrations/test_logsumexp.py
Normal file
204
tests/e2e/integrations/test_logsumexp.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
sanity checks on logsumexp kernel validity
|
||||
"""
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from axolotl.integrations.kd.topk_logprob.logsumexp import logsumexp_kernel
|
||||
|
||||
|
||||
# PyTorch implementation of logsumexp for reference
|
||||
def torch_logsumexp(logits):
|
||||
"""PyTorch implementation of logsumexp over last dimension"""
|
||||
return torch.logsumexp(logits, dim=-1)
|
||||
|
||||
|
||||
# Wrapper function for Triton logsumexp kernel
|
||||
def triton_logsumexp(logits):
|
||||
"""Triton implementation of logsumexp over last dimension"""
|
||||
B, S, V = logits.shape # pylint: disable=invalid-name
|
||||
output = torch.empty((B, S), dtype=torch.float32, device=logits.device)
|
||||
|
||||
grid = (B * S,)
|
||||
logsumexp_kernel[grid](
|
||||
logits.contiguous(),
|
||||
output,
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
min(1024, triton.next_power_of_2(V)),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TritonLogSumExp(torch.autograd.Function):
|
||||
"""
|
||||
Wrap a custom autograd function to use the Triton logsumexp for gradient testing
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, logits):
|
||||
B, S, V = logits.shape # pylint: disable=invalid-name
|
||||
output = torch.empty((B, S), dtype=torch.float32, device=logits.device)
|
||||
|
||||
# Save inputs for backward pass
|
||||
ctx.save_for_backward(logits)
|
||||
ctx.shape = logits.shape
|
||||
|
||||
grid = (B * S,)
|
||||
logsumexp_kernel[grid](
|
||||
logits.contiguous(),
|
||||
output,
|
||||
B,
|
||||
S,
|
||||
V,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
min(1024, triton.next_power_of_2(V)),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(logits,) = ctx.saved_tensors
|
||||
|
||||
# For logsumexp, the gradient is softmax(input) * grad_output
|
||||
# First compute the logsumexp
|
||||
lse = TritonLogSumExp.apply(logits)
|
||||
|
||||
# Compute softmax by exponentiating differences
|
||||
softmax_output = torch.exp(logits - lse.unsqueeze(-1))
|
||||
|
||||
# Compute gradient of logsumexp by multiplying the softmax output by the gradient
|
||||
grad_input = softmax_output * grad_output.unsqueeze(-1)
|
||||
|
||||
return grad_input
|
||||
|
||||
|
||||
def test_logsumexp_values():
|
||||
"""Test that the Triton logsumexp implementation matches PyTorch's"""
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Test with various input shapes
|
||||
test_shapes = [
|
||||
(2, 3, 10), # small vocab
|
||||
(4, 5, 100), # medium vocab
|
||||
(2, 2, 32000), # large vocab (typical for LLMs)
|
||||
]
|
||||
|
||||
for shape in test_shapes:
|
||||
# Create random input tensors
|
||||
logits = torch.randn(shape, device="cuda", requires_grad=False)
|
||||
|
||||
# Compute logsumexp using both implementations
|
||||
torch_result = torch_logsumexp(logits)
|
||||
triton_result = triton_logsumexp(logits)
|
||||
|
||||
# Compare results
|
||||
max_diff = (torch_result - triton_result).abs().max().item()
|
||||
print(f"Shape {shape}, Max diff: {max_diff}")
|
||||
|
||||
# Assert that the results are very close
|
||||
assert max_diff < 1e-5, f"Results differ for shape {shape}: max diff {max_diff}"
|
||||
|
||||
|
||||
def test_logsumexp_edge_cases():
|
||||
"""Test edge cases for numerical stability"""
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Case 1: Very large values that might cause overflow
|
||||
logits_large = torch.ones(2, 3, 100, device="cuda") * 1000
|
||||
|
||||
# Case 2: Very small values that might cause underflow
|
||||
logits_small = torch.ones(2, 3, 100, device="cuda") * -1000
|
||||
|
||||
# Case 3: Mix of large and small values
|
||||
logits_mixed = torch.zeros(2, 3, 100, device="cuda")
|
||||
logits_mixed[:, :, 0] = 1000 # One very large value
|
||||
|
||||
# Case 4: All identical values
|
||||
logits_identical = torch.ones(2, 3, 100, device="cuda") * 5
|
||||
|
||||
# Case 5: Extreme values with NaN check
|
||||
logits_extreme = torch.cat(
|
||||
[
|
||||
torch.full((1, 3, 50), 1e10, device="cuda"),
|
||||
torch.full((1, 3, 50), -1e10, device="cuda"),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
for i, logits in enumerate(
|
||||
[logits_large, logits_small, logits_mixed, logits_identical, logits_extreme]
|
||||
):
|
||||
# Compute logsumexp using both implementations
|
||||
torch_result = torch_logsumexp(logits)
|
||||
triton_result = triton_logsumexp(logits)
|
||||
|
||||
# Check for NaNs
|
||||
assert not torch.isnan(
|
||||
torch_result
|
||||
).any(), f"PyTorch produced NaNs for case {i+1}"
|
||||
assert not torch.isnan(
|
||||
triton_result
|
||||
).any(), f"Triton produced NaNs for case {i+1}"
|
||||
|
||||
# Compare results
|
||||
max_diff = (torch_result - triton_result).abs().max().item()
|
||||
print(f"Edge case {i+1}, Max diff: {max_diff}")
|
||||
|
||||
# For very extreme values, allow a bit more tolerance
|
||||
if i == 4: # extreme case
|
||||
assert max_diff < 1e-2, f"Results differ too much for edge case {i+1}"
|
||||
else:
|
||||
assert max_diff < 1e-5, f"Results differ too much for edge case {i+1}"
|
||||
|
||||
|
||||
def test_logsumexp_gradients():
|
||||
"""Test that the gradients of Triton logsumexp match PyTorch's"""
|
||||
# Set random seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create input tensors with gradients enabled
|
||||
shapes = [(2, 3, 10), (4, 5, 100)]
|
||||
|
||||
for shape in shapes:
|
||||
# Create two identical tensors for PyTorch and Triton
|
||||
logits_torch = torch.randn(shape, device="cuda", requires_grad=True)
|
||||
logits_triton = logits_torch.clone().detach().requires_grad_(True)
|
||||
|
||||
# Forward pass
|
||||
torch_output = torch_logsumexp(logits_torch)
|
||||
triton_output = TritonLogSumExp.apply(logits_triton)
|
||||
|
||||
# Compare forward pass values
|
||||
max_diff_forward = (torch_output - triton_output).abs().max().item()
|
||||
assert max_diff_forward < 1e-5, f"Forward pass values differ for shape {shape}"
|
||||
|
||||
# Create random gradient
|
||||
grad_output = torch.randn_like(torch_output)
|
||||
|
||||
# Backward pass
|
||||
torch_output.backward(grad_output)
|
||||
triton_output.backward(grad_output)
|
||||
|
||||
# Compare gradients
|
||||
max_diff_grad = (logits_torch.grad - logits_triton.grad).abs().max().item()
|
||||
print(f"Shape {shape}, Max gradient diff: {max_diff_grad}")
|
||||
|
||||
# Assert that gradients are very close
|
||||
assert (
|
||||
max_diff_grad < 1e-5
|
||||
), f"Gradients differ for shape {shape}: max diff {max_diff_grad}"
|
||||
76
tests/e2e/kernels/test_geglu.py
Normal file
76
tests/e2e/kernels/test_geglu.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Tests for GEGLU activation function Triton kernels."""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.kernels.geglu import geglu_backward, geglu_forward
|
||||
|
||||
|
||||
def test_geglu_forward_shape():
|
||||
"""Test that GEGLU forward pass preserves expected shapes."""
|
||||
batch, seq_len, hidden_dim = 2, 3, 64
|
||||
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
||||
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
||||
|
||||
out = geglu_forward(gate, up)
|
||||
assert out.shape == (batch, seq_len, hidden_dim)
|
||||
assert out.dtype == gate.dtype
|
||||
assert out.device == gate.device
|
||||
|
||||
|
||||
def test_geglu_forward_values():
|
||||
"""Test GEGLU forward pass matches PyTorch reference implementation."""
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
# Custom implementation
|
||||
triton_out = geglu_forward(gate.clone(), up.clone())
|
||||
|
||||
# PyTorch reference
|
||||
torch_out = F.gelu(gate) * up
|
||||
|
||||
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
|
||||
|
||||
|
||||
def test_geglu_backward():
|
||||
"""Test GEGLU backward pass matches PyTorch autograd."""
|
||||
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||
grad_output = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
# PyTorch reference - compute intermediates
|
||||
gelu_gate = F.gelu(gate)
|
||||
torch_out = gelu_gate * up
|
||||
torch_out.backward(grad_output)
|
||||
|
||||
# Custom backward pass
|
||||
gate_clone = gate.clone().detach()
|
||||
up_clone = up.clone().detach()
|
||||
grad_output_clone = grad_output.clone()
|
||||
|
||||
h, grad_gate, grad_up = geglu_backward(grad_output_clone, gate_clone, up_clone)
|
||||
|
||||
# Compare outputs and gradients
|
||||
assert torch.allclose(h, torch_out, rtol=1e-3)
|
||||
assert torch.allclose(grad_gate, gate.grad, rtol=1e-3)
|
||||
assert torch.allclose(grad_up, up.grad, rtol=1e-3)
|
||||
|
||||
|
||||
def test_geglu_inplace_preservation():
|
||||
"""Test that GEGLU backward doesn't modify original tensors unexpectedly."""
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
grad_output = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
gate_copy = gate.clone()
|
||||
up_copy = up.clone()
|
||||
grad_copy = grad_output.clone()
|
||||
|
||||
geglu_backward(grad_output, gate, up)
|
||||
|
||||
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
|
||||
assert not torch.equal(up, up_copy), "Up should be modified in-place"
|
||||
assert not torch.equal(
|
||||
grad_output, grad_copy
|
||||
), "Grad output should be modified in-place"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user