Compare commits
38 Commits
feat/linea
...
seq-parall
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee489d16bf | ||
|
|
d88e071120 | ||
|
|
a4170030ab | ||
|
|
bf842730a5 | ||
|
|
1db6ad60a7 | ||
|
|
29b366b2e1 | ||
|
|
b53a41372f | ||
|
|
02f45e94be | ||
|
|
954e192f38 | ||
|
|
8dfadc2b3c | ||
|
|
23a9fcb0a7 | ||
|
|
c3d4f6e295 | ||
|
|
7fa690fac8 | ||
|
|
3c743c4bfb | ||
|
|
91bb95685a | ||
|
|
b194e17c28 | ||
|
|
3aac3b1da9 | ||
|
|
3d8425fa91 | ||
|
|
97a2fa2781 | ||
|
|
a98526ef78 | ||
|
|
2e57391bf8 | ||
|
|
aa45fed451 | ||
|
|
a09a5cfd1c | ||
|
|
40362d60e0 | ||
|
|
ffae8d6a95 | ||
|
|
fdbb1a207c | ||
|
|
30046315d9 | ||
|
|
e37a4a536a | ||
|
|
44f64ab627 | ||
|
|
826f1b1494 | ||
|
|
526e5ee8b8 | ||
|
|
fd8cb32547 | ||
|
|
e48e2df4dd | ||
|
|
b7616022ab | ||
|
|
1faf1a5c5a | ||
|
|
5bbad5ef93 | ||
|
|
a971eb4ce6 | ||
|
|
a620d481e2 |
12
.github/workflows/base.yml
vendored
12
.github/workflows/base.yml
vendored
@@ -22,12 +22,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.10"
|
||||
pytorch: 2.4.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
@@ -40,6 +34,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
2
.github/workflows/docs.yml
vendored
2
.github/workflows/docs.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.11'
|
||||
- name: install dependencies
|
||||
run: |
|
||||
python3 -m pip install jupyter
|
||||
|
||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -19,6 +19,6 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
|
||||
7
.github/workflows/main.yml
vendored
7
.github/workflows/main.yml
vendored
@@ -24,8 +24,13 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
16
.github/workflows/multi-gpu-e2e.yml
vendored
16
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -4,6 +4,10 @@ on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'tests/e2e/multigpu/*.py'
|
||||
- 'requirements.txt'
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/multi-gpu-e2e.yml'
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||
@@ -24,13 +28,21 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
axolotl_extras: # no vllm support for 2.4.1
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras: vllm
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
# awaiting vllm#12721
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
@@ -42,7 +54,7 @@ jobs:
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
5
.github/workflows/nightlies.yml
vendored
5
.github/workflows/nightlies.yml
vendored
@@ -22,6 +22,11 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
2
.github/workflows/pypi.yml
vendored
2
.github/workflows/pypi.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
|
||||
20
.github/workflows/tests-nightly.yml
vendored
20
.github/workflows/tests-nightly.yml
vendored
@@ -12,7 +12,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
env:
|
||||
@@ -25,13 +25,8 @@ jobs:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11"]
|
||||
pytorch_version: ["2.4.1", "2.5.1"]
|
||||
exclude:
|
||||
- python_version: "3.10"
|
||||
pytorch_version: "2.4.1"
|
||||
- python_version: "3.10"
|
||||
pytorch_version: "2.5.1"
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -112,13 +107,20 @@ jobs:
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
25
.github/workflows/tests.yml
vendored
25
.github/workflows/tests.yml
vendored
@@ -35,7 +35,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
env:
|
||||
@@ -48,13 +48,8 @@ jobs:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11"]
|
||||
pytorch_version: ["2.4.1", "2.5.1"]
|
||||
exclude:
|
||||
- python_version: "3.10"
|
||||
pytorch_version: "2.4.1"
|
||||
- python_version: "3.10"
|
||||
pytorch_version: "2.5.1"
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -127,7 +122,7 @@ jobs:
|
||||
max-parallel: 1
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.4.1", "2.5.1"]
|
||||
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -209,14 +204,14 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
@@ -251,13 +246,19 @@ jobs:
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
@@ -51,7 +51,7 @@ Features:
|
||||
|
||||
**Requirements**:
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python ≥3.10
|
||||
- Python 3.11
|
||||
- PyTorch ≥2.4.1
|
||||
|
||||
### Installation
|
||||
|
||||
@@ -4,8 +4,8 @@ set -e
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -37,15 +37,11 @@ temp_dir = tempfile.mkdtemp()
|
||||
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||
f.write(dockerfile_contents)
|
||||
|
||||
cicd_image = (
|
||||
Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
)
|
||||
.env(df_args)
|
||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
)
|
||||
cicd_image = Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
).env(df_args)
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
modal application to run axolotl gpu tests in Modal
|
||||
"""
|
||||
"""Modal app to run axolotl GPU tests"""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
|
||||
@@ -46,6 +46,10 @@ overrides_of_model_config:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
# optional overrides the base model loading from_pretrained
|
||||
overrides_of_model_kwargs:
|
||||
# use_cache: False
|
||||
|
||||
# optional overrides to the bnb 4bit quantization configuration
|
||||
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
||||
bnb_config_kwargs:
|
||||
@@ -87,7 +91,12 @@ datasets:
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
|
||||
data_files: # Optional[str] path to source data files
|
||||
shards: # Optional[int] number of shards to split data into
|
||||
|
||||
shards: # Optional[int] split dataset into N pieces (use with shards_idx)
|
||||
shards_idx: # Optional[int] = 0 the index of sharded dataset to use
|
||||
|
||||
preprocess_shards: # Optional[int] process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)
|
||||
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
train_on_split: train # Optional[str] name of dataset split to load from
|
||||
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||
@@ -133,10 +142,19 @@ datasets:
|
||||
|
||||
# Key containing the messages (default: "messages")
|
||||
field_messages: messages
|
||||
# Key for role in each message (default: "role")
|
||||
message_field_role: role
|
||||
# Key for content in each message (default: "content")
|
||||
message_field_content: content
|
||||
|
||||
# Mapping of properties from the input dataset to the chat template.
|
||||
# (default: message_property_mappings={'role':'role', 'content':'content'})
|
||||
# If a property exists in the template but not in this mapping, the system will attempt
|
||||
# to load it directly from the message using the property name as the key.
|
||||
# Example: In the mapping below, 'from' is loaded from input dataset and used as 'role',
|
||||
# while 'value' is loaded and used as 'content' in the chat template.
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
# ...
|
||||
|
||||
message_property_mappings:
|
||||
|
||||
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
||||
roles:
|
||||
@@ -296,6 +314,13 @@ lora_modules_to_save:
|
||||
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# Apply custom LoRA autograd functions and activation function Triton kernels for
|
||||
# speed and memory savings
|
||||
# See: https://axolotl-ai-cloud.github.io/axolotl/docs/lora_optims.html
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
|
||||
# LoRA+ hyperparameters
|
||||
# For more details about the following options, see:
|
||||
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
|
||||
@@ -344,6 +369,9 @@ comet_mode: # Create a new experiment ("create") or log to an existing one ("get
|
||||
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
|
||||
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
|
||||
|
||||
# Tensorboard
|
||||
use_tensorboard: # Optional[bool]
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
output_dir: ./completed-model
|
||||
|
||||
@@ -378,6 +406,12 @@ save_total_limit: # Checkpoints saved at a time
|
||||
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
||||
max_steps:
|
||||
|
||||
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
|
||||
include_tokens_per_second: # Optional[bool]
|
||||
|
||||
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
|
||||
auto_find_batch_size: # Optional[bool]
|
||||
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
||||
|
||||
@@ -6,7 +6,7 @@ order: 3
|
||||
|
||||
## sharegpt
|
||||
|
||||
IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
|
||||
IMPORTANT: ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
|
||||
|
||||
## pygmalion
|
||||
|
||||
@@ -22,7 +22,7 @@ Chat Template strategy uses a jinja2 template that converts a list of messages i
|
||||
{"conversations": [{"role": "...", "content": "..."}]}
|
||||
```
|
||||
|
||||
See `config.qmd` for full configs and supported templates.
|
||||
See [configs](../config.qmd) for full configs and supported templates.
|
||||
|
||||
### Migrating from sharegpt
|
||||
|
||||
@@ -42,8 +42,9 @@ datasets:
|
||||
type: chat_template
|
||||
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
# new (if setting a new chat_template like chatml, gemma, etc)
|
||||
chat_template: chatml
|
||||
@@ -52,8 +53,9 @@ datasets:
|
||||
type: chat_template
|
||||
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
```
|
||||
|
||||
We recommend checking the below examples for other usecases.
|
||||
@@ -138,8 +140,9 @@ datasets:
|
||||
type: chat_template
|
||||
chat_template: tokenizer_default
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
roles_to_train: []
|
||||
train_on_eos: turn
|
||||
message_field_training: train
|
||||
|
||||
@@ -1,14 +1,458 @@
|
||||
---
|
||||
title: Dataset Formats
|
||||
description: Supported dataset formats.
|
||||
listing:
|
||||
fields: [title, description]
|
||||
type: table
|
||||
sort-ui: false
|
||||
filter-ui: false
|
||||
max-description-length: 250
|
||||
description: Guide to Dataset Formats in Axolotl
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-depth: 5
|
||||
---
|
||||
|
||||
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL format. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
||||
|
||||
Below are these various formats organized by task:
|
||||
Axolotl is a training framework that aims to make the process convenient yet flexible to users by simply passing a config yaml file.
|
||||
|
||||
As there are a lot of available options in Axolotl, this guide aims to provide an simplify the user experience to choosing the proper choice.
|
||||
|
||||
Axolotl supports 3 kinds of training methods: pre-training, supervised fine-tuning, and preference-based post-training (e.g. DPO, ORPO, PRMs). Each method has their own dataset format which are described below.
|
||||
|
||||
## [Pre-training](pretraining.qmd)
|
||||
|
||||
When aiming to train on large corpora of text datasets, pre-training is your go-to choice. Due to the size of these datasets, downloading the entire-datasets before beginning training would be prohibitively time-consuming. Axolotl supports [streaming](https://huggingface.co/docs/datasets/en/stream) to only load batches into memory at a time.
|
||||
|
||||
A sample format for a pre-training dataset is as follows:
|
||||
|
||||
```json
|
||||
{"text": "first row"}
|
||||
{"text": "second row"}
|
||||
...
|
||||
```
|
||||
|
||||
It is typically recommended to save your dataset as `.jsonl` due to its flexibility and simplicity.
|
||||
|
||||
Axolotl supports loading from a Hugging Face hub repo or from local files.
|
||||
|
||||
::: {.callout-important}
|
||||
For pre-training only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts.
|
||||
:::
|
||||
|
||||
### Pre-training from Hugging Face hub datasets
|
||||
|
||||
As an example, to train using a Hugging Face dataset `hf_org/name`, you can pass the following config:
|
||||
|
||||
```yaml
|
||||
pretraining_dataset: hf_org/name
|
||||
```
|
||||
|
||||
### Pre-training from local dataset files
|
||||
|
||||
Given a few corpus files: `A.jsonl`, `B.jsonl`, and `C.jsonl`, your config will look like the below:
|
||||
|
||||
```yaml
|
||||
pretraining_dataset:
|
||||
- path: json
|
||||
data_files:
|
||||
- A.jsonl
|
||||
- B.jsonl
|
||||
- C.jsonl
|
||||
```
|
||||
|
||||
While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet`, `arrow`, `SQL`, `Webdataset`) that are supported by [`Dataset.load_dataset`](https://huggingface.co/docs/datasets/loading#local-and-remote-files)
|
||||
|
||||
### Pre-training without streaming
|
||||
|
||||
On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
|
||||
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.
|
||||
|
||||
From Hugging Face:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: hf_org/name
|
||||
type: completion
|
||||
```
|
||||
|
||||
From local files (either example works):
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type: completion
|
||||
|
||||
- path: json
|
||||
data_files: ["A.jsonl", "B.jsonl", "C.jsonl"]
|
||||
type: completion
|
||||
```
|
||||
|
||||
### Pre-training dataset configuration tips
|
||||
|
||||
#### Setting max_steps
|
||||
|
||||
When using streaming for large datasets, Axolotl does not know in advance how large the dataset is and does not know when to stop.
|
||||
|
||||
Therefore, it is necessary to set `max_steps: int` in your config for pre-training to run, so that Axolotl knows when to stop training.
|
||||
|
||||
One step is equal to `sequence_len * micro_batch_size * gradient_accumulation_steps * total_num_gpus` tokens.
|
||||
|
||||
#### Group_by_length
|
||||
|
||||
It is recommended to leave this off if downloading from Hugging Face hub as it would download the entire dataset which can be very large.
|
||||
|
||||
## Supervised fine-tuning (SFT)
|
||||
|
||||
Supervised fine-tuning is the process of training models to respond to an instruction or chat input.
|
||||
|
||||
As there are a wide variety of dataset formats, Axolotl tries to support a majority of the formats available in public datasets.
|
||||
|
||||
Axolotl provides four approaches for loading datasets, however, it's easier to work backwards from the dataset you have available to figure out which approach to use.
|
||||
|
||||
A flow chart is as follows:
|
||||
|
||||
1. Do you already have the dataset tokenized? If yes, check [Pre-Tokenized Dataset](#pre-tokenized-dataset).
|
||||
|
||||
2. Do you want to format the dataset yourself and manually choose each section to mask? If yes, check [Template Free Dataset](#template-free-dataset)
|
||||
|
||||
3. Is your dataset in a "conversation" format, containing a `list[messages]`? If yes, check [Conversation Dataset](#conversation-dataset)
|
||||
|
||||
4. Is your dataset in an "instruct" format, containing `{ instruction, response }`? If yes, check [Instruction Dataset](#instruction-dataset)
|
||||
|
||||
If you went through the flow chart and did not find one that matches, it is recommended to preprocess your dataset into one of the above or create a thread on Github Discussion.
|
||||
|
||||
::: {.callout-tip}
|
||||
You can mix and match within each approach or across approaches to train a model on a variety of datasets.
|
||||
:::
|
||||
|
||||
### [Pre-Tokenized Dataset](tokenized.qmd)
|
||||
|
||||
We suggest this approach when you want to bring your own tokenized dataset.
|
||||
|
||||
Axolotl expects the dataset to have three keys:
|
||||
- `input_ids`: from tokenizing formatted prompt
|
||||
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
|
||||
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.
|
||||
|
||||
::: {.callout-tip}
|
||||
Make sure to add BOS/EOS tokens to your prompt and mask it appropriately.
|
||||
:::
|
||||
|
||||
A config for this would look like:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type:
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
`type: ` is empty!
|
||||
:::
|
||||
|
||||
### [Template Free Dataset](template_free.qmd)
|
||||
|
||||
We reccomend this approach when you want granular control over the prompt formatting, special tokens, and masking, whilst letting Axolotl handle the tokenization. This is very useful if your dataset has unique prompts that differ across samples and where one single general template wouldn't suffice.
|
||||
|
||||
In the example below, you could see that there is no proper structure. At the same time, it's very flexible as there are no constraints on how your prompt can look.
|
||||
|
||||
```json
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"label": true,
|
||||
"text": "<s>Hello\n"
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "hi there!. "
|
||||
},
|
||||
{
|
||||
"label": false,
|
||||
"text": "goodbye "
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "farewell</s>"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Each prompt must be have a key called `segments` which is a list of `{ text, label }`.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type: input_output
|
||||
```
|
||||
|
||||
### [Conversation Dataset](conversation.qmd)
|
||||
|
||||
`conversation` messages are a list of messages which usually contain a `role` and `content` key.
|
||||
|
||||
::: {.callout-tip}
|
||||
Fun fact: Axolotl synonymously refers to "chat" messages as `conversation` messages due to how FastChat initially used this term to build a widely used [fastchat conversation](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) method for formatting chat messages prior to the creation of `chat_templates`.
|
||||
:::
|
||||
|
||||
#### What are `chat_templates`?
|
||||
|
||||
The current most popular and convenient method for inference is to use `chat_templates` for formatting prompts. Axolotl supports using `chat_templates` for training to ensure that the model performs in the same environment as in inference.
|
||||
|
||||
Here's a quick rundown on `chat_template`: A `chat_template` is a Jinja2 template which formats a list of messages into a prompt.
|
||||
|
||||
An example of a prompt formatted into a popular template called ChatML can be seen below:
|
||||
|
||||
Single prompt (pretty-printed):
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "How can I help you?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Can you add 3+5?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The answer is 8."
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The ChatML template is as follows:
|
||||
```jinja2
|
||||
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
|
||||
```
|
||||
|
||||
The above prompt formatted into this template will result in:
|
||||
|
||||
```
|
||||
<|im_start|>user
|
||||
Hi<|im_end|>
|
||||
<|im_start|>assistant
|
||||
How can I help you?<|im_end|>
|
||||
<|im_start|>user
|
||||
Can you add 3+5?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
The answer is 8.<|im_end|>
|
||||
```
|
||||
|
||||
By using delimiters (`<|im_start|>` and `<|im_end|>`), a prompt separates different speakers which helps the model identify which portion belongs to whom.
|
||||
|
||||
#### Common Conversation Dataset formats
|
||||
|
||||
Older conversation datasets with the following format are colloquially called `sharegpt` datasets.
|
||||
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
Newer conversation datasets usually follow the OpenAI format.
|
||||
|
||||
```json
|
||||
{"messages": [{"role": "...", "content": "..."}]}
|
||||
```
|
||||
|
||||
Axolotl supports both as well as allowing customization of any kind of key.
|
||||
|
||||
#### [Chat Template Usage](conversation.qmd#chat_template)
|
||||
|
||||
To properly use this method, it is important to identify three things:
|
||||
|
||||
1. Which `chat_template` would you use?
|
||||
|
||||
2. What are the keys in your dataset, and what are the possible roles? For example, in OpenAI format, the keys would be `messages`, `role`, and `content`, respectively, whereas the possible roles are `system`, `user`, and `assistant`.
|
||||
|
||||
3. What do you want to mask? For instance, only assistant messages, only last message, or nothing.
|
||||
|
||||
##### Choosing a `chat_template`
|
||||
|
||||
There are a lot of `chat_templates` out there. Axolotl supports the common ones: [supported chat templates](https://github.com/axolotl-ai-cloud/axolotl/blob/860609392184cf62a7e0ca676658b170e059ce6c/src/axolotl/utils/chat_templates.py#L17). For example, to use ChatML, it would be `chat_template: chatml`.
|
||||
|
||||
However, it is also possible to use the already configured template within the tokenizer by specifying `chat_template: tokenizer_default`. If you want a fallback (in case some tokenizer does not have it pre-configured), you can do `chat_template: tokenizer_default_fallback_chatml` to fallback to the ChatML template if a tokenizer template was not found.
|
||||
|
||||
One last but powerful approach is to bring your own template. This can be set via:
|
||||
|
||||
```yaml
|
||||
chat_template_jinja: # your template
|
||||
```
|
||||
|
||||
##### Setting `chat_template` dataset keys
|
||||
|
||||
We currently default to OpenAI format for dataset keys, so if that's your current dataset format, there's nothing to do here.
|
||||
|
||||
If your dataset format is different, here are the keys you should check (with their defaults):
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
...
|
||||
field_messages: messages # this should point to the key containing the list of conversations
|
||||
message_property_mappings: # this is a mapping from keys in your dataset to keys in chat_template
|
||||
role: role
|
||||
content: content
|
||||
```
|
||||
|
||||
In some `chat_templates` (e.g. [Gemma](https://huggingface.co/google/gemma-2b-it/blob/main/tokenizer_config.json#L1507)), the roles are hardcoded to `user` and `assistant`. Consequently, you may find it necessary to map the roles in your dataset to these above. We currently have some defaults that should work for common datasets, but if you get a `KeyError`, it would be necessary to add mapping for your roles. Here is an example of how it would look like:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
...
|
||||
roles:
|
||||
assistant:
|
||||
- gpt
|
||||
- model
|
||||
user:
|
||||
- human
|
||||
```
|
||||
|
||||
In the example above, all `gpt` and `model` values are converted to `assistant`. All `human` values are converted to `user.`
|
||||
|
||||
##### Handling masking
|
||||
|
||||
The common use case for `chat_template` is for chat messages, therefore, it is common to mask all non-assistant messages. Assistant messages refer to the bot messages that you want the model to learn on.
|
||||
|
||||
To train on all `assistant` messages, you would set the following configs.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
...
|
||||
roles_to_train: ["assistant"]
|
||||
train_on_eos: "turn"
|
||||
```
|
||||
|
||||
The `train_on_eos` config means that it would mask all EOS tokens for turns that aren't assistant-turns. The other options are: `all` and `last` to choose which EOS to train on.
|
||||
|
||||
Perhaps, you want to train on `assistant` and `narrator` roles, you can simply add `narrator` to the list of `roles_to_train`. You would also need to add it to the mapping of `roles` above.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
...
|
||||
roles_to_train: ["assistant", "narrator"]
|
||||
roles:
|
||||
assistant:
|
||||
- gpt
|
||||
- model
|
||||
user:
|
||||
- human
|
||||
narrator: ["narrator"]
|
||||
```
|
||||
|
||||
#### Applying `chat_template`
|
||||
|
||||
Once all the above steps are completed, you could combine all these configs together to form a bespoke configuration for your custom dataset. The final step would be to correctly set the EOS token in your config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type: chat_template
|
||||
|
||||
# step 1
|
||||
chat_template: chatml
|
||||
|
||||
# step 2
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
roles:
|
||||
assistant:
|
||||
- gpt
|
||||
- model
|
||||
- assistant
|
||||
user:
|
||||
- human
|
||||
- user
|
||||
|
||||
# step 3
|
||||
roles_to_train: ["assistant"]
|
||||
train_on_eos: "turn"
|
||||
|
||||
special_tokens:
|
||||
eos_token: <|im_end|>
|
||||
```
|
||||
|
||||
If this config were to be applied to the sample dataset above, the output would look as such (which can be retrieved via `axolotl preprocess config.yaml --debug`):
|
||||
|
||||
```
|
||||
<|im_start|>(-100, 128256) user(-100, 882)
|
||||
(-100, 198) Hi(-100, 13347) <|im_end|>(-100, 128257)
|
||||
(-100, 198) <|im_start|>(-100, 128256) assistant(-100, 78191)
|
||||
(-100, 198) How(4438, 4438) can(649, 649) I(358, 358) help(1520, 1520) you(499, 499) ?(30, 30) <|im_end|>(128257, 128257)
|
||||
(-100, 198) <|im_start|>(-100, 128256) user(-100, 882)
|
||||
(-100, 198) Can(-100, 6854) you(-100, 499) add(-100, 923) (-100, 220) 3(-100, 18) +(-100, 10) 5(-100, 20) ?(-100, 30) <|im_end|>(-100, 128257)
|
||||
(-100, 198) <|im_start|>(-100, 128256) assistant(-100, 78191)
|
||||
(-100, 198) The(791, 791) answer(4320, 4320) is(374, 374) (220, 220) 8(23, 23) .(13, 13) <|im_end|>(128257, 128257)
|
||||
(-100, 198)
|
||||
```
|
||||
|
||||
The first number refers to the label, the second refers to the `token_id`. For example, `-100` labels appear on non-assistant portions, meaning that they are masked during. For assistant portions, the label is the same as the `token_id`.
|
||||
|
||||
### [Instruction Dataset](inst_tune.qmd)
|
||||
|
||||
Instruction datasets are used to train instruction-following models and comprise a prompt, containing an instruction, and a single response. In contrast to chat datasets which may be multi-turn, instruct datasets are typically single-turn.
|
||||
|
||||
An example is of a common format called Alpaca:
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "output": "..."}
|
||||
```
|
||||
|
||||
Using those keys, a prompt can be built based on it.
|
||||
```
|
||||
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction:
|
||||
{instruction}
|
||||
|
||||
### Input:
|
||||
{input}
|
||||
|
||||
### Response:
|
||||
{output}
|
||||
```
|
||||
|
||||
This can be configured as such:
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type: alpaca
|
||||
```
|
||||
|
||||
Axolotl supports many kinds of instruction dataset. All of them can be found here (https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/inst_tune.html) with their respective type and sample row format.
|
||||
|
||||
#### Custom Instruct Prompt Format
|
||||
|
||||
Due to the myriad possibilities of instruction formats, Axolotl allows customizing your own instruction format without having to dive into the code directly.
|
||||
|
||||
In the example below, a sample row is used to output in `mistral_v1` format.
|
||||
```json
|
||||
{"input": "...", "output": "..."}
|
||||
```
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: repo
|
||||
type:
|
||||
system_prompt: ""
|
||||
|
||||
field_system:
|
||||
field_instruction: input
|
||||
field_input:
|
||||
field_output: output
|
||||
|
||||
# multi-line example with input
|
||||
format: |-
|
||||
[INST] {instruction} {input} [/INST]
|
||||
|
||||
# single-line example without input
|
||||
no_input_format: "[INST] {instruction} [/INST]"
|
||||
```
|
||||
|
||||
The config sets that the `field_instruction` is actually named `input`, and the `field_input` is empty as we don't have an `input` in this sample. Generally, `instruction` can be thought as the question to the model, and `input` as the additional information with `output` being the response. It is not necessary to have an `input` nor `system`. In the end, the most important part is to understand what format you want it to look like and how you can customize this to your use case.
|
||||
|
||||
## Reinforcement Learning from Human Feedback (RLHF)
|
||||
|
||||
As there are multiple RLHF methods with their own dataset requirements. Please see [RLHF datasets](../rlhf.qmd) documentation for more detail.
|
||||
|
||||
@@ -19,3 +19,11 @@ description: Frequently asked questions
|
||||
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
|
||||
|
||||
> A: You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
|
||||
|
||||
**Q: The codes is stuck on saving preprocessed datasets.**
|
||||
|
||||
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
|
||||
> A: This means that the property mapping for the stated attribute does not exist when building `chat_template` prompt. For example, if `no attribute 'content'`, please check you have added the correct mapping for `content` under `message_property_mappings`.
|
||||
|
||||
128
docs/lora_optims.qmd
Normal file
128
docs/lora_optims.qmd
Normal file
@@ -0,0 +1,128 @@
|
||||
---
|
||||
title: "LoRA Optimizations"
|
||||
description: "Custom autograd functions and Triton kernels in Axolotl for optimized
|
||||
LoRA fine-tuning"
|
||||
---
|
||||
|
||||
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
|
||||
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
|
||||
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
|
||||
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
|
||||
to leverage operator fusion and tensor re-use in order to improve speed and reduce
|
||||
memory usage during the forward and backward passes of these calculations.
|
||||
|
||||
We currently support several common model architectures, including (but not limited to):
|
||||
|
||||
- `llama`
|
||||
- `mistral`
|
||||
- `qwen2`
|
||||
- `gemma`
|
||||
- `gemma2`
|
||||
|
||||
<details>
|
||||
|
||||
The set of models we support is currently limited by our attention patching strategy,
|
||||
which assumes (and replaces) specific code blocks for query / key / value and output
|
||||
projections:
|
||||
|
||||
```python
|
||||
ORIGINAL_QKV_CODE = """
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
ORIGINAL_O_CODE = """
|
||||
attn_output = self.o_proj(attn_output)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
```
|
||||
|
||||
Is replaced with:
|
||||
|
||||
```python
|
||||
PATCHED_QKV_CODE = """
|
||||
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
||||
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
||||
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
PATCHED_O_CODE = """
|
||||
attn_output = self.apply_o(attn_output)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
```
|
||||
|
||||
Where `apply_qkv` and `apply_o` are defined in the `axolotl.kernels.lora` module.
|
||||
|
||||
We welcome testing of other model architectures and / or PRs to expand our patching
|
||||
logic to be compatible with more of them.
|
||||
|
||||
</details>
|
||||
|
||||
## Usage
|
||||
|
||||
These optimizations can be enabled in your Axolotl config YAML file. The
|
||||
`lora_mlp_kernel` option enables the optimized MLP path, while `lora_qkv_kernel` and
|
||||
`lora_o_kernel` enable the fused query-key-value projection and optimized output
|
||||
projection, respectively.
|
||||
|
||||
```yaml
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
||||
- Note: Set `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` to enable [memory-efficient attention on AMD GPUs](https://github.com/ROCm/aotriton/issues/16#issuecomment-2346675491)
|
||||
- Targeted LoRA adapters cannot use Dropout
|
||||
- This may limit model expressivity / cause overfitting
|
||||
- Targeted LoRA adapters cannot have bias terms
|
||||
- This may limit model expressivity
|
||||
|
||||
Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to
|
||||
be re-finetuned without these features in order to be useful.
|
||||
|
||||
## Implementation details
|
||||
|
||||
### Custom autograd functions
|
||||
|
||||
The LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the
|
||||
LoRA and base weight computations together and provides a single, efficient backward
|
||||
pass for the entire MLP block.
|
||||
|
||||
For attention components, similar optimizations are provided through a function that
|
||||
handles the query, key, and value projections, and a function that handles the output
|
||||
projection. They are designed to work with the existing `transformers` attention
|
||||
implementation via some monkey-patching logic.
|
||||
|
||||
### Triton kernels
|
||||
|
||||
Two activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for
|
||||
improved speed and memory performance. These kernels handle both the forward and
|
||||
backward passes.
|
||||
|
||||
### Integration
|
||||
|
||||
The custom autograd functions and Triton kernels are designed to work together. The
|
||||
autograd function manages the high-level computation flow and gradient tracking, while
|
||||
calling the Triton kernels for the activation function computation. During the backward
|
||||
pass, the kernel computes both the activation output and the required gradients, which
|
||||
the autograd function then uses to compute the final gradients for the entire
|
||||
computation path.
|
||||
|
||||
## Future Work
|
||||
|
||||
- Support for additional model architectures
|
||||
- Support for the FSDP setting
|
||||
- Support for dropout and bias
|
||||
- Additional operator fusions
|
||||
@@ -3,6 +3,18 @@ title: Multi Node
|
||||
description: How to use Axolotl on multiple machines
|
||||
---
|
||||
|
||||
The below are three ways to train multi-node in Axolotl.
|
||||
|
||||
::: {.callout-important}
|
||||
Each machine needs a copy of Axolotl, we suggest using the same commit to ensure compatibility.
|
||||
|
||||
You will also need to have the same configuration file for your model on each machine.
|
||||
|
||||
Make sure the main machine is reachable by other machines.
|
||||
:::
|
||||
|
||||
# Accelerate
|
||||
|
||||
You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below:
|
||||
|
||||
~/.cache/huggingface/accelerate/default_config.yaml
|
||||
@@ -26,7 +38,7 @@ tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
Configure your model to use FSDP with for example:
|
||||
Configure your model to use FSDP in the Axolotl yaml. For example:
|
||||
```yaml
|
||||
fsdp:
|
||||
- full_shard
|
||||
@@ -37,12 +49,40 @@ fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
```
|
||||
|
||||
## Machine configuration
|
||||
|
||||
On each machine you need a copy of Axolotl, we suggest using the same commit to ensure compatibility.
|
||||
|
||||
You will also need to have the same configuration file for your model on each machine.
|
||||
|
||||
On the main machine only, make sure the port you set as `main_process_port` is open in TCP and reachable by other machines.
|
||||
|
||||
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.
|
||||
|
||||
# Raytrain
|
||||
|
||||
Please see ray train doc [here](ray-integration.qmd).
|
||||
|
||||
# Torchrun
|
||||
|
||||
If you are using Infiniband, we recommend torchrun to utilize the full bandwidth.
|
||||
|
||||
Set the following env (change buffersize/socketname depending on your system):
|
||||
|
||||
```yaml
|
||||
export NCCL_IB_DISABLE=0
|
||||
export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond"
|
||||
export NCCL_BUFFSIZE=2097152
|
||||
```
|
||||
|
||||
Run the following on each node:
|
||||
|
||||
```bash
|
||||
torchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" -m axolotl.cli.train config.yaml
|
||||
```
|
||||
|
||||
Please make sure to substitute the placeholder variables.
|
||||
|
||||
- `num_nodes`: Number of nodes (containing GPUs)
|
||||
- `gpu_per_node`: Number of gpus per node
|
||||
- `head_node_ip`: IP of the head node (make sure other machines can connect to this)
|
||||
- `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400)
|
||||
- `rdzv_id`: A unique job ID that is used by the job across nodes.
|
||||
|
||||
::: {.callout-note}
|
||||
You need to call `axolotl.cli.train` instead of `axolotl train` as the latter calls accelerate under the hood
|
||||
:::
|
||||
|
||||
More info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html)
|
||||
|
||||
451
docs/rlhf.qmd
451
docs/rlhf.qmd
@@ -1,26 +1,39 @@
|
||||
---
|
||||
title: "RLHF (Beta)"
|
||||
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
---
|
||||
|
||||
### Overview
|
||||
# Overview
|
||||
|
||||
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
|
||||
feedback. Various methods include, but not limited to:
|
||||
|
||||
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
||||
- Direct Preference Optimization (DPO)
|
||||
- Identity Preference Optimization (IPO)
|
||||
- [Direct Preference Optimization (DPO)](#dpo)
|
||||
- [Identity Preference Optimization (IPO)](#ipo)
|
||||
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
||||
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
||||
|
||||
|
||||
### RLHF using Axolotl
|
||||
# RLHF using Axolotl
|
||||
|
||||
>[!IMPORTANT]
|
||||
>This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
||||
::: {.callout-important}
|
||||
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
||||
:::
|
||||
|
||||
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
|
||||
We rely on the [TRL](https://github.com/huggingface/trl) library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.
|
||||
|
||||
::: {.callout-tip}
|
||||
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
|
||||
:::
|
||||
|
||||
## DPO
|
||||
|
||||
Example config:
|
||||
|
||||
#### DPO
|
||||
```yaml
|
||||
rl: dpo
|
||||
datasets:
|
||||
@@ -32,12 +45,265 @@ datasets:
|
||||
type: chatml
|
||||
```
|
||||
|
||||
#### IPO
|
||||
DPO supports the following types with the following dataset format:
|
||||
|
||||
### chatml.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"instruction": "...",
|
||||
"chosen_response": "...",
|
||||
"rejected_response": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.icr
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"input": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.intel
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"question": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"instruction": "...",
|
||||
"chosen_response": "...",
|
||||
"rejected_response": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.icr
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"input": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.intel
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"question": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### zephyr.nectar
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "...",
|
||||
"answers": [
|
||||
{
|
||||
"answer": "...",
|
||||
"rank": 1
|
||||
},
|
||||
{
|
||||
"answer": "...",
|
||||
"rank": 2
|
||||
}
|
||||
// ... more answers with ranks
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### chat_template.default
|
||||
|
||||
```yaml
|
||||
rl: dpo
|
||||
datasets:
|
||||
- path: ...
|
||||
split: train
|
||||
type: chat_template.default
|
||||
field_messages: "messages"
|
||||
field_chosen: "chosen"
|
||||
field_rejected: "rejected"
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user: ["user"]
|
||||
assistant: ["assistant"]
|
||||
system: ["system"]
|
||||
```
|
||||
|
||||
Sample input format:
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "..."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "..."
|
||||
},
|
||||
// ... more messages
|
||||
],
|
||||
"chosen": {
|
||||
"role": "assistant",
|
||||
"content": "..."
|
||||
},
|
||||
"rejected": {
|
||||
"role": "assistant",
|
||||
"content": "..."
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### user_defined.default
|
||||
|
||||
For custom behaviors,
|
||||
|
||||
```yaml
|
||||
rl: dpo
|
||||
datasets:
|
||||
- path: ...
|
||||
split: train
|
||||
type: user_defined.default
|
||||
|
||||
field_prompt: "prompt"
|
||||
field_system: "system"
|
||||
field_chosen: "chosen"
|
||||
field_rejected: "rejected"
|
||||
prompt_format: "{prompt}"
|
||||
chosen_format: "{chosen}"
|
||||
rejected_format: "{rejected}"
|
||||
```
|
||||
|
||||
The input format is a simple JSON input with customizable fields based on the above config.
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
## IPO
|
||||
|
||||
As IPO is just DPO with a different loss function, all supported options for DPO works here.
|
||||
|
||||
```yaml
|
||||
rl: ipo
|
||||
```
|
||||
|
||||
#### ORPO
|
||||
## ORPO
|
||||
|
||||
Paper: https://arxiv.org/abs/2403.07691
|
||||
|
||||
@@ -52,8 +318,28 @@ datasets:
|
||||
type: chat_template.argilla
|
||||
```
|
||||
|
||||
ORPO supports the following types with the following dataset format:
|
||||
|
||||
#### KTO
|
||||
### chat_template.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...", // if available, will be taken as user message for single-turn instead of from list below
|
||||
|
||||
// chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## KTO
|
||||
|
||||
```yaml
|
||||
rl: kto
|
||||
@@ -72,7 +358,144 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: true
|
||||
```
|
||||
|
||||
#### Using local dataset files
|
||||
KTO supports the following types with the following dataset format:
|
||||
|
||||
### chatml.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"instruction": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."}
|
||||
],
|
||||
"completion": [
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.intel
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"question": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"instruction": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
"completion": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.intel
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"question": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### user_defined.default
|
||||
|
||||
For custom behaviors,
|
||||
|
||||
```yaml
|
||||
rl: kto
|
||||
datasets:
|
||||
- path: ...
|
||||
split: train
|
||||
type: user_defined.default
|
||||
|
||||
field_prompt: "prompt"
|
||||
field_system: "system"
|
||||
field_completion: "completion"
|
||||
field_label: "label"
|
||||
prompt_format: "{prompt}"
|
||||
completion_format: "{completion}"
|
||||
```
|
||||
|
||||
The input format is a simple JSON input with customizable fields based on the above config.
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"completion": "...",
|
||||
"label": "..."
|
||||
}
|
||||
```
|
||||
|
||||
## Using local dataset files
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- ds_type: json
|
||||
@@ -82,9 +505,9 @@ datasets:
|
||||
type: chatml.intel
|
||||
```
|
||||
|
||||
#### Trl autounwrap for peft
|
||||
## TRL auto-unwrapping for PEFT
|
||||
|
||||
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
|
||||
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
|
||||
|
||||
```yaml
|
||||
# load ref model when adapter training.
|
||||
|
||||
@@ -21,8 +21,9 @@ datasets:
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -16,8 +16,9 @@ datasets:
|
||||
type: chat_template
|
||||
drop_system_message: true
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
@@ -13,8 +13,9 @@ datasets:
|
||||
type: chat_template
|
||||
drop_system_message: true
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -17,8 +17,9 @@ datasets:
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
|
||||
@@ -17,8 +17,9 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
|
||||
@@ -14,8 +14,9 @@ datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
|
||||
@@ -17,8 +17,9 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
@@ -31,8 +32,9 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
|
||||
82
examples/llama-3/lora-1b-kernels.yml
Normal file
82
examples/llama-3/lora-1b-kernels.yml
Normal file
@@ -0,0 +1,82 @@
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
# Currently, we don't support dropout with our custom Triton kernels
|
||||
# lora_dropout: 0.05
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# These options enable our custom Triton kernels / autograd
|
||||
# functions for MLP and attention calculations
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
@@ -22,8 +22,9 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -14,8 +14,9 @@ datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
|
||||
@@ -12,8 +12,9 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.45.1
|
||||
bitsandbytes==0.45.2
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
flash-attn==2.7.0.post2
|
||||
flash-attn==2.7.4.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.5.2
|
||||
@@ -13,12 +13,12 @@ liger-kernel==0.5.2
|
||||
packaging==23.2
|
||||
|
||||
peft==0.14.0
|
||||
transformers==4.48.1
|
||||
transformers==4.49.0
|
||||
tokenizers>=0.21.0
|
||||
accelerate==1.3.0
|
||||
datasets==3.2.0
|
||||
deepspeed==0.16.1
|
||||
trl==0.13.0
|
||||
trl==0.15.1
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
@@ -26,7 +26,7 @@ sentencepiece
|
||||
gradio==3.50.2
|
||||
|
||||
modal==0.70.5
|
||||
pydantic==2.6.3
|
||||
pydantic==2.10.6
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
|
||||
@@ -31,27 +31,26 @@ def parse_dataset(dataset=None, split="train"):
|
||||
ds_cfg["field_messages"] = field_messages
|
||||
|
||||
message_fields = features[field_messages][0].keys()
|
||||
message_field_role = None
|
||||
|
||||
message_property_mappings = {"role": None, "content": None}
|
||||
for key in ["from", "role"]:
|
||||
if key in message_fields:
|
||||
message_field_role = key
|
||||
message_property_mappings["role"] = key
|
||||
break
|
||||
if not message_field_role:
|
||||
if not message_property_mappings["role"]:
|
||||
raise ValueError(
|
||||
f'No role field found in messages: {", ".join(message_fields)}'
|
||||
)
|
||||
ds_cfg["message_field_role"] = message_field_role
|
||||
|
||||
message_field_content = None
|
||||
for key in ["content", "text", "value"]:
|
||||
if key in message_fields:
|
||||
message_field_content = key
|
||||
message_property_mappings["content"] = key
|
||||
break
|
||||
if not message_field_content:
|
||||
if not message_property_mappings["content"]:
|
||||
raise ValueError(
|
||||
f'No content field found in messages: {", ".join(message_fields)}'
|
||||
)
|
||||
ds_cfg["message_field_content"] = message_field_content
|
||||
ds_cfg["message_property_mappings"] = message_property_mappings
|
||||
|
||||
print(yaml.dump({"datasets": [ds_cfg]}))
|
||||
|
||||
|
||||
12
setup.py
12
setup.py
@@ -71,12 +71,15 @@ def parse_requirements():
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 5):
|
||||
if (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.29.post2")
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.28.post3")
|
||||
_install_requires.append("xformers>=0.0.28.post3")
|
||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||
elif (major, minor) >= (2, 4):
|
||||
if patch == 0:
|
||||
@@ -122,7 +125,7 @@ setup(
|
||||
},
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn==2.7.0.post2",
|
||||
"flash-attn==2.7.4.post1",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.16.1",
|
||||
@@ -153,5 +156,8 @@ setup(
|
||||
"ray": [
|
||||
"ray[train]",
|
||||
],
|
||||
"vllm": [
|
||||
"vllm==0.7.2",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -4,4 +4,4 @@ import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.6.0"
|
||||
__version__ = "0.8.0.dev0"
|
||||
|
||||
@@ -35,13 +35,18 @@ def do_cli_train(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
accelerate: bool = True,
|
||||
cwd=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
config_yaml = file.read()
|
||||
cloud.train(config_yaml, accelerate=accelerate)
|
||||
local_dirs = {}
|
||||
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
|
||||
local_dirs = {"/workspace/mounts": cwd}
|
||||
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
|
||||
|
||||
|
||||
def do_cli_lm_eval(
|
||||
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
import subprocess # nosec B404
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
from typing import Optional
|
||||
|
||||
import modal
|
||||
|
||||
@@ -22,8 +23,18 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
|
||||
|
||||
# modal workaround so it doesn't use the automounted axolotl
|
||||
new_env = copy.deepcopy(os.environ)
|
||||
|
||||
if "PYTHONPATH" in new_env:
|
||||
del new_env["PYTHONPATH"]
|
||||
paths = ["/workspace/mounts"]
|
||||
for sub_python_path_str in new_env["PYTHONPATH"].split(":"):
|
||||
sub_python_path = Path(sub_python_path_str)
|
||||
if not sub_python_path.joinpath("src", "axolotl").exists():
|
||||
# we don't want to use the automounted axolotl or unexpected behavior happens
|
||||
paths.append(str(sub_python_path))
|
||||
if paths:
|
||||
new_env["PYTHONPATH"] = ":".join(paths)
|
||||
else:
|
||||
del new_env["PYTHONPATH"]
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call( # nosec B603
|
||||
@@ -112,8 +123,6 @@ class ModalCloud(Cloud):
|
||||
if env := self.get_env():
|
||||
image = image.env(env)
|
||||
|
||||
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
|
||||
return image
|
||||
|
||||
def get_secrets(self):
|
||||
@@ -203,9 +212,12 @@ class ModalCloud(Cloud):
|
||||
memory = int(self.config.memory)
|
||||
return 1024 * memory
|
||||
|
||||
def get_train_env(self):
|
||||
def get_train_env(self, local_dirs=None):
|
||||
image = self.get_image()
|
||||
for mount, local_dir in (local_dirs or {}).items():
|
||||
image = image.add_local_dir(local_dir, mount)
|
||||
return self.app.function(
|
||||
image=self.get_image(),
|
||||
image=image,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
cpu=16.0,
|
||||
gpu=self.get_train_gpu(),
|
||||
@@ -214,14 +226,21 @@ class ModalCloud(Cloud):
|
||||
secrets=self.get_secrets(),
|
||||
)
|
||||
|
||||
def train(self, config_yaml: str, accelerate: bool = True):
|
||||
modal_fn = self.get_train_env()(_train)
|
||||
def train(
|
||||
self,
|
||||
config_yaml: str,
|
||||
accelerate: bool = True,
|
||||
local_dirs: Optional[dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
modal_fn = self.get_train_env(local_dirs)(_train)
|
||||
with modal.enable_output():
|
||||
with self.app.run(detach=True):
|
||||
modal_fn.remote(
|
||||
config_yaml,
|
||||
accelerate=accelerate,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def lm_eval(self, config_yaml: str):
|
||||
@@ -252,7 +271,7 @@ def _preprocess(config_yaml: str, volumes=None):
|
||||
)
|
||||
|
||||
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
@@ -262,8 +281,11 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None):
|
||||
accelerate_args = "--accelerate"
|
||||
else:
|
||||
accelerate_args = "--no-accelerate"
|
||||
num_processes_args = ""
|
||||
if num_processes := kwargs.pop("num_processes", None):
|
||||
num_processes_args = f"--num-processes {num_processes}"
|
||||
run_cmd(
|
||||
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
|
||||
f"axolotl train {accelerate_args} {num_processes_args} /workspace/artifacts/axolotl/config.yaml",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
"""CLI to run training on a model."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.integrations.lolcats.linear_llama.configuration_linear_llama import (
|
||||
LinearLlamaConfig,
|
||||
)
|
||||
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
|
||||
LinearLlamaForCausalLM,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model_config
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
"""
|
||||
Convert attention to linear attention and perform attention transfer via distillation.
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
|
||||
# ensure quantization and peft are turned off (due to how we need to re-apply peft later)
|
||||
cfg.load_in_8bit = False
|
||||
cfg.load_in_4bit = False
|
||||
cfg.adapter = None
|
||||
|
||||
# load model
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
# freeze model
|
||||
for p in model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# convert to linear llama
|
||||
linear_llama_config = LinearLlamaConfig.from_llama(
|
||||
model.config, cfg.attention_config
|
||||
)
|
||||
model = LinearLlamaForCausalLM.from_llama(
|
||||
model, config=linear_llama_config, train_attention=True
|
||||
)
|
||||
|
||||
# set save_path, save tokenizer and model config.
|
||||
save_path = str(os.path.join(cfg.output_dir, "distilled"))
|
||||
tokenizer.save_pretrained(save_path)
|
||||
if hasattr(model, "config"):
|
||||
model.config.save_pretrained(save_path)
|
||||
|
||||
# Get datasets
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
# toggle attention to be trainable
|
||||
model.toggle_attention(train=True)
|
||||
|
||||
# Setup trainer
|
||||
trainer = setup_trainer(
|
||||
cfg=cfg,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
model=(model, None, None),
|
||||
tokenizer=tokenizer,
|
||||
processor=None,
|
||||
total_num_steps=total_num_steps,
|
||||
)
|
||||
|
||||
# train
|
||||
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
|
||||
|
||||
# drop base_attention + remove training attn
|
||||
model.toggle_attention(train=False)
|
||||
model.remove_base_attention()
|
||||
|
||||
# NOTE: If in peft mode, consider whether to auto-merge
|
||||
|
||||
# save model
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
# NOTE: may need to consider other ways of saving due to multi-gpu etc
|
||||
model.save_pretrained(save_path, safe_serialization=safe_serialization)
|
||||
|
||||
# cleanup
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
del model
|
||||
del tokenizer
|
||||
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
"""
|
||||
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# load cfg, force linearize and add plugin to linearize
|
||||
parsed_cfg = load_cfg(
|
||||
config,
|
||||
linearize=True,
|
||||
plugins=["axolotl.integrations.lolcats.LinearizePlugin"],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
do_linearize(parsed_cfg, parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
@@ -2,19 +2,19 @@
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import logging
|
||||
import random
|
||||
import os
|
||||
import subprocess # nosec B404
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import axolotl
|
||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.cli.sweeps import generate_sweep_configs
|
||||
from axolotl.cli.utils import (
|
||||
add_options_from_config,
|
||||
add_options_from_dataclass,
|
||||
@@ -27,76 +27,6 @@ from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
|
||||
def generate_sweep_configs(base_config, sweeps_config):
|
||||
"""
|
||||
Recursively generates all possible configurations by applying sweeps to the base config.
|
||||
|
||||
Args:
|
||||
base_config (dict): The original configuration dictionary
|
||||
sweeps_config (dict): Dictionary where keys are parameters and values are either:
|
||||
- lists of values to sweep independently
|
||||
- or for paired values, a list of dicts under the '_' key
|
||||
|
||||
Returns:
|
||||
list: List of all possible configuration dictionaries
|
||||
|
||||
Example:
|
||||
sweeps_config = {
|
||||
'learning_rate': [0.1, 0.01],
|
||||
'_': [
|
||||
{'load_in_8bit': True, 'adapter': 'lora'},
|
||||
{'load_in_4bit': True, 'adapter': 'qlora'}
|
||||
]
|
||||
}
|
||||
"""
|
||||
# Separate paired values from regular sweeps
|
||||
paired_values = sweeps_config.get("_", [])
|
||||
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
|
||||
|
||||
# Process regular sweeps
|
||||
param_names = list(regular_sweeps.keys())
|
||||
param_values = list(regular_sweeps.values())
|
||||
|
||||
# Generate combinations for regular sweeps
|
||||
regular_combinations = list(product(*param_values)) if param_values else [()]
|
||||
|
||||
# Combine regular sweeps with paired values
|
||||
all_combinations = []
|
||||
for reg_combo in regular_combinations:
|
||||
if paired_values:
|
||||
for paired_set in paired_values:
|
||||
new_config = {}
|
||||
# new_config = deepcopy(base_config)
|
||||
# Combine regular parameters with paired parameters
|
||||
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
|
||||
for param_name, param_value in full_combo.items():
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
else:
|
||||
# If no paired values, just use regular combinations
|
||||
# new_config = deepcopy(base_config)
|
||||
new_config = {}
|
||||
for param_name, param_value in zip(param_names, reg_combo):
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
|
||||
# randomize the order of trials
|
||||
random.seed(42)
|
||||
random.shuffle(all_combinations)
|
||||
|
||||
# Generate a new config for each combination
|
||||
result_configs = []
|
||||
for combination in all_combinations:
|
||||
new_config = deepcopy(base_config)
|
||||
for param_name, param_value in combination.items():
|
||||
new_config[param_name] = param_value
|
||||
result_configs.append(new_config)
|
||||
|
||||
return result_configs
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||
def cli():
|
||||
@@ -165,7 +95,6 @@ def train(
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
@@ -199,7 +128,16 @@ def train(
|
||||
try:
|
||||
if accelerate:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
cwd = os.getcwd()
|
||||
do_cli_train(
|
||||
cloud_config=cloud,
|
||||
config=config,
|
||||
accelerate=True,
|
||||
cwd=cwd,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
@@ -208,7 +146,7 @@ def train(
|
||||
accelerate_args.append(str(main_process_port))
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes", None)
|
||||
accelerate_args.append("--num-processes")
|
||||
accelerate_args.append("--num_processes")
|
||||
accelerate_args.append(str(num_processes))
|
||||
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
@@ -220,7 +158,11 @@ def train(
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
do_cli_train(
|
||||
cloud_config=cloud, config=config, accelerate=False, **kwargs
|
||||
)
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
@@ -381,4 +323,5 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
main()
|
||||
|
||||
77
src/axolotl/cli/sweeps.py
Normal file
77
src/axolotl/cli/sweeps.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Utilities for handling sweeps over configs for axolotl train CLI command"""
|
||||
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
|
||||
|
||||
def generate_sweep_configs(
|
||||
base_config: dict[str, list], sweeps_config: dict[str, list]
|
||||
) -> list[dict[str, list]]:
|
||||
"""
|
||||
Recursively generates all possible configurations by applying sweeps to the base config.
|
||||
|
||||
Args:
|
||||
base_config (dict): The original configuration dictionary
|
||||
sweeps_config (dict): Dictionary where keys are parameters and values are either:
|
||||
- lists of values to sweep independently
|
||||
- or for paired values, a list of dicts under the '_' key
|
||||
|
||||
Returns:
|
||||
list: List of all possible configuration dictionaries
|
||||
|
||||
Example:
|
||||
sweeps_config = {
|
||||
'learning_rate': [0.1, 0.01],
|
||||
'_': [
|
||||
{'load_in_8bit': True, 'adapter': 'lora'},
|
||||
{'load_in_4bit': True, 'adapter': 'qlora'}
|
||||
]
|
||||
}
|
||||
"""
|
||||
# Separate paired values from regular sweeps
|
||||
paired_values = sweeps_config.get("_", [])
|
||||
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
|
||||
|
||||
# Process regular sweeps
|
||||
param_names = list(regular_sweeps.keys())
|
||||
param_values = list(regular_sweeps.values())
|
||||
|
||||
# Generate combinations for regular sweeps
|
||||
regular_combinations = list(product(*param_values)) if param_values else [()]
|
||||
|
||||
# Combine regular sweeps with paired values
|
||||
all_combinations = []
|
||||
for reg_combo in regular_combinations:
|
||||
if paired_values:
|
||||
for paired_set in paired_values:
|
||||
new_config = {}
|
||||
# new_config = deepcopy(base_config)
|
||||
# Combine regular parameters with paired parameters
|
||||
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
|
||||
for param_name, param_value in full_combo.items():
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
else:
|
||||
# If no paired values, just use regular combinations
|
||||
# new_config = deepcopy(base_config)
|
||||
new_config = {}
|
||||
for param_name, param_value in zip(param_names, reg_combo):
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
|
||||
# randomize the order of trials
|
||||
random.seed(42)
|
||||
random.shuffle(all_combinations)
|
||||
|
||||
# Generate a new config for each combination
|
||||
result_configs = []
|
||||
for combination in all_combinations:
|
||||
new_config = deepcopy(base_config)
|
||||
for param_name, param_value in combination.items():
|
||||
new_config[param_name] = param_value
|
||||
result_configs.append(new_config)
|
||||
|
||||
return result_configs
|
||||
@@ -122,9 +122,11 @@ def load_preference_datasets(
|
||||
`total_num_steps`.
|
||||
"""
|
||||
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||
total_num_steps = int(
|
||||
total_num_steps: Optional[int] = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
if cfg.rl == "grpo":
|
||||
total_num_steps = None
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
|
||||
@@ -39,7 +39,6 @@ from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlDPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlMambaTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
@@ -48,9 +47,11 @@ from axolotl.core.trainers.base import (
|
||||
AxolotlTrainer,
|
||||
ReLoRATrainer,
|
||||
)
|
||||
from axolotl.core.trainers.dpo import DPOStrategy
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlCPOConfig,
|
||||
AxolotlDPOConfig,
|
||||
AxolotlKTOConfig,
|
||||
AxolotlORPOConfig,
|
||||
AxolotlPRMConfig,
|
||||
@@ -58,6 +59,7 @@ from axolotl.core.training_args import (
|
||||
AxolotlTrainingArguments,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType, get_extract_fn
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
@@ -329,6 +331,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
|
||||
training_arguments_kwargs = {}
|
||||
|
||||
if self.cfg.include_tokens_per_second is not None:
|
||||
training_arguments_kwargs[
|
||||
"include_tokens_per_second"
|
||||
] = self.cfg.include_tokens_per_second
|
||||
|
||||
if self.cfg.bf16 == "full":
|
||||
training_arguments_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
@@ -641,9 +649,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
|
||||
if self.cfg.rl == "orpo":
|
||||
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
|
||||
|
||||
if self.cfg.neftune_noise_alpha is not None:
|
||||
training_arguments_kwargs[
|
||||
"neftune_noise_alpha"
|
||||
@@ -652,7 +657,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_kwargs = {}
|
||||
|
||||
if self.cfg.reward_model:
|
||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
if self.cfg.optimizer in [
|
||||
@@ -742,6 +747,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||
if self.cfg.sp_ulysses_degree:
|
||||
data_collator_kwargs["sp_extract_fn"] = get_extract_fn(
|
||||
USPRingAttnType.ZIGZAG,
|
||||
sp_ulysses_degree=self.cfg.sp_ulysses_degree
|
||||
)
|
||||
|
||||
if self.cfg.reward_model:
|
||||
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
||||
@@ -965,10 +975,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
if self.cfg.rl_beta:
|
||||
training_args_kwargs["beta"] = self.cfg.rl_beta
|
||||
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
|
||||
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
|
||||
if self.cfg.orpo_alpha:
|
||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||
@@ -977,6 +988,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl == "simpo":
|
||||
training_args_cls = AxolotlCPOConfig
|
||||
training_args_kwargs["loss_type"] = "simpo"
|
||||
@@ -1001,11 +1013,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kto_undesirable_weight or 1.0
|
||||
)
|
||||
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl == "grpo":
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||
|
||||
else:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
if self.cfg.rl == "ipo":
|
||||
@@ -1016,11 +1032,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
if self.cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs[
|
||||
"use_logits_to_keep"
|
||||
] = self.cfg.dpo_use_logits_to_keep
|
||||
|
||||
for blocklist_key in blocklist_args_kwargs:
|
||||
if blocklist_key in training_args_kwargs:
|
||||
del training_args_kwargs[blocklist_key]
|
||||
|
||||
max_steps = self.cfg.max_steps or total_num_steps or -1
|
||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
output_dir=self.cfg.output_dir,
|
||||
self.cfg.output_dir,
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
max_steps=self.cfg.max_steps or total_num_steps,
|
||||
max_steps=max_steps,
|
||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||
learning_rate=self.cfg.learning_rate,
|
||||
warmup_steps=self.cfg.warmup_steps,
|
||||
@@ -1047,8 +1073,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs[
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
if self.cfg.rl in ["dpo", "ipo"]:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
if self.cfg.rl == "grpo":
|
||||
trainer_cls = GRPOStrategy.get_trainer_class()
|
||||
trainer_cls_args = [self.model]
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
elif self.cfg.rl in ["dpo", "ipo"]:
|
||||
trainer_cls = DPOStrategy.get_trainer_class()
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
@@ -1063,12 +1094,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "processing_class" in sig.parameters.keys():
|
||||
dpo_trainer_kwargs["processing_class"] = self.tokenizer
|
||||
else:
|
||||
if "tokenizer" in sig.parameters.keys():
|
||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
else:
|
||||
dpo_trainer_kwargs["processing_class"] = self.tokenizer
|
||||
|
||||
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
||||
if self.cfg.datasets is not None and (
|
||||
trainer_cls is DPOStrategy.get_trainer_class()
|
||||
):
|
||||
dpo_trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
|
||||
@@ -5,30 +5,21 @@ module for customized trainers
|
||||
from __future__ import annotations
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from typing import Dict, Literal, Optional
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import (
|
||||
CPOTrainer,
|
||||
DPOTrainer,
|
||||
KTOTrainer,
|
||||
ORPOTrainer,
|
||||
PRMTrainer,
|
||||
RewardTrainer,
|
||||
)
|
||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||
from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
@@ -847,107 +838,6 @@ class ReLoRATrainer(AxolotlTrainer):
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "dpo"]
|
||||
|
||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
def create_optimizer(self):
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
opt_model,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
if loraplus_lr_ratio:
|
||||
print("Using lora+")
|
||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
@wraps(DPOTrainer.push_to_hub)
|
||||
def push_to_hub(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> Dict:
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
)
|
||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||
for key in res.keys():
|
||||
res[key] = res[key][1:]
|
||||
|
||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||
|
||||
return res
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
num_items_in_batch=None,
|
||||
) -> torch.Tensor:
|
||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
|
||||
33
src/axolotl/core/trainers/dpo/__init__.py
Normal file
33
src/axolotl/core/trainers/dpo/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
DPO Specific Strategy for training
|
||||
"""
|
||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||
|
||||
|
||||
class DPOStrategy:
|
||||
"""
|
||||
Strategy for DPO training
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls):
|
||||
return AxolotlDPOTrainer
|
||||
|
||||
@classmethod
|
||||
def get_training_args_class(cls):
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
|
||||
return AxolotlDPOConfig
|
||||
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg):
|
||||
training_args_kwargs = {}
|
||||
if cfg.rl == "ipo":
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
||||
if cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||
return training_args_kwargs
|
||||
15
src/axolotl/core/trainers/dpo/args.py
Normal file
15
src/axolotl/core/trainers/dpo/args.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Axolotl specific DPO args
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trl import DPOConfig
|
||||
|
||||
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
125
src/axolotl/core/trainers/dpo/trainer.py
Normal file
125
src/axolotl/core/trainers/dpo/trainer.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
DPO trainer for axolotl
|
||||
"""
|
||||
import gc
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from transformers import Trainer
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
SchedulerMixin,
|
||||
_sanitize_kwargs_for_ds_tagging,
|
||||
_sanitize_kwargs_for_tagging,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "dpo"]
|
||||
|
||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
def create_optimizer(self):
|
||||
# pylint: disable=duplicate-code
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
opt_model,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
if loraplus_lr_ratio:
|
||||
print("Using lora+")
|
||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||
# pylint: disable=duplicate-code
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
@wraps(DPOTrainer.push_to_hub)
|
||||
def push_to_hub(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> Dict:
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
)
|
||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||
for key in res.keys():
|
||||
res[key] = res[key][1:]
|
||||
|
||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||
|
||||
return res
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
num_items_in_batch=None,
|
||||
) -> torch.Tensor:
|
||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
119
src/axolotl/core/trainers/grpo/__init__.py
Normal file
119
src/axolotl/core/trainers/grpo/__init__.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
GRPO Specific Strategy for training
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
class GRPOStrategy:
|
||||
"""
|
||||
Strategy for GRPO training
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls):
|
||||
return AxolotlGRPOTrainer
|
||||
|
||||
@classmethod
|
||||
def get_training_args_class(cls):
|
||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||
|
||||
return AxolotlGRPOConfig
|
||||
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg):
|
||||
grpo_args_kwargs = {}
|
||||
if cfg.trl and cfg.trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
|
||||
if cfg.trl and cfg.trl.vllm_device:
|
||||
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
|
||||
else:
|
||||
grpo_args_kwargs["vllm_device"] = "auto"
|
||||
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
|
||||
grpo_args_kwargs[
|
||||
"vllm_gpu_memory_utilization"
|
||||
] = cfg.trl.vllm_gpu_memory_utilization
|
||||
if cfg.trl and cfg.trl.vllm_max_model_len:
|
||||
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
|
||||
if cfg.trl and cfg.trl.num_generations:
|
||||
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
|
||||
if cfg.trl and cfg.trl.sync_ref_model:
|
||||
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
|
||||
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
|
||||
grpo_args_kwargs[
|
||||
"ref_model_mixup_alpha"
|
||||
] = cfg.trl.ref_model_mixup_alpha
|
||||
if cfg.trl and cfg.trl.ref_model_sync_steps:
|
||||
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
|
||||
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
|
||||
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
def set_trainer_args(cls, cfg):
|
||||
trainer_args = []
|
||||
if cfg.trl and cfg.trl.reward_funcs:
|
||||
reward_funcs = []
|
||||
for reward_func_fqn in cfg.trl.reward_funcs:
|
||||
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
|
||||
trainer_args.append(reward_funcs)
|
||||
return trainer_args
|
||||
|
||||
@classmethod
|
||||
def set_trainer_kwargs(cls, cfg):
|
||||
trainer_kwargs = {}
|
||||
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||
trainer_kwargs[
|
||||
"reward_processing_classes"
|
||||
] = cfg.trl.reward_processing_classes
|
||||
return trainer_kwargs
|
||||
|
||||
@classmethod
|
||||
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||
# No data collation is needed in GRPO, handled by trl's trainer __init__
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_blocklist_args_kwargs(cls):
|
||||
return ["dataset_num_proc"]
|
||||
|
||||
@classmethod
|
||||
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
|
||||
"""
|
||||
Returns the reward function from the given fully qualified name, or the path to the reward function model.
|
||||
|
||||
Args:
|
||||
reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform),
|
||||
or a HF hub path to the reward model.
|
||||
Raises:
|
||||
ValueError: If the reward function does not accept at least two arguments.
|
||||
|
||||
Returns:
|
||||
RewardFunc: A callable that accepts prompts and completions and returns rewards,
|
||||
or a path to a reward model.
|
||||
|
||||
"""
|
||||
try:
|
||||
# use importlib to dynamically load the reward function from the module
|
||||
reward_func_module_name = reward_func_fqn.split(".")[-1]
|
||||
reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
|
||||
reward_func = getattr(reward_func_module, reward_func_module_name)
|
||||
if not len(inspect.signature(reward_func).parameters) >= 2:
|
||||
raise ValueError(
|
||||
"Reward function must accept at least two arguments: prompts: list and completions: list"
|
||||
)
|
||||
return reward_func
|
||||
except ModuleNotFoundError:
|
||||
# the user has passed a string (ideally indicating the path of a reward model)
|
||||
LOG.info(
|
||||
f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
|
||||
)
|
||||
return reward_func
|
||||
15
src/axolotl/core/trainers/grpo/args.py
Normal file
15
src/axolotl/core/trainers/grpo/args.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Axolotl Specific Training Args
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trl import GRPOConfig
|
||||
|
||||
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""
|
||||
Axolotl GRPO Config for GRPO training
|
||||
"""
|
||||
108
src/axolotl/core/trainers/grpo/trainer.py
Normal file
108
src/axolotl/core/trainers/grpo/trainer.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Axolotl GRPO trainer
|
||||
"""
|
||||
from accelerate.utils import is_peft_model
|
||||
from accelerate.utils.other import is_compiled_module
|
||||
from transformers import PreTrainedModel
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
from trl.models import unwrap_model_for_generation
|
||||
|
||||
from axolotl.core.trainers.base import SchedulerMixin
|
||||
|
||||
|
||||
# mypy: ignore-errors
|
||||
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
"""
|
||||
Extend the base GRPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# pylint: disable=access-member-before-definition
|
||||
# Enable gradient checkpointing if requested
|
||||
if kwargs["args"].gradient_checkpointing:
|
||||
# Ensure use_cache is disabled
|
||||
if hasattr(self.model, "config"):
|
||||
self.model.config.use_cache = False
|
||||
|
||||
# Enable gradient checkpointing on the base model for PEFT
|
||||
if is_peft_model(self.model) and hasattr(
|
||||
self.model.base_model, "gradient_checkpointing_enable"
|
||||
):
|
||||
self.model.base_model.gradient_checkpointing_enable()
|
||||
# Enable gradient checkpointing for non-PEFT models
|
||||
elif hasattr(self.model, "gradient_checkpointing_enable"):
|
||||
self.model.gradient_checkpointing_enable()
|
||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
||||
# pylint: enable=access-member-before-definition
|
||||
|
||||
def _enable_gradient_checkpointing(
|
||||
self, model: PreTrainedModel, args: GRPOConfig
|
||||
) -> PreTrainedModel:
|
||||
"""Enables gradient checkpointing for the model."""
|
||||
# pylint: disable=unused-argument,redefined-builtin
|
||||
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
||||
use_reentrant = (
|
||||
"use_reentrant" not in gradient_checkpointing_kwargs
|
||||
or gradient_checkpointing_kwargs["use_reentrant"]
|
||||
)
|
||||
|
||||
if use_reentrant:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(
|
||||
make_inputs_require_grad
|
||||
)
|
||||
|
||||
return model
|
||||
# pylint: enable=unused-argument,redefined-builtin
|
||||
|
||||
def _move_model_to_vllm(self):
|
||||
with unwrap_model_for_generation(
|
||||
self.model,
|
||||
self.accelerator,
|
||||
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
|
||||
) as unwrapped_model:
|
||||
if is_compiled_module(unwrapped_model):
|
||||
unwrapped_model = (
|
||||
unwrapped_model._orig_mod # pylint: disable=protected-access
|
||||
)
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.merge_adapter()
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
# Remove base_model and base_layer prefixes
|
||||
state_dict = {
|
||||
k.removeprefix("base_model.model.")
|
||||
.removeprefix("base_model.model.")
|
||||
.replace(".base_layer", ""): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
# Remove values with adapter prefix (example: "_lora")
|
||||
state_dict = {
|
||||
k: v
|
||||
for k, v in state_dict.items()
|
||||
if unwrapped_model.prefix not in k
|
||||
}
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
state_dict = {
|
||||
k.replace("modules_to_save.default.", ""): v
|
||||
for k, v in state_dict.items()
|
||||
if "original_module" not in k
|
||||
}
|
||||
else:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = (
|
||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
)
|
||||
llm_model.load_weights(state_dict.items())
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.unmerge_adapter()
|
||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -206,6 +206,16 @@ class AxolotlTrainingMixins:
|
||||
},
|
||||
)
|
||||
|
||||
sp_ulysses_degree: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Ulysses parallelism for hybrid sequence parallel long context attn"},
|
||||
)
|
||||
|
||||
sp_ring_degree: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Ring attention parallelism for sequence parallel long context attn"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
@@ -217,13 +227,6 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||
"""
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -1,44 +0,0 @@
|
||||
# Low-rank Linear Conversion via Attention Transfer (LoLCATs)
|
||||
|
||||
https://github.com/HazyResearch/lolcats/
|
||||
|
||||
### Usage
|
||||
|
||||
Install `causal_dot_product` CUDA kernel (check the README in the `csrc` directory):
|
||||
|
||||
```bash
|
||||
cd src/axolotl/integrations/lolcats/linear_llama/csrc
|
||||
|
||||
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
|
||||
# nano setup.py
|
||||
|
||||
# Build the CUDA kernel
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
Step 1:
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.lolcats.LinearizePlugin
|
||||
|
||||
linearize: true
|
||||
```
|
||||
|
||||
Run axolotl: `python -m axolotl.cli.convert_linear_attention config.yaml` TODO: change path CLI
|
||||
|
||||
Step 2: Remove the config `linearize: true` and finetune with lora with below possible targets.
|
||||
|
||||
```yaml
|
||||
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||
|
||||
# with optional config below but this requires patching axolotl
|
||||
# to allow this config to work with lora
|
||||
# unfrozen_parameters: ['.*feature_map_q.mlp.layer.*', '.*feature_map_k.mlp.layer.*', '.*window_factors.*']
|
||||
```
|
||||
|
||||
`axolotl train config.yaml --base-model={output_dir}/distilled --trust-remote-code --learning-rate=0.0001 # --wandb-project="..."`
|
||||
|
||||
Step 3: Run inference on the finetuned model
|
||||
|
||||
`axolotl inference config.yaml --lora-model-dir="{output_dir}" --trust-remote-code # --prompter="AlpacaPrompter"`
|
||||
@@ -1,43 +0,0 @@
|
||||
"""
|
||||
Module for the Plugin for LoLCATs linear attention integration with Axolotl.
|
||||
|
||||
Low-rank Linear Conversion via Attention Transfer
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.integrations.lolcats.trainer.distill_attention_xent_mse import (
|
||||
DistillAttentionXentMSETrainer,
|
||||
)
|
||||
|
||||
from .args import LinearAttentionArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.lolcats")
|
||||
|
||||
|
||||
class LinearizePlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for lolcats integration with Axolotl.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Register the Linear Llama model with transformers
|
||||
from axolotl.integrations.lolcats.linear_llama.modeling_linear_llama import (
|
||||
register_linear_llama,
|
||||
)
|
||||
|
||||
register_linear_llama()
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.lolcats.LinearAttentionArgs"
|
||||
|
||||
def get_trainer_cls(self, cfg):
|
||||
# defualt to XentMSE
|
||||
# TODO: add check to allow MSE_linear
|
||||
if cfg.linearize:
|
||||
return DistillAttentionXentMSETrainer
|
||||
|
||||
return None
|
||||
@@ -1,47 +0,0 @@
|
||||
"""
|
||||
Module for handling linear attention input arguments.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FeatureMapKwargs(BaseModel):
|
||||
"""Args for feature map"""
|
||||
|
||||
eps: float
|
||||
mlp: Optional[None] = None
|
||||
fullspace: bool
|
||||
|
||||
|
||||
class LearnedKernelKwargs(BaseModel):
|
||||
"""Args for learned kernel"""
|
||||
|
||||
feature_dim: int
|
||||
skip_connection: bool
|
||||
bias: bool
|
||||
zero_init: bool
|
||||
|
||||
|
||||
class AttentionConfig(BaseModel):
|
||||
"""Args for attention config"""
|
||||
|
||||
attention_type: str
|
||||
feature_map: str
|
||||
feature_map_kwargs: FeatureMapKwargs
|
||||
layer_idx: Optional[None] = None
|
||||
learned_kernel: str
|
||||
learned_kernel_kwargs: LearnedKernelKwargs
|
||||
tie_qk_kernels: bool
|
||||
train_qk: bool
|
||||
|
||||
|
||||
class LinearAttentionArgs(BaseModel):
|
||||
"""
|
||||
Input args for linear attention
|
||||
"""
|
||||
|
||||
attention_config: AttentionConfig
|
||||
|
||||
linearize: Optional[bool] = False
|
||||
@@ -1,90 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Linear LLaMA model configuration"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers import LlamaConfig
|
||||
|
||||
|
||||
class LinearLlamaConfig(LlamaConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`LinearLlamaModel`].
|
||||
It is a modified LlamaConfig that includes additional parameters for linear attention.
|
||||
|
||||
Args:
|
||||
attention_config (`dict`):
|
||||
Dictionary containing the configuration for linear attention mechanism.
|
||||
Expected contents:
|
||||
`attention_type` (str):
|
||||
The type of attention to convert to.
|
||||
`feature_map` (`str`):
|
||||
The type of feature map to use for linear attention.
|
||||
`feature_map_kwargs` (`dict`):
|
||||
Additional arguments for the feature map.
|
||||
`learned_kernel` (`str`, *optional*):
|
||||
Type of learned kernel to use, if any.
|
||||
`learned_kernel_kwargs` (`dict`, *optional*):
|
||||
Additional arguments for the learned kernel.
|
||||
`tie_qk_kernels` (`bool`, *optional*, defaults to False):
|
||||
Whether to tie query and key kernels.
|
||||
`rotary_config` (`dict`, *optional*):
|
||||
Configuration for rotary embeddings.
|
||||
`train_attention` (`bool`, *optional*, defaults to False):
|
||||
Whether to train attention to match softmax attention.
|
||||
`remove_base_attn` (`bool`, *optional*, defaults to True):
|
||||
Whether to remove base attention after initialization.
|
||||
`mask_value` (`int`, *optional*, defaults to 0):
|
||||
Value to use for masking.
|
||||
`eps` (`float`, *optional*, defaults to 1e-12):
|
||||
Epsilon value for numerical stability.
|
||||
`fp32_attention` (`bool`, *optional*, defaults to False):
|
||||
Whether to use fp32 precision for attention computation.
|
||||
`track_state_grads` (`bool`, *optional*, defaults to False):
|
||||
Whether to track gradients of attention states.
|
||||
|
||||
**kwargs:
|
||||
Additional arguments inherited from LlamaConfig.
|
||||
"""
|
||||
|
||||
model_type = "linear_llama"
|
||||
|
||||
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set auto_map
|
||||
self.auto_map = {
|
||||
"AutoConfig": "configuration_linear_llama.LinearLlamaConfig",
|
||||
"AutoModel": "modeling_linear_llama.LinearLlamaModel",
|
||||
"AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM",
|
||||
}
|
||||
|
||||
# Set default attention config if none provided
|
||||
self.attention_config = attention_config or {"attention_type": "softmax"}
|
||||
|
||||
@classmethod
|
||||
def from_llama(cls, llama_config: LlamaConfig, attention_config: dict):
|
||||
"""
|
||||
Instantiate a LinearLlamaConfig from a LlamaConfig and additional attention config.
|
||||
|
||||
Args:
|
||||
llama_config (:class:`~transformers.LlamaConfig`):
|
||||
The LlamaConfig to inherit from.
|
||||
|
||||
attention_config (`dict`):
|
||||
Dictionary containing the configuration for linear attention mechanism.
|
||||
"""
|
||||
|
||||
return cls(attention_config=attention_config, **llama_config.to_dict())
|
||||
@@ -1,30 +0,0 @@
|
||||
# Causal linear attention CUDA kernel
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
cd src/axolotl/integrations/lolcats/linear_llama/csrc
|
||||
|
||||
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
|
||||
# nano setup.py
|
||||
|
||||
# Build the CUDA kernel
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
Reference: https://github.com/idiap/fast-transformers/
|
||||
|
||||
```bib
|
||||
@inproceedings{katharopoulos_et_al_2020,
|
||||
author = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
|
||||
title = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
|
||||
booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
|
||||
year = {2020}
|
||||
}
|
||||
|
||||
@article{vyas_et_al_2020,
|
||||
author={Vyas, A. and Katharopoulos, A. and Fleuret, F.},
|
||||
title={Fast Transformers with Clustered Attention},
|
||||
booktitle = {Proceedings of the International Conference on Neural Information Processing Systems (NeurIPS)},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
@@ -1,6 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
# Apoorv Vyas <avyas@idiap.ch>
|
||||
#
|
||||
from .causal_attention import causal_dot_product
|
||||
@@ -1,225 +0,0 @@
|
||||
//
|
||||
// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
// Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
// Apoorv Vyas <avyas@idiap.ch>
|
||||
//
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
||||
/**
|
||||
* Compute a*b^T and save it into out.
|
||||
*
|
||||
* a \in R^A
|
||||
* b \in R^B
|
||||
*/
|
||||
inline void vvt_dot(float *a, float *b, float *out, int A, int B) {
|
||||
for (int i=0; i<A; i++) {
|
||||
float * bi = b;
|
||||
for (int j=0; j<B; j++) {
|
||||
*out += (*a) * (*bi);
|
||||
out++;
|
||||
bi++;
|
||||
}
|
||||
a++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Implement a vector matrix product v*m and save it into out.
|
||||
*
|
||||
* v \in R^A
|
||||
* m \in R^{AxB}
|
||||
*/
|
||||
inline void vm_dot(float *v, float *m, float *out, int A, int B) {
|
||||
// TODO: Consider removing the zeroing part and assuming out already
|
||||
// contains 0s
|
||||
for (int i=0; i<B; i++) {
|
||||
out[i] = 0;
|
||||
}
|
||||
|
||||
for (int i=0; i<A; i++) {
|
||||
float *oi = out;
|
||||
for (int j=0; j<B; j++) {
|
||||
*oi += (*v) * (*m);
|
||||
oi++;
|
||||
m++;
|
||||
}
|
||||
v++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Implement a vector transposed-matrix product and save it into out.
|
||||
*
|
||||
* v \in R^B
|
||||
* m \in R^{AxB}
|
||||
*/
|
||||
inline void vmt_dot(float *v, float *m, float *out, int A, int B) {
|
||||
for (int i=0; i<A; i++) {
|
||||
float *vi = v;
|
||||
float s = 0;
|
||||
for (int j=0; j<B; j++) {
|
||||
s += (*vi) * (*m);
|
||||
vi++;
|
||||
m++;
|
||||
}
|
||||
// TODO: Should we be aggregating? See the comment on vm_dot.
|
||||
*out = s;
|
||||
out++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute the causally masked dot products of queries, keys and values.
|
||||
*
|
||||
* Basically compute V_j' = (Q_{0:j} * K_{0:j}^T) * V_{0:j} for all j. The
|
||||
* computation is done efficiently by changing the order of the dot products.
|
||||
*/
|
||||
void causal_dot_product(
|
||||
const torch::Tensor queries,
|
||||
const torch::Tensor keys,
|
||||
const torch::Tensor values,
|
||||
torch::Tensor product
|
||||
) {
|
||||
// Extract some shapes
|
||||
int N = queries.size(0);
|
||||
int H = queries.size(1);
|
||||
int L = queries.size(2);
|
||||
int E = queries.size(3);
|
||||
int M = values.size(3);
|
||||
|
||||
// Create accessors for all the arguments
|
||||
auto qa = queries.accessor<float, 4>();
|
||||
auto ka = keys.accessor<float, 4>();
|
||||
auto va = values.accessor<float, 4>();
|
||||
auto pa = product.accessor<float, 4>();
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int n=0; n<N; n++) {
|
||||
for (int h=0; h<H; h++) {
|
||||
auto kv = torch::zeros({E, M}, queries.options());
|
||||
float *kvp = kv.data_ptr<float>();
|
||||
for (int l=0; l<L; l++) {
|
||||
vvt_dot(
|
||||
&ka[n][h][l][0],
|
||||
&va[n][h][l][0],
|
||||
kvp,
|
||||
E,
|
||||
M
|
||||
);
|
||||
vm_dot(
|
||||
&qa[n][h][l][0],
|
||||
kvp,
|
||||
&pa[n][h][l][0],
|
||||
E,
|
||||
M
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute the gradients of queries, keys and values given the gradient of the
|
||||
* causal_dot_product output.
|
||||
*
|
||||
* Make sure that everything is computed in O(N D^2) complexity.
|
||||
*/
|
||||
void causal_dot_backward(
|
||||
const torch::Tensor queries,
|
||||
const torch::Tensor keys,
|
||||
const torch::Tensor values,
|
||||
const torch::Tensor grad_out,
|
||||
torch::Tensor grad_queries,
|
||||
torch::Tensor grad_keys,
|
||||
torch::Tensor grad_values
|
||||
) {
|
||||
// Extract some shapes
|
||||
int N = queries.size(0);
|
||||
int H = queries.size(1);
|
||||
int L = queries.size(2);
|
||||
int E = queries.size(3);
|
||||
int M = values.size(3);
|
||||
|
||||
// Create accessors for all the arguments
|
||||
auto qa = queries.accessor<float, 4>();
|
||||
auto ka = keys.accessor<float, 4>();
|
||||
auto va = values.accessor<float, 4>();
|
||||
auto ga = grad_out.accessor<float, 4>();
|
||||
auto gqa = grad_queries.accessor<float, 4>();
|
||||
auto gka = grad_keys.accessor<float, 4>();
|
||||
auto gva = grad_values.accessor<float, 4>();
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int n=0; n<N; n++) {
|
||||
for (int h=0; h<H; h++) {
|
||||
auto kv = torch::zeros({E, M}, queries.options());
|
||||
float *kvp = kv.data_ptr<float>();
|
||||
|
||||
// Compute the gradient wrt the queries
|
||||
for (int l=0; l<L; l++) {
|
||||
vvt_dot(
|
||||
&ka[n][h][l][0],
|
||||
&va[n][h][l][0],
|
||||
kvp,
|
||||
E,
|
||||
M
|
||||
);
|
||||
vmt_dot(
|
||||
&ga[n][h][l][0],
|
||||
kvp,
|
||||
&gqa[n][h][l][0],
|
||||
E,
|
||||
M
|
||||
);
|
||||
}
|
||||
|
||||
// Compute the gradient wrt the keys and values
|
||||
kv.zero_();
|
||||
for (int l=L-1; l>=0; l--) {
|
||||
vvt_dot(
|
||||
&qa[n][h][l][0],
|
||||
&ga[n][h][l][0],
|
||||
kvp,
|
||||
E,
|
||||
M
|
||||
);
|
||||
vmt_dot(
|
||||
&va[n][h][l][0],
|
||||
kvp,
|
||||
&gka[n][h][l][0],
|
||||
E,
|
||||
M
|
||||
);
|
||||
vm_dot(
|
||||
&ka[n][h][l][0],
|
||||
kvp,
|
||||
&gva[n][h][l][0],
|
||||
E,
|
||||
M
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"causal_dot_product",
|
||||
&causal_dot_product,
|
||||
"Compute the weighted sum of values but attending only to previous "
|
||||
"values."
|
||||
);
|
||||
m.def(
|
||||
"causal_dot_backward",
|
||||
&causal_dot_backward,
|
||||
"Compute the gradient of queries, keys and values given the gradient "
|
||||
"of causal_dot_product."
|
||||
);
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
# Apoorv Vyas <avyas@idiap.ch>
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from causal_attention_cuda import causal_dot_backward as causal_dot_backward_cuda
|
||||
from causal_attention_cuda import causal_dot_product as causal_dot_product_cuda
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
causal_dot_product_cuda = causal_dot_backward_cuda = None
|
||||
|
||||
|
||||
class CausalDotProduct(torch.autograd.Function):
|
||||
"""Compute the weighted sum of values but attending only to previous
|
||||
values."""
|
||||
|
||||
dot = {
|
||||
# "cpu": causal_dot_product_cpu,
|
||||
"cuda": causal_dot_product_cuda
|
||||
}
|
||||
dot_backward = {
|
||||
# "cpu": causal_dot_backward_cpu,
|
||||
"cuda": causal_dot_backward_cuda
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, Q, K, V):
|
||||
# Save the inputs for the gradient computation
|
||||
ctx.save_for_backward(Q, K, V)
|
||||
|
||||
# Create the output tensor
|
||||
device = Q.device
|
||||
N, H, L, _ = Q.shape
|
||||
_, _, _, M = V.shape
|
||||
product = torch.zeros((N, H, L, M), dtype=Q.dtype, device=device)
|
||||
|
||||
# Actually perform the dot product
|
||||
CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
|
||||
# breakpoint()
|
||||
# CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
|
||||
|
||||
return product
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
# Extract the saved tensors
|
||||
Q, K, V = ctx.saved_tensors
|
||||
|
||||
# Allocate memory for the gradients
|
||||
grad_Q = torch.zeros_like(Q)
|
||||
grad_K = torch.zeros_like(K)
|
||||
grad_V = torch.zeros_like(V)
|
||||
|
||||
# Actually compute the gradients
|
||||
CausalDotProduct.dot_backward[Q.device.type](
|
||||
Q.data, K.data, V.data, grad_out, grad_Q, grad_K, grad_V
|
||||
)
|
||||
|
||||
return grad_Q, grad_K, grad_V
|
||||
|
||||
|
||||
# Alias the autograd functions to python style snake case naming
|
||||
causal_dot_product = CausalDotProduct.apply
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,65 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
# Apoorv Vyas <avyas@idiap.ch>
|
||||
#
|
||||
|
||||
import subprocess # nosec
|
||||
|
||||
import torch
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
|
||||
|
||||
|
||||
def get_last_arch_torch():
|
||||
arch = torch.cuda.get_arch_list()[-1]
|
||||
print(f"Found arch: {arch} from existing torch installation")
|
||||
return arch
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output(
|
||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True # nosec
|
||||
)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
return nvcc_extra_args
|
||||
|
||||
|
||||
arch = get_last_arch_torch()
|
||||
sm_num = arch[-2:]
|
||||
cc_flag = ["--generate-code=arch=compute_90,code=compute_90"] # for H100
|
||||
# cc_flag = ['--generate-code=arch=compute_80,code=compute_80'] # for A100
|
||||
# cc_flag = ['--generate-code=arch=compute_89,code=compute_89'] # for RTX 6000, 4090
|
||||
# cc_flag = ['--generate-code=arch=compute_86,code=compute_86'] # for A6000, 3090
|
||||
# cc_flag = ['--generate-code=arch=compute_75,code=compute_75']
|
||||
|
||||
setup(
|
||||
name="causal_attention_cuda_cpp",
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
"causal_attention_cuda",
|
||||
[
|
||||
# 'causal_attention.cpp',
|
||||
"causal_attention_cuda.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": ["-O3"],
|
||||
"nvcc": append_nvcc_threads(
|
||||
["-O3", "-lineinfo", "--use_fast_math", "-std=c++17"] + cc_flag
|
||||
),
|
||||
},
|
||||
)
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
||||
@@ -1,856 +0,0 @@
|
||||
"""
|
||||
Linear attention classes
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
# Causal linear attention dot product CUDA kernel from fast-transformers
|
||||
try:
|
||||
from csrc import causal_dot_product as fast_causal_dot_product
|
||||
except ImportError:
|
||||
fast_causal_dot_product = None
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
# -------------------
|
||||
# Attention functions
|
||||
# -------------------
|
||||
|
||||
|
||||
def causal_dot_product(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
"""
|
||||
Causal linear attention dot product
|
||||
- If available, use CUDA kernel from fast-transformers
|
||||
"""
|
||||
if fast_causal_dot_product is None:
|
||||
kv = torch.einsum("bhlf,bhld->bhlfd", k, v)
|
||||
return torch.einsum("bhlf,bhlfd->bhld", q, kv.cumsum(dim=2))
|
||||
return fast_causal_dot_product(q, k, v)
|
||||
|
||||
|
||||
def linear_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
fp32_attention: bool = False,
|
||||
eps: float = 1e-12,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Compute linear attention with CUDA kernel implementation from fast-transformers
|
||||
- https://github.com/idiap/fast-transformers
|
||||
- Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim);
|
||||
v is shape (b, h, l, head_dim)
|
||||
"""
|
||||
dtype = q.dtype
|
||||
# Causal mask already applied
|
||||
y = causal_dot_product(
|
||||
q.contiguous().to(dtype=torch.float32),
|
||||
k.contiguous().to(dtype=torch.float32),
|
||||
v.contiguous().to(dtype=torch.float32),
|
||||
)
|
||||
if fp32_attention:
|
||||
y = (
|
||||
y
|
||||
/ (
|
||||
torch.einsum("bhld,bhld->bhl", q.float(), k.float().cumsum(dim=2)) + eps
|
||||
)[..., None]
|
||||
).to(dtype=dtype)
|
||||
else:
|
||||
y = y.to(dtype=dtype)
|
||||
k = k.float().cumsum(dim=2).to(dtype=dtype)
|
||||
y = y / (torch.einsum("bhld,bhld->bhl", q, k) + eps)[..., None]
|
||||
return y, None, None
|
||||
|
||||
|
||||
def softmax_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: Optional[torch.Tensor] = None,
|
||||
causal: bool = True,
|
||||
fp32_attention: bool = True,
|
||||
):
|
||||
"""
|
||||
Standard softmax attention; only compute outputs if v is not None
|
||||
-> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim)
|
||||
"""
|
||||
y = None
|
||||
a = torch.einsum("bhmd,bhnd->bhmn", q, k) * (k.shape[-1] ** -0.5)
|
||||
if causal: # Apply causal mask
|
||||
m, n = a.shape[-2:]
|
||||
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
|
||||
n - m + 1
|
||||
)
|
||||
a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max)
|
||||
if fp32_attention:
|
||||
a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||
else:
|
||||
a = torch.softmax(a, dim=-1)
|
||||
if v is not None:
|
||||
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
|
||||
return y, a, None
|
||||
|
||||
|
||||
def quadratic_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: Optional[torch.Tensor] = None,
|
||||
causal: bool = True,
|
||||
fp32_attention: bool = False,
|
||||
eps: float = 1e-12,
|
||||
):
|
||||
"""
|
||||
Compute attention with feature maps by instantiating L x L matrix of attention weights
|
||||
-> Use for attention distillation
|
||||
-> Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); v is shape (b, h, l, head_dim)
|
||||
"""
|
||||
y = None
|
||||
dtype = q.dtype
|
||||
if fp32_attention:
|
||||
q, k = q.float(), k.float()
|
||||
a = torch.einsum("bhmd,bhnd->bhmn", q, k) # note we don't scale, tho we could
|
||||
if causal: # Apply causal mask
|
||||
m, n = a.shape[-2:]
|
||||
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
|
||||
n - m + 1
|
||||
)
|
||||
a = a.masked_fill(causal_mask, 0)
|
||||
# Normalize to compute attention
|
||||
a = a / (a.sum(dim=-1, keepdim=True) + eps)
|
||||
a = a.to(dtype=dtype) if fp32_attention else a
|
||||
if torch.isnan(a).sum() > 0:
|
||||
breakpoint()
|
||||
if v is not None:
|
||||
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
|
||||
return y, a, None
|
||||
|
||||
|
||||
# ---------------------
|
||||
# Attention layer class
|
||||
# ---------------------
|
||||
|
||||
|
||||
class LolcatsLinearAttention(nn.Module):
|
||||
"""
|
||||
LoLCATs attention implementation initialized from a
|
||||
`LlamaAttention` or `MistralAttention` object (base_attn)
|
||||
|
||||
Most of the arguments are directly tied to argparse args
|
||||
- For now we don't support padding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_attn: nn.Module, # like LlamaAttention
|
||||
feature_map: str,
|
||||
feature_map_kwargs: dict,
|
||||
layer_idx: Optional[int] = None,
|
||||
max_layer_idx: Optional[int] = None,
|
||||
learned_kernel: Optional[str] = None,
|
||||
learned_kernel_kwargs: Optional[dict] = None,
|
||||
tie_qk_kernels: Optional[bool] = False,
|
||||
rotary_config: Optional[dict] = None,
|
||||
train_attention: Optional[bool] = False,
|
||||
remove_base_attn: bool = True,
|
||||
attention_type: Optional[str] = "lolcats_llama",
|
||||
mask_value: int = 0,
|
||||
eps: float = 1e-12,
|
||||
fp32_attention: bool = False,
|
||||
track_state_grads: bool = False,
|
||||
rank: Optional[int] = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.base_config = getattr(base_attn, "config", None)
|
||||
if self.base_config is not None:
|
||||
self.base_config = self.base_config.to_dict()
|
||||
self.attention_type = attention_type
|
||||
self.mask_value = mask_value
|
||||
self.eps = eps
|
||||
self.layer_idx = layer_idx if layer_idx is not None else base_attn.layer_idx
|
||||
self.max_layer_idx = max_layer_idx
|
||||
self.tie_qk_kernels = tie_qk_kernels
|
||||
self.train_attention = train_attention
|
||||
self.base_inference = False
|
||||
self.fp32_attention = fp32_attention
|
||||
self.track_state_grads = track_state_grads
|
||||
if rank == 0: # multi-gpu
|
||||
if fp32_attention and layer_idx == 0:
|
||||
print(f"-> fp32_attention is {fp32_attention}")
|
||||
if layer_idx == 0 and feature_map_kwargs is not None:
|
||||
for k, v in feature_map_kwargs.items():
|
||||
print(f"-> {k}: {v}")
|
||||
if layer_idx == 0 and learned_kernel_kwargs is not None:
|
||||
for k, v in learned_kernel_kwargs.items():
|
||||
print(f"-> {k}: {v}")
|
||||
|
||||
self.remove_base_attn = remove_base_attn
|
||||
|
||||
self.init_weights_(base_attn, remove_base_attn)
|
||||
self.init_feature_map_(
|
||||
feature_map, feature_map_kwargs, learned_kernel, learned_kernel_kwargs
|
||||
)
|
||||
|
||||
def init_feature_map_(
|
||||
self,
|
||||
feature_map: str,
|
||||
feature_map_kwargs: dict,
|
||||
learned_kernel: Optional[str] = None,
|
||||
learned_kernel_kwargs: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize MLP-based feature map
|
||||
"""
|
||||
self.fmap_gqa = False # Turn True if specified below
|
||||
if learned_kernel is not None and learned_kernel_kwargs is not None:
|
||||
# Ensure dict
|
||||
learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()}
|
||||
learned_kernel_kwargs["num_heads"] = self.num_heads
|
||||
learned_kernel_kwargs["head_dim"] = self.head_dim
|
||||
learned_kernel_kwargs["dtype"] = self.q_proj.weight.dtype
|
||||
learned_kernel_kwargs["device"] = self.q_proj.weight.device
|
||||
# Create MLP
|
||||
mlp_learned_kernel = init_learned_kernel(
|
||||
learned_kernel, **learned_kernel_kwargs
|
||||
)
|
||||
# Add "activation"; see src.models.feature_map.py
|
||||
self.feature_map_q = init_feature_map(
|
||||
name=feature_map, mlp=mlp_learned_kernel, **feature_map_kwargs
|
||||
)
|
||||
if self.tie_qk_kernels: # tie mlp weights for query and key feature maps
|
||||
self.feature_map_k = self.feature_map_q
|
||||
else:
|
||||
self.feature_map_k = copy.deepcopy(self.feature_map_q)
|
||||
|
||||
def init_weights_(self, base_attn: nn.Module, remove_base_attn: bool = True):
|
||||
"""
|
||||
Initialize module layers, weights, positional dependencies, etc.
|
||||
from original softmax attention layer (base_attn)
|
||||
"""
|
||||
# Make other attributes accessible
|
||||
self.attention_dropout = 0 # We don't use dropout
|
||||
self.hidden_size = base_attn.config.hidden_size
|
||||
self.num_heads = base_attn.config.num_attention_heads
|
||||
self.head_dim = base_attn.head_dim
|
||||
self.num_key_value_heads = base_attn.config.num_key_value_heads
|
||||
self.num_key_value_groups = base_attn.num_key_value_groups
|
||||
|
||||
self.q_shape = [self.num_heads, self.head_dim]
|
||||
self.k_shape = [self.num_key_value_heads, self.head_dim]
|
||||
self.v_shape = [self.num_key_value_heads, self.head_dim]
|
||||
|
||||
# Copy original model projection layers
|
||||
self.q_proj = base_attn.q_proj
|
||||
self.k_proj = base_attn.k_proj
|
||||
self.v_proj = base_attn.v_proj
|
||||
self.o_proj = base_attn.o_proj
|
||||
try: # If wanting to use FA2 for ground-truth inference
|
||||
self._flash_attn_uses_top_left_mask = (
|
||||
base_attn._flash_attn_uses_top_left_mask
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if self.remove_base_attn or remove_base_attn:
|
||||
del base_attn # We don't need to keep these around
|
||||
else:
|
||||
self.base_attn = base_attn # For some training runs helpful to just call
|
||||
|
||||
def process_qkv(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
past_key_value: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Compute queries, keys, and values
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
kv_seq_len = k.shape[-2]
|
||||
|
||||
# Shape is (batch_size, seq_len, num_heads, head_dim)
|
||||
q = q.view(b, l, *self.q_shape).transpose(1, 2)
|
||||
k = k.view(b, l, *self.k_shape).transpose(1, 2)
|
||||
v = v.view(b, l, *self.v_shape).transpose(1, 2)
|
||||
|
||||
if (
|
||||
past_key_value is not None
|
||||
): # and k.shape[2] > q.shape[2]: # e.g., when generating
|
||||
past_key_value.window_size = getattr(
|
||||
self, "decode_window_size", None
|
||||
) # self.decode_window_size
|
||||
if isinstance(
|
||||
past_key_value, Cache
|
||||
): # In Transformers v4.36+ this is a DynamicCache object
|
||||
kv_seq_len += past_key_value.get_usable_length(
|
||||
kv_seq_len, self.layer_idx
|
||||
)
|
||||
else:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
# Apply rotary embeddings
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||
|
||||
k = repeat_kv(k, self.num_key_value_groups)
|
||||
v = repeat_kv(v, self.num_key_value_groups)
|
||||
return q, k, v, kv_seq_len
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
past_key_value: Optional[Any] = None, # "legacy" cache approach
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass modified from transformers.models.mistral.modeling_mistral (v4.36)
|
||||
- Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_embeddings, past_key_value
|
||||
)
|
||||
|
||||
if self.base_inference:
|
||||
with torch.no_grad():
|
||||
# 1. Compute "ground-truth" attention output and weights
|
||||
y_true, _, _ = softmax_attention(q, k, v, causal=True)
|
||||
y_true = (
|
||||
y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
)
|
||||
y_true = self.o_proj(y_true)
|
||||
attn_weights = (None, None)
|
||||
|
||||
elif self.train_attention: # Distilling / learning attentions
|
||||
# Note for now we assume no padding when distilling; attention masks only enforce causality
|
||||
assert (
|
||||
output_attentions is True
|
||||
), f"When training feature maps, output_attentions should be True but is {output_attentions}"
|
||||
with torch.no_grad():
|
||||
# 1. Compute "ground-truth" attention output and weights
|
||||
_y_true, attn_true, _ = softmax_attention(q, k, v, causal=True)
|
||||
y_true = (
|
||||
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
)
|
||||
y_true = self.o_proj(y_true)
|
||||
|
||||
# 2. Compute "predicted" attention (just weights)
|
||||
q, k = self.feature_map_q.q_map(q), self.feature_map_k.k_map(k)
|
||||
y_pred, attn_pred, _ = quadratic_attention(q, k, v, causal=True)
|
||||
attn_weights = ( # type: ignore
|
||||
(attn_pred, attn_true),
|
||||
(y_pred, _y_true),
|
||||
) # Save both attention weights so we can supervise.
|
||||
|
||||
else: # Finetuning
|
||||
q, k = self.feature_map_q(q), self.feature_map_k(k)
|
||||
# Apply prefill mask
|
||||
if attention_mask is not None and q.shape[2] > 1:
|
||||
if len(attention_mask.shape) == 4:
|
||||
lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][
|
||||
..., None
|
||||
] # b, 1, k_len, 1
|
||||
else:
|
||||
lin_attn_mask = attention_mask.bool()[:, None, :, None] # b, 1, k_len, 1
|
||||
k = k.masked_fill(~lin_attn_mask, 0)
|
||||
|
||||
if past_key_value is not None: # Initialize states
|
||||
if len(past_key_value.kv_states) == self.layer_idx:
|
||||
b, h, _, f = k.shape
|
||||
past_key_value.kv_states.append(
|
||||
torch.zeros(
|
||||
b, h, f, self.head_dim, dtype=q.dtype, device=q.device
|
||||
)
|
||||
)
|
||||
past_key_value.k_states.append(
|
||||
torch.zeros(b, h, 1, f, dtype=q.dtype, device=q.device)
|
||||
)
|
||||
# Generating
|
||||
if q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None:
|
||||
assert use_cache is True
|
||||
kv_state, k_state = past_key_value.update(
|
||||
k, v, self.layer_idx, accumulate_in_fp32=self.fp32_attention
|
||||
)
|
||||
if self.fp32_attention:
|
||||
q = q.float()
|
||||
y_true = (
|
||||
torch.einsum("bhlf,bhfd->bhld", q, kv_state.float())
|
||||
/ torch.einsum("bhlf,bhlf->bhl", q, k_state.float())[
|
||||
..., None
|
||||
]
|
||||
).to(dtype=k.dtype)
|
||||
else:
|
||||
y_true = (
|
||||
torch.einsum("bhlf,bhfd->bhld", q, kv_state)
|
||||
/ torch.einsum("bhlf,bhlf->bhl", q, k_state)[..., None]
|
||||
)
|
||||
else:
|
||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||
k_state = past_key_value.k_states[self.layer_idx]
|
||||
y_true, _, _ = linear_attention(
|
||||
q, k, v, self.fp32_attention, self.eps
|
||||
) # Ordinarily the states are ignored
|
||||
past_key_value.update(
|
||||
k.detach(),
|
||||
v.detach(),
|
||||
self.layer_idx,
|
||||
accumulate_in_fp32=self.fp32_attention,
|
||||
)
|
||||
# doing some unnecessary recomputation here
|
||||
else:
|
||||
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps)
|
||||
|
||||
# Concatenate heads and apply output projection
|
||||
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
y_true = self.o_proj(y_true)
|
||||
attn_weights = None
|
||||
|
||||
return y_true, attn_weights
|
||||
|
||||
|
||||
class LinearAttentionState(Cache):
|
||||
"""
|
||||
Handle the KV and K states for linear attention
|
||||
- Adopts HF Transformers `past_key_values` convention
|
||||
- Inherits from `Cache` class
|
||||
- Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||
self._seen_tokens_by_layer: List[int] = []
|
||||
self.kv_states: List[torch.Tensor] = []
|
||||
self.k_states: List[torch.Tensor] = []
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""
|
||||
Returns the sequence length of the cached states. A layer index can be optionally passed.
|
||||
"""
|
||||
if layer_idx is None:
|
||||
raise ValueError("Layer index must not be None")
|
||||
|
||||
if len(self._seen_tokens_by_layer) <= layer_idx: # Initializing kv and k states
|
||||
self._seen_tokens_by_layer.append(0)
|
||||
return self._seen_tokens_by_layer[layer_idx]
|
||||
|
||||
def get_max_length(self) -> Optional[int]:
|
||||
"""
|
||||
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_usable_length(
|
||||
self, new_seq_length: int, layer_idx: Optional[int] = 0
|
||||
) -> int:
|
||||
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
||||
# Cache without size limit -> all cache is usable
|
||||
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
||||
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
||||
max_length = self.get_max_length()
|
||||
previous_seq_length = self.get_seq_length(layer_idx)
|
||||
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
||||
return max_length - new_seq_length
|
||||
return previous_seq_length
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: Optional[int] = None,
|
||||
cache_kwargs: Optional[Any] = None,
|
||||
accumulate_in_fp32: bool = True,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if layer_idx is None:
|
||||
raise ValueError("Layer index must not be None")
|
||||
|
||||
with torch.no_grad():
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
dtype = key_states.dtype
|
||||
if accumulate_in_fp32:
|
||||
key_states, value_states = key_states.float(), value_states.float()
|
||||
|
||||
kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd", key_states, value_states
|
||||
).detach()
|
||||
k_state = key_states.sum(
|
||||
dim=-2, keepdim=True
|
||||
).detach() # b, h, 1, f; note the 1
|
||||
# Update the cache
|
||||
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||
print(
|
||||
"if len(self.k_states) <= layer_idx: # Initializing kv and k states"
|
||||
)
|
||||
self.kv_states.append(kv_state.to(dtype))
|
||||
self.k_states.append(k_state.to(dtype))
|
||||
else:
|
||||
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||
dtype
|
||||
)
|
||||
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||
dtype
|
||||
)
|
||||
self.kv_states[layer_idx] = kv_state
|
||||
self.k_states[layer_idx] = k_state
|
||||
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||
|
||||
def to_legacy_cache(self):
|
||||
"""Hack, but just return self"""
|
||||
return self
|
||||
|
||||
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||
"""
|
||||
Reorders the cache for beam search, given the selected beam indices.
|
||||
-> Copied from transformers/src/transformers/cache_utils.py
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Reordering cache not implemented for LinearAttentionState"
|
||||
)
|
||||
|
||||
|
||||
# -------------------
|
||||
# feature map functions
|
||||
# -------------------
|
||||
|
||||
|
||||
def init_feature_map(name: str, mlp: nn.Module, **kwargs):
|
||||
"""
|
||||
Initialize feature map final activation for linear attention
|
||||
"""
|
||||
return FeatureMap(activation_name=name, mlp=mlp, **kwargs)
|
||||
|
||||
|
||||
def init_feature_map_act(name: str, fullspace: bool = True, **kwargs):
|
||||
"""
|
||||
Initialize feature map final activation for linear attention
|
||||
"""
|
||||
if name == "softmax_dim" and fullspace:
|
||||
return SoftmaxDim(**kwargs)
|
||||
elif name == "softmax_dim" and not fullspace:
|
||||
return SoftmaxDimHalfspace(**kwargs)
|
||||
elif name == "exp_dim" and fullspace:
|
||||
return Exp(**kwargs)
|
||||
elif name == "exp_dim" and not fullspace:
|
||||
return ExpHalfspace(**kwargs)
|
||||
elif name == "pos_elu":
|
||||
return PosELU(**kwargs)
|
||||
elif name == "relu":
|
||||
return ReLU(**kwargs)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def init_learned_kernel(name: str, **kwargs):
|
||||
"""
|
||||
Initialize feature map MLP for linear attention
|
||||
"""
|
||||
if name == "untied_head_einsum":
|
||||
return FeatureMapMLP(**kwargs)
|
||||
elif name == "untied_head_adapter":
|
||||
return FeatureMapAdapter(**kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FeatureMap(nn.Module):
|
||||
"""
|
||||
Final 'activation' of feature map. Can probably be combined with
|
||||
`FeatureMapMLP` below
|
||||
|
||||
Full feature map is like f(xW + b)
|
||||
-> This is the `f` part
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_name: str,
|
||||
head_dim_idx: int = -1,
|
||||
eps: float = 1e-12,
|
||||
mlp: Optional[nn.Module] = None,
|
||||
fullspace: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.head_dim_idx = head_dim_idx
|
||||
self.eps = eps
|
||||
self.mlp = mlp if mlp is not None else nn.Identity()
|
||||
self.activation = init_feature_map_act(activation_name, fullspace, eps=eps)
|
||||
|
||||
def forward(self, x: torch.Tensor, *mlp_args, **mlp_kwargs):
|
||||
"""
|
||||
Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
|
||||
"""
|
||||
return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x)
|
||||
|
||||
def q_map(self, *args, **kwargs):
|
||||
"""
|
||||
Use for inference in case q and k feature maps differ
|
||||
"""
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def k_map(self, *args, **kwargs):
|
||||
"""
|
||||
Use for inference in case q and k feature maps differ
|
||||
"""
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Feature map activations
|
||||
# -----------------------
|
||||
class FeatureMapAct(nn.Module):
|
||||
"""
|
||||
Base class for feature map activations
|
||||
"""
|
||||
|
||||
def __init__(self, eps: float = 1e-12):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
"""
|
||||
x.shape is (batch_size, n_heads, seq_len, head_dim)
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
class PosELU(FeatureMapAct):
|
||||
"""
|
||||
1 + ELU activation as in https://arxiv.org/abs/2006.16236
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
return (1 + F.elu(x)).clamp(min=self.eps)
|
||||
|
||||
|
||||
class ReLU(FeatureMapAct):
|
||||
"""
|
||||
ReLU activation as in https://arxiv.org/abs/2103.13076
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
return F.relu(x).clamp(min=self.eps)
|
||||
|
||||
|
||||
class SoftmaxDim(FeatureMapAct):
|
||||
"""
|
||||
Softmax activation as in https://arxiv.org/abs/2402.04347
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
return torch.cat(
|
||||
[torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)], dim=-1
|
||||
).clamp(min=self.eps)
|
||||
|
||||
|
||||
class SoftmaxDimHalfspace(FeatureMapAct):
|
||||
"""
|
||||
Softmax activation as in https://arxiv.org/abs/2402.04347
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
return torch.softmax(x, dim=-1).clamp(min=self.eps)
|
||||
|
||||
|
||||
class Exp(FeatureMapAct):
|
||||
"""
|
||||
Exp activation as in https://arxiv.org/abs/2402.04347
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
x_max = torch.amax(x, dim=-1, keepdim=True)
|
||||
x_min = torch.amin(x, dim=-1, keepdim=True)
|
||||
return torch.cat([torch.exp(x - x_max), torch.exp(-x + x_min)], dim=-1).clamp(
|
||||
min=self.eps
|
||||
)
|
||||
|
||||
|
||||
class ExpHalfspace(FeatureMapAct):
|
||||
"""
|
||||
Exp activation as in https://arxiv.org/abs/2402.04347
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
x_max = torch.amax(x, dim=-1, keepdim=True)
|
||||
return torch.exp(x - x_max).clamp(min=self.eps)
|
||||
|
||||
|
||||
# ----------------
|
||||
# Feature map MLPs
|
||||
# ----------------
|
||||
|
||||
|
||||
class FeatureMapMLP(nn.Module):
|
||||
"""
|
||||
Learnable MLP in feature map.
|
||||
|
||||
Full feature map is like f(xW + b)
|
||||
-> This is the `W` and (optional) `b` part
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_dim: int, # input dim
|
||||
feature_dim: int, # output dim
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
skip_connection: bool = False,
|
||||
bias: bool = False,
|
||||
zero_init: bool = False,
|
||||
normal_init: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.feature_dim = feature_dim
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.skip_connection = skip_connection
|
||||
self.bias = bias
|
||||
self.zero_init = zero_init
|
||||
self.normal_init = normal_init
|
||||
self.init_weights_()
|
||||
|
||||
if self.zero_init: # Zero-out weights or set as identity post-initialization
|
||||
self.zero_init_with_skip_() if self.skip_connection else self.zero_init_()
|
||||
|
||||
if self.normal_init:
|
||||
with torch.no_grad():
|
||||
nn.init.normal_(self.layer)
|
||||
|
||||
if self.skip_connection:
|
||||
assertion_fail = f"If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}"
|
||||
assert self.head_dim == self.feature_dim, assertion_fail
|
||||
|
||||
def init_weights_(self):
|
||||
"""
|
||||
Initialize (W)eights and (b)iases
|
||||
"""
|
||||
self.layer = nn.Parameter(
|
||||
torch.zeros(
|
||||
(self.num_heads, self.head_dim, self.feature_dim),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
)
|
||||
nn.init.kaiming_uniform_(self.layer)
|
||||
|
||||
if self.bias:
|
||||
self.bias = nn.Parameter(
|
||||
torch.zeros(
|
||||
(1, self.num_heads, 1, 1), # self.feature_dim),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
)
|
||||
nn.init.kaiming_uniform_(self.bias)
|
||||
else:
|
||||
self.bias = 0.0 # hack
|
||||
|
||||
def zero_init_with_skip_(self):
|
||||
"""
|
||||
Initialize weights to zero matrix if skip connection
|
||||
"""
|
||||
with torch.no_grad():
|
||||
nn.init.zeros_(self.layer)
|
||||
|
||||
def zero_init_(self):
|
||||
"""
|
||||
Initialize weights to identity matrix if no skip connection
|
||||
"""
|
||||
with torch.no_grad():
|
||||
for i in range(self.layer.shape[0]):
|
||||
try:
|
||||
nn.init.eye_(self.layer[i])
|
||||
except RuntimeError:
|
||||
with torch.no_grad():
|
||||
dtype = self.layer[i].dtype
|
||||
weight = torch.eye(
|
||||
*self.layer[i].shape,
|
||||
requires_grad=self.layer[i].requires_grad,
|
||||
device=self.layer[i].device,
|
||||
)
|
||||
self.layer[i] = weight.to(dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
||||
"""
|
||||
_x = torch.einsum("hdf,bhld->bhlf", self.layer, x) + self.bias
|
||||
return x + _x if self.skip_connection else _x
|
||||
|
||||
|
||||
class FeatureMapAdapter(FeatureMapMLP):
|
||||
"""
|
||||
Learnable Feature map with bottleneck adapter
|
||||
as in https://arxiv.org/abs/1902.00751
|
||||
|
||||
We don't use but could be fun to try
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim: int, *args, **kwargs):
|
||||
kwargs["skip_connection"] = True
|
||||
kwargs["bias"] = True
|
||||
kwargs["zero_init"] = True
|
||||
self.hidden_dim = hidden_dim
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def init_weights_(self):
|
||||
"""
|
||||
Initialize (W)eights and (b)iases
|
||||
"""
|
||||
kwargs = {"dtype": self.dtype, "device": self.device}
|
||||
self.layer0 = nn.Parameter(
|
||||
torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs)
|
||||
)
|
||||
self.layer1 = nn.Parameter(
|
||||
torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs)
|
||||
)
|
||||
nn.init.kaiming_uniform_(self.layer0)
|
||||
nn.init.kaiming_uniform_(self.layer1)
|
||||
|
||||
self.bias0 = nn.Parameter(
|
||||
torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs)
|
||||
)
|
||||
self.bias1 = nn.Parameter(
|
||||
torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs)
|
||||
)
|
||||
nn.init.kaiming_uniform_(self.bias0)
|
||||
nn.init.kaiming_uniform_(self.bias1)
|
||||
|
||||
def zero_init_with_skip_(self):
|
||||
with torch.no_grad():
|
||||
nn.init.zeros_(self.layer0)
|
||||
nn.init.zeros_(self.layer1)
|
||||
nn.init.zeros_(self.bias0)
|
||||
nn.init.zeros_(self.bias1)
|
||||
|
||||
def zero_init_(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
||||
-> Down-project, apply nonlinearity, up-project; add skip connection
|
||||
"""
|
||||
_x = torch.einsum("hde,bhld->bhle", self.layer0, x) + self.bias0
|
||||
_x = F.relu(_x)
|
||||
_x = torch.einsum("hef,bhle->bhlf", self.layer1, _x) + self.bias1
|
||||
return x + _x if self.skip_connection else _x
|
||||
@@ -1,460 +0,0 @@
|
||||
"""
|
||||
Subquadratic attention combining sliding window and linear attentions
|
||||
- Using "standard" sliding windows
|
||||
- Didactically computes outputs with n^2 attention weights for now
|
||||
- Copied + adapted from linear_window_attention_tk.py for single-file reference
|
||||
|
||||
For each layer:
|
||||
- We first compute (softmax) attention over sliding windows
|
||||
- We then compute standard linear attention to "fill in" the earlier parts
|
||||
- We combine to model the entire sequence
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from .linear_attention import (
|
||||
LinearAttentionState,
|
||||
LolcatsLinearAttention,
|
||||
softmax_attention,
|
||||
)
|
||||
|
||||
|
||||
# ----------------------
|
||||
# Sliding window helpers
|
||||
# ----------------------
|
||||
def get_masks(
|
||||
window_size: int, q_len: int, k_len: int, device: torch.device
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return masks for softmax and linear attention terms
|
||||
-> 1 is include, 0 is ignore
|
||||
"""
|
||||
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||
k_len - q_len
|
||||
)
|
||||
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||
k_len - q_len - window_size
|
||||
)
|
||||
window_mask = causal_mask - linear_mask
|
||||
# Return softmax mask (window), linear attention mask
|
||||
# -> shapes broadcast over (b, h, q_len, k_len)
|
||||
return window_mask[None, None, ...], linear_mask[None, None, ...]
|
||||
|
||||
|
||||
def hybrid_attention_quadratic(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
f_q: torch.Tensor,
|
||||
f_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_factor: torch.Tensor,
|
||||
linear_factor: torch.Tensor,
|
||||
window_size: int,
|
||||
kv_state: Optional[torch.Tensor] = None,
|
||||
k_state: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-12,
|
||||
mask_value: float = -1e8,
|
||||
):
|
||||
"""
|
||||
Hybrid attention combining sliding window and linear attentions
|
||||
"""
|
||||
|
||||
mask_window, mask_linear = get_masks(
|
||||
window_size, q.shape[-2], k.shape[-2], q.device
|
||||
)
|
||||
|
||||
# 1. Sliding window (softmax attention)
|
||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
||||
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
||||
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 2. Under window (linear attention)
|
||||
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
||||
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
||||
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 3. Combine
|
||||
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
||||
# Allow outputs to also depend on prior kv_state and k_state
|
||||
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
||||
if (
|
||||
kv_state is not None and k_state is not None
|
||||
): # Combine with prior kv_state and k_state
|
||||
y += linear_factor * torch.einsum(
|
||||
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||
)
|
||||
sum_ln += (
|
||||
linear_factor
|
||||
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
||||
)
|
||||
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
||||
return y, a # attention weights only for the last chunk
|
||||
|
||||
|
||||
# ---------------------
|
||||
# Attention layer class
|
||||
# ---------------------
|
||||
class LolcatsSlidingWindowAttention(LolcatsLinearAttention):
|
||||
"""
|
||||
Lolcats attention combining sliding window and linear attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
window_size: int = 64,
|
||||
decode_window_size: Optional[int] = None,
|
||||
affine_attention_factors: bool = False,
|
||||
init_window_factor: float = 0,
|
||||
train_window_factor: bool = True,
|
||||
state_grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.window_size = window_size
|
||||
self.decode_window_size = (
|
||||
decode_window_size if decode_window_size is not None else window_size
|
||||
)
|
||||
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||
super().__init__(**kwargs)
|
||||
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_sw'
|
||||
# Determine how we compute attentions
|
||||
self.quadratic_attention = hybrid_attention_quadratic
|
||||
self.attention_type = kwargs[
|
||||
"attention_type"
|
||||
] # 'hedgehog_long_llama_window_sw'
|
||||
# Learnable factor for combining attentions
|
||||
self.affine_attention_factors = affine_attention_factors
|
||||
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
||||
if train_window_factor:
|
||||
self.window_factors = nn.Parameter(
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"window_factors",
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
||||
)
|
||||
# Whether we use original flash attention 2 inference (use during attention transfer)
|
||||
self.base_inference = False
|
||||
self.state_grad_enabled = state_grad_enabled
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass with the option to compute attention weights multiple ways
|
||||
if self.train_attention is True
|
||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_ids, past_key_value
|
||||
)
|
||||
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
||||
k
|
||||
) # Have to do after repeat for grouped-query attn if we use same fmap
|
||||
|
||||
if self.train_attention:
|
||||
# 1. Compute "ground-truth" attention output and weights
|
||||
with torch.no_grad():
|
||||
_y_true, a_true = softmax_attention(q, k, v)[:2]
|
||||
y_true = (
|
||||
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
)
|
||||
y_true = self.o_proj(y_true)
|
||||
|
||||
# 2. Compute "predicted" attention outputs
|
||||
# compute attn weights under sliding window
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||
y_pred, a_pred = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
|
||||
else:
|
||||
attn_weights = None
|
||||
# attention_mask = None # For now this is always True
|
||||
if past_key_value is None: # Regular training
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, a_pred = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
attn_weights = a_pred
|
||||
else:
|
||||
past_key_value.window_size = self.decode_window_size
|
||||
if (
|
||||
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
||||
): # Generating
|
||||
assert use_cache is True
|
||||
_kv = past_key_value.update_for_decoding(
|
||||
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||
)
|
||||
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||
|
||||
# Sliding window + linear attention decode
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
|
||||
# Softmax attention terms
|
||||
a_sm = torch.einsum(
|
||||
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
||||
) * (k.shape[-1] ** -0.5)
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Combine with linear attention terms
|
||||
y_true = torch.einsum(
|
||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||
) + linear_factors * torch.einsum(
|
||||
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||
)
|
||||
sum_ln = (
|
||||
linear_factors
|
||||
* torch.einsum(
|
||||
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
|
||||
)[..., None]
|
||||
)
|
||||
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||
|
||||
else: # Stateful training
|
||||
try:
|
||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||
k_state = past_key_value.k_states[self.layer_idx]
|
||||
except IndexError:
|
||||
kv_state, k_state = None, None
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, _ = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
kv_state=kv_state,
|
||||
k_state=k_state,
|
||||
)
|
||||
# Save and update KV cache and states
|
||||
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||
# fmap_key_states=f_k.detach(),
|
||||
# accumulate_in_fp32=True)
|
||||
past_key_value.update(
|
||||
k,
|
||||
v,
|
||||
self.layer_idx,
|
||||
fmap_key_states=f_k,
|
||||
accumulate_in_fp32=True,
|
||||
)
|
||||
# Concatenate heads and apply output projection
|
||||
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
y_true = self.o_proj(y_true)
|
||||
return y_true, attn_weights, past_key_value
|
||||
|
||||
|
||||
class LinearAttentionSlidingWindowCache(LinearAttentionState):
|
||||
"""
|
||||
Class for `past_key_values`
|
||||
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
||||
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||
"""
|
||||
|
||||
def __init__(self, window_size: int = 64) -> None:
|
||||
super().__init__()
|
||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||
self._seen_tokens_by_layer: List[int] = []
|
||||
self.kv_states: List[torch.Tensor] = []
|
||||
self.k_states: List[torch.Tensor] = []
|
||||
|
||||
# Account for sliding windows
|
||||
self.decode_kv_states: List[torch.Tensor] = []
|
||||
self.decode_k_states: List[torch.Tensor] = []
|
||||
self.k_cache: List[torch.Tensor] = []
|
||||
self.v_cache: List[torch.Tensor] = []
|
||||
self.window_size = window_size
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: Optional[int] = None,
|
||||
cache_kwargs: Optional[Any] = None,
|
||||
accumulate_in_fp32: bool = False,
|
||||
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
||||
grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Update KV, K states; and KV cache during training
|
||||
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
||||
up to sliding window terms
|
||||
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
||||
up to end of sequence
|
||||
- Likewise for `self.decode_k_states` and `self.k_states`
|
||||
"""
|
||||
if fmap_key_states is None:
|
||||
raise ValueError("fmap_key_states must not be None")
|
||||
|
||||
if layer_idx is None:
|
||||
raise ValueError("Layer index must not be None")
|
||||
|
||||
with torch.set_grad_enabled(grad_enabled):
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
|
||||
dtype = key_states.dtype
|
||||
if accumulate_in_fp32:
|
||||
# key_states = key_states.float()
|
||||
fmap_key_states = fmap_key_states.float()
|
||||
value_states = value_states.float()
|
||||
|
||||
# Decoding KV state (KV terms up to last window_size)
|
||||
decode_kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, : -self.window_size],
|
||||
value_states[:, :, : -self.window_size],
|
||||
)
|
||||
# KV state
|
||||
kv_state = decode_kv_state + torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, -self.window_size :],
|
||||
value_states[:, :, -self.window_size :],
|
||||
)
|
||||
# shape is b, h, 1, f; note the 1
|
||||
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||
self.kv_states.append(kv_state.to(dtype))
|
||||
self.k_states.append(k_state.to(dtype))
|
||||
|
||||
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
||||
self.decode_k_states.append(decode_k_state.to(dtype))
|
||||
|
||||
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
||||
self.v_cache.append(
|
||||
value_states[:, :, -self.window_size :, :].to(dtype)
|
||||
)
|
||||
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
||||
else:
|
||||
# Update kv and k states recurrently
|
||||
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||
dtype
|
||||
)
|
||||
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||
dtype
|
||||
)
|
||||
self.kv_states[layer_idx] = kv_state
|
||||
self.k_states[layer_idx] = k_state
|
||||
|
||||
decode_kv_state = (
|
||||
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
||||
+ decode_kv_state
|
||||
).to(dtype)
|
||||
decode_k_state = (
|
||||
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
||||
).to(dtype)
|
||||
self.decode_kv_states[layer_idx] = decode_kv_state
|
||||
self.decode_k_states[layer_idx] = decode_k_state
|
||||
|
||||
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
||||
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
||||
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||
|
||||
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||
|
||||
def update_for_decoding(
|
||||
self,
|
||||
keys: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
layer_idx: int,
|
||||
feature_map_k: Callable,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""
|
||||
Update the decoding KV and K states, and KV cache, during decodeing
|
||||
"""
|
||||
with torch.no_grad():
|
||||
k_cache = self.k_cache[layer_idx]
|
||||
v_cache = self.v_cache[layer_idx]
|
||||
|
||||
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
||||
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
||||
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
||||
else:
|
||||
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
|
||||
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
|
||||
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
|
||||
# else:
|
||||
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
|
||||
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
v_state = v_cache[:, :, :1, :]
|
||||
kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
||||
).to(
|
||||
dtype
|
||||
) # b, h, f, d
|
||||
self.decode_kv_states[layer_idx] += kv_state
|
||||
self.decode_k_states[layer_idx] += k_state
|
||||
|
||||
self.k_cache[layer_idx] = torch.cat(
|
||||
[k_cache[:, :, 1:, :], keys], dim=-2
|
||||
)
|
||||
self.v_cache[layer_idx] = torch.cat(
|
||||
[v_cache[:, :, 1:, :], values], dim=-2
|
||||
)
|
||||
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += keys.shape[-2]
|
||||
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
||||
return (
|
||||
self.k_cache[layer_idx],
|
||||
self.v_cache[layer_idx],
|
||||
self.decode_kv_states[layer_idx],
|
||||
self.decode_k_states[layer_idx],
|
||||
)
|
||||
@@ -1,685 +0,0 @@
|
||||
"""
|
||||
Subquadratic attention combining sliding window and linear attentions
|
||||
- Using "standard" sliding windows
|
||||
- Didactically computes outputs with n^2 attention weights for now
|
||||
- Copied + adapted from linear_window_attention_tk.py for single-file reference
|
||||
|
||||
For each layer:
|
||||
- We first compute (softmax) attention over sliding windows
|
||||
- We then compute standard linear attention to "fill in" the earlier parts
|
||||
- We combine to model the entire sequence
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
except ModuleNotFoundError:
|
||||
_flash_attention_forward = None # Transformers v4.36
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
# Causal linear attention dot product CUDA kernel from fast-transformers
|
||||
from .linear_attention import (
|
||||
LinearAttentionState,
|
||||
LolcatsLinearAttention,
|
||||
causal_dot_product,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ----------------------
|
||||
# Sliding window helpers
|
||||
# ----------------------
|
||||
def get_masks(
|
||||
window_size: int, q_len: int, k_len: int, device: torch.device
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return masks for softmax and linear attention terms
|
||||
-> 1 is include, 0 is ignore
|
||||
"""
|
||||
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||
max(k_len - q_len, 0)
|
||||
)
|
||||
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||
max(k_len - q_len, 0) - window_size
|
||||
)
|
||||
window_mask = causal_mask - linear_mask
|
||||
# Return softmax mask (window), linear attention mask
|
||||
# -> shapes broadcast over (b, h, q_len, k_len)
|
||||
return window_mask[None, None, ...], linear_mask[None, None, ...]
|
||||
|
||||
|
||||
def hybrid_attention_quadratic(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
f_q: torch.Tensor,
|
||||
f_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_factor: torch.Tensor,
|
||||
linear_factor: torch.Tensor,
|
||||
window_size: int,
|
||||
kv_state: Optional[torch.Tensor] = None,
|
||||
k_state: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-12,
|
||||
mask_value: float = -1e8,
|
||||
):
|
||||
"""
|
||||
Hybrid attention combining sliding window and linear attentions
|
||||
"""
|
||||
|
||||
mask_window, mask_linear = get_masks(
|
||||
window_size, q.shape[-2], k.shape[-2], q.device
|
||||
)
|
||||
|
||||
# 1. Sliding window (softmax attention)
|
||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
||||
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
||||
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 2. Under window (linear attention)
|
||||
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
||||
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
||||
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 3. Combine
|
||||
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
||||
# Allow outputs to also depend on prior kv_state and k_state
|
||||
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
||||
if (
|
||||
kv_state is not None and k_state is not None
|
||||
): # Combine with prior kv_state and k_state
|
||||
y += linear_factor * torch.einsum(
|
||||
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||
)
|
||||
sum_ln += (
|
||||
linear_factor
|
||||
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
||||
)
|
||||
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
||||
return y, a # attention weights only for the last chunk
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Hybrid window attention linear
|
||||
# ------------------------------
|
||||
def under_window_linear_attention(
|
||||
f_q: torch.Tensor,
|
||||
f_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_size: int,
|
||||
linear_factor: torch.Tensor,
|
||||
eps: float = 1e-12,
|
||||
):
|
||||
"""Compute hybrid window attention dot product with linear complexity in q_len"""
|
||||
dtype = f_q.dtype
|
||||
w = window_size
|
||||
f_k = F.pad(f_k, (0, 0, w, 0), value=0)[:, :, :-w, :]
|
||||
v = F.pad(v, (0, 0, w, 0), value=0)[:, :, :-w, :]
|
||||
qkv = linear_factor * causal_dot_product(
|
||||
f_q.contiguous().to(dtype=torch.float32),
|
||||
f_k.contiguous().to(dtype=torch.float32),
|
||||
v.contiguous().to(dtype=torch.float32),
|
||||
).to(dtype=dtype)
|
||||
sum_f_k = f_k.float().cumsum(dim=2).to(dtype=dtype)
|
||||
sum_qk = linear_factor * torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None]
|
||||
sum_qk[sum_qk == 0] += eps
|
||||
return qkv, sum_qk
|
||||
|
||||
|
||||
def sliding_window_softmax_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_size: int,
|
||||
window_factor: torch.Tensor,
|
||||
mask_value: float = -1e8,
|
||||
):
|
||||
"""
|
||||
Compute sliding window softmax attention without materializing
|
||||
O(seq_len^2) attention weights
|
||||
"""
|
||||
d = q.shape[-1]
|
||||
# Compute windows for keys
|
||||
window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||
k = F.pad(k, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
|
||||
v = F.pad(v, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
|
||||
|
||||
# Compute windowed_softmax(qk); causal in its construction
|
||||
a_sm = torch.einsum("bhld,bhldw->bhlw", q, k) * (d**-0.5)
|
||||
a_sm[a_sm == 0] = -torch.finfo(
|
||||
q.dtype
|
||||
).max # heuristic for zeroing out padding above
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
return torch.einsum("bhlw,bhldw->bhld", a_sm, v), sum_sm
|
||||
# return torch.einsum('bhlw,bhldw->bhld', torch.softmax(qk, dim=-1), v)
|
||||
|
||||
|
||||
def hybrid_attention_linear(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
f_q: torch.Tensor,
|
||||
f_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_factor: Optional[torch.Tensor] = None,
|
||||
linear_factor: Optional[torch.Tensor] = None,
|
||||
window_size: int = 64,
|
||||
kv_state: Optional[torch.Tensor] = None,
|
||||
k_state: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-12,
|
||||
mask_value: float = -1e8,
|
||||
):
|
||||
"""
|
||||
Alternative hybrid attention combining sliding window and linear attentions
|
||||
-> Uses O(n) memory if n is sequence length by padding and unfolding windows
|
||||
"""
|
||||
# window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||
if window_factor is None:
|
||||
raise ValueError("window_factor must be provided")
|
||||
|
||||
if linear_factor is None:
|
||||
raise ValueError("linear_factor must be provided")
|
||||
|
||||
# 1. Sliding window (softmax attention)
|
||||
with torch.no_grad():
|
||||
qkv_sm, sum_qk_sm = sliding_window_softmax_attention(
|
||||
q, k, v, window_size, window_factor, mask_value
|
||||
)
|
||||
|
||||
# 2. Under window (linear attention)
|
||||
qkv_ln, sum_qk_ln = under_window_linear_attention(
|
||||
f_q, f_k, v, window_size, linear_factor, eps
|
||||
)
|
||||
|
||||
# 3. Combine
|
||||
y = (qkv_sm + qkv_ln) / (sum_qk_sm + sum_qk_ln)
|
||||
return y, None
|
||||
|
||||
|
||||
# ---------------------
|
||||
# Attention layer class
|
||||
# ---------------------
|
||||
class LolcatsLinearSlidingWindowAttention(LolcatsLinearAttention):
|
||||
"""
|
||||
Lolcats attention combining sliding window and linear attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
window_size: int = 64,
|
||||
decode_window_size: Optional[int] = None,
|
||||
affine_attention_factors: bool = False,
|
||||
init_window_factor: float = 0,
|
||||
train_window_factor: bool = True,
|
||||
state_grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.window_size = window_size
|
||||
self.decode_window_size = (
|
||||
decode_window_size if decode_window_size is not None else window_size
|
||||
)
|
||||
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||
super().__init__(**kwargs)
|
||||
# Determine how we compute attentions
|
||||
self.linear_attention = hybrid_attention_linear
|
||||
self.attention_type = "lolcats_llama_window_sw"
|
||||
# Learnable factor for combining attentions
|
||||
self.affine_attention_factors = affine_attention_factors
|
||||
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
||||
if train_window_factor:
|
||||
self.window_factors = nn.Parameter(
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"window_factors",
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
||||
)
|
||||
# Whether we use original flash attention 2 inference (use during attention transfer)
|
||||
self.base_inference = False
|
||||
self.state_grad_enabled = state_grad_enabled
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass with the option to compute attention weights multiple ways
|
||||
if self.train_attention is True
|
||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
|
||||
if self.train_attention and self.base_inference:
|
||||
with torch.no_grad():
|
||||
_y_true = flash_attention_2(
|
||||
self, # self.base_attn,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
)[0]
|
||||
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
|
||||
y_true = _y_true.reshape(b, l, -1).contiguous()
|
||||
y_true = self.o_proj(y_true)
|
||||
# layer_io = (hidden_states, _y_true) # hack
|
||||
layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
|
||||
return y_true, layer_io, None
|
||||
|
||||
else:
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_ids, past_key_value
|
||||
)
|
||||
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
||||
k
|
||||
) # Have to do after repeat for grouped-query attn if we use same fmap
|
||||
|
||||
attn_weights = None
|
||||
# attention_mask = None # For now this is always True
|
||||
if past_key_value is None: # Regular training
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, a_pred = self.linear_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
attn_weights = a_pred
|
||||
else:
|
||||
past_key_value.window_size = self.decode_window_size
|
||||
if (
|
||||
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
||||
): # Generating
|
||||
assert use_cache is True
|
||||
_kv = past_key_value.update_for_decoding(
|
||||
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||
)
|
||||
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||
|
||||
# Sliding window + linear attention decode
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
|
||||
# Softmax attention terms
|
||||
a_sm = torch.einsum(
|
||||
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
||||
) * (k.shape[-1] ** -0.5)
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Combine with linear attention terms
|
||||
y_true = torch.einsum(
|
||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||
) + linear_factors * torch.einsum(
|
||||
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||
)
|
||||
sum_ln = (
|
||||
linear_factors
|
||||
* torch.einsum(
|
||||
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
|
||||
)[..., None]
|
||||
)
|
||||
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||
|
||||
else: # Stateful training
|
||||
try:
|
||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||
k_state = past_key_value.k_states[self.layer_idx]
|
||||
except IndexError:
|
||||
kv_state, k_state = None, None
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, _ = self.linear_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
kv_state=kv_state,
|
||||
k_state=k_state,
|
||||
)
|
||||
# Save and update KV cache and states
|
||||
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||
# fmap_key_states=f_k.detach(),
|
||||
# accumulate_in_fp32=True)
|
||||
past_key_value.update(
|
||||
k,
|
||||
v,
|
||||
self.layer_idx,
|
||||
fmap_key_states=f_k,
|
||||
accumulate_in_fp32=True,
|
||||
)
|
||||
# Concatenate heads and apply output projection
|
||||
_y_true = y_true.transpose(1, 2).contiguous()
|
||||
y_true = self.o_proj(_y_true.view(b, l, self.hidden_size))
|
||||
|
||||
if self.train_attention:
|
||||
attn_weights = _y_true # flash_attn outputs are shape (b, l, h, d)
|
||||
return y_true, attn_weights, past_key_value
|
||||
|
||||
|
||||
class LinearAttentionSlidingWindowCache(LinearAttentionState):
|
||||
"""
|
||||
Class for `past_key_values`
|
||||
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
||||
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||
"""
|
||||
|
||||
def __init__(self, window_size: int = 64) -> None:
|
||||
super().__init__()
|
||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||
self._seen_tokens_by_layer: List[int] = []
|
||||
self.kv_states: List[torch.Tensor] = []
|
||||
self.k_states: List[torch.Tensor] = []
|
||||
|
||||
# Account for sliding windows
|
||||
self.decode_kv_states: List[torch.Tensor] = []
|
||||
self.decode_k_states: List[torch.Tensor] = []
|
||||
self.k_cache: List[torch.Tensor] = []
|
||||
self.v_cache: List[torch.Tensor] = []
|
||||
self.window_size = window_size
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: Optional[int] = None,
|
||||
cache_kwargs: Optional[Any] = None,
|
||||
accumulate_in_fp32: bool = False,
|
||||
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
||||
grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Update KV, K states; and KV cache during training
|
||||
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
||||
up to sliding window terms
|
||||
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
||||
up to end of sequence
|
||||
- Likewise for `self.decode_k_states` and `self.k_states`
|
||||
"""
|
||||
if fmap_key_states is None:
|
||||
raise ValueError("fmap_key_states must not be None")
|
||||
|
||||
if layer_idx is None:
|
||||
raise ValueError("Layer index must not be None")
|
||||
|
||||
with torch.set_grad_enabled(grad_enabled):
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
|
||||
dtype = key_states.dtype
|
||||
if accumulate_in_fp32:
|
||||
# key_states = key_states.float()
|
||||
fmap_key_states = fmap_key_states.float()
|
||||
value_states = value_states.float()
|
||||
|
||||
# Decoding KV state (KV terms up to last window_size)
|
||||
decode_kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, : -self.window_size],
|
||||
value_states[:, :, : -self.window_size],
|
||||
)
|
||||
# KV state
|
||||
kv_state = decode_kv_state + torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, -self.window_size :],
|
||||
value_states[:, :, -self.window_size :],
|
||||
)
|
||||
# shape is b, h, 1, f; note the 1
|
||||
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||
self.kv_states.append(kv_state.to(dtype))
|
||||
self.k_states.append(k_state.to(dtype))
|
||||
|
||||
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
||||
self.decode_k_states.append(decode_k_state.to(dtype))
|
||||
|
||||
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
||||
self.v_cache.append(
|
||||
value_states[:, :, -self.window_size :, :].to(dtype)
|
||||
)
|
||||
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
||||
else:
|
||||
# Update kv and k states recurrently
|
||||
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||
dtype
|
||||
)
|
||||
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||
dtype
|
||||
)
|
||||
self.kv_states[layer_idx] = kv_state
|
||||
self.k_states[layer_idx] = k_state
|
||||
|
||||
decode_kv_state = (
|
||||
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
||||
+ decode_kv_state
|
||||
).to(dtype)
|
||||
decode_k_state = (
|
||||
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
||||
).to(dtype)
|
||||
self.decode_kv_states[layer_idx] = decode_kv_state
|
||||
self.decode_k_states[layer_idx] = decode_k_state
|
||||
|
||||
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
||||
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
||||
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||
|
||||
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||
|
||||
def update_for_decoding(
|
||||
self,
|
||||
keys: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
layer_idx: int,
|
||||
feature_map_k: Callable,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""
|
||||
Update the decoding KV and K states, and KV cache, during decodeing
|
||||
"""
|
||||
with torch.no_grad():
|
||||
k_cache = self.k_cache[layer_idx]
|
||||
v_cache = self.v_cache[layer_idx]
|
||||
|
||||
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
||||
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
||||
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
||||
else:
|
||||
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
|
||||
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
|
||||
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
|
||||
# else:
|
||||
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
|
||||
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
v_state = v_cache[:, :, :1, :]
|
||||
kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
||||
).to(
|
||||
dtype
|
||||
) # b, h, f, d
|
||||
self.decode_kv_states[layer_idx] += kv_state
|
||||
self.decode_k_states[layer_idx] += k_state
|
||||
|
||||
self.k_cache[layer_idx] = torch.cat(
|
||||
[k_cache[:, :, 1:, :], keys], dim=-2
|
||||
)
|
||||
self.v_cache[layer_idx] = torch.cat(
|
||||
[v_cache[:, :, 1:, :], values], dim=-2
|
||||
)
|
||||
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += keys.shape[-2]
|
||||
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
||||
return (
|
||||
self.k_cache[layer_idx],
|
||||
self.v_cache[layer_idx],
|
||||
self.decode_kv_states[layer_idx],
|
||||
self.decode_k_states[layer_idx],
|
||||
)
|
||||
|
||||
|
||||
# -----------------
|
||||
# Flash Attention 2
|
||||
# -----------------
|
||||
|
||||
|
||||
def flash_attention_2(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
"""
|
||||
Wrapper for LlamaFlashAttention2
|
||||
Copied and modified from HF Transformers v4.36 and v4.43 implementations
|
||||
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
|
||||
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
|
||||
"""
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
try: # As in Transformers v4.36
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
except Exception: # As in Transformers v4.39
|
||||
cos, sin = self.rotary_emb(key_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
LOG.debug(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
if getattr(self, "_flash_attention_forward", False):
|
||||
attn_output = self._flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=0, # dropout_rate,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
return attn_output, past_key_value
|
||||
@@ -1,24 +0,0 @@
|
||||
"""
|
||||
LoLCATs attention combining sliding window and linear attentions
|
||||
- Using standard sliding window arrangement
|
||||
- Training over long sequences with fixed memory with recurrent view
|
||||
- During attention transfer, use Flash Attention to compute softmax attention outputs
|
||||
|
||||
For each layer:
|
||||
- We first compute (softmax) attention over sliding windows
|
||||
- We then compute standard linear attention to "fill in" the earlier parts
|
||||
- We combine to model the entire sequence
|
||||
"""
|
||||
from .linear_window_attention_sw import hybrid_attention_quadratic
|
||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||
|
||||
|
||||
class LolcatsSlidingWindowLongAttention(LolcatsTKWindowLongAttention):
|
||||
"""
|
||||
Lolcats attention combining sliding window and linear attention
|
||||
"""
|
||||
|
||||
def __init__(self, remove_base_attn=True, **kwargs):
|
||||
# keep self.base_attn for Flash Attention inference
|
||||
super().__init__(remove_base_attn=True, **kwargs)
|
||||
self.quadratic_attention = hybrid_attention_quadratic
|
||||
@@ -1,466 +0,0 @@
|
||||
"""
|
||||
Subquadratic attention combining sliding window and linear attentions
|
||||
- Using the TK "terracing" arrangement
|
||||
|
||||
For each layer:
|
||||
- We first compute (softmax) attention over sliding windows
|
||||
- We then compute standard linear attention to "fill in" the earlier parts
|
||||
- We combine to model the entire sequence
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from .linear_attention import (
|
||||
LinearAttentionState,
|
||||
LolcatsLinearAttention,
|
||||
softmax_attention,
|
||||
)
|
||||
|
||||
|
||||
# ----------------------
|
||||
# Sliding window helpers
|
||||
# ----------------------
|
||||
def get_masks(
|
||||
window_size: int, q_len: int, k_len: int, device: torch.device
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return masks for softmax and linear attention terms
|
||||
-> 1 is include, 0 is ignore
|
||||
"""
|
||||
win_len = window_size
|
||||
m = math.ceil(max(q_len, k_len) / window_size)
|
||||
# Creates an n x n mask where n = window_size^2
|
||||
mask = torch.block_diag(
|
||||
*[
|
||||
torch.ones(
|
||||
(win_len, win_len),
|
||||
)
|
||||
]
|
||||
* m
|
||||
)
|
||||
mask += torch.roll(mask, -win_len, -1) # this adds the terracing
|
||||
if mask.shape[0] > q_len:
|
||||
mask = mask[-q_len:]
|
||||
if mask.shape[1] > k_len:
|
||||
mask = mask[:, -k_len:]
|
||||
# Return softmax mask (window), linear attention mask
|
||||
mask = mask[None, None, ...] # b, h, q_len, k_len
|
||||
return (
|
||||
torch.tril(mask).to(device=device, dtype=torch.int),
|
||||
torch.tril(1 - mask).to(device=device, dtype=torch.int),
|
||||
)
|
||||
|
||||
|
||||
def hybrid_attention_quadratic(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
f_q: torch.Tensor,
|
||||
f_k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
window_factor: torch.Tensor,
|
||||
linear_factor: torch.Tensor,
|
||||
window_size: int,
|
||||
kv_state: Optional[torch.Tensor] = None,
|
||||
k_state: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-12,
|
||||
mask_value: float = -1e8,
|
||||
):
|
||||
"""
|
||||
Hybrid attention combining sliding window and linear attentions
|
||||
"""
|
||||
|
||||
mask_window, mask_linear = get_masks(
|
||||
window_size, q.shape[-2], k.shape[-2], q.device
|
||||
)
|
||||
|
||||
# 1. Sliding window (softmax attention)
|
||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
||||
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
||||
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 2. Under window (linear attention)
|
||||
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
||||
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
||||
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
||||
|
||||
# 3. Combine
|
||||
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
||||
# Allow outputs to also depend on prior kv_state and k_state
|
||||
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
||||
if (
|
||||
kv_state is not None and k_state is not None
|
||||
): # Combine with prior kv_state and k_state
|
||||
y += linear_factor * torch.einsum(
|
||||
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||
)
|
||||
sum_ln += (
|
||||
linear_factor
|
||||
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
||||
)
|
||||
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
||||
return y, a # attention weights only for the last chunk
|
||||
|
||||
|
||||
# ---------------------
|
||||
# Attention layer class
|
||||
# ---------------------
|
||||
class LolcatsTKWindowAttention(LolcatsLinearAttention):
|
||||
"""
|
||||
Lolcats attention combining sliding window and linear attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
window_size: int = 64,
|
||||
decode_window_size: Optional[int] = None,
|
||||
affine_attention_factors: bool = False,
|
||||
init_window_factor: float = 0,
|
||||
train_window_factor: bool = True,
|
||||
state_grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.window_size = window_size
|
||||
self.decode_window_size = (
|
||||
decode_window_size if decode_window_size is not None else window_size
|
||||
)
|
||||
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||
super().__init__(**kwargs)
|
||||
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_tk'
|
||||
# Determine how we compute attentions
|
||||
self.quadratic_attention = hybrid_attention_quadratic
|
||||
self.attention_type = kwargs[
|
||||
"attention_type"
|
||||
] # 'hedgehog_long_llama_window_tk'
|
||||
# Learnable factor for combining attentions
|
||||
self.affine_attention_factors = affine_attention_factors
|
||||
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
||||
if train_window_factor:
|
||||
self.window_factors = nn.Parameter(
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"window_factors",
|
||||
init_window_factor
|
||||
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
||||
)
|
||||
# Whether we use original flash attention 2 inference (use during attention transfer)
|
||||
self.base_inference = False
|
||||
self.state_grad_enabled = state_grad_enabled
|
||||
self.window_factor = self.window_factors # legacy naming support
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass with the option to compute attention weights multiple ways
|
||||
if self.train_attention is True
|
||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_ids, past_key_value
|
||||
)
|
||||
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
||||
k
|
||||
) # Have to do after repeat for grouped-query attn if we use same fmap
|
||||
|
||||
if self.train_attention:
|
||||
# 1. Compute "ground-truth" attention output and weights
|
||||
with torch.no_grad():
|
||||
_y_true, a_true = softmax_attention(q, k, v)[:2]
|
||||
y_true = (
|
||||
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
)
|
||||
y_true = self.o_proj(y_true)
|
||||
|
||||
# 2. Compute "predicted" attention outputs
|
||||
# compute attn weights under sliding window
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||
y_pred, a_pred = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
|
||||
else:
|
||||
attn_weights = None
|
||||
# attention_mask = None # For now this is always True
|
||||
if past_key_value is None: # Regular training
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, a_pred = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
attn_weights = a_pred
|
||||
else:
|
||||
past_key_value.window_size = self.decode_window_size
|
||||
if (
|
||||
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
||||
): # Generating
|
||||
assert use_cache is True
|
||||
_kv = past_key_value.update_for_decoding(
|
||||
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||
)
|
||||
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||
|
||||
# Sliding window + linear attention decode
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
|
||||
# Softmax attention terms
|
||||
a_sm = torch.einsum(
|
||||
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
||||
) * (k.shape[-1] ** -0.5)
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Combine with linear attention terms
|
||||
y_true = torch.einsum(
|
||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||
) + linear_factors * torch.einsum(
|
||||
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||
)
|
||||
sum_ln = (
|
||||
linear_factors
|
||||
* torch.einsum(
|
||||
"bhld,bhnd->bhl", f_q.float(), f_k_state.float()
|
||||
)[..., None]
|
||||
)
|
||||
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||
|
||||
else: # Stateful training
|
||||
try:
|
||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||
k_state = past_key_value.k_states[self.layer_idx]
|
||||
except IndexError:
|
||||
kv_state, k_state = None, None
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_true, _ = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
kv_state=kv_state,
|
||||
k_state=k_state,
|
||||
)
|
||||
# Save and update KV cache and states
|
||||
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||
# fmap_key_states=f_k.detach(),
|
||||
# accumulate_in_fp32=True)
|
||||
past_key_value.update(
|
||||
k,
|
||||
v,
|
||||
self.layer_idx,
|
||||
fmap_key_states=f_k,
|
||||
accumulate_in_fp32=True,
|
||||
)
|
||||
# Concatenate heads and apply output projection
|
||||
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
y_true = self.o_proj(y_true)
|
||||
return y_true, attn_weights, past_key_value
|
||||
|
||||
|
||||
class LinearAttentionTKWindowCache(LinearAttentionState):
|
||||
"""
|
||||
Class for `past_key_values`
|
||||
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
||||
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||
"""
|
||||
|
||||
def __init__(self, window_size: int = 64) -> None:
|
||||
super().__init__()
|
||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||
self._seen_tokens_by_layer: List[int] = []
|
||||
self.kv_states: List[torch.Tensor] = []
|
||||
self.k_states: List[torch.Tensor] = []
|
||||
|
||||
# Account for sliding windows
|
||||
self.decode_kv_states: List[torch.Tensor] = []
|
||||
self.decode_k_states: List[torch.Tensor] = []
|
||||
self.k_cache: List[torch.Tensor] = []
|
||||
self.v_cache: List[torch.Tensor] = []
|
||||
self.window_size = window_size
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: Optional[int] = None,
|
||||
cache_kwargs: Optional[Any] = None,
|
||||
accumulate_in_fp32: bool = False,
|
||||
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
||||
grad_enabled: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Update KV, K states; and KV cache during training
|
||||
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
||||
up to sliding window terms
|
||||
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
||||
up to end of sequence
|
||||
- Likewise for `self.decode_k_states` and `self.k_states`
|
||||
"""
|
||||
if fmap_key_states is None:
|
||||
raise ValueError("fmap_key_states should not be None")
|
||||
|
||||
if layer_idx is None:
|
||||
raise ValueError("layer_idx should not be None")
|
||||
|
||||
with torch.set_grad_enabled(grad_enabled):
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += key_states.shape[-2]
|
||||
|
||||
dtype = key_states.dtype
|
||||
if accumulate_in_fp32:
|
||||
# key_states = key_states.float()
|
||||
fmap_key_states = fmap_key_states.float()
|
||||
value_states = value_states.float()
|
||||
|
||||
# Decoding KV state (KV terms up to last window_size)
|
||||
decode_kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, : -self.window_size],
|
||||
value_states[:, :, : -self.window_size],
|
||||
)
|
||||
# KV state
|
||||
kv_state = decode_kv_state + torch.einsum(
|
||||
"bhlf,bhld->bhfd",
|
||||
fmap_key_states[:, :, -self.window_size :],
|
||||
value_states[:, :, -self.window_size :],
|
||||
)
|
||||
# shape is b, h, 1, f; note the 1
|
||||
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
||||
dim=-2, keepdim=True
|
||||
)
|
||||
|
||||
# Update the cache
|
||||
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||
self.kv_states.append(kv_state.to(dtype))
|
||||
self.k_states.append(k_state.to(dtype))
|
||||
|
||||
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
||||
self.decode_k_states.append(decode_k_state.to(dtype))
|
||||
|
||||
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
||||
self.v_cache.append(
|
||||
value_states[:, :, -self.window_size :, :].to(dtype)
|
||||
)
|
||||
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
||||
else:
|
||||
# Update kv and k states recurrently
|
||||
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||
dtype
|
||||
)
|
||||
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||
dtype
|
||||
)
|
||||
self.kv_states[layer_idx] = kv_state
|
||||
self.k_states[layer_idx] = k_state
|
||||
|
||||
decode_kv_state = (
|
||||
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
||||
+ decode_kv_state
|
||||
).to(dtype)
|
||||
decode_k_state = (
|
||||
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
||||
).to(dtype)
|
||||
self.decode_kv_states[layer_idx] = decode_kv_state
|
||||
self.decode_k_states[layer_idx] = decode_k_state
|
||||
|
||||
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
||||
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
||||
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||
|
||||
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||
|
||||
def update_for_decoding(
|
||||
self,
|
||||
keys: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
layer_idx: int,
|
||||
feature_map_k: Callable,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""
|
||||
Update the decoding KV and K states, and KV cache, during decodeing
|
||||
"""
|
||||
with torch.no_grad():
|
||||
k_cache = self.k_cache[layer_idx]
|
||||
v_cache = self.v_cache[layer_idx]
|
||||
|
||||
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
||||
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
||||
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
||||
else:
|
||||
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
v_state = v_cache[:, :, :1, :]
|
||||
kv_state = torch.einsum(
|
||||
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
||||
).to(
|
||||
dtype
|
||||
) # b, h, f, d
|
||||
self.decode_kv_states[layer_idx] += kv_state
|
||||
self.decode_k_states[layer_idx] += k_state
|
||||
|
||||
self.k_cache[layer_idx] = torch.cat(
|
||||
[k_cache[:, :, 1:, :], keys], dim=-2
|
||||
)
|
||||
self.v_cache[layer_idx] = torch.cat(
|
||||
[v_cache[:, :, 1:, :], values], dim=-2
|
||||
)
|
||||
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += keys.shape[-2]
|
||||
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
||||
return (
|
||||
self.k_cache[layer_idx],
|
||||
self.v_cache[layer_idx],
|
||||
self.decode_kv_states[layer_idx],
|
||||
self.decode_k_states[layer_idx],
|
||||
)
|
||||
@@ -1,219 +0,0 @@
|
||||
"""
|
||||
LoLCATs + ThunderKittens linear attention + sliding window for generation
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .linear_attention import LinearAttentionState
|
||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from thunderkittens import hedgehog as tk_window_hedgehog_attention
|
||||
|
||||
LOG.debug("Successfully imported ThunderKittens for TK window attention")
|
||||
except ImportError:
|
||||
LOG.debug("Failed to import ThunderKittens for TK window attention")
|
||||
|
||||
|
||||
class LolcatsWindowAttentionTKGen(LolcatsTKWindowLongAttention):
|
||||
def __init__(self, *args, window_size: int = 64, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.train_attention = False
|
||||
self.base_inference = False
|
||||
self.window_size = 64 # hard-coded support for TK kernel
|
||||
self.decode_window_size = 64
|
||||
|
||||
b, h, l, d = 1, 32, 8192, 128
|
||||
self.y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device="cuda")
|
||||
self.kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device="cuda")
|
||||
self.k_state = torch.zeros(b, h, d, dtype=torch.float32, device="cuda")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Any] = None, # “legacy” cache approach
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass with the option to compute attention weights multiple ways
|
||||
if self.train_attention is True
|
||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
assert (
|
||||
past_key_value is not None
|
||||
), "past_key_value must be provided for generation"
|
||||
assert (
|
||||
self.train_attention is False
|
||||
), "train_attention is not supported for generation"
|
||||
assert (
|
||||
self.base_inference is False
|
||||
), "base_inference is not supported for generation"
|
||||
assert use_cache is True, "use_cache must be True for generation"
|
||||
past_key_value.window_size = self.decode_window_size
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_ids, past_key_value
|
||||
)
|
||||
if q.shape[2] == 1 and kv_seq_len > 1: # Generating after prefill
|
||||
f_q = self.feature_map_q(q)
|
||||
_kv = past_key_value.update_for_decoding(
|
||||
k, v, self.layer_idx, self.feature_map_k
|
||||
)
|
||||
k_cache, v_cache, kv_state, k_state = _kv
|
||||
# Sliding window + linear attention decode
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||
|
||||
# Softmax attention terms
|
||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * (
|
||||
k.shape[-1] ** -0.5
|
||||
)
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Combine with linear attention terms
|
||||
y_true = torch.einsum(
|
||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||
) + linear_factors * torch.einsum(
|
||||
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||
)
|
||||
sum_ln = (
|
||||
linear_factors
|
||||
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[
|
||||
..., None
|
||||
]
|
||||
)
|
||||
self.y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||
|
||||
else: # Process prefill
|
||||
# Use TK-implemented linear + terrace window attention
|
||||
b, h, l, d = q.shape
|
||||
device = q.device
|
||||
# tk.hedgehog arguments
|
||||
# y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device=device)
|
||||
# kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device=device)
|
||||
# k_state = torch.zeros(b, h, d, dtype=torch.float32, device=device)
|
||||
betas = F.sigmoid(self.window_factors[0, :, 0, 0].to(dtype=torch.float32))
|
||||
alphas = (
|
||||
1 - betas
|
||||
if self.affine_attention_factors
|
||||
else torch.ones(betas.shape, dtype=torch.float32, device=device)
|
||||
)
|
||||
q_map = self.feature_map_q.mlp.layer
|
||||
k_map = self.feature_map_k.mlp.layer
|
||||
# Saves outputs to y_pred, k_state, kv_state, where we fuse:
|
||||
# 1. f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
|
||||
# 2. y_pred = attention(q, k, f_q, f_k, v) # b, h, l, d
|
||||
# 3. kv_state = torch.einsum(‘bhlf,bhld->bhfd’,
|
||||
# f_k[:, :, :-self.window_size],
|
||||
# v[:, :, :-self.window_size]) # b, h, f, d
|
||||
# 4. k_state = f_k[:, :, :-self.window_size].sum(dim=-2) # b, h, d
|
||||
|
||||
tk_window_hedgehog_attention(
|
||||
q.contiguous(),
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
self.y_true,
|
||||
self.k_state,
|
||||
self.kv_state,
|
||||
q_map,
|
||||
k_map,
|
||||
alphas,
|
||||
betas,
|
||||
)
|
||||
|
||||
past_key_value.update_with_kv(
|
||||
self.kv_state, self.k_state.unsqueeze(-2), k, v, self.layer_idx
|
||||
)
|
||||
|
||||
# Concatenate heads and apply output projection
|
||||
y_true = self.y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||
y_true = self.o_proj(y_true)
|
||||
return y_true, None, past_key_value
|
||||
|
||||
|
||||
class LinearAttentionTKWindowGenerationCache(LinearAttentionState):
|
||||
"""
|
||||
Class for `past_key_values`
|
||||
-> Alternative to KV cache; here we only maintain a “KV state” and “K state”
|
||||
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||
"""
|
||||
|
||||
def __init__(self, window_size: int = 64) -> None:
|
||||
super().__init__()
|
||||
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||
self._seen_tokens_by_layer: List[int] = []
|
||||
self.window_size = window_size
|
||||
|
||||
self.decode_kv_states: List[torch.Tensor] = []
|
||||
self.decode_k_states: List[torch.Tensor] = []
|
||||
self.k_cache: List[torch.Tensor] = []
|
||||
self.v_cache: List[torch.Tensor] = []
|
||||
|
||||
def update_with_kv(
|
||||
self,
|
||||
kv_state: torch.Tensor,
|
||||
k_state: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_idx: int,
|
||||
):
|
||||
"""
|
||||
Update the cache with new KV and K states
|
||||
"""
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += k.shape[2]
|
||||
self._seen_tokens_by_layer.append(k.shape[2])
|
||||
|
||||
# Initialize KV and K states
|
||||
if len(self.decode_k_states) <= layer_idx:
|
||||
self.decode_kv_states.append(kv_state)
|
||||
self.decode_k_states.append(k_state)
|
||||
else: # Update KV and K states
|
||||
self.decode_kv_states[layer_idx] = (
|
||||
self.decode_kv_states[layer_idx] + kv_state
|
||||
)
|
||||
self.decode_k_states[layer_idx] = self.decode_k_states[layer_idx] + k_state
|
||||
|
||||
self.k_cache.append(k[:, :, -self.window_size :, :])
|
||||
self.v_cache.append(v[:, :, -self.window_size :, :])
|
||||
|
||||
def update_for_decoding(
|
||||
self, k: torch.Tensor, v: torch.Tensor, layer_idx: int, feature_map_k: Callable
|
||||
):
|
||||
"""
|
||||
Update the cache for decoding
|
||||
"""
|
||||
k_cache = self.k_cache[layer_idx]
|
||||
v_cache = self.v_cache[layer_idx]
|
||||
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||
v_state = v_cache[:, :, :1, :]
|
||||
kv_state = torch.einsum("bhlf,bhld->bhfd", k_state.float(), v_state.float()).to(
|
||||
k.dtype
|
||||
)
|
||||
|
||||
self.decode_kv_states[layer_idx] += kv_state
|
||||
self.decode_k_states[layer_idx] += k_state
|
||||
|
||||
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], k], dim=-2)
|
||||
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], v], dim=-2)
|
||||
if layer_idx == 0:
|
||||
self._seen_tokens += k.shape[-2]
|
||||
self._seen_tokens_by_layer[layer_idx] += k.shape[-2]
|
||||
return (
|
||||
self.k_cache[layer_idx],
|
||||
self.v_cache[layer_idx],
|
||||
self.decode_kv_states[layer_idx],
|
||||
self.decode_k_states[layer_idx],
|
||||
)
|
||||
@@ -1,306 +0,0 @@
|
||||
"""
|
||||
LoLCATs attention combining sliding window and linear attentions
|
||||
- Using the TK "terracing" arrangement
|
||||
- Training over long sequences with fixed memory with recurrent view
|
||||
- During attention transfer, use Flash Attention to compute softmax attention outputs
|
||||
|
||||
For each layer:
|
||||
- We first compute (softmax) attention over sliding windows
|
||||
- We then compute standard linear attention to "fill in" the earlier parts
|
||||
- We combine to model the entire sequence
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
except ModuleNotFoundError:
|
||||
_flash_attention_forward = None # Transformers v4.36
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
from .linear_attention import softmax_attention
|
||||
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
||||
|
||||
LOG = logging.getLogger(
|
||||
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long"
|
||||
)
|
||||
|
||||
|
||||
class LolcatsTKWindowLongAttention(LolcatsTKWindowAttention):
|
||||
"""
|
||||
Lolcats attention combining sliding window and linear attention
|
||||
"""
|
||||
|
||||
def __init__(self, remove_base_attn=True, **kwargs):
|
||||
# keep self.base_attn for Flash Attention inference
|
||||
super().__init__(remove_base_attn=True, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Forward pass with the option to compute attention weights multiple ways
|
||||
if self.train_attention is True
|
||||
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||
"""
|
||||
b, l, _ = hidden_states.size()
|
||||
if self.train_attention and self.base_inference:
|
||||
with torch.no_grad():
|
||||
# LOG.debug(hidden_states.shape)
|
||||
_y_true = flash_attention_2(
|
||||
self, # self.base_attn,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
# output_hidden_states=False,
|
||||
use_cache=False,
|
||||
)[0]
|
||||
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
|
||||
y_true = _y_true.reshape(b, l, -1).contiguous()
|
||||
y_true = self.o_proj(y_true)
|
||||
layer_io = (hidden_states, _y_true) # hack
|
||||
# layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
|
||||
return y_true, layer_io, None
|
||||
|
||||
q, k, v, kv_seq_len = self.process_qkv(
|
||||
hidden_states, attention_mask, position_ids, past_key_value
|
||||
)
|
||||
f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
|
||||
|
||||
# attention_mask = None # For now this is always True
|
||||
if past_key_value is None: # Regular training
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||
y_pred, a_pred = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
else:
|
||||
past_key_value.window_size = self.decode_window_size
|
||||
if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating
|
||||
assert use_cache is True
|
||||
_kv = past_key_value.update_for_decoding(
|
||||
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||
)
|
||||
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||
|
||||
# Sliding window + linear attention decode
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
|
||||
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * (
|
||||
k.shape[-1] ** -0.5
|
||||
)
|
||||
# a_sm = torch.softmax(a_sm, dim=-1)
|
||||
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||
|
||||
y_pred = torch.einsum(
|
||||
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||
) + linear_factors * torch.einsum(
|
||||
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||
)
|
||||
sum_ln = (
|
||||
linear_factors
|
||||
* torch.einsum("bhlf,bhnf->bhl", f_q.float(), f_k_state.float())[
|
||||
..., None
|
||||
]
|
||||
)
|
||||
y_pred = (y_pred / (sum_sm + sum_ln)).to(q.dtype)
|
||||
|
||||
else: # Stateful training
|
||||
if (
|
||||
self.state_grad_enabled
|
||||
and self.layer_idx == 0
|
||||
and position_ids is not None
|
||||
):
|
||||
LOG.debug(
|
||||
f"\n position_ids: [{position_ids[0, 0]}, {position_ids[0, -1]}]"
|
||||
)
|
||||
LOG.debug(
|
||||
f"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}"
|
||||
)
|
||||
try:
|
||||
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||
k_state = past_key_value.k_states[self.layer_idx]
|
||||
except IndexError:
|
||||
kv_state, k_state = None, None
|
||||
window_factors = F.sigmoid(self.window_factors)
|
||||
linear_factors = (
|
||||
1 - window_factors if self.affine_attention_factors else 1
|
||||
)
|
||||
y_pred, a_pred = self.quadratic_attention(
|
||||
q,
|
||||
k,
|
||||
f_q,
|
||||
f_k,
|
||||
v,
|
||||
window_factors,
|
||||
linear_factors,
|
||||
window_size=self.window_size,
|
||||
kv_state=kv_state,
|
||||
k_state=k_state,
|
||||
)
|
||||
# Save and update KV cache and states
|
||||
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||
# fmap_key_states=f_k.detach(),
|
||||
# accumulate_in_fp32=True)
|
||||
past_key_value.update(
|
||||
k, v, self.layer_idx, fmap_key_states=f_k, accumulate_in_fp32=True
|
||||
)
|
||||
|
||||
# Concatenate heads and apply output projection
|
||||
_y_pred = y_pred.transpose(1, 2).contiguous()
|
||||
y_pred = self.o_proj(_y_pred.view(b, l, self.hidden_size))
|
||||
|
||||
if self.train_attention:
|
||||
with torch.no_grad():
|
||||
a_true = softmax_attention(q, k, None, causal=True)[1]
|
||||
attn_weights = (_y_pred, (a_pred, a_true))
|
||||
else:
|
||||
attn_weights = _y_pred # flash_attn outputs are shape (b, l, h, d)
|
||||
return y_pred, attn_weights, past_key_value
|
||||
|
||||
|
||||
# -----------------
|
||||
# Flash Attention 2
|
||||
# -----------------
|
||||
|
||||
|
||||
def flash_attention_2(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
"""
|
||||
Wrapper for LlamaFlashAttention2
|
||||
Copied and modified from HF Transformers v4.36 and v4.43 implementations
|
||||
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
|
||||
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
|
||||
"""
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
try: # As in Transformers v4.36
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
except Exception: # As in Transformers v4.39
|
||||
cos, sin = self.rotary_emb(key_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
LOG.debug(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
if getattr(self, "_flash_attention_forward", False):
|
||||
attn_output = self._flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
is_causal=True,
|
||||
)
|
||||
else:
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=0, # dropout_rate,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=True,
|
||||
)
|
||||
return attn_output, past_key_value
|
||||
@@ -1,361 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
"""Linear LLaMA model implementation."""
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
LlamaRotaryEmbedding,
|
||||
)
|
||||
|
||||
from .configuration_linear_llama import LinearLlamaConfig
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LinearLlamaDecoderLayer(LlamaDecoderLayer):
|
||||
"""
|
||||
Modified LlamaDecoderLayer that uses LinearAttention instead of standard attention.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LinearLlamaConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
|
||||
# Replace the attention layer with our custom attention
|
||||
self.self_attn = convert_llama_attention(
|
||||
layer=self, attention_config=config.attention_config
|
||||
)
|
||||
|
||||
|
||||
class LinearLlamaModel(LlamaModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LinearLlamaDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: LinearLlamaConfig
|
||||
"""
|
||||
|
||||
config_class = LinearLlamaConfig
|
||||
base_model_prefix = "linear_llama"
|
||||
|
||||
def __init__(self, config: LinearLlamaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(
|
||||
config.vocab_size, config.hidden_size, self.padding_idx
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
LinearLlamaDecoderLayer(config, layer_idx)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
|
||||
class LinearLlamaForCausalLM(LlamaForCausalLM):
|
||||
"""
|
||||
Linear LLaMA model for causal language modeling.
|
||||
"""
|
||||
|
||||
config_class = LinearLlamaConfig
|
||||
base_model_prefix = "linear_llama"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = LinearLlamaModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@classmethod
|
||||
def from_llama(
|
||||
cls,
|
||||
model: LlamaForCausalLM,
|
||||
config: LinearLlamaConfig,
|
||||
train_attention: bool = False,
|
||||
remove_base_attn: bool = True,
|
||||
) -> "LinearLlamaForCausalLM":
|
||||
"""
|
||||
Initialize a LinearLlamaForCausalLM from a LlamaModel
|
||||
"""
|
||||
|
||||
if config is None:
|
||||
raise ValueError("Missing config")
|
||||
|
||||
# initialize a new model with config
|
||||
new_model = cls(config=config)
|
||||
|
||||
# remove the default model and lm_head
|
||||
del new_model.model
|
||||
del new_model.lm_head
|
||||
|
||||
# load converted model, lm_head, and vocab_size from llama model
|
||||
new_model.model = convert_attention(
|
||||
model.model,
|
||||
attention_config=config.attention_config,
|
||||
train_attention=train_attention,
|
||||
remove_base_attn=remove_base_attn,
|
||||
)
|
||||
new_model.lm_head = model.lm_head
|
||||
new_model.vocab_size = model.vocab_size
|
||||
|
||||
return new_model
|
||||
|
||||
def toggle_attention(self, train: bool = True):
|
||||
"""
|
||||
Toggle attention to be trainable or not
|
||||
"""
|
||||
|
||||
toggle_attention(self.model, train=train)
|
||||
|
||||
def remove_base_attention(self):
|
||||
"""
|
||||
Remove base attention after distillation
|
||||
"""
|
||||
|
||||
remove_base_attention(self.model)
|
||||
|
||||
|
||||
def convert_attention(
|
||||
model: nn.Module,
|
||||
attention_config: dict,
|
||||
train_attention: bool = False,
|
||||
remove_base_attn: bool = True,
|
||||
):
|
||||
"""
|
||||
Call to convert all attention layers
|
||||
"""
|
||||
# Get the layers to convert if provided
|
||||
softmax_attns = attention_config.get("softmax_attentions", [])
|
||||
|
||||
# Get the attention to convert to
|
||||
attention_type = attention_config.get("attention_type")
|
||||
|
||||
if attention_type != "softmax":
|
||||
layers = traverse_layers(model)
|
||||
for layer_idx, layer in enumerate(
|
||||
tqdm(layers, desc="Converting attentions...")
|
||||
):
|
||||
if layer_idx not in softmax_attns:
|
||||
layer.self_attn = convert_llama_attention(
|
||||
layer,
|
||||
attention_config,
|
||||
layers,
|
||||
train_attention,
|
||||
remove_base_attn,
|
||||
)
|
||||
layer.self_attn.converted = True
|
||||
else:
|
||||
# Freeze any preserved softmax attention layers
|
||||
for p in layer.parameters():
|
||||
p.requires_grad = False
|
||||
else:
|
||||
LOG.info(
|
||||
f"-> attention_config.attention_type is {attention_type}; not converting attentions"
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def toggle_attention(llama_model: nn.Module, train: bool = False):
|
||||
"""
|
||||
Make attentions trainable if train is True
|
||||
-> Set train_attention = False when finetuning
|
||||
"""
|
||||
for layer in traverse_layers(llama_model):
|
||||
layer.self_attn.train_attention = train
|
||||
return llama_model
|
||||
|
||||
|
||||
def remove_base_attention(llama_model: nn.Module):
|
||||
"""
|
||||
Remove teacher attention after distillation (if we keep it)
|
||||
"""
|
||||
for layer in traverse_layers(llama_model):
|
||||
if getattr(layer.self_attn, "base_attn", False):
|
||||
del layer.self_attn.base_attn
|
||||
return llama_model
|
||||
|
||||
|
||||
def traverse_layers(model: nn.Module, verbose: bool = False):
|
||||
"""
|
||||
Return list of model layers
|
||||
"""
|
||||
try:
|
||||
layers = model.model.layers
|
||||
if verbose:
|
||||
LOG.info("-> Loading from model.model.layers")
|
||||
except AttributeError as e: # if base model
|
||||
if verbose:
|
||||
LOG.info(e)
|
||||
try:
|
||||
layers = model.layers
|
||||
if verbose:
|
||||
LOG.info("-> Loading from model.layers")
|
||||
except AttributeError as e1: # If we make a PEFT model
|
||||
if verbose:
|
||||
LOG.info(e1)
|
||||
layers = model.base_model.model.model.layers
|
||||
if verbose:
|
||||
LOG.info("-> Loading from model.base_model.model.model.layers")
|
||||
return layers
|
||||
|
||||
|
||||
def convert_llama_attention(
|
||||
layer: nn.Module,
|
||||
attention_config: dict,
|
||||
layers: Optional[list[nn.Module]] = None, # list of layers
|
||||
train_attention: bool = False,
|
||||
remove_base_attn: bool = True,
|
||||
):
|
||||
"""
|
||||
Converts a single layer's attention layer as specified by attention_config
|
||||
"""
|
||||
return get_attention(**attention_config)(
|
||||
base_attn=layer.self_attn,
|
||||
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
|
||||
max_layer_idx=len(layers) - 1 if layers else None,
|
||||
train_attention=train_attention,
|
||||
remove_base_attn=remove_base_attn,
|
||||
)
|
||||
|
||||
|
||||
def get_attention(attention_type: str, **kwargs):
|
||||
"""
|
||||
Get the linear attention class; either purely linear or linear with sliding window
|
||||
-> 'linear' == 'lolcats_llama'
|
||||
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
|
||||
"""
|
||||
kwargs["attention_type"] = attention_type
|
||||
|
||||
if attention_type == "lolcats_llama":
|
||||
from .linear_attention import LolcatsLinearAttention
|
||||
|
||||
return partial(LolcatsLinearAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_tk":
|
||||
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
||||
|
||||
return partial(LolcatsTKWindowAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_sw":
|
||||
from .linear_window_attention_sw import LolcatsSlidingWindowAttention
|
||||
|
||||
return partial(LolcatsSlidingWindowAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_sw_linear":
|
||||
from .linear_window_attention_sw_linear import (
|
||||
LolcatsLinearSlidingWindowAttention,
|
||||
)
|
||||
|
||||
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
|
||||
|
||||
# Experimental chunked linear attentions below
|
||||
elif attention_type == "lolcats_long_llama_window_tk":
|
||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||
|
||||
return partial(LolcatsTKWindowLongAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_long_llama_window_sw":
|
||||
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
|
||||
|
||||
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
|
||||
|
||||
# TK generation build (requires Thunderkittens)
|
||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||
from .linear_window_attention_tk_gen import LolcatsWindowAttentionTKGen
|
||||
|
||||
return partial(LolcatsWindowAttentionTKGen, **kwargs)
|
||||
|
||||
else:
|
||||
LOG.info(f"-> attention_type {attention_type} not handled... returning None")
|
||||
return None
|
||||
|
||||
|
||||
def get_attention_cache(attention_type: str, past_key_values: Any = None):
|
||||
"""
|
||||
Determine how we store past keys and values when generating
|
||||
"""
|
||||
if attention_type is None:
|
||||
return past_key_values
|
||||
|
||||
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
|
||||
elif "lolcats_llama_window_tk_gen" in attention_type:
|
||||
from .linear_window_attention_tk_gen import (
|
||||
LinearAttentionTKWindowGenerationCache,
|
||||
)
|
||||
|
||||
return LinearAttentionTKWindowGenerationCache()
|
||||
|
||||
elif "llama_window_tk" in attention_type:
|
||||
from .linear_window_attention_tk import LinearAttentionTKWindowCache
|
||||
|
||||
return LinearAttentionTKWindowCache()
|
||||
|
||||
elif "llama_window_sw" in attention_type:
|
||||
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
|
||||
|
||||
return LinearAttentionSlidingWindowCache()
|
||||
|
||||
elif "llama_window_sw_linear" in attention_type:
|
||||
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
|
||||
|
||||
return LinearAttentionSlidingWindowCache()
|
||||
|
||||
# TK generation build (requires Thunderkittens)
|
||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||
from .linear_window_attention_tk_gen import (
|
||||
LinearAttentionTKWindowGenerationCache,
|
||||
)
|
||||
|
||||
return LinearAttentionTKWindowGenerationCache()
|
||||
|
||||
elif "softmax" in attention_type:
|
||||
return past_key_values
|
||||
|
||||
else:
|
||||
from .linear_attention import LinearAttentionState
|
||||
|
||||
return LinearAttentionState()
|
||||
|
||||
|
||||
def register_linear_llama():
|
||||
"""
|
||||
Register Linear LLaMA model with the Transformers library.
|
||||
"""
|
||||
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||
|
||||
AutoConfig.register("linear_llama", LinearLlamaConfig)
|
||||
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
|
||||
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
|
||||
|
||||
# registering for auto classes to save files
|
||||
LinearLlamaConfig.register_for_auto_class("AutoConfig")
|
||||
LinearLlamaModel.register_for_auto_class("AutoModel")
|
||||
LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
||||
@@ -1,118 +0,0 @@
|
||||
"""
|
||||
Custom trainer class for distilling attentions ("attention transfer"). Can substitute for Hugging Face trainer.
|
||||
|
||||
In this implementation we support using either just the softmax attention outputs, or the softmax attention weights.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from torch import Tensor, nn, tensor
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
|
||||
class DistillAttentionXentMSETrainer(AxolotlTrainer):
|
||||
"""
|
||||
Custom trainer class for distilling attentions.
|
||||
- We compute and store the attention outputs and/or weights for each head and layer,
|
||||
for both the "teacher" softmax attentions and "student" learnable subquadratic attentions
|
||||
- We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
mse_factor: float = 1e3,
|
||||
xent_factor: float = 0,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.criterion_xent = nn.CrossEntropyLoss(reduction="mean")
|
||||
self.criterion_mse = nn.MSELoss(reduction="mean")
|
||||
self.mse_factor = mse_factor
|
||||
self.xent_factor = xent_factor
|
||||
# self.compute_loss_backprop = False # Whether we backprop in self.compute_loss # NOTE: this config seems unnecessary
|
||||
|
||||
self.model_accepts_loss_kwargs = False # added to combat explosive loss
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, Tensor],
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
) -> tuple[Tensor, dict]:
|
||||
"""
|
||||
Attention distillation ("attention transfer")
|
||||
- For each layer and head, get attentions and train to
|
||||
minimize some combo of MSE and cross-entropy loss
|
||||
"""
|
||||
# alias inputs to data
|
||||
data = inputs
|
||||
|
||||
device = model.device
|
||||
|
||||
# Filter out labels
|
||||
inputs = {k: v.to(device) for k, v in data.items() if k != "labels"}
|
||||
|
||||
# set num_items_in_batch
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
|
||||
# Forward pass
|
||||
outputs = model(**inputs, output_attentions=True, use_cache=False)
|
||||
outputs = outputs.get("attentions")
|
||||
|
||||
# Attentions are tuple[tuple[torch.Tensor, torch.Tensor]]
|
||||
# n_layers x (predicted_attns, true_attns)
|
||||
# predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len)
|
||||
loss_mse = tensor(0.0, device=device)
|
||||
loss_xent = tensor(0.0, device=device)
|
||||
n_layers = 0 # Number of layers to distill
|
||||
softmax_layers = []
|
||||
for layer_idx, attns in enumerate(outputs):
|
||||
if attns is not None:
|
||||
if len(attns) != 2:
|
||||
attns = attns.cpu()
|
||||
else:
|
||||
if self.xent_factor > 0:
|
||||
# Cross-entropy loss
|
||||
a_pred, a_true = attns[0]
|
||||
a_pred = a_pred.clamp(
|
||||
min=1e-12
|
||||
).log() # nn.CrossEntropy assumes unnormalized logits
|
||||
k_len = a_true.shape[-1] # batch, n_heads, q_len, k_len
|
||||
# Compute mean cross-entropy over all queries
|
||||
a_pred = a_pred.contiguous().view(-1, k_len)
|
||||
a_true = a_true.contiguous().view(-1, k_len)
|
||||
loss_xent += self.criterion_xent(a_pred, a_true)
|
||||
if self.mse_factor > 0:
|
||||
loss_mse += self.criterion_mse(*attns[1])
|
||||
n_layers += 1
|
||||
else:
|
||||
softmax_layers.append(layer_idx)
|
||||
if n_layers > 0:
|
||||
loss_xent = loss_xent / n_layers * self.xent_factor
|
||||
loss_mse = loss_mse / n_layers * self.mse_factor
|
||||
loss = loss_xent + loss_mse
|
||||
|
||||
if "position_ids" in data:
|
||||
outputs = {
|
||||
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
|
||||
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
|
||||
"input_len": data["position_ids"].shape[1],
|
||||
"position_ids": data["position_ids"][0].detach().cpu().numpy(),
|
||||
"mse_factor": self.mse_factor,
|
||||
"xent_factor": self.xent_factor,
|
||||
}
|
||||
else:
|
||||
outputs = {
|
||||
"loss_xent": loss_xent.item() if self.xent_factor > 0 else 0,
|
||||
"loss_mse": loss_mse if self.mse_factor > 0 else 0,
|
||||
"mse_factor": self.mse_factor,
|
||||
"xent_factor": self.xent_factor,
|
||||
}
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,590 @@
|
||||
{
|
||||
"model.layers.0.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.1.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.2.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.3.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.4.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.5.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.6.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.7.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.8.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.9.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.10.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.11.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.12.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.13.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.14.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.15.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"lm_head": {
|
||||
"snr": Infinity,
|
||||
"type": "lm_head"
|
||||
},
|
||||
"model.layers.0.mlp.down_proj": {
|
||||
"snr": 70.0594253540039,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.1.mlp.down_proj": {
|
||||
"snr": 11.135851860046387,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.2.mlp.down_proj": {
|
||||
"snr": 7.035482883453369,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.3.mlp.down_proj": {
|
||||
"snr": 6.422532081604004,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.4.mlp.down_proj": {
|
||||
"snr": 5.748020172119141,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.5.mlp.down_proj": {
|
||||
"snr": 3.885556697845459,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.6.mlp.down_proj": {
|
||||
"snr": 3.4336745738983154,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.7.mlp.down_proj": {
|
||||
"snr": 2.791595935821533,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.8.mlp.down_proj": {
|
||||
"snr": 5.36277961730957,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.9.mlp.down_proj": {
|
||||
"snr": 4.459208011627197,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.10.mlp.down_proj": {
|
||||
"snr": 6.272170066833496,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.11.mlp.down_proj": {
|
||||
"snr": 5.264761447906494,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.12.mlp.down_proj": {
|
||||
"snr": 4.324735641479492,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.13.mlp.down_proj": {
|
||||
"snr": 3.878648042678833,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.14.mlp.down_proj": {
|
||||
"snr": 2.9773054122924805,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.15.mlp.down_proj": {
|
||||
"snr": 4.471445560455322,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.0.mlp.gate_proj": {
|
||||
"snr": 25.227100372314453,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.1.mlp.gate_proj": {
|
||||
"snr": 6.58299446105957,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.2.mlp.gate_proj": {
|
||||
"snr": 3.4688243865966797,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.3.mlp.gate_proj": {
|
||||
"snr": 1.555246114730835,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.4.mlp.gate_proj": {
|
||||
"snr": 0.7770601511001587,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.5.mlp.gate_proj": {
|
||||
"snr": 0.6239906549453735,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.6.mlp.gate_proj": {
|
||||
"snr": 0.6440379023551941,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.7.mlp.gate_proj": {
|
||||
"snr": 0.5120116472244263,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.8.mlp.gate_proj": {
|
||||
"snr": 0.6544050574302673,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.9.mlp.gate_proj": {
|
||||
"snr": 0.5381016731262207,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.10.mlp.gate_proj": {
|
||||
"snr": 0.622873842716217,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.11.mlp.gate_proj": {
|
||||
"snr": 0.9361700415611267,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.12.mlp.gate_proj": {
|
||||
"snr": 1.475605845451355,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.13.mlp.gate_proj": {
|
||||
"snr": 1.608325719833374,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.14.mlp.gate_proj": {
|
||||
"snr": 1.0720024108886719,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.15.mlp.gate_proj": {
|
||||
"snr": 0.7111338973045349,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.0.mlp.up_proj": {
|
||||
"snr": 28.431896209716797,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.1.mlp.up_proj": {
|
||||
"snr": 15.546019554138184,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.2.mlp.up_proj": {
|
||||
"snr": 23.048023223876953,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.3.mlp.up_proj": {
|
||||
"snr": 25.790977478027344,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.4.mlp.up_proj": {
|
||||
"snr": 18.552549362182617,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.5.mlp.up_proj": {
|
||||
"snr": 8.85106372833252,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.6.mlp.up_proj": {
|
||||
"snr": 10.653799057006836,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.7.mlp.up_proj": {
|
||||
"snr": 7.365357875823975,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.8.mlp.up_proj": {
|
||||
"snr": 11.98373794555664,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.9.mlp.up_proj": {
|
||||
"snr": 8.04493236541748,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.10.mlp.up_proj": {
|
||||
"snr": 8.523039817810059,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.11.mlp.up_proj": {
|
||||
"snr": 5.381742477416992,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.12.mlp.up_proj": {
|
||||
"snr": 3.9845118522644043,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.13.mlp.up_proj": {
|
||||
"snr": 3.4893221855163574,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.14.mlp.up_proj": {
|
||||
"snr": 1.764201045036316,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.15.mlp.up_proj": {
|
||||
"snr": 0.9730708599090576,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.embed_tokens": {
|
||||
"snr": Infinity,
|
||||
"type": "model.embed_tokens"
|
||||
},
|
||||
"model.norm": {
|
||||
"snr": Infinity,
|
||||
"type": "model.norm"
|
||||
},
|
||||
"model.layers.0.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.1.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.2.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.3.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.4.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.5.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.6.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.7.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.8.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.9.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.10.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.11.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.12.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.13.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.14.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.15.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.0.self_attn.k_proj": {
|
||||
"snr": 0.11727584153413773,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.1.self_attn.k_proj": {
|
||||
"snr": 0.24786807596683502,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.2.self_attn.k_proj": {
|
||||
"snr": 0.36378130316734314,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.3.self_attn.k_proj": {
|
||||
"snr": 0.2983120381832123,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.4.self_attn.k_proj": {
|
||||
"snr": 0.33789733052253723,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.5.self_attn.k_proj": {
|
||||
"snr": 0.29155924916267395,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.6.self_attn.k_proj": {
|
||||
"snr": 0.2537297010421753,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.7.self_attn.k_proj": {
|
||||
"snr": 0.28204113245010376,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.8.self_attn.k_proj": {
|
||||
"snr": 0.2776711583137512,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.9.self_attn.k_proj": {
|
||||
"snr": 0.2927376627922058,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.10.self_attn.k_proj": {
|
||||
"snr": 0.31486213207244873,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.11.self_attn.k_proj": {
|
||||
"snr": 0.32363659143447876,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.12.self_attn.k_proj": {
|
||||
"snr": 0.31382912397384644,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.13.self_attn.k_proj": {
|
||||
"snr": 0.4635234773159027,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.14.self_attn.k_proj": {
|
||||
"snr": 0.25379249453544617,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.15.self_attn.k_proj": {
|
||||
"snr": 0.2628238797187805,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.0.self_attn.o_proj": {
|
||||
"snr": 0.27602291107177734,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.1.self_attn.o_proj": {
|
||||
"snr": 0.2149604707956314,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.2.self_attn.o_proj": {
|
||||
"snr": 0.2540294826030731,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.3.self_attn.o_proj": {
|
||||
"snr": 0.27978822588920593,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.4.self_attn.o_proj": {
|
||||
"snr": 0.3121289908885956,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.5.self_attn.o_proj": {
|
||||
"snr": 0.35037684440612793,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.6.self_attn.o_proj": {
|
||||
"snr": 0.366205096244812,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.7.self_attn.o_proj": {
|
||||
"snr": 0.3692712187767029,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.8.self_attn.o_proj": {
|
||||
"snr": 0.3301038146018982,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.9.self_attn.o_proj": {
|
||||
"snr": 0.3003396987915039,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.10.self_attn.o_proj": {
|
||||
"snr": 0.30804169178009033,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.11.self_attn.o_proj": {
|
||||
"snr": 0.28501132130622864,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.12.self_attn.o_proj": {
|
||||
"snr": 0.2171541005373001,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.13.self_attn.o_proj": {
|
||||
"snr": 0.19183959066867828,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.14.self_attn.o_proj": {
|
||||
"snr": 0.19215913116931915,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.15.self_attn.o_proj": {
|
||||
"snr": 0.25486502051353455,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.0.self_attn.q_proj": {
|
||||
"snr": 0.03850084915757179,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.1.self_attn.q_proj": {
|
||||
"snr": 0.0713055431842804,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.2.self_attn.q_proj": {
|
||||
"snr": 0.07948919385671616,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.3.self_attn.q_proj": {
|
||||
"snr": 0.08047746121883392,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.4.self_attn.q_proj": {
|
||||
"snr": 0.0852593332529068,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.5.self_attn.q_proj": {
|
||||
"snr": 0.09794823825359344,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.6.self_attn.q_proj": {
|
||||
"snr": 0.09627152234315872,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.7.self_attn.q_proj": {
|
||||
"snr": 0.11065381020307541,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.8.self_attn.q_proj": {
|
||||
"snr": 0.12031875550746918,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.9.self_attn.q_proj": {
|
||||
"snr": 0.09804573655128479,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.10.self_attn.q_proj": {
|
||||
"snr": 0.10897502303123474,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.11.self_attn.q_proj": {
|
||||
"snr": 0.09267337620258331,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.12.self_attn.q_proj": {
|
||||
"snr": 0.08803492039442062,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.13.self_attn.q_proj": {
|
||||
"snr": 0.0902542844414711,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.14.self_attn.q_proj": {
|
||||
"snr": 0.10154066979885101,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.15.self_attn.q_proj": {
|
||||
"snr": 0.09083802253007889,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.0.self_attn.v_proj": {
|
||||
"snr": 2.842210054397583,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.1.self_attn.v_proj": {
|
||||
"snr": 10.59461498260498,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.2.self_attn.v_proj": {
|
||||
"snr": 8.993025779724121,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.3.self_attn.v_proj": {
|
||||
"snr": 62.567787170410156,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.4.self_attn.v_proj": {
|
||||
"snr": 23.80082893371582,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.5.self_attn.v_proj": {
|
||||
"snr": 7.957369804382324,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.6.self_attn.v_proj": {
|
||||
"snr": 12.01815414428711,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.7.self_attn.v_proj": {
|
||||
"snr": 5.095500469207764,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.8.self_attn.v_proj": {
|
||||
"snr": 11.719332695007324,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.9.self_attn.v_proj": {
|
||||
"snr": 555.0869750976562,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.10.self_attn.v_proj": {
|
||||
"snr": 22.95538330078125,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.11.self_attn.v_proj": {
|
||||
"snr": 30.042158126831055,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.12.self_attn.v_proj": {
|
||||
"snr": 9.577271461486816,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.13.self_attn.v_proj": {
|
||||
"snr": 18.176361083984375,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.14.self_attn.v_proj": {
|
||||
"snr": 1.5695856809616089,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.15.self_attn.v_proj": {
|
||||
"snr": 2.7235565185546875,
|
||||
"type": "self_attn.v_proj"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,590 @@
|
||||
{
|
||||
"model.layers.0.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.1.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.2.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.3.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.4.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.5.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.6.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.7.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.8.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.9.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.10.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.11.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.12.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.13.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.14.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"model.layers.15.input_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "input_layernorm"
|
||||
},
|
||||
"lm_head": {
|
||||
"snr": Infinity,
|
||||
"type": "lm_head"
|
||||
},
|
||||
"model.layers.0.mlp.down_proj": {
|
||||
"snr": 57.09797286987305,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.1.mlp.down_proj": {
|
||||
"snr": 9.538983345031738,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.2.mlp.down_proj": {
|
||||
"snr": 6.227016925811768,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.3.mlp.down_proj": {
|
||||
"snr": 5.660686492919922,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.4.mlp.down_proj": {
|
||||
"snr": 5.178432464599609,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.5.mlp.down_proj": {
|
||||
"snr": 3.5638349056243896,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.6.mlp.down_proj": {
|
||||
"snr": 3.0918056964874268,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.7.mlp.down_proj": {
|
||||
"snr": 2.456392288208008,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.8.mlp.down_proj": {
|
||||
"snr": 4.525328636169434,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.9.mlp.down_proj": {
|
||||
"snr": 3.9409055709838867,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.10.mlp.down_proj": {
|
||||
"snr": 5.447249412536621,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.11.mlp.down_proj": {
|
||||
"snr": 4.807600975036621,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.12.mlp.down_proj": {
|
||||
"snr": 3.915374517440796,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.13.mlp.down_proj": {
|
||||
"snr": 3.4820363521575928,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.14.mlp.down_proj": {
|
||||
"snr": 2.6045074462890625,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.15.mlp.down_proj": {
|
||||
"snr": 3.7237701416015625,
|
||||
"type": "mlp.down_proj"
|
||||
},
|
||||
"model.layers.0.mlp.gate_proj": {
|
||||
"snr": 22.160131454467773,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.1.mlp.gate_proj": {
|
||||
"snr": 6.072206020355225,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.2.mlp.gate_proj": {
|
||||
"snr": 3.2467362880706787,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.3.mlp.gate_proj": {
|
||||
"snr": 1.4111896753311157,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.4.mlp.gate_proj": {
|
||||
"snr": 0.7405938506126404,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.5.mlp.gate_proj": {
|
||||
"snr": 0.5916463136672974,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.6.mlp.gate_proj": {
|
||||
"snr": 0.6149423718452454,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.7.mlp.gate_proj": {
|
||||
"snr": 0.48369669914245605,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.8.mlp.gate_proj": {
|
||||
"snr": 0.6047574877738953,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.9.mlp.gate_proj": {
|
||||
"snr": 0.5092479586601257,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.10.mlp.gate_proj": {
|
||||
"snr": 0.5999670624732971,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.11.mlp.gate_proj": {
|
||||
"snr": 0.8980127573013306,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.12.mlp.gate_proj": {
|
||||
"snr": 1.4252448081970215,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.13.mlp.gate_proj": {
|
||||
"snr": 1.509937047958374,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.14.mlp.gate_proj": {
|
||||
"snr": 1.0066585540771484,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.15.mlp.gate_proj": {
|
||||
"snr": 0.6413647532463074,
|
||||
"type": "mlp.gate_proj"
|
||||
},
|
||||
"model.layers.0.mlp.up_proj": {
|
||||
"snr": 26.08852195739746,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.1.mlp.up_proj": {
|
||||
"snr": 13.382951736450195,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.2.mlp.up_proj": {
|
||||
"snr": 20.088768005371094,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.3.mlp.up_proj": {
|
||||
"snr": 23.0632381439209,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.4.mlp.up_proj": {
|
||||
"snr": 16.07433319091797,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.5.mlp.up_proj": {
|
||||
"snr": 8.00507640838623,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.6.mlp.up_proj": {
|
||||
"snr": 9.538354873657227,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.7.mlp.up_proj": {
|
||||
"snr": 6.286602973937988,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.8.mlp.up_proj": {
|
||||
"snr": 10.092820167541504,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.9.mlp.up_proj": {
|
||||
"snr": 7.193963527679443,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.10.mlp.up_proj": {
|
||||
"snr": 7.320116996765137,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.11.mlp.up_proj": {
|
||||
"snr": 4.8728532791137695,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.12.mlp.up_proj": {
|
||||
"snr": 3.596583366394043,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.13.mlp.up_proj": {
|
||||
"snr": 3.166161298751831,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.14.mlp.up_proj": {
|
||||
"snr": 1.5600818395614624,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.layers.15.mlp.up_proj": {
|
||||
"snr": 0.8726214170455933,
|
||||
"type": "mlp.up_proj"
|
||||
},
|
||||
"model.embed_tokens": {
|
||||
"snr": Infinity,
|
||||
"type": "model.embed_tokens"
|
||||
},
|
||||
"model.norm": {
|
||||
"snr": Infinity,
|
||||
"type": "model.norm"
|
||||
},
|
||||
"model.layers.0.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.1.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.2.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.3.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.4.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.5.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.6.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.7.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.8.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.9.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.10.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.11.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.12.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.13.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.14.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.15.post_attention_layernorm": {
|
||||
"snr": Infinity,
|
||||
"type": "post_attention_layernorm"
|
||||
},
|
||||
"model.layers.0.self_attn.k_proj": {
|
||||
"snr": 0.1154392883181572,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.1.self_attn.k_proj": {
|
||||
"snr": 0.24299409985542297,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.2.self_attn.k_proj": {
|
||||
"snr": 0.3624322712421417,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.3.self_attn.k_proj": {
|
||||
"snr": 0.29509487748146057,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.4.self_attn.k_proj": {
|
||||
"snr": 0.32953736186027527,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.5.self_attn.k_proj": {
|
||||
"snr": 0.2908833622932434,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.6.self_attn.k_proj": {
|
||||
"snr": 0.2488437294960022,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.7.self_attn.k_proj": {
|
||||
"snr": 0.27847856283187866,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.8.self_attn.k_proj": {
|
||||
"snr": 0.27143892645835876,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.9.self_attn.k_proj": {
|
||||
"snr": 0.28804272413253784,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.10.self_attn.k_proj": {
|
||||
"snr": 0.31197959184646606,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.11.self_attn.k_proj": {
|
||||
"snr": 0.3203586935997009,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.12.self_attn.k_proj": {
|
||||
"snr": 0.30905747413635254,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.13.self_attn.k_proj": {
|
||||
"snr": 0.46828722953796387,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.14.self_attn.k_proj": {
|
||||
"snr": 0.24205778539180756,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.15.self_attn.k_proj": {
|
||||
"snr": 0.2559327781200409,
|
||||
"type": "self_attn.k_proj"
|
||||
},
|
||||
"model.layers.0.self_attn.o_proj": {
|
||||
"snr": 0.2638678550720215,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.1.self_attn.o_proj": {
|
||||
"snr": 0.21109595894813538,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.2.self_attn.o_proj": {
|
||||
"snr": 0.24751724302768707,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.3.self_attn.o_proj": {
|
||||
"snr": 0.2728094160556793,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.4.self_attn.o_proj": {
|
||||
"snr": 0.3001374304294586,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.5.self_attn.o_proj": {
|
||||
"snr": 0.33903488516807556,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.6.self_attn.o_proj": {
|
||||
"snr": 0.3530929982662201,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.7.self_attn.o_proj": {
|
||||
"snr": 0.36753255128860474,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.8.self_attn.o_proj": {
|
||||
"snr": 0.3373180329799652,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.9.self_attn.o_proj": {
|
||||
"snr": 0.2970578670501709,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.10.self_attn.o_proj": {
|
||||
"snr": 0.3076324760913849,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.11.self_attn.o_proj": {
|
||||
"snr": 0.2766900658607483,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.12.self_attn.o_proj": {
|
||||
"snr": 0.20973259210586548,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.13.self_attn.o_proj": {
|
||||
"snr": 0.18185566365718842,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.14.self_attn.o_proj": {
|
||||
"snr": 0.18329747021198273,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.15.self_attn.o_proj": {
|
||||
"snr": 0.2437991499900818,
|
||||
"type": "self_attn.o_proj"
|
||||
},
|
||||
"model.layers.0.self_attn.q_proj": {
|
||||
"snr": 0.038040731102228165,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.1.self_attn.q_proj": {
|
||||
"snr": 0.0707998052239418,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.2.self_attn.q_proj": {
|
||||
"snr": 0.0787411704659462,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.3.self_attn.q_proj": {
|
||||
"snr": 0.08089710026979446,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.4.self_attn.q_proj": {
|
||||
"snr": 0.08591937273740768,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.5.self_attn.q_proj": {
|
||||
"snr": 0.09852176159620285,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.6.self_attn.q_proj": {
|
||||
"snr": 0.09690654277801514,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.7.self_attn.q_proj": {
|
||||
"snr": 0.11181341856718063,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.8.self_attn.q_proj": {
|
||||
"snr": 0.12042108923196793,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.9.self_attn.q_proj": {
|
||||
"snr": 0.09799323976039886,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.10.self_attn.q_proj": {
|
||||
"snr": 0.10901063680648804,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.11.self_attn.q_proj": {
|
||||
"snr": 0.09307146072387695,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.12.self_attn.q_proj": {
|
||||
"snr": 0.0880950540304184,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.13.self_attn.q_proj": {
|
||||
"snr": 0.08886399120092392,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.14.self_attn.q_proj": {
|
||||
"snr": 0.09955056011676788,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.15.self_attn.q_proj": {
|
||||
"snr": 0.08929339051246643,
|
||||
"type": "self_attn.q_proj"
|
||||
},
|
||||
"model.layers.0.self_attn.v_proj": {
|
||||
"snr": 2.5501928329467773,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.1.self_attn.v_proj": {
|
||||
"snr": 9.449499130249023,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.2.self_attn.v_proj": {
|
||||
"snr": 7.9920830726623535,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.3.self_attn.v_proj": {
|
||||
"snr": 50.69462585449219,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.4.self_attn.v_proj": {
|
||||
"snr": 19.083511352539062,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.5.self_attn.v_proj": {
|
||||
"snr": 7.21597146987915,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.6.self_attn.v_proj": {
|
||||
"snr": 11.27744197845459,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.7.self_attn.v_proj": {
|
||||
"snr": 4.579711437225342,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.8.self_attn.v_proj": {
|
||||
"snr": 10.940719604492188,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.9.self_attn.v_proj": {
|
||||
"snr": 553.4417724609375,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.10.self_attn.v_proj": {
|
||||
"snr": 20.59434700012207,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.11.self_attn.v_proj": {
|
||||
"snr": 26.636865615844727,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.12.self_attn.v_proj": {
|
||||
"snr": 8.614749908447266,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.13.self_attn.v_proj": {
|
||||
"snr": 17.722007751464844,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.14.self_attn.v_proj": {
|
||||
"snr": 1.48500657081604,
|
||||
"type": "self_attn.v_proj"
|
||||
},
|
||||
"model.layers.15.self_attn.v_proj": {
|
||||
"snr": 2.5776851177215576,
|
||||
"type": "self_attn.v_proj"
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
159
src/axolotl/kernels/geglu.py
Normal file
159
src/axolotl/kernels/geglu.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Module for definition of GEGLU Triton kernels.
|
||||
|
||||
See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
|
||||
|
||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||
"""
|
||||
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
SQRT_2_PI: tl.constexpr = 0.7978845608028654 # sqrt(2/π)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _geglu_fwd_kernel(
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""GEGLU forward kernel.
|
||||
|
||||
Args:
|
||||
gate_ptr: Pointer to gate tensor [*, hidden_dim].
|
||||
up_ptr: Pointer to up-projection tensor [*, hidden_dim].
|
||||
out_ptr: Pointer to output tensor [*, hidden_dim].
|
||||
n_elements: Total number of elements in the input tensors.
|
||||
BLOCK_SIZE: Size of thread blocks for parallel computation.
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
|
||||
up = tl.load(up_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
# Compute activation in fp32 then convert back
|
||||
gelu_gate = 0.5 * gate * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0)
|
||||
gelu_gate = gelu_gate.to(up.dtype)
|
||||
result = gelu_gate * up
|
||||
|
||||
tl.store(out_ptr + offsets, result, mask=mask)
|
||||
|
||||
|
||||
def geglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
|
||||
"""GEGLU forward pass.
|
||||
|
||||
Args:
|
||||
gate: Input gate tensor of shape [batch, seq_len, hidden_dim].
|
||||
up: Up-projection tensor of shape [batch, seq_len, hidden_dim].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape [batch, seq_len, hidden_dim].
|
||||
"""
|
||||
batch, seq_len, hidden_dim = gate.shape
|
||||
n_elements = gate.numel()
|
||||
out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device="cuda")
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
|
||||
_geglu_fwd_kernel[grid](
|
||||
gate_ptr=gate,
|
||||
up_ptr=up,
|
||||
out_ptr=out,
|
||||
n_elements=n_elements,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _geglu_bwd_kernel(
|
||||
grad_out_ptr,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""GEGLU backward kernel. Stores gradient results in-place.
|
||||
|
||||
Args:
|
||||
grad_out_ptr: Pointer to gradient output tensor [*, hidden_dim].
|
||||
gate_ptr: Pointer to gate tensor [*, hidden_dim].
|
||||
up_ptr: Pointer to up-projection tensor [*, hidden_dim].
|
||||
n_elements: Total number of elements in the input tensors.
|
||||
BLOCK_SIZE: Size of thread blocks for parallel computation.
|
||||
|
||||
Note:
|
||||
After kernel execution, tensors are modified in-place:
|
||||
- `grad_out_ptr` contains GEGLU activation output (`h`)
|
||||
- `gate_ptr` contains gradient w.r.t gate (`grad_gate`)
|
||||
- `up_ptr` contains gradient w.r.t up (`grad_up`)
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0)
|
||||
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
|
||||
up = tl.load(up_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
# Forward pass
|
||||
gelu_partial = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0)
|
||||
gelu_gate = gelu_partial * gate
|
||||
gelu_gate = gelu_gate.to(grad_out.dtype)
|
||||
|
||||
# Forward output
|
||||
h = gelu_gate * up
|
||||
|
||||
# Compute gradients
|
||||
grad_up = grad_out * gelu_gate
|
||||
|
||||
# Compute gate gradient using GELU derivative
|
||||
temp = grad_out * up
|
||||
t = 0.3989422804014327 # 1/sqrt(2*pi)
|
||||
dgelu_dgate = gelu_partial + t * gate * tl.exp(-0.5 * gate * gate)
|
||||
grad_gate = temp.to(tl.float32) * dgelu_dgate
|
||||
grad_gate = grad_gate.to(grad_out.dtype)
|
||||
|
||||
# Store results
|
||||
tl.store(grad_out_ptr + offsets, h, mask=mask)
|
||||
tl.store(gate_ptr + offsets, grad_gate, mask=mask)
|
||||
tl.store(up_ptr + offsets, grad_up, mask=mask)
|
||||
|
||||
|
||||
def geglu_backward(
|
||||
grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""GEGLU backward pass using in-place operations.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`.
|
||||
gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
|
||||
up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- GEGLU activation output (`h`)
|
||||
- Gradient with respect to gate (`grad_gate`)
|
||||
- Gradient with respect to up (`grad_up`)
|
||||
|
||||
Note:
|
||||
This function modifies its input tensors in-place to store results.
|
||||
"""
|
||||
n_elements = grad_output.numel()
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
|
||||
_geglu_bwd_kernel[grid](
|
||||
grad_out_ptr=grad_output,
|
||||
gate_ptr=gate,
|
||||
up_ptr=up,
|
||||
n_elements=n_elements,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
return grad_output, gate, up
|
||||
779
src/axolotl/kernels/lora.py
Normal file
779
src/axolotl/kernels/lora.py
Normal file
@@ -0,0 +1,779 @@
|
||||
"""
|
||||
Module for definition of Low-Rank Adaptation (LoRA) Triton kernels.
|
||||
|
||||
See "LoRA: Low-Rank Adaptation of Large Language Models"
|
||||
(https://arxiv.org/abs/2106.09685).
|
||||
|
||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||
"""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
from torch import nn
|
||||
|
||||
from .geglu import geglu_backward, geglu_forward
|
||||
from .quantize import dequantize
|
||||
from .swiglu import swiglu_backward, swiglu_forward
|
||||
from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd
|
||||
|
||||
|
||||
def get_lora_parameters(
|
||||
proj: nn.Module,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
QuantState | None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
float | None,
|
||||
]:
|
||||
"""
|
||||
Gets LoRA parameters from a projection module.
|
||||
|
||||
Args:
|
||||
proj: The projection module to extract parameters from.
|
||||
|
||||
Returns:
|
||||
A tuple containing the base weight matrix, quantization state, LoRA A matrix,
|
||||
LoRA B matrix, and scaling factor. States and matrices may be None if not
|
||||
available.
|
||||
"""
|
||||
# For DPO or disabled adapters
|
||||
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
|
||||
W = base_layer.weight
|
||||
|
||||
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
return W, quant_state, None, None, None
|
||||
|
||||
active_adapter = (
|
||||
proj.active_adapters[0]
|
||||
if hasattr(proj, "active_adapters")
|
||||
else proj.active_adapter
|
||||
)
|
||||
A = proj.lora_A[active_adapter].weight
|
||||
B = proj.lora_B[active_adapter].weight
|
||||
s = proj.scaling[active_adapter]
|
||||
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
|
||||
return W, quant_state, A, B, s
|
||||
|
||||
|
||||
def matmul_lora(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
W_quant: QuantState,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
s: float,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Efficient fused matmul + LoRA computation.
|
||||
|
||||
Args:
|
||||
X: Input tensor [*, in_features]
|
||||
W: Base weight matrix [out_features, in_features]
|
||||
W_quant: Quantization state for W
|
||||
A: LoRA A matrix [rank, in_features]
|
||||
B: LoRA B matrix [out_features, rank]
|
||||
s: LoRA scaling factor
|
||||
out: Optional output tensor for inplace operations
|
||||
|
||||
Returns:
|
||||
Result of X @ W + X @ A @ B
|
||||
"""
|
||||
dtype = X.dtype
|
||||
W = dequantize(W.t(), W_quant)
|
||||
|
||||
if X.dim() == 3:
|
||||
batch, seq_len, _ = X.shape
|
||||
X = X.view(-1, X.shape[-1])
|
||||
reshape = True
|
||||
else:
|
||||
reshape = False
|
||||
|
||||
out = torch.matmul(X, W, out=out)
|
||||
if W_quant is not None:
|
||||
del W
|
||||
|
||||
if A is not None:
|
||||
A, B = A.t(), B.t()
|
||||
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
||||
|
||||
return out.view(batch, seq_len, -1) if reshape else out
|
||||
|
||||
|
||||
class LoRA_MLP(torch.autograd.Function):
|
||||
"""Optimized LoRA MLP implementation."""
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def forward(
|
||||
ctx,
|
||||
X: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
gate_quant: object | None,
|
||||
gate_A: torch.Tensor | None,
|
||||
gate_B: torch.Tensor | None,
|
||||
gate_scale: float,
|
||||
up_weight: torch.Tensor,
|
||||
up_quant: object | None,
|
||||
up_A: torch.Tensor | None,
|
||||
up_B: torch.Tensor | None,
|
||||
up_scale: float,
|
||||
down_weight: torch.Tensor,
|
||||
down_quant: object | None,
|
||||
down_A: torch.Tensor | None,
|
||||
down_B: torch.Tensor | None,
|
||||
down_scale: float,
|
||||
activation_fn: Callable,
|
||||
activation_fn_backward: Callable,
|
||||
inplace: bool | None = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for LoRA MLP.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
X: Input features
|
||||
gate_weight: Gate projection weight
|
||||
gate_quant: Gate quantization state
|
||||
gate_A: Gate LoRA A matrix
|
||||
gate_B: Gate LoRA B matrix
|
||||
gate_scale: Gate LoRA scale
|
||||
up_weight: Up-projection weight
|
||||
up_quant: Up-projection quantization state
|
||||
up_A: Up-projection LoRA A matrix
|
||||
up_B: Up-projection LoRA B matrix
|
||||
up_scale: Up-projection LoRA scale
|
||||
down_weight: Down-projection weight
|
||||
down_quant: Down-projection quantization state
|
||||
down_A: Down-projection LoRA A matrix
|
||||
down_B: Down-projection LoRA B matrix
|
||||
down_scale: Down-projection LoRA scale
|
||||
activation_fn: Forward activation function
|
||||
activation_fn_backward: Backward activation function
|
||||
inplace: Whether to perform operations in-place
|
||||
|
||||
Returns:
|
||||
Output transformed by multi-layer perceptron and activation function
|
||||
"""
|
||||
# Compute projections
|
||||
gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale)
|
||||
up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale)
|
||||
|
||||
# Activation
|
||||
hidden = activation_fn(gate, up)
|
||||
|
||||
# Down projection
|
||||
output = matmul_lora(
|
||||
hidden, down_weight, down_quant, down_A, down_B, down_scale
|
||||
)
|
||||
|
||||
# Save for backward
|
||||
ctx.save_for_backward(X, gate, up, gate_A, gate_B, up_A, up_B, down_A, down_B)
|
||||
ctx.scales = (gate_scale, up_scale, down_scale)
|
||||
ctx.quants = (gate_quant, up_quant, down_quant)
|
||||
ctx.weights = (gate_weight, up_weight, down_weight)
|
||||
ctx.activation_fn = activation_fn
|
||||
ctx.activation_fn_backward = activation_fn_backward
|
||||
ctx.inplace = inplace
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_bwd
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_output: torch.Tensor,
|
||||
) -> tuple[
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Performs backward pass computation for LoRA MLP.
|
||||
|
||||
Args:
|
||||
ctx: Context object storing tensors saved during forward pass
|
||||
grad_output: Gradient of loss with respect to layer output
|
||||
|
||||
Returns:
|
||||
Tuple containing gradients for all inputs from forward pass:
|
||||
- Input gradient tensor (or `None`)
|
||||
- `None` for weights/quantization states
|
||||
- LoRA A/B matrix gradients (or `None`)
|
||||
- `None` for scaling factors
|
||||
- `None` for activation functions and flags
|
||||
"""
|
||||
(
|
||||
X,
|
||||
gate,
|
||||
up,
|
||||
gate_A,
|
||||
gate_B,
|
||||
up_A,
|
||||
up_B,
|
||||
down_A,
|
||||
down_B,
|
||||
) = ctx.saved_tensors
|
||||
gate_scale, up_scale, down_scale = ctx.scales
|
||||
gate_quant, up_quant, down_quant = ctx.quants
|
||||
gate_weight, up_weight, down_weight = ctx.weights
|
||||
|
||||
# Transpose all LoRA matrices
|
||||
gate_A, gate_B = (
|
||||
gate_A.t() if gate_A is not None else None,
|
||||
gate_B.t() if gate_B is not None else None,
|
||||
)
|
||||
up_A, up_B = (
|
||||
up_A.t() if up_A is not None else None,
|
||||
up_B.t() if up_B is not None else None,
|
||||
)
|
||||
down_A, down_B = (
|
||||
down_A.t() if down_A is not None else None,
|
||||
down_B.t() if down_B is not None else None,
|
||||
)
|
||||
|
||||
# Reshape inputs
|
||||
batch, seq_len, hd = X.shape
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
X = X.view(-1, X.shape[-1])
|
||||
gate = gate.view(-1, gate.shape[-1])
|
||||
up = up.view(-1, up.shape[-1])
|
||||
dtype = X.dtype
|
||||
|
||||
# Down projection
|
||||
DW = matmul_lora(
|
||||
grad_output,
|
||||
down_weight.t(),
|
||||
down_quant,
|
||||
down_B,
|
||||
down_A,
|
||||
down_scale,
|
||||
)
|
||||
|
||||
# Activation backward
|
||||
h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up)
|
||||
|
||||
# Initialize and compute LoRA gradients
|
||||
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
|
||||
|
||||
if down_A is not None:
|
||||
d_down_A = h.t() @ (grad_output @ down_B.t())
|
||||
d_down_B = (down_A.t() @ h.t()) @ grad_output
|
||||
d_down_A *= down_scale
|
||||
d_down_B *= down_scale
|
||||
|
||||
if up_A is not None:
|
||||
d_up_A = X.t() @ (grad_up @ up_B.t())
|
||||
d_up_B = (up_A.t() @ X.t()) @ grad_up
|
||||
d_up_A *= up_scale
|
||||
d_up_B *= up_scale
|
||||
|
||||
if gate_A is not None:
|
||||
d_gate_A = X.t() @ (grad_gate @ gate_B.t())
|
||||
d_gate_B = (gate_A.t() @ X.t()) @ grad_gate
|
||||
d_gate_A *= gate_scale
|
||||
d_gate_B *= gate_scale
|
||||
|
||||
# Compute input gradients
|
||||
dX = torch.zeros_like(X) if ctx.needs_input_grad[0] else None
|
||||
|
||||
if dX is not None:
|
||||
# Up projection gradients
|
||||
up_weight = dequantize(up_weight.t(), up_quant)
|
||||
if ctx.inplace:
|
||||
dX = torch.matmul(grad_up, up_weight.t(), out=X)
|
||||
else:
|
||||
dX = torch.matmul(grad_up, up_weight.t())
|
||||
del up_weight
|
||||
|
||||
# Note the .to(dtype) only where mixing LoRA with base weights
|
||||
if up_A is not None:
|
||||
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
|
||||
|
||||
# Gate projection gradients
|
||||
gate_weight = dequantize(gate_weight.t(), gate_quant)
|
||||
dX += grad_gate @ gate_weight.t()
|
||||
del gate_weight
|
||||
|
||||
if gate_A is not None:
|
||||
dX += (
|
||||
grad_gate
|
||||
@ gate_B.to(dtype).t()
|
||||
@ (gate_scale * gate_A.to(dtype).t())
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
dX = dX.view(batch, seq_len, hd)
|
||||
|
||||
# Return gradients in correct order matching forward inputs
|
||||
return (
|
||||
dX,
|
||||
None,
|
||||
None,
|
||||
d_gate_A.t() if d_gate_A is not None else None,
|
||||
d_gate_B.t() if d_gate_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_up_A.t() if d_up_A is not None else None,
|
||||
d_up_B.t() if d_up_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_down_A.t() if d_down_A is not None else None,
|
||||
d_down_B.t() if d_down_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Applies LoRA to MLP layer with SwiGLU activation.
|
||||
|
||||
Args:
|
||||
X: Input tensor for the MLP layer
|
||||
inplace: Whether to perform operations in-place to save memory
|
||||
|
||||
Returns:
|
||||
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
|
||||
"""
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
|
||||
out = LoRA_MLP.apply(
|
||||
X,
|
||||
gateW,
|
||||
gateW_quant,
|
||||
gateA,
|
||||
gateB,
|
||||
gateS,
|
||||
upW,
|
||||
upW_quant,
|
||||
upA,
|
||||
upB,
|
||||
upS,
|
||||
downW,
|
||||
downW_quant,
|
||||
downA,
|
||||
downB,
|
||||
downS,
|
||||
swiglu_forward,
|
||||
swiglu_backward,
|
||||
inplace,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Applies LoRA to MLP layer with GEGLU activation.
|
||||
|
||||
Args:
|
||||
X: Input tensor for the MLP layer
|
||||
inplace: Whether to perform operations in-place to save memory
|
||||
|
||||
Returns:
|
||||
Output tensor after applying LoRA-adapted MLP with GEGLU activation
|
||||
"""
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
out = LoRA_MLP.apply(
|
||||
X,
|
||||
gateW,
|
||||
gateW_quant,
|
||||
gateA,
|
||||
gateB,
|
||||
gateS,
|
||||
upW,
|
||||
upW_quant,
|
||||
upA,
|
||||
upB,
|
||||
upS,
|
||||
downW,
|
||||
downW_quant,
|
||||
downA,
|
||||
downB,
|
||||
downS,
|
||||
geglu_forward,
|
||||
geglu_backward,
|
||||
inplace,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class LoRA_QKV(torch.autograd.Function):
|
||||
"""
|
||||
Optimized LoRA QKV implementation with quantization support.
|
||||
|
||||
Implements efficient computation of query, key, value projections with LoRA,
|
||||
supporting quantization and memory optimization.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
X: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
q_quant: QuantState | None,
|
||||
q_A: torch.Tensor | None,
|
||||
q_B: torch.Tensor | None,
|
||||
q_scale: float,
|
||||
k_weight: torch.Tensor,
|
||||
k_quant: QuantState | None,
|
||||
k_A: torch.Tensor | None,
|
||||
k_B: torch.Tensor | None,
|
||||
k_scale: float,
|
||||
v_weight: torch.Tensor,
|
||||
v_quant: QuantState | None,
|
||||
v_A: torch.Tensor | None,
|
||||
v_B: torch.Tensor | None,
|
||||
v_scale: float,
|
||||
inplace: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass computing Q, K, V projections with LoRA.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
X: Input tensor
|
||||
q_weight: Query projection weight
|
||||
q_quant: Query quantization state
|
||||
q_A: Query LoRA A matrix
|
||||
q_B: Query LoRA B matrix
|
||||
q_scale: Query LoRA scale
|
||||
k_weight: Key projection weight
|
||||
k_quant: Key quantization state
|
||||
k_A: Key LoRA A matrix
|
||||
k_B: Key LoRA B matrix
|
||||
k_scale: Key LoRA scale
|
||||
v_weight: Value projection weight
|
||||
v_quant: Value quantization state
|
||||
v_A: Value LoRA A matrix
|
||||
v_B: Value LoRA B matrix
|
||||
v_scale: Value LoRA scale
|
||||
inplace: Whether to perform operations in-place
|
||||
|
||||
Returns:
|
||||
Tuple of (Query, Key, Value) projection tensors
|
||||
"""
|
||||
Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale)
|
||||
K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale)
|
||||
V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale)
|
||||
|
||||
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
|
||||
ctx.scales = (q_scale, k_scale, v_scale)
|
||||
ctx.quants = (q_quant, k_quant, v_quant)
|
||||
ctx.weights = (q_weight, k_weight, v_weight)
|
||||
ctx.inplace = inplace
|
||||
|
||||
return Q, K, V
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
q_grad: torch.Tensor,
|
||||
k_grad: torch.Tensor,
|
||||
v_grad: torch.Tensor,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Backward pass computing gradients for LoRA QKV.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
q_grad: Gradient for query projection
|
||||
k_grad: Gradient for key projection
|
||||
v_grad: Gradient for value projection
|
||||
|
||||
Returns:
|
||||
Tuple containing gradients for all forward inputs
|
||||
"""
|
||||
X, A_q, B_q, A_k, B_k, A_v, B_v = ctx.saved_tensors
|
||||
q_weight, k_weight, v_weight = ctx.weights
|
||||
q_quant, k_quant, v_quant = ctx.quants
|
||||
q_scale, k_scale, v_scale = ctx.scales
|
||||
dtype = X.dtype
|
||||
|
||||
# Reshape gradients
|
||||
batch, seq_len = X.shape[:2]
|
||||
q_grad = q_grad.view(-1, q_grad.shape[-1])
|
||||
k_grad = k_grad.reshape(-1, k_grad.shape[-1])
|
||||
v_grad = v_grad.view(-1, v_grad.shape[-1])
|
||||
X = X.view(-1, X.shape[-1])
|
||||
|
||||
# Pre-transpose X once
|
||||
X_t = X.t()
|
||||
|
||||
# Initialize LoRA gradients as None
|
||||
d_A_q = d_B_q = d_A_k = d_B_k = d_A_v = d_B_v = None
|
||||
|
||||
# Compute q path LoRA gradients if adapters exist
|
||||
if A_q is not None and B_q is not None:
|
||||
A_q_scaled = (q_scale * A_q).to(dtype)
|
||||
B_q_scaled = B_q.to(dtype)
|
||||
d_A_q = torch.mm(X_t, torch.mm(q_grad, B_q_scaled))
|
||||
d_B_q = torch.mm(torch.mm(A_q_scaled, X_t), q_grad)
|
||||
|
||||
# Compute k path LoRA gradients if adapters exist
|
||||
if A_k is not None and B_k is not None:
|
||||
A_k_scaled = (k_scale * A_k).to(dtype)
|
||||
B_k_scaled = B_k.to(dtype)
|
||||
d_A_k = torch.mm(X_t, torch.mm(k_grad, B_k_scaled))
|
||||
d_B_k = torch.mm(torch.mm(A_k_scaled, X_t), k_grad)
|
||||
|
||||
# Compute v path LoRA gradients if adapters exist
|
||||
if A_v is not None and B_v is not None:
|
||||
A_v_scaled = (v_scale * A_v).to(dtype)
|
||||
B_v_scaled = B_v.to(dtype)
|
||||
d_A_v = torch.mm(X_t, torch.mm(v_grad, B_v_scaled))
|
||||
d_B_v = torch.mm(torch.mm(A_v_scaled, X_t), v_grad)
|
||||
|
||||
# Compute input gradient, reusing X memory if possible
|
||||
out_buffer = X if ctx.inplace else None
|
||||
|
||||
# Q path
|
||||
q_weight_t = dequantize(q_weight, q_quant)
|
||||
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
|
||||
del q_weight
|
||||
del q_weight_t
|
||||
if A_q is not None and B_q is not None:
|
||||
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
|
||||
|
||||
# K path
|
||||
k_weight_t = dequantize(k_weight, k_quant)
|
||||
grad_X.addmm_(k_grad, k_weight_t)
|
||||
del k_weight
|
||||
del k_weight_t
|
||||
if A_k is not None and B_k is not None:
|
||||
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
|
||||
|
||||
# V path
|
||||
v_weight_t = dequantize(v_weight, v_quant)
|
||||
grad_X.addmm_(v_grad, v_weight_t)
|
||||
del v_weight
|
||||
del v_weight_t
|
||||
if A_v is not None and B_v is not None:
|
||||
grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled))
|
||||
|
||||
# Transpose gradients if needed
|
||||
if d_A_q is not None:
|
||||
d_A_q = d_A_q.t()
|
||||
if d_B_q is not None:
|
||||
d_B_q = d_B_q.t()
|
||||
if d_A_k is not None:
|
||||
d_A_k = d_A_k.t()
|
||||
if d_B_k is not None:
|
||||
d_B_k = d_B_k.t()
|
||||
if d_A_v is not None:
|
||||
d_A_v = d_A_v.t()
|
||||
if d_B_v is not None:
|
||||
d_B_v = d_B_v.t()
|
||||
|
||||
return (
|
||||
grad_X.view(batch, seq_len, -1),
|
||||
None,
|
||||
None,
|
||||
d_A_q,
|
||||
d_B_q,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_k,
|
||||
d_B_k,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_v,
|
||||
d_B_v,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_qkv(
|
||||
self, X: torch.Tensor, inplace: bool = True
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Applies LoRA to compute Query, Key, Value projections.
|
||||
|
||||
Args:
|
||||
X: Input tensor
|
||||
inplace: Whether to perform operations in-place
|
||||
|
||||
Returns:
|
||||
Tuple of (Query, Key, Value) projection tensors
|
||||
"""
|
||||
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
||||
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
||||
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
||||
Q, K, V = LoRA_QKV.apply(
|
||||
X,
|
||||
QW,
|
||||
QW_quant,
|
||||
QA,
|
||||
QB,
|
||||
QS,
|
||||
KW,
|
||||
KW_quant,
|
||||
KA,
|
||||
KB,
|
||||
KS,
|
||||
VW,
|
||||
VW_quant,
|
||||
VA,
|
||||
VB,
|
||||
VS,
|
||||
inplace,
|
||||
)
|
||||
|
||||
return Q, K, V
|
||||
|
||||
|
||||
class LoRA_O(torch.autograd.Function):
|
||||
"""Optimized LoRA implementation for output projection."""
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
W_quant: QuantState | None,
|
||||
A: torch.Tensor | None,
|
||||
B: torch.Tensor | None,
|
||||
S: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for output projection with LoRA.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
X: Input tensor
|
||||
W: Output projection weight
|
||||
W_quant: Weight quantization state
|
||||
A: LoRA A matrix
|
||||
B: LoRA B matrix
|
||||
S: LoRA scaling factor
|
||||
|
||||
Returns:
|
||||
Output projection tensor
|
||||
"""
|
||||
XW = matmul_lora(X, W, W_quant, A, B, S)
|
||||
ctx.custom_saved_tensors = (
|
||||
W,
|
||||
W_quant,
|
||||
S,
|
||||
)
|
||||
ctx.save_for_backward(A, B, X)
|
||||
|
||||
return XW
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_bwd
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
dY: torch.Tensor,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Backward pass computing gradients for LoRA output projection.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
dY: Gradient of loss with respect to output
|
||||
|
||||
Returns:
|
||||
Tuple containing gradients for all forward inputs
|
||||
"""
|
||||
W, W_quant, S = ctx.custom_saved_tensors
|
||||
A, B, X = ctx.saved_tensors
|
||||
|
||||
batch, seq_len, hd = X.shape
|
||||
dY = dY.reshape(-1, dY.shape[-1])
|
||||
X = X.reshape(-1, X.shape[-1])
|
||||
dtype = X.dtype
|
||||
|
||||
# Weight projection
|
||||
dY_X = X.t() @ dY
|
||||
d_A = S * dY_X @ B
|
||||
d_B = S * A @ dY_X
|
||||
|
||||
# Get derivative for dX
|
||||
W = dequantize(W.t(), W_quant)
|
||||
dX = dY @ W.t()
|
||||
del W
|
||||
dX += dY @ B.to(dtype) @ (S * A.to(dtype))
|
||||
|
||||
# W, W_quant, A, B, S
|
||||
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
|
||||
|
||||
|
||||
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies LoRA to output projection layer.
|
||||
|
||||
Args:
|
||||
X: Input tensor
|
||||
|
||||
Returns:
|
||||
Transformed output tensor
|
||||
"""
|
||||
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
||||
output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS)
|
||||
|
||||
return output
|
||||
149
src/axolotl/kernels/quantize.py
Normal file
149
src/axolotl/kernels/quantize.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Dequantization utilities for `bitsandbytes` integration."""
|
||||
# pylint: disable=invalid-name,global-statement
|
||||
|
||||
import ctypes
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState, get_ptr
|
||||
from packaging.version import Version
|
||||
|
||||
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
||||
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
|
||||
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
|
||||
|
||||
CUDA_STREAM: torch.cuda.Stream | None = None
|
||||
HAS_CUDA_STREAM: bool = Version(bnb.__version__) > Version("0.43.3")
|
||||
|
||||
|
||||
def dequantize(
|
||||
W: torch.Tensor,
|
||||
quant_state: QuantState | list | None = None,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fast NF4 dequantization using `bitsandbytes` CUDA kernels.
|
||||
|
||||
Performs efficient dequantization of weights from NF4 format using `bitsandbytes`'
|
||||
optimized CUDA implementations. Supports both legacy list and new `QuantState`
|
||||
formats.
|
||||
|
||||
Args:
|
||||
W: Quantized weight tensor to dequantize
|
||||
quant_state: Quantization state containing metadata needed for
|
||||
dequantization. Can be either a `QuantState` object or legacy list format.
|
||||
If None, returns `W` unchanged.
|
||||
out: Optional output tensor for storing dequantized results. Must match
|
||||
expected shape and dtype if provided.
|
||||
|
||||
Returns:
|
||||
Dequantized tensor in the specified dtype (fp16 or bf16). Will be transposed if
|
||||
input `W` was transposed.
|
||||
|
||||
Raises:
|
||||
AssertionError: If provided output tensor doesn't match expected shape / dtype.
|
||||
|
||||
Note:
|
||||
Uses CUDA streams for better performance when available in newer `bitsandbytes`
|
||||
versions (>0.43.3).
|
||||
"""
|
||||
if quant_state is None:
|
||||
return W
|
||||
|
||||
# Get the target device from input tensor W
|
||||
target_device = W.device
|
||||
|
||||
# Extract quantization state
|
||||
if not isinstance(quant_state, list):
|
||||
# New style quant_state class
|
||||
absmax = quant_state.absmax.to(target_device)
|
||||
shape = quant_state.shape
|
||||
dtype = quant_state.dtype
|
||||
blocksize = quant_state.blocksize
|
||||
offset = quant_state.offset.to(target_device)
|
||||
state2 = quant_state.state2
|
||||
absmax2 = state2.absmax.to(target_device)
|
||||
code2 = state2.code.to(target_device)
|
||||
blocksize2 = state2.blocksize
|
||||
else:
|
||||
# Legacy list format
|
||||
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
|
||||
absmax = absmax.to(target_device)
|
||||
offset, state2 = compressed_stats
|
||||
offset = offset.to(target_device)
|
||||
absmax2, code2, blocksize2, _, _, _, _ = state2
|
||||
absmax2 = absmax2.to(target_device)
|
||||
code2 = code2.to(target_device)
|
||||
|
||||
# Setup output tensor on the same device as input
|
||||
if out is None:
|
||||
out = torch.empty(shape, dtype=dtype, device=target_device)
|
||||
else:
|
||||
assert out.shape == shape and out.dtype == dtype
|
||||
out = out.to(target_device)
|
||||
|
||||
# Dequantize statistics on the target device
|
||||
n_elements_absmax: int = absmax.numel()
|
||||
out_absmax: torch.Tensor = torch.empty(
|
||||
n_elements_absmax, dtype=torch.float32, device=target_device
|
||||
)
|
||||
ptr_out_absmax: int = get_ptr(out_absmax)
|
||||
|
||||
# Use CUDA stream if available
|
||||
if HAS_CUDA_STREAM:
|
||||
global CUDA_STREAM
|
||||
if CUDA_STREAM is None:
|
||||
CUDA_STREAM = torch.cuda.current_stream(target_device)
|
||||
|
||||
cdequantize_blockwise_fp32(
|
||||
get_ptr(code2),
|
||||
get_ptr(absmax),
|
||||
get_ptr(absmax2),
|
||||
ptr_out_absmax,
|
||||
ctypes.c_int(blocksize2),
|
||||
ctypes.c_int(n_elements_absmax),
|
||||
CUDA_STREAM,
|
||||
)
|
||||
else:
|
||||
cdequantize_blockwise_fp32(
|
||||
get_ptr(code2),
|
||||
get_ptr(absmax),
|
||||
get_ptr(absmax2),
|
||||
ptr_out_absmax,
|
||||
ctypes.c_int(blocksize2),
|
||||
ctypes.c_int(n_elements_absmax),
|
||||
)
|
||||
|
||||
out_absmax += offset
|
||||
|
||||
# Choose appropriate dequantization function
|
||||
fx = (
|
||||
cdequantize_blockwise_fp16_nf4
|
||||
if dtype == torch.float16
|
||||
else cdequantize_blockwise_bf16_nf4
|
||||
)
|
||||
|
||||
# Dequantize weights
|
||||
if HAS_CUDA_STREAM:
|
||||
fx(
|
||||
get_ptr(None),
|
||||
get_ptr(W),
|
||||
ptr_out_absmax,
|
||||
get_ptr(out),
|
||||
ctypes.c_int(blocksize),
|
||||
ctypes.c_int(out.numel()),
|
||||
CUDA_STREAM,
|
||||
)
|
||||
else:
|
||||
fx(
|
||||
get_ptr(None),
|
||||
get_ptr(W),
|
||||
ptr_out_absmax,
|
||||
get_ptr(out),
|
||||
ctypes.c_int(blocksize),
|
||||
ctypes.c_int(out.numel()),
|
||||
)
|
||||
|
||||
# Handle transposed data
|
||||
is_transposed: bool = W.shape[0] == 1
|
||||
return out.t() if is_transposed else out
|
||||
163
src/axolotl/kernels/swiglu.py
Normal file
163
src/axolotl/kernels/swiglu.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Module for definition of SwiGLU Triton kernels.
|
||||
|
||||
See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
|
||||
|
||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||
"""
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _swiglu_fwd_kernel(
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
SwiGLU forward kernel. The kernel computes activation in fp32 precision for better
|
||||
numerical stability, then converts back to original dtype for the final result.
|
||||
|
||||
Args:
|
||||
gate_ptr: Pointer to gate tensor `[*, hidden_dim]`.
|
||||
up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`.
|
||||
out_ptr: Pointer to output tensor `[*, hidden_dim]`.
|
||||
n_elements: Total number of elements in the input tensors.
|
||||
block_size: Size of thread blocks for parallel computation.
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * block_size + tl.arange(0, block_size)
|
||||
mask = offsets < n_elements
|
||||
|
||||
# Load gate in fp32, keep up in original dtype
|
||||
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
|
||||
up = tl.load(up_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
# Compute activation in fp32 then convert back
|
||||
f = gate * tl.sigmoid(gate)
|
||||
f = f.to(up.dtype)
|
||||
result = f * up
|
||||
|
||||
tl.store(out_ptr + offsets, result, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _swiglu_bwd_kernel(
|
||||
grad_out_ptr,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
n_elements,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
SwiGLU backward kernel. Stores gradient results in-place.
|
||||
|
||||
Args:
|
||||
grad_out_ptr: Pointer to gradient output tensor `[*, hidden_dim]`.
|
||||
gate_ptr: Pointer to gate tensor `[*, hidden_dim]`.
|
||||
up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`.
|
||||
n_elements: Total number of elements in the input tensors.
|
||||
block_size: Size of thread blocks for parallel computation.
|
||||
|
||||
Note:
|
||||
After kernel execution, tensors are modified in-place:
|
||||
- `grad_out_ptr` contains forward output (`h`)
|
||||
- `gate_ptr` contains gradient w.r.t gate (`grad_gate`)
|
||||
- `up_ptr` contains gradient w.r.t up (`grad_up`)
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * block_size + tl.arange(0, block_size)
|
||||
mask = offsets < n_elements
|
||||
|
||||
# Load values - only convert gate to fp32
|
||||
grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0)
|
||||
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
|
||||
up = tl.load(up_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
# Compute SiLU and forward output
|
||||
sigmoid_gate = tl.sigmoid(gate)
|
||||
silu_gate = sigmoid_gate * gate
|
||||
silu_gate = silu_gate.to(grad_out.dtype)
|
||||
h = silu_gate * up
|
||||
|
||||
# Compute gradients
|
||||
grad_up = grad_out * silu_gate # gradient for up is grad_out * SiLU(gate)
|
||||
|
||||
# Compute gate gradient
|
||||
temp = grad_out * up
|
||||
grad_gate = temp.to(tl.float32) * sigmoid_gate * (1.0 + gate * (1.0 - sigmoid_gate))
|
||||
grad_gate = grad_gate.to(grad_out.dtype)
|
||||
|
||||
# Store results with correct gradient ordering
|
||||
tl.store(grad_out_ptr + offsets, h, mask=mask)
|
||||
tl.store(gate_ptr + offsets, grad_gate, mask=mask) # grad wrt gate
|
||||
tl.store(up_ptr + offsets, grad_up, mask=mask) # grad wrt up
|
||||
|
||||
|
||||
# pylint: disable=unnecessary-lambda-assignment
|
||||
def swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
SwiGLU forward pass. Computes SwiGLU activation: `x * sigmoid(x) * up`, where
|
||||
`x` is the gate tensor.
|
||||
|
||||
Args:
|
||||
gate: Input gate tensor of shape `[batch, seq_len, hidden_dim]`.
|
||||
up: Up-projection tensor of shape `[batch, seq_len, hidden_dim]`.
|
||||
|
||||
Returns:
|
||||
Output tensor of shape `[batch, seq_len, hidden_dim]`.
|
||||
"""
|
||||
batch, seq_len, hidden_dim = gate.shape
|
||||
n_elements = gate.numel()
|
||||
out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device="cuda")
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["block_size"]),) # noqa: E731
|
||||
_swiglu_fwd_kernel[grid](
|
||||
gate_ptr=gate,
|
||||
up_ptr=up,
|
||||
out_ptr=out,
|
||||
n_elements=n_elements,
|
||||
block_size=1024,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# pylint: disable=unnecessary-lambda-assignment
|
||||
def swiglu_backward(
|
||||
grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
SwiGLU backward pass using in-place operations.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`.
|
||||
gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
|
||||
up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Forward pass output (`h`)
|
||||
- Gradient with respect to gate (`df`)
|
||||
- Gradient with respect to up-projection (`de`)
|
||||
"""
|
||||
n_elements = grad_output.numel()
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["block_size"]),) # noqa: E731
|
||||
_swiglu_bwd_kernel[grid](
|
||||
grad_out_ptr=grad_output,
|
||||
gate_ptr=gate,
|
||||
up_ptr=up,
|
||||
n_elements=n_elements,
|
||||
block_size=1024,
|
||||
)
|
||||
|
||||
# After kernel execution, tensors contain:
|
||||
# grad_output: h (forward output)
|
||||
# gate: grad_gate (grad wrt gate)
|
||||
# up: grad_up (grad wrt up)
|
||||
return grad_output, gate, up
|
||||
11
src/axolotl/kernels/utils.py
Normal file
11
src/axolotl/kernels/utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Utilities for `axolotl.kernels` submodules."""
|
||||
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
|
||||
if Version(torch.__version__) < Version("2.4.0"):
|
||||
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
||||
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
||||
else:
|
||||
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
@@ -0,0 +1,45 @@
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from yunchang import set_seq_parallel_pg, EXTRACT_FUNC_DICT
|
||||
|
||||
from axolotl.utils.distributed import get_world_size, get_rank
|
||||
|
||||
|
||||
class USPRingAttnType(Enum):
|
||||
BASIC = "basic"
|
||||
ZIGZAG = "zigzag"
|
||||
STRIPE = "stripe"
|
||||
|
||||
def apply_usp_attn_patch(ring_impl_type: USPRingAttnType):
|
||||
from axolotl.monkeypatch.attention.sequence_parallel.usp import build_usp_fa_forward
|
||||
|
||||
fa_forward = build_usp_fa_forward(ring_impl_type)
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = fa_forward
|
||||
|
||||
def get_extract_fn(ring_impl_type: USPRingAttnType, sp_ulysses_degree: int):
|
||||
fn = EXTRACT_FUNC_DICT["basic"]
|
||||
if ring_impl_type.value in EXTRACT_FUNC_DICT.keys():
|
||||
fn = EXTRACT_FUNC_DICT[ring_impl_type.value]
|
||||
|
||||
# map bad key upstream
|
||||
elif ring_impl_type == USPRingAttnType.STRIPE:
|
||||
fn = EXTRACT_FUNC_DICT["strip"]
|
||||
|
||||
world_size = get_world_size()
|
||||
rd = world_size // sp_ulysses_degree
|
||||
|
||||
return partial(fn, rank=get_rank(), world_size=world_size, rd=rd, ud=sp_ulysses_degree)
|
||||
|
||||
def set_usp_parallel_group(sp_ulysses_degree):
|
||||
"""
|
||||
setup distributed parallel group for USP attention
|
||||
make sure this gets called before building any USP attention modules
|
||||
:param sp_ulysses_degree:
|
||||
:return:
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
rank = get_rank()
|
||||
sp_ring_degree = world_size // sp_ulysses_degree
|
||||
set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
|
||||
36
src/axolotl/monkeypatch/attention/sequence_parallel/usp.py
Normal file
36
src/axolotl/monkeypatch/attention/sequence_parallel/usp.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple, Callable
|
||||
|
||||
import torch
|
||||
from yunchang import LongContextAttention
|
||||
|
||||
from axolotl.monkeypatch.attention.sequence_parallel import USPRingAttnType
|
||||
|
||||
|
||||
def build_usp_fa_forward(ring_impl_type: USPRingAttnType) -> Callable:
|
||||
usp_attn = LongContextAttention(ring_impl_type.value)
|
||||
|
||||
def flash_attention_forward(
|
||||
module: torch.nn.Module, # pylint: disable=unused-argument
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor], # pylint: disable=unused-argument
|
||||
dropout: float = 0.0,
|
||||
scaling: Optional[float] = None,
|
||||
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
|
||||
softcap: Optional[float] = None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
attn_output = usp_attn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=scaling,
|
||||
causal=True,
|
||||
softcap=softcap,
|
||||
)
|
||||
return attn_output, None
|
||||
|
||||
return flash_attention_forward
|
||||
333
src/axolotl/monkeypatch/lora_kernels.py
Normal file
333
src/axolotl/monkeypatch/lora_kernels.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Module for patching custom LoRA Triton kernels and `torch.autograd` functions."""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import types
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from peft import PeftModelForCausalLM
|
||||
from torch import nn
|
||||
from transformers import AutoConfig
|
||||
|
||||
from axolotl.kernels.lora import (
|
||||
apply_lora_mlp_geglu,
|
||||
apply_lora_mlp_swiglu,
|
||||
apply_lora_o,
|
||||
apply_lora_qkv,
|
||||
)
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_QKV_CODE = """
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
PATCHED_QKV_CODE = """
|
||||
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
||||
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
||||
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
ORIGINAL_O_CODE = """
|
||||
attn_output = self.o_proj(attn_output)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
PATCHED_O_CODE = """
|
||||
attn_output = self.apply_o(attn_output)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
SUPPORTED_ACTIVATIONS = ["silu", "gelu"]
|
||||
APPLY_FN_MAPPING = {
|
||||
"silu": apply_lora_mlp_swiglu,
|
||||
"gelu": apply_lora_mlp_geglu,
|
||||
}
|
||||
|
||||
|
||||
def original_apply_qkv(
|
||||
self: nn.Module, hidden_states: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Original implementation of QKV projection without optimizations.
|
||||
|
||||
Args:
|
||||
self: The attention module instance.
|
||||
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim].
|
||||
|
||||
Returns:
|
||||
A tuple `(query_states, key_states, value_states)` containing the projected
|
||||
states for query, key, and value.
|
||||
"""
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
return query_states, key_states, value_states
|
||||
|
||||
|
||||
def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Original implementation of output projection without optimizations.
|
||||
|
||||
Args:
|
||||
self: The attention module instance.
|
||||
hidden_states: Input tensor of shape `[`batch_size, seq_len, hidden_dim]`.
|
||||
|
||||
Returns:
|
||||
The output projection result.
|
||||
"""
|
||||
attn_output = self.o_proj(hidden_states)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
"""
|
||||
Get the appropriate attention class by inspecting the model config.
|
||||
Uses dynamic import to support any model architecture that follows
|
||||
the standard transformers naming convention.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
The appropriate attention class for the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If `base_model` not specified or attention class cannot be imported
|
||||
ImportError: If the model module or attention class doesn't exist
|
||||
"""
|
||||
if "base_model" not in cfg:
|
||||
raise ValueError("base_model must be specified in config")
|
||||
|
||||
# Get model config without loading the model
|
||||
model_config = AutoConfig.from_pretrained(cfg["base_model"])
|
||||
model_type = model_config.model_type
|
||||
|
||||
# Special case for model_type = "qwen2"
|
||||
if model_type == "qwen2":
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
|
||||
|
||||
return Qwen2Attention
|
||||
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
module = __import__(
|
||||
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
|
||||
)
|
||||
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
|
||||
|
||||
return attention_cls
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(
|
||||
f"Could not import attention class for model_type: {model_type}. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def patch_self_attn_lora(cfg: DictDefault):
|
||||
"""
|
||||
Given an `axolotl` config, this method patches the inferred attention class forward
|
||||
pass with optimized LoRA implementations.
|
||||
|
||||
It modifies the attention class to use optimized QKV and output projections. The
|
||||
original implementation is preserved and can be restored if needed.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the required code blocks are not found in the attention
|
||||
implementation.
|
||||
"""
|
||||
attention_cls = get_attention_cls_from_config(cfg)
|
||||
|
||||
# Check if already patched
|
||||
if hasattr(attention_cls, "_original_forward"):
|
||||
LOG.info(f"{attention_cls.__name__} already patched")
|
||||
return
|
||||
|
||||
self_attn_forward = inspect.getsource(attention_cls.forward)
|
||||
attention_cls._original_forward = self_attn_forward
|
||||
self_attn_forward, _ = detab_code(self_attn_forward)
|
||||
|
||||
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
|
||||
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
|
||||
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(
|
||||
"def forward(",
|
||||
"def axolotl_attn_forward(",
|
||||
1,
|
||||
)
|
||||
|
||||
# Load necessary imports
|
||||
module_name = attention_cls.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in self_attn_forward:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
|
||||
attention_cls.forward = (
|
||||
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_kernel_patches(
|
||||
model: PeftModelForCausalLM, cfg: DictDefault
|
||||
) -> PeftModelForCausalLM:
|
||||
"""
|
||||
Applies optimized Triton kernel patches to a PEFT model.
|
||||
|
||||
Patches a PEFT model with optimized implementations for MLP and attention
|
||||
computations. The optimizations include custom Triton kernels for activation
|
||||
functions and specialized autograd functions for LoRA computations.
|
||||
|
||||
Args:
|
||||
model: A PEFT model to be patched with optimized kernels.
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
PeftModelForCausalLM: The patched model with optimized kernels.
|
||||
|
||||
Raises:
|
||||
TypeError: If the provided model is not a `PeftModelForCausalLM`.
|
||||
NotImplementedError: If the model type is not supported.
|
||||
AssertionError: If multiple adapters are active (currently unsupported).
|
||||
|
||||
Note:
|
||||
The optimizations require LoRA adapters with no dropout and no bias terms. The
|
||||
function will skip patching if these conditions aren't met.
|
||||
"""
|
||||
if not isinstance(model, PeftModelForCausalLM):
|
||||
raise TypeError("Model must be a PeftModelForCausalLM")
|
||||
|
||||
# Get active LoRA adapter config
|
||||
if hasattr(model, "active_adapters"):
|
||||
assert (
|
||||
len(model.active_adapters) == 1
|
||||
), "Axolotl currently does not support LoRA Triton kernels for multiple adapters"
|
||||
active_adapter = model.active_adapters[0]
|
||||
else:
|
||||
active_adapter = model.active_adapter
|
||||
lora_config = model.model.peft_config[active_adapter]
|
||||
|
||||
# Only patch if conditions are met
|
||||
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
|
||||
|
||||
if not can_patch:
|
||||
LOG.warning("Cannot patch layers - requires no dropout and no bias")
|
||||
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
|
||||
return model
|
||||
|
||||
# This needs to be reset after patching
|
||||
original_level = LOG.getEffectiveLevel()
|
||||
LOG.setLevel(logging.INFO)
|
||||
|
||||
# Choose activation based on model type
|
||||
activation = model.config.hidden_act
|
||||
if activation not in SUPPORTED_ACTIVATIONS:
|
||||
raise NotImplementedError(f"Activation {activation} is not supported")
|
||||
|
||||
# Patch each layer
|
||||
for layer in model.model.model.layers:
|
||||
# Add QKV, O fallback implementations to start
|
||||
# These will be overwritten later (if some conditions apply)
|
||||
layer.self_attn.apply_qkv = types.MethodType(
|
||||
original_apply_qkv, layer.self_attn
|
||||
)
|
||||
layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn)
|
||||
|
||||
if cfg.lora_mlp_kernel:
|
||||
# MLP patching
|
||||
gate_proj = layer.mlp.gate_proj
|
||||
up_proj = layer.mlp.up_proj
|
||||
down_proj = layer.mlp.down_proj
|
||||
|
||||
can_patch_mlp = all(
|
||||
hasattr(proj, "lora_A")
|
||||
and getattr(proj, "base_layer", proj).bias is None
|
||||
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
||||
for proj in (gate_proj, up_proj, down_proj)
|
||||
)
|
||||
|
||||
if can_patch_mlp:
|
||||
apply_fn = APPLY_FN_MAPPING[activation]
|
||||
layer.mlp.forward = types.MethodType(apply_fn, layer.mlp)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
|
||||
)
|
||||
if cfg.lora_qkv_kernel:
|
||||
# Query, key, value patching
|
||||
layer_modules = [
|
||||
getattr(layer.self_attn, linear_proj)
|
||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||
]
|
||||
can_patch_qkv = all(
|
||||
hasattr(module, "lora_A")
|
||||
and getattr(module, "base_layer", module).bias is None
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if can_patch_qkv:
|
||||
# Add optimized implementation
|
||||
layer.self_attn.apply_qkv = types.MethodType(
|
||||
apply_lora_qkv, layer.self_attn
|
||||
)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
|
||||
)
|
||||
if cfg.lora_o_kernel:
|
||||
# Output patching
|
||||
layer_modules = [
|
||||
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||
]
|
||||
can_patch_o = all(
|
||||
hasattr(module, "lora_A")
|
||||
and getattr(module, "base_layer", module).bias is None
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if can_patch_o:
|
||||
layer.self_attn.apply_o = types.MethodType(
|
||||
apply_lora_o, layer.self_attn
|
||||
)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
|
||||
)
|
||||
|
||||
LOG.setLevel(original_level)
|
||||
|
||||
return model
|
||||
@@ -127,6 +127,8 @@ class ReLoRACallback(TrainerCallback):
|
||||
optimizer: torch.optim.Optimizer,
|
||||
**_kwargs,
|
||||
):
|
||||
if not optimizer:
|
||||
optimizer = state.optimizer
|
||||
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
|
||||
@@ -41,10 +41,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
||||
load_kwargs["ds_cfg"] = ds_cfg
|
||||
if "processor" in sig.parameters:
|
||||
load_kwargs["processor"] = processor
|
||||
|
||||
return func(tokenizer, cfg, **load_kwargs)
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
||||
raise exc
|
||||
return None
|
||||
|
||||
@@ -13,8 +13,19 @@ def load(strategy, cfg, module_base=None, **kwargs):
|
||||
if len(strategy.split(".")) == 1:
|
||||
strategy = strategy + ".default"
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{strategy}", module_base)
|
||||
if len(strategy.split(".")) > 1:
|
||||
try:
|
||||
importlib.import_module(
|
||||
strategy.split(".")[-2],
|
||||
".".join(strategy.split(".")[:-2]),
|
||||
)
|
||||
module_base = ".".join(strategy.split(".")[:-2])
|
||||
strategy = strategy.split(".")[-2]
|
||||
except ModuleNotFoundError:
|
||||
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||
else:
|
||||
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(strategy, module_base)
|
||||
func = getattr(mod, load_fn)
|
||||
return func(cfg, **kwargs)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
|
||||
@@ -34,20 +34,17 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
|
||||
max_length = self.prompter.max_length
|
||||
|
||||
self.messages = "chosen_messages"
|
||||
# pylint: disable=duplicate-code
|
||||
prompt[self.messages] = []
|
||||
prompt["messages"] = []
|
||||
if prompt["system"]:
|
||||
prompt[self.messages].append(
|
||||
{"role": "system", "content": prompt["system"]}
|
||||
)
|
||||
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
||||
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
|
||||
prompt["messages"].append({"role": "system", "content": prompt["system"]})
|
||||
prompt["messages"].append({"role": "user", "content": prompt["input"]})
|
||||
prompt["messages"].append({"role": "assistant", "content": prompt["chosen"]})
|
||||
chosen_tokenized = super()._tokenize_single_prompt(prompt)
|
||||
|
||||
if len(chosen_tokenized["input_ids"]) > max_length:
|
||||
LOG.warning(
|
||||
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
|
||||
f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
|
||||
)
|
||||
|
||||
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
|
||||
@@ -55,22 +52,17 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
:max_length
|
||||
]
|
||||
|
||||
self.messages = "rejected_messages"
|
||||
# pylint: disable=duplicate-code
|
||||
prompt[self.messages] = []
|
||||
prompt["messages"] = []
|
||||
if prompt["system"]:
|
||||
prompt[self.messages].append(
|
||||
{"role": "system", "content": prompt["system"]}
|
||||
)
|
||||
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
||||
prompt[self.messages].append(
|
||||
{"role": "assistant", "content": prompt["rejected"]}
|
||||
)
|
||||
prompt["messages"].append({"role": "system", "content": prompt["system"]})
|
||||
prompt["messages"].append({"role": "user", "content": prompt["input"]})
|
||||
prompt["messages"].append({"role": "assistant", "content": prompt["rejected"]})
|
||||
rejected_tokenized = super()._tokenize_single_prompt(prompt)
|
||||
|
||||
if len(rejected_tokenized["input_ids"]) > max_length:
|
||||
LOG.warning(
|
||||
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
|
||||
f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
|
||||
)
|
||||
|
||||
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
|
||||
@@ -99,8 +91,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_template_string,
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_property_mappings": ds_cfg.get(
|
||||
"message_property_mappings",
|
||||
{
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_field_training_detail", None
|
||||
@@ -124,7 +121,4 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
||||
)
|
||||
|
||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
|
||||
return strategy
|
||||
|
||||
@@ -4,13 +4,16 @@ HF Chat Templates prompt strategy
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import ProcessorMixin
|
||||
|
||||
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
|
||||
|
||||
# Configure the logger
|
||||
LOG = logging.getLogger("axolotl")
|
||||
@@ -23,16 +26,23 @@ class ChatTemplatePrompter(Prompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template: str,
|
||||
processor=None,
|
||||
chat_template=None,
|
||||
max_length=2048,
|
||||
message_field_role: str = "role",
|
||||
message_field_content: str = "content",
|
||||
message_property_mappings: Optional[Dict[str, str]] = None,
|
||||
message_field_training: Optional[str] = None,
|
||||
message_field_training_detail: Optional[str] = None,
|
||||
field_messages: str = "messages",
|
||||
roles: Optional[Dict[str, List[str]]] = None,
|
||||
drop_system_message: bool = False,
|
||||
):
|
||||
# check if message_property_mappings is None or empty dict
|
||||
if message_property_mappings is None or (not message_property_mappings):
|
||||
message_property_mappings = {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
}
|
||||
|
||||
if roles:
|
||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||
else:
|
||||
@@ -45,18 +55,28 @@ class ChatTemplatePrompter(Prompter):
|
||||
"tool": "tool",
|
||||
}
|
||||
|
||||
self.message_field_role = message_field_role
|
||||
self.message_field_content = message_field_content
|
||||
self._chat_template_msg_variables = self.get_chat_template_msg_variables(
|
||||
chat_template, field_messages
|
||||
)
|
||||
self.message_property_mappings = message_property_mappings
|
||||
self.message_field_training = message_field_training
|
||||
self.message_field_training_detail = message_field_training_detail
|
||||
self.field_messages = field_messages
|
||||
self.tokenizer = tokenizer
|
||||
self.processor: ProcessorMixin = processor
|
||||
self.processor: Optional[ProcessorMixin] = processor
|
||||
self.chat_template = chat_template
|
||||
self.max_length = max_length
|
||||
self.drop_system_message = drop_system_message
|
||||
|
||||
@property
|
||||
def chat_template_msg_variables(self) -> Set[str]:
|
||||
return self._chat_template_msg_variables
|
||||
|
||||
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
||||
if self.processor:
|
||||
if not callable(self.processor):
|
||||
raise TypeError("Processor must be callable")
|
||||
|
||||
text = self.processor.apply_chat_template(
|
||||
conversation,
|
||||
chat_template=self.chat_template,
|
||||
@@ -184,17 +204,21 @@ class ChatTemplatePrompter(Prompter):
|
||||
|
||||
return adjusted_details
|
||||
|
||||
def get_chat_template_msg_variables(
|
||||
self, chat_template: str, field_messages: str
|
||||
) -> Set[str]:
|
||||
template_analyzer = JinjaTemplateAnalyzer(chat_template)
|
||||
return template_analyzer.get_message_vars(field_messages)
|
||||
|
||||
|
||||
class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for instruction-based prompts.
|
||||
"""
|
||||
|
||||
_messages = "messages"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter: ChatTemplatePrompter,
|
||||
prompter: "ChatTemplatePrompter",
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
@@ -202,6 +226,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
train_on_eos=None,
|
||||
):
|
||||
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||
self.prompter: ChatTemplatePrompter = prompter
|
||||
|
||||
self.roles_to_train = []
|
||||
if roles_to_train:
|
||||
@@ -213,13 +238,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
self.train_on_eos = train_on_eos
|
||||
self.images = "images"
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
return self._messages
|
||||
|
||||
@messages.setter
|
||||
def messages(self, messages):
|
||||
self._messages = messages
|
||||
LOG.debug(
|
||||
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_batched(self) -> bool:
|
||||
@@ -229,7 +250,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
||||
try:
|
||||
return all(isinstance(v, list) for v in prompt.values()) and all(
|
||||
isinstance(v, list) for v in prompt[self.messages]
|
||||
isinstance(v, list) for v in prompt[self.prompter.field_messages]
|
||||
)
|
||||
except KeyError:
|
||||
return False
|
||||
@@ -251,8 +272,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
dict(zip(feature_names, row))
|
||||
)
|
||||
for key, val in tokenized_prompt.items():
|
||||
for i in range(0, len(val), self.sequence_len):
|
||||
res[key].append(val[i : i + self.sequence_len])
|
||||
res[key].append(val)
|
||||
|
||||
# If there are no examples left, return an empty dictionary
|
||||
if not res:
|
||||
@@ -464,30 +484,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
turns = []
|
||||
optional_keys = [
|
||||
"tool_calls", # tool that 'assistant' calls
|
||||
"name", # name of tool given by 'tool'
|
||||
"tool_call_id", # mistral/mixtral requires this
|
||||
]
|
||||
for message in prompt[self.messages]:
|
||||
for message in prompt[self.prompter.field_messages]:
|
||||
transformed_message = self.transform_message(message)
|
||||
|
||||
turn = {
|
||||
"role": self.prompter.roles[message[self.prompter.message_field_role]],
|
||||
**transformed_message,
|
||||
"training": message.get(self.prompter.message_field_training),
|
||||
"training_detail": message.get(
|
||||
self.prompter.message_field_training_detail
|
||||
),
|
||||
}
|
||||
|
||||
# do not add content if None as it may conflict with some templates due to tools
|
||||
content = message.get(self.prompter.message_field_content, None)
|
||||
if content is not None:
|
||||
turn["content"] = content
|
||||
|
||||
for key in optional_keys:
|
||||
value = message.get(key, None)
|
||||
if value is not None:
|
||||
turn[key] = value
|
||||
|
||||
turns.append(turn)
|
||||
|
||||
if self.prompter.drop_system_message and turns[0]["role"] == "system":
|
||||
@@ -495,6 +502,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
return turns
|
||||
|
||||
def transform_message(self, message):
|
||||
# Build the initial transformed message from the mappings
|
||||
transformed_message = {}
|
||||
for key, value in self.prompter.message_property_mappings.items():
|
||||
if message.get(value) is not None:
|
||||
transformed_message[key] = message[value]
|
||||
else:
|
||||
LOG.debug(
|
||||
f"Could not find value for property {value} in message: {message}"
|
||||
)
|
||||
|
||||
# Map the role if necessary
|
||||
if "role" in transformed_message:
|
||||
transformed_message["role"] = self.prompter.roles.get(
|
||||
transformed_message["role"], transformed_message["role"]
|
||||
)
|
||||
|
||||
# Determine which keys in the original message were not mapped
|
||||
mapped_values = set(self.prompter.message_property_mappings.values())
|
||||
remaining_keys = set(message) - mapped_values
|
||||
|
||||
# Keep only the properties defined in the chat template
|
||||
# and not already mapped
|
||||
for key in self.prompter.chat_template_msg_variables:
|
||||
if key in remaining_keys:
|
||||
val = message.get(key)
|
||||
if val is not None:
|
||||
transformed_message[key] = val
|
||||
|
||||
return transformed_message
|
||||
|
||||
def get_images(self, prompt):
|
||||
return prompt.get(self.images, None)
|
||||
|
||||
@@ -516,33 +554,46 @@ class StrategyLoader:
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self, tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
|
||||
self,
|
||||
tokenizer,
|
||||
cfg,
|
||||
ds_cfg: Optional[Union[Dict[str, Any], DatasetConfig]] = None,
|
||||
processor=None,
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
ds_cfg = ds_cfg or {}
|
||||
if ds_cfg is None:
|
||||
dataset_config = {}
|
||||
elif isinstance(ds_cfg, BaseModel):
|
||||
dataset_config = ds_cfg.model_dump()
|
||||
else:
|
||||
dataset_config = ds_cfg
|
||||
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
|
||||
)
|
||||
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
||||
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_template_string,
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_property_mappings": dataset_config.get(
|
||||
"message_property_mappings", {}
|
||||
),
|
||||
"message_field_training": dataset_config.get(
|
||||
"message_field_training", None
|
||||
),
|
||||
"message_field_training_detail": dataset_config.get(
|
||||
"message_field_training_detail",
|
||||
None,
|
||||
),
|
||||
"roles": ds_cfg.get("roles"),
|
||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||
"field_messages": dataset_config.get("field_messages", "messages"),
|
||||
"roles": dataset_config.get("roles"),
|
||||
"drop_system_message": dataset_config.get("drop_system_message", False),
|
||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||
"max_length": cfg.sequence_len + 1,
|
||||
"processor": processor,
|
||||
}
|
||||
|
||||
strategy_params = self._get_strategy_params(cfg, ds_cfg)
|
||||
strategy_params = self._get_strategy_params(cfg, dataset_config)
|
||||
strategy_cls = self._get_strategy_cls()
|
||||
|
||||
strategy = strategy_cls(
|
||||
@@ -551,9 +602,6 @@ class StrategyLoader:
|
||||
**strategy_params,
|
||||
)
|
||||
|
||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
|
||||
@@ -3,20 +3,28 @@ DPO prompt strategies for using tokenizer chat templates.
|
||||
"""
|
||||
|
||||
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
|
||||
|
||||
|
||||
def default(
|
||||
cfg, dataset_idx=0, **kwargs
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
ds_cfg = cfg["datasets"][dataset_idx]
|
||||
ds_cfg = handle_legacy_message_fields_logic(ds_cfg)
|
||||
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg=cfg, ds_cfg=ds_cfg
|
||||
)
|
||||
field_messages = ds_cfg.get("field_messages", "messages")
|
||||
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||
field_message_role = ds_cfg.get("message_field_role", "role")
|
||||
field_message_content = ds_cfg.get("message_field_content", "content")
|
||||
message_property_mappings = ds_cfg.get(
|
||||
"message_property_mappings",
|
||||
{
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
)
|
||||
role_map_inv = ds_cfg.get(
|
||||
"roles",
|
||||
{
|
||||
@@ -40,18 +48,18 @@ def default(
|
||||
messages = sample[field_messages]
|
||||
messages = [
|
||||
{
|
||||
"role": role_map[m[field_message_role]],
|
||||
"content": m[field_message_content],
|
||||
"role": role_map[m[message_property_mappings["role"]]],
|
||||
"content": m[message_property_mappings["content"]],
|
||||
}
|
||||
for m in messages
|
||||
]
|
||||
chosen = {
|
||||
"role": role_map[sample[field_chosen][field_message_role]],
|
||||
"content": sample[field_chosen][field_message_content],
|
||||
"role": role_map[sample[field_chosen][message_property_mappings["role"]]],
|
||||
"content": sample[field_chosen][message_property_mappings["content"]],
|
||||
}
|
||||
rejected = {
|
||||
"role": role_map[sample[field_rejected][field_message_role]],
|
||||
"content": sample[field_rejected][field_message_content],
|
||||
"role": role_map[sample[field_rejected][message_property_mappings["role"]]],
|
||||
"content": sample[field_rejected][message_property_mappings["content"]],
|
||||
}
|
||||
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
|
||||
|
||||
|
||||
14
src/axolotl/prompt_strategies/dpo/passthrough.py
Normal file
14
src/axolotl/prompt_strategies/dpo/passthrough.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
DPO prompt strategies passthrough/zero-processing strategy
|
||||
"""
|
||||
|
||||
|
||||
def default(
|
||||
cfg, dataset_idx=0, **kwargs
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(
|
||||
sample, tokenizer=None
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
318
src/axolotl/prompt_strategies/jinja_template_analyzer.py
Normal file
318
src/axolotl/prompt_strategies/jinja_template_analyzer.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""Module for inspect jinja templates for the variables they use"""
|
||||
from typing import Dict, Optional, Set, TypedDict, Union
|
||||
|
||||
from jinja2 import Environment, meta, nodes
|
||||
|
||||
|
||||
class JinjaTemplateAnalysis(TypedDict):
|
||||
"""
|
||||
Represents the detailed analysis of a Jinja template variable.
|
||||
|
||||
Attributes:
|
||||
accessed_properties (Set[str]): A set of properties accessed from the variable
|
||||
(e.g., `foo.bar` results in 'bar' being accessed for 'foo').
|
||||
accessed_indices (Set[Union[int, float]]): A set of indices accessed from the variable.
|
||||
is_iterated (bool): Indicates if the variable is used as an iteration source in a `for` loop.
|
||||
is_conditional (bool): Indicates if the variable is referenced within a conditional statement (e.g., an `if` block).
|
||||
iteration_source (Optional[str]): The name of the variable being iterated over, if applicable.
|
||||
iteration_target (Optional[Union[str, list[str]]]): The loop target(s) assigned in the iteration.
|
||||
"""
|
||||
|
||||
accessed_properties: Set[str]
|
||||
accessed_indices: Set[Union[int, float]]
|
||||
is_iterated: bool
|
||||
is_conditional: bool
|
||||
iteration_source: Optional[str]
|
||||
iteration_target: Optional[Union[str, list[str]]]
|
||||
|
||||
|
||||
class JinjaTemplateAnalyzer:
|
||||
"""
|
||||
Analyzes Jinja templates to extract information about variable usage,
|
||||
including accessed properties, iteration, and conditional references.
|
||||
|
||||
Attributes:
|
||||
env (jinja2.Environment): The Jinja2 environment used for parsing templates.
|
||||
property_access (Dict[str, Set[str]]): Tracks accessed properties for variables.
|
||||
iteration_targets (Dict[str, str]): Maps iteration target variables to their sources.
|
||||
|
||||
Methods:
|
||||
get_template_variables(template: str) -> Dict[str, Set[str]]:
|
||||
Parse a Jinja template and return a mapping of variables to their accessed properties.
|
||||
|
||||
analyze_template(template: str) -> Dict[str, JinjaTemplateAnalysis]:
|
||||
Perform a detailed analysis of the template, including variable usage,
|
||||
iteration, and conditional references.
|
||||
|
||||
Private Methods:
|
||||
_visit_node(node) -> None:
|
||||
Recursively visit AST nodes to detect attribute access and iteration targets.
|
||||
|
||||
_get_base_name(node) -> Optional[str]:
|
||||
Extract the base variable name from a node.
|
||||
|
||||
_get_target_name(node) -> Optional[Union[str, list[str]]]:
|
||||
Extract the target name(s) from a `For` node.
|
||||
"""
|
||||
|
||||
def __init__(self, template: str):
|
||||
self.env: Environment = Environment(autoescape=True)
|
||||
self.property_access: Dict[str, Set[str]] = {}
|
||||
self.iteration_targets: Dict[str, Union[str, list[str]]] = {}
|
||||
self.index_access: Dict[str, Set[Union[int, float]]] = {}
|
||||
self.ast: nodes.Node = self.env.parse(template)
|
||||
self.template: str = template
|
||||
self.variable_assignments: Dict[str, str] = {}
|
||||
|
||||
def _visit_node(self, node) -> None:
|
||||
"""Recursively visit AST nodes to find attribute access."""
|
||||
# Handle attribute access (dot notation)
|
||||
if isinstance(node, nodes.Getattr):
|
||||
base_name = self._get_base_name(node.node)
|
||||
if base_name:
|
||||
self.property_access.setdefault(base_name, set()).add(node.attr)
|
||||
|
||||
# Handle dictionary access (subscript notation)
|
||||
elif isinstance(node, nodes.Getitem):
|
||||
base_name = self._get_base_name(node.node)
|
||||
if base_name and isinstance(node.arg, nodes.Const):
|
||||
value = node.arg.value
|
||||
if isinstance(value, (int, float)):
|
||||
self.index_access.setdefault(base_name, set()).add(value)
|
||||
else:
|
||||
self.property_access.setdefault(base_name, set()).add(value)
|
||||
|
||||
elif isinstance(node, nodes.Test) and node.name == "defined":
|
||||
base_name = self._get_base_name(node.node)
|
||||
if base_name:
|
||||
if isinstance(node.node, nodes.Getattr):
|
||||
self.property_access.setdefault(base_name, set()).add(
|
||||
node.node.attr
|
||||
)
|
||||
|
||||
# Handle loop variables
|
||||
elif isinstance(node, nodes.For):
|
||||
iter_name = self._get_base_name(node.iter)
|
||||
target_name = self._get_target_name(node.target)
|
||||
if iter_name and target_name:
|
||||
self.iteration_targets[target_name] = iter_name
|
||||
self.property_access.setdefault(iter_name, set())
|
||||
|
||||
elif isinstance(node, nodes.Assign):
|
||||
target_name = self._get_target_name(node.target)
|
||||
source_name = self._get_base_name(node.node)
|
||||
if target_name and source_name:
|
||||
self.variable_assignments[target_name] = source_name
|
||||
|
||||
elif isinstance(node, nodes.Filter):
|
||||
if node.name == "selectattr":
|
||||
target = self._get_base_name(node.node)
|
||||
if target:
|
||||
self.variable_assignments[f"filtered_{target}"] = target
|
||||
|
||||
for child in node.iter_child_nodes():
|
||||
self._visit_node(child)
|
||||
|
||||
def _get_target_name(self, node) -> Optional[str]:
|
||||
"""Get the target variable name from a For node.
|
||||
|
||||
Args:
|
||||
node: A Jinja AST node representing either a Name or Tuple node
|
||||
|
||||
Returns:
|
||||
- str: For simple variable targets (e.g., "item" in "for item in items")
|
||||
- None: If the node type is not recognized or is a tuple
|
||||
"""
|
||||
if isinstance(node, nodes.Name):
|
||||
return node.name
|
||||
return None
|
||||
|
||||
def _get_target_names(self, node) -> list[str]:
|
||||
"""Get all target variable names from a For node, including tuple unpacking.
|
||||
|
||||
Args:
|
||||
node: A Jinja AST node representing either a Name or Tuple node
|
||||
|
||||
Returns:
|
||||
List of target variable names
|
||||
"""
|
||||
if isinstance(node, nodes.Name):
|
||||
return [node.name]
|
||||
|
||||
if isinstance(node, nodes.Tuple):
|
||||
names = []
|
||||
for n in node.items:
|
||||
if isinstance(n, nodes.Name):
|
||||
names.append(n.name)
|
||||
return names
|
||||
|
||||
return []
|
||||
|
||||
def _get_base_name(self, node) -> Optional[str]:
|
||||
"""Get the base variable name from a node."""
|
||||
if isinstance(node, nodes.Name):
|
||||
return node.name
|
||||
|
||||
if isinstance(node, nodes.Getattr):
|
||||
return self._get_base_name(node.node)
|
||||
|
||||
if isinstance(node, nodes.Getitem):
|
||||
return self._get_base_name(node.node)
|
||||
|
||||
return None
|
||||
|
||||
def get_template_variables(self) -> Dict[str, Set[str]]:
|
||||
"""
|
||||
Parse a Jinja template and return both variables and their accessed properties.
|
||||
|
||||
Args:
|
||||
template (str): The Jinja template string
|
||||
|
||||
Returns:
|
||||
Dict[str, Set[str]]: Dictionary mapping variable names to sets of accessed properties
|
||||
"""
|
||||
# Parse the template
|
||||
ast = self.env.parse(self.template)
|
||||
|
||||
# Get all undeclared variables
|
||||
variables = meta.find_undeclared_variables(ast)
|
||||
|
||||
# Reset property access tracking
|
||||
self.property_access = {}
|
||||
|
||||
# Visit all nodes to find property access
|
||||
self._visit_node(ast)
|
||||
|
||||
# Create result dictionary
|
||||
result: Dict[str, Set[str]] = {var: set() for var in variables}
|
||||
# Merge in any discovered sub-properties
|
||||
for var, props in self.property_access.items():
|
||||
if var not in result:
|
||||
result[var] = set()
|
||||
result[var].update(props)
|
||||
|
||||
return result
|
||||
|
||||
def analyze_template(self) -> Dict[str, JinjaTemplateAnalysis]:
|
||||
"""
|
||||
Provide a detailed analysis of template variables and their usage.
|
||||
"""
|
||||
variables = self.get_template_variables()
|
||||
self.iteration_targets = {}
|
||||
|
||||
analysis: Dict[str, JinjaTemplateAnalysis] = {
|
||||
var: JinjaTemplateAnalysis(
|
||||
accessed_properties=props,
|
||||
accessed_indices=set(),
|
||||
is_iterated=False,
|
||||
is_conditional=False,
|
||||
iteration_source=None,
|
||||
iteration_target=None,
|
||||
)
|
||||
for var, props in variables.items()
|
||||
}
|
||||
|
||||
for var, indices in self.index_access.items():
|
||||
if var in analysis:
|
||||
analysis[var]["accessed_indices"] = indices
|
||||
|
||||
def visit_node(node):
|
||||
if isinstance(node, nodes.If):
|
||||
|
||||
def find_test_vars(test_node):
|
||||
if isinstance(test_node, nodes.Name):
|
||||
if test_node.name in analysis:
|
||||
analysis[test_node.name]["is_conditional"] = True
|
||||
for child in test_node.iter_child_nodes():
|
||||
find_test_vars(child)
|
||||
|
||||
find_test_vars(node.test)
|
||||
|
||||
if isinstance(node, nodes.For):
|
||||
iter_target = self._get_base_name(node.iter)
|
||||
target_name = self._get_target_name(node.target)
|
||||
if iter_target in analysis:
|
||||
analysis[iter_target]["is_iterated"] = True
|
||||
if target_name:
|
||||
analysis[iter_target]["iteration_target"] = target_name
|
||||
if isinstance(target_name, str) and target_name not in analysis:
|
||||
analysis[target_name] = {
|
||||
"accessed_properties": set(),
|
||||
"is_iterated": False,
|
||||
"is_conditional": False,
|
||||
"iteration_source": iter_target,
|
||||
"iteration_target": None,
|
||||
}
|
||||
|
||||
for child in node.iter_child_nodes():
|
||||
visit_node(child)
|
||||
|
||||
visit_node(self.ast)
|
||||
return analysis
|
||||
|
||||
def get_downstream_properties(self, start_var: str) -> Dict[str, Set[str]]:
|
||||
"""
|
||||
Get all properties accessed on a variable and its downstream assignments.
|
||||
|
||||
Args:
|
||||
start_var: The starting variable to trace
|
||||
|
||||
Returns:
|
||||
Dict mapping variable names to their accessed properties
|
||||
"""
|
||||
visited = set()
|
||||
properties = {}
|
||||
|
||||
def trace_variable(var_name: str):
|
||||
if var_name in visited:
|
||||
return
|
||||
visited.add(var_name)
|
||||
|
||||
# Get direct properties
|
||||
if var_name in self.property_access:
|
||||
properties[var_name] = self.property_access[var_name]
|
||||
|
||||
# Get properties from iteration targets
|
||||
if var_name in self.iteration_targets:
|
||||
target = self.iteration_targets[var_name]
|
||||
if isinstance(target, str):
|
||||
trace_variable(target)
|
||||
elif isinstance(target, list):
|
||||
for t in target:
|
||||
trace_variable(t)
|
||||
|
||||
# Follow assignments
|
||||
for target, source in self.variable_assignments.items():
|
||||
if source == var_name:
|
||||
trace_variable(target)
|
||||
|
||||
# Check for array slicing
|
||||
analysis = self.analyze_template()
|
||||
if var_name in analysis:
|
||||
var_info = analysis[var_name]
|
||||
if var_info["accessed_indices"]:
|
||||
# If this variable is sliced, follow the resulting assignment
|
||||
slice_result = f"{var_name}_slice"
|
||||
if slice_result in self.property_access:
|
||||
trace_variable(slice_result)
|
||||
|
||||
trace_variable(start_var)
|
||||
return properties
|
||||
|
||||
def get_message_vars(self, field_messages: str = "messages") -> Set[str]:
|
||||
"""
|
||||
Get all properties accessed on messages and derived variables.
|
||||
"""
|
||||
all_properties = self.get_downstream_properties(field_messages)
|
||||
|
||||
# Combine all properties from all related variables
|
||||
combined_properties = set()
|
||||
for properties in all_properties.values():
|
||||
combined_properties.update(properties)
|
||||
|
||||
# Also include properties from the message iteration variable
|
||||
analysis = self.analyze_template()
|
||||
if "message" in analysis:
|
||||
combined_properties.update(analysis["message"]["accessed_properties"])
|
||||
|
||||
return combined_properties
|
||||
@@ -51,8 +51,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ds_cfg = ds_cfg or {}
|
||||
|
||||
field_messages = ds_cfg.get("field_messages")
|
||||
message_field_role = ds_cfg.get("message_field_role")
|
||||
message_field_content = ds_cfg.get("message_field_content")
|
||||
message_property_mappings = ds_cfg.get("message_property_mappings")
|
||||
message_field_role = (
|
||||
message_property_mappings.get("role") if message_property_mappings else None
|
||||
)
|
||||
message_field_content = (
|
||||
message_property_mappings.get("content") if message_property_mappings else None
|
||||
)
|
||||
message_field_training = ds_cfg.get("message_field_training")
|
||||
|
||||
builder_kwargs = {}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user