Compare commits

..

34 Commits

Author SHA1 Message Date
Wing Lian
c9d842ef2e test not deleting pythonpath for custom code bundling 2025-02-06 13:40:30 -05:00
Wing Lian
ecea44c902 fix num_processes in passing to accelerate 2025-02-06 13:39:46 -05:00
Wing Lian
4f9c57e95d check for src axolotl in PYTHONPATH before removing it 2025-02-06 13:26:23 -05:00
Wing Lian
3d38bc82b8 include vllm in build 2025-02-06 11:09:42 -05:00
Wing Lian
756a8332d6 set default on trl config 2025-02-05 22:17:10 -05:00
Wing Lian
aded9c500d refactor cfg.grpo_* to use cfg.trl.* 2025-02-05 20:41:14 -05:00
Wing Lian
3659d812f7 use cfg.max_completion_length, not sequence_len 2025-02-05 13:20:17 -05:00
Salman Mohammadi
bdb0f97082 adding 'reward_processing_classes' 2025-02-05 18:18:42 +00:00
Salman Mohammadi
65b6519447 adding 'reward_processing_classes' 2025-02-05 18:13:05 +00:00
Wing Lian
a1958b09de seperately include max_completion_len 2025-02-05 13:01:52 -05:00
Salman Mohammadi
b8f258817e adding reward fn verification 2025-02-05 13:30:02 +00:00
Wing Lian
753146b458 max_length moved to reward config 2025-02-04 11:06:26 -05:00
Wing Lian
d683c50113 fix config cls 2025-02-04 11:06:26 -05:00
Wing Lian
234cd8311e fix failure case in prompter loading 2025-02-04 11:06:26 -05:00
Wing Lian
f9893e3842 fix dpo config and add use_logits_to_keep 2025-02-04 11:06:26 -05:00
Wing Lian
ac1ebc58a8 add support for num_generations 2025-02-04 11:06:25 -05:00
Wing Lian
56f3b9f20f bump pydantic to support vllm 2025-02-04 11:06:25 -05:00
Wing Lian
2c1376d8c4 don't shrink embeddings unless told to 2025-02-04 11:06:25 -05:00
Wing Lian
3c7517fd55 add support for passing map kwargs to dataset map in rl 2025-02-04 11:06:25 -05:00
Wing Lian
1e94d7ef65 more fixes to get grpo working 2025-02-04 11:06:25 -05:00
Wing Lian
cfc7fe0df2 remove ununsable args kwargs 2025-02-04 11:06:25 -05:00
Wing Lian
3c4fe478cf be nice with self.cfg.dataset_processes 2025-02-04 11:06:25 -05:00
Wing Lian
c810599c66 order matters 2025-02-04 11:06:24 -05:00
Wing Lian
300ffc2cb6 make it a dataclass 2025-02-04 11:06:24 -05:00
Wing Lian
b1c4711145 load the class from strat 2025-02-04 11:06:24 -05:00
Wing Lian
d155849e2c use correct builder 2025-02-04 11:06:24 -05:00
Wing Lian
626db6cb84 collator for grpo and prompt loader 2025-02-04 11:06:24 -05:00
Wing Lian
79159b4871 support custom module prompt strategy for rl 2025-02-04 11:06:24 -05:00
Wing Lian
704ddd6ff1 honor skip prepare for rl 2025-02-04 11:06:24 -05:00
Wing Lian
54b0d3d0e8 passthrough dataset parser for dpo/grpo 2025-02-04 11:06:23 -05:00
Wing Lian
59ad21f2de refactor a bit for better grpo support 2025-02-04 11:06:23 -05:00
Wing Lian
57264b6491 respect dotenv for cli 2025-02-04 11:06:23 -05:00
Wing Lian
d495e41ba1 refactor dpo trainer into own module 2025-02-04 11:06:23 -05:00
Wing Lian
6067fe6c28 upgrade trl to 0.14.0 2025-02-04 11:06:23 -05:00
116 changed files with 501 additions and 17104 deletions

View File

@@ -22,6 +22,12 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.10"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
@@ -34,12 +40,6 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -19,7 +19,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
python-version: '3.10'
- name: install dependencies
run: |
python3 -m pip install jupyter

View File

@@ -19,6 +19,6 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1

View File

@@ -24,13 +24,8 @@ jobs:
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras: vllm
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -24,21 +24,13 @@ jobs:
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras: # no vllm support for 2.4.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
# awaiting vllm#12721
axolotl_extras:
num_gpus: 2
nightly_build: "true"
@@ -50,7 +42,7 @@ jobs:
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip

View File

@@ -22,11 +22,6 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -36,7 +36,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
python-version: "3.10"
- name: Install dependencies
run: |

View File

@@ -12,7 +12,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1
env:
@@ -25,8 +25,13 @@ jobs:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
python_version: ["3.10", "3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
- python_version: "3.10"
pytorch_version: "2.5.1"
timeout-minutes: 20
steps:
@@ -107,20 +112,13 @@ jobs:
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip

View File

@@ -35,7 +35,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1
env:
@@ -48,8 +48,13 @@ jobs:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
python_version: ["3.10", "3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
- python_version: "3.10"
pytorch_version: "2.5.1"
timeout-minutes: 20
steps:
@@ -122,7 +127,7 @@ jobs:
max-parallel: 1
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
pytorch_version: ["2.4.1", "2.5.1"]
timeout-minutes: 20
steps:
@@ -204,14 +209,14 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
num_gpus: 1
axolotl_extras: vllm
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip
@@ -246,19 +251,13 @@ jobs:
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip

View File

@@ -51,7 +51,7 @@ Features:
**Requirements**:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- Python 3.10
- PyTorch ≥2.4.1
### Installation

View File

@@ -32,9 +32,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -4,8 +4,8 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -1,4 +1,6 @@
"""Modal app to run axolotl GPU tests"""
"""
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os

View File

@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,vllm] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -46,10 +46,6 @@ overrides_of_model_config:
type: # linear | dynamic
factor: # float
# optional overrides the base model loading from_pretrained
overrides_of_model_kwargs:
# use_cache: False
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
@@ -91,12 +87,7 @@ datasets:
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
data_files: # Optional[str] path to source data files
shards: # Optional[int] split dataset into N pieces (use with shards_idx)
shards_idx: # Optional[int] = 0 the index of sharded dataset to use
preprocess_shards: # Optional[int] process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)
shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
@@ -142,19 +133,10 @@ datasets:
# Key containing the messages (default: "messages")
field_messages: messages
# Mapping of properties from the input dataset to the chat template.
# (default: message_property_mappings={'role':'role', 'content':'content'})
# If a property exists in the template but not in this mapping, the system will attempt
# to load it directly from the message using the property name as the key.
# Example: In the mapping below, 'from' is loaded from input dataset and used as 'role',
# while 'value' is loaded and used as 'content' in the chat template.
message_property_mappings:
role: from
content: value
# ...
message_property_mappings:
# Key for role in each message (default: "role")
message_field_role: role
# Key for content in each message (default: "content")
message_field_content: content
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:
@@ -314,13 +296,6 @@ lora_modules_to_save:
lora_fan_in_fan_out: false
# Apply custom LoRA autograd functions and activation function Triton kernels for
# speed and memory savings
# See: https://axolotl-ai-cloud.github.io/axolotl/docs/lora_optims.html
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
# LoRA+ hyperparameters
# For more details about the following options, see:
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
@@ -369,9 +344,6 @@ comet_mode: # Create a new experiment ("create") or log to an existing one ("get
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
# Tensorboard
use_tensorboard: # Optional[bool]
# Where to save the full-finetuned model to
output_dir: ./completed-model
@@ -406,9 +378,6 @@ save_total_limit: # Checkpoints saved at a time
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
max_steps:
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
include_tokens_per_second:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]

View File

@@ -6,7 +6,7 @@ order: 3
## sharegpt
IMPORTANT: ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
## pygmalion
@@ -22,7 +22,7 @@ Chat Template strategy uses a jinja2 template that converts a list of messages i
{"conversations": [{"role": "...", "content": "..."}]}
```
See [configs](../config.qmd) for full configs and supported templates.
See `config.qmd` for full configs and supported templates.
### Migrating from sharegpt
@@ -42,9 +42,8 @@ datasets:
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
message_field_role: from
message_field_content: value
# new (if setting a new chat_template like chatml, gemma, etc)
chat_template: chatml
@@ -53,9 +52,8 @@ datasets:
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
message_field_role: from
message_field_content: value
```
We recommend checking the below examples for other usecases.
@@ -140,9 +138,8 @@ datasets:
type: chat_template
chat_template: tokenizer_default
field_messages: conversations
message_property_mappings:
role: from
content: value
message_field_role: from
message_field_content: value
roles_to_train: []
train_on_eos: turn
message_field_training: train

View File

@@ -1,458 +1,14 @@
---
title: Dataset Formats
description: Guide to Dataset Formats in Axolotl
back-to-top-navigation: true
toc: true
toc-depth: 5
description: Supported dataset formats.
listing:
fields: [title, description]
type: table
sort-ui: false
filter-ui: false
max-description-length: 250
---
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL format. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
Axolotl is a training framework that aims to make the process convenient yet flexible to users by simply passing a config yaml file.
As there are a lot of available options in Axolotl, this guide aims to provide an simplify the user experience to choosing the proper choice.
Axolotl supports 3 kinds of training methods: pre-training, supervised fine-tuning, and preference-based post-training (e.g. DPO, ORPO, PRMs). Each method has their own dataset format which are described below.
## [Pre-training](pretraining.qmd)
When aiming to train on large corpora of text datasets, pre-training is your go-to choice. Due to the size of these datasets, downloading the entire-datasets before beginning training would be prohibitively time-consuming. Axolotl supports [streaming](https://huggingface.co/docs/datasets/en/stream) to only load batches into memory at a time.
A sample format for a pre-training dataset is as follows:
```json
{"text": "first row"}
{"text": "second row"}
...
```
It is typically recommended to save your dataset as `.jsonl` due to its flexibility and simplicity.
Axolotl supports loading from a Hugging Face hub repo or from local files.
::: {.callout-important}
For pre-training only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts.
:::
### Pre-training from Hugging Face hub datasets
As an example, to train using a Hugging Face dataset `hf_org/name`, you can pass the following config:
```yaml
pretraining_dataset: hf_org/name
```
### Pre-training from local dataset files
Given a few corpus files: `A.jsonl`, `B.jsonl`, and `C.jsonl`, your config will look like the below:
```yaml
pretraining_dataset:
- path: json
data_files:
- A.jsonl
- B.jsonl
- C.jsonl
```
While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet`, `arrow`, `SQL`, `Webdataset`) that are supported by [`Dataset.load_dataset`](https://huggingface.co/docs/datasets/loading#local-and-remote-files)
### Pre-training without streaming
On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.
From Hugging Face:
```yaml
datasets:
- path: hf_org/name
type: completion
```
From local files (either example works):
```yaml
datasets:
- path: A.jsonl
type: completion
- path: json
data_files: ["A.jsonl", "B.jsonl", "C.jsonl"]
type: completion
```
### Pre-training dataset configuration tips
#### Setting max_steps
When using streaming for large datasets, Axolotl does not know in advance how large the dataset is and does not know when to stop.
Therefore, it is necessary to set `max_steps: int` in your config for pre-training to run, so that Axolotl knows when to stop training.
One step is equal to `sequence_len * micro_batch_size * gradient_accumulation_steps * total_num_gpus` tokens.
#### Group_by_length
It is recommended to leave this off if downloading from Hugging Face hub as it would download the entire dataset which can be very large.
## Supervised fine-tuning (SFT)
Supervised fine-tuning is the process of training models to respond to an instruction or chat input.
As there are a wide variety of dataset formats, Axolotl tries to support a majority of the formats available in public datasets.
Axolotl provides four approaches for loading datasets, however, it's easier to work backwards from the dataset you have available to figure out which approach to use.
A flow chart is as follows:
1. Do you already have the dataset tokenized? If yes, check [Pre-Tokenized Dataset](#pre-tokenized-dataset).
2. Do you want to format the dataset yourself and manually choose each section to mask? If yes, check [Template Free Dataset](#template-free-dataset)
3. Is your dataset in a "conversation" format, containing a `list[messages]`? If yes, check [Conversation Dataset](#conversation-dataset)
4. Is your dataset in an "instruct" format, containing `{ instruction, response }`? If yes, check [Instruction Dataset](#instruction-dataset)
If you went through the flow chart and did not find one that matches, it is recommended to preprocess your dataset into one of the above or create a thread on Github Discussion.
::: {.callout-tip}
You can mix and match within each approach or across approaches to train a model on a variety of datasets.
:::
### [Pre-Tokenized Dataset](tokenized.qmd)
We suggest this approach when you want to bring your own tokenized dataset.
Axolotl expects the dataset to have three keys:
- `input_ids`: from tokenizing formatted prompt
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.
::: {.callout-tip}
Make sure to add BOS/EOS tokens to your prompt and mask it appropriately.
:::
A config for this would look like:
```yaml
datasets:
- path: A.jsonl
type:
```
::: {.callout-note}
`type: ` is empty!
:::
### [Template Free Dataset](template_free.qmd)
We reccomend this approach when you want granular control over the prompt formatting, special tokens, and masking, whilst letting Axolotl handle the tokenization. This is very useful if your dataset has unique prompts that differ across samples and where one single general template wouldn't suffice.
In the example below, you could see that there is no proper structure. At the same time, it's very flexible as there are no constraints on how your prompt can look.
```json
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
```
Each prompt must be have a key called `segments` which is a list of `{ text, label }`.
```yaml
datasets:
- path: A.jsonl
type: input_output
```
### [Conversation Dataset](conversation.qmd)
`conversation` messages are a list of messages which usually contain a `role` and `content` key.
::: {.callout-tip}
Fun fact: Axolotl synonymously refers to "chat" messages as `conversation` messages due to how FastChat initially used this term to build a widely used [fastchat conversation](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) method for formatting chat messages prior to the creation of `chat_templates`.
:::
#### What are `chat_templates`?
The current most popular and convenient method for inference is to use `chat_templates` for formatting prompts. Axolotl supports using `chat_templates` for training to ensure that the model performs in the same environment as in inference.
Here's a quick rundown on `chat_template`: A `chat_template` is a Jinja2 template which formats a list of messages into a prompt.
An example of a prompt formatted into a popular template called ChatML can be seen below:
Single prompt (pretty-printed):
```json
{
"messages": [
{
"role": "user",
"content": "Hi"
},
{
"role": "assistant",
"content": "How can I help you?"
},
{
"role": "user",
"content": "Can you add 3+5?"
},
{
"role": "assistant",
"content": "The answer is 8."
}
]
}
```
The ChatML template is as follows:
```jinja2
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
```
The above prompt formatted into this template will result in:
```
<|im_start|>user
Hi<|im_end|>
<|im_start|>assistant
How can I help you?<|im_end|>
<|im_start|>user
Can you add 3+5?<|im_end|>
<|im_start|>assistant
The answer is 8.<|im_end|>
```
By using delimiters (`<|im_start|>` and `<|im_end|>`), a prompt separates different speakers which helps the model identify which portion belongs to whom.
#### Common Conversation Dataset formats
Older conversation datasets with the following format are colloquially called `sharegpt` datasets.
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
Newer conversation datasets usually follow the OpenAI format.
```json
{"messages": [{"role": "...", "content": "..."}]}
```
Axolotl supports both as well as allowing customization of any kind of key.
#### [Chat Template Usage](conversation.qmd#chat_template)
To properly use this method, it is important to identify three things:
1. Which `chat_template` would you use?
2. What are the keys in your dataset, and what are the possible roles? For example, in OpenAI format, the keys would be `messages`, `role`, and `content`, respectively, whereas the possible roles are `system`, `user`, and `assistant`.
3. What do you want to mask? For instance, only assistant messages, only last message, or nothing.
##### Choosing a `chat_template`
There are a lot of `chat_templates` out there. Axolotl supports the common ones: [supported chat templates](https://github.com/axolotl-ai-cloud/axolotl/blob/860609392184cf62a7e0ca676658b170e059ce6c/src/axolotl/utils/chat_templates.py#L17). For example, to use ChatML, it would be `chat_template: chatml`.
However, it is also possible to use the already configured template within the tokenizer by specifying `chat_template: tokenizer_default`. If you want a fallback (in case some tokenizer does not have it pre-configured), you can do `chat_template: tokenizer_default_fallback_chatml` to fallback to the ChatML template if a tokenizer template was not found.
One last but powerful approach is to bring your own template. This can be set via:
```yaml
chat_template_jinja: # your template
```
##### Setting `chat_template` dataset keys
We currently default to OpenAI format for dataset keys, so if that's your current dataset format, there's nothing to do here.
If your dataset format is different, here are the keys you should check (with their defaults):
```yaml
datasets:
...
field_messages: messages # this should point to the key containing the list of conversations
message_property_mappings: # this is a mapping from keys in your dataset to keys in chat_template
role: role
content: content
```
In some `chat_templates` (e.g. [Gemma](https://huggingface.co/google/gemma-2b-it/blob/main/tokenizer_config.json#L1507)), the roles are hardcoded to `user` and `assistant`. Consequently, you may find it necessary to map the roles in your dataset to these above. We currently have some defaults that should work for common datasets, but if you get a `KeyError`, it would be necessary to add mapping for your roles. Here is an example of how it would look like:
```yaml
datasets:
...
roles:
assistant:
- gpt
- model
user:
- human
```
In the example above, all `gpt` and `model` values are converted to `assistant`. All `human` values are converted to `user.`
##### Handling masking
The common use case for `chat_template` is for chat messages, therefore, it is common to mask all non-assistant messages. Assistant messages refer to the bot messages that you want the model to learn on.
To train on all `assistant` messages, you would set the following configs.
```yaml
datasets:
...
roles_to_train: ["assistant"]
train_on_eos: "turn"
```
The `train_on_eos` config means that it would mask all EOS tokens for turns that aren't assistant-turns. The other options are: `all` and `last` to choose which EOS to train on.
Perhaps, you want to train on `assistant` and `narrator` roles, you can simply add `narrator` to the list of `roles_to_train`. You would also need to add it to the mapping of `roles` above.
```yaml
datasets:
...
roles_to_train: ["assistant", "narrator"]
roles:
assistant:
- gpt
- model
user:
- human
narrator: ["narrator"]
```
#### Applying `chat_template`
Once all the above steps are completed, you could combine all these configs together to form a bespoke configuration for your custom dataset. The final step would be to correctly set the EOS token in your config:
```yaml
datasets:
- path: A.jsonl
type: chat_template
# step 1
chat_template: chatml
# step 2
field_messages: messages
message_property_mappings:
role: role
content: content
roles:
assistant:
- gpt
- model
- assistant
user:
- human
- user
# step 3
roles_to_train: ["assistant"]
train_on_eos: "turn"
special_tokens:
eos_token: <|im_end|>
```
If this config were to be applied to the sample dataset above, the output would look as such (which can be retrieved via `axolotl preprocess config.yaml --debug`):
```
<|im_start|>(-100, 128256) user(-100, 882)
(-100, 198) Hi(-100, 13347) <|im_end|>(-100, 128257)
(-100, 198) <|im_start|>(-100, 128256) assistant(-100, 78191)
(-100, 198) How(4438, 4438) can(649, 649) I(358, 358) help(1520, 1520) you(499, 499) ?(30, 30) <|im_end|>(128257, 128257)
(-100, 198) <|im_start|>(-100, 128256) user(-100, 882)
(-100, 198) Can(-100, 6854) you(-100, 499) add(-100, 923) (-100, 220) 3(-100, 18) +(-100, 10) 5(-100, 20) ?(-100, 30) <|im_end|>(-100, 128257)
(-100, 198) <|im_start|>(-100, 128256) assistant(-100, 78191)
(-100, 198) The(791, 791) answer(4320, 4320) is(374, 374) (220, 220) 8(23, 23) .(13, 13) <|im_end|>(128257, 128257)
(-100, 198)
```
The first number refers to the label, the second refers to the `token_id`. For example, `-100` labels appear on non-assistant portions, meaning that they are masked during. For assistant portions, the label is the same as the `token_id`.
### [Instruction Dataset](inst_tune.qmd)
Instruction datasets are used to train instruction-following models and comprise a prompt, containing an instruction, and a single response. In contrast to chat datasets which may be multi-turn, instruct datasets are typically single-turn.
An example is of a common format called Alpaca:
```json
{"instruction": "...", "input": "...", "output": "..."}
```
Using those keys, a prompt can be built based on it.
```
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:
{output}
```
This can be configured as such:
```yaml
datasets:
- path: A.jsonl
type: alpaca
```
Axolotl supports many kinds of instruction dataset. All of them can be found here (https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/inst_tune.html) with their respective type and sample row format.
#### Custom Instruct Prompt Format
Due to the myriad possibilities of instruction formats, Axolotl allows customizing your own instruction format without having to dive into the code directly.
In the example below, a sample row is used to output in `mistral_v1` format.
```json
{"input": "...", "output": "..."}
```
```yaml
datasets:
- path: repo
type:
system_prompt: ""
field_system:
field_instruction: input
field_input:
field_output: output
# multi-line example with input
format: |-
[INST] {instruction} {input} [/INST]
# single-line example without input
no_input_format: "[INST] {instruction} [/INST]"
```
The config sets that the `field_instruction` is actually named `input`, and the `field_input` is empty as we don't have an `input` in this sample. Generally, `instruction` can be thought as the question to the model, and `input` as the additional information with `output` being the response. It is not necessary to have an `input` nor `system`. In the end, the most important part is to understand what format you want it to look like and how you can customize this to your use case.
## Reinforcement Learning from Human Feedback (RLHF)
As there are multiple RLHF methods with their own dataset requirements. Please see [RLHF datasets](../rlhf.qmd) documentation for more detail.
Below are these various formats organized by task:

View File

@@ -19,11 +19,3 @@ description: Frequently asked questions
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
> A: You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
**Q: The codes is stuck on saving preprocessed datasets.**
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
> A: This means that the property mapping for the stated attribute does not exist when building `chat_template` prompt. For example, if `no attribute 'content'`, please check you have added the correct mapping for `content` under `message_property_mappings`.

View File

@@ -1,128 +0,0 @@
---
title: "LoRA Optimizations"
description: "Custom autograd functions and Triton kernels in Axolotl for optimized
LoRA fine-tuning"
---
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
to leverage operator fusion and tensor re-use in order to improve speed and reduce
memory usage during the forward and backward passes of these calculations.
We currently support several common model architectures, including (but not limited to):
- `llama`
- `mistral`
- `qwen2`
- `gemma`
- `gemma2`
<details>
The set of models we support is currently limited by our attention patching strategy,
which assumes (and replaces) specific code blocks for query / key / value and output
projections:
```python
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
ORIGINAL_O_CODE = """
attn_output = self.o_proj(attn_output)
""".lstrip(
"\n"
)
```
Is replaced with:
```python
PATCHED_QKV_CODE = """
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
PATCHED_O_CODE = """
attn_output = self.apply_o(attn_output)
""".lstrip(
"\n"
)
```
Where `apply_qkv` and `apply_o` are defined in the `axolotl.kernels.lora` module.
We welcome testing of other model architectures and / or PRs to expand our patching
logic to be compatible with more of them.
</details>
## Usage
These optimizations can be enabled in your Axolotl config YAML file. The
`lora_mlp_kernel` option enables the optimized MLP path, while `lora_qkv_kernel` and
`lora_o_kernel` enable the fused query-key-value projection and optimized output
projection, respectively.
```yaml
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
```
## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
- Note: Set `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` to enable [memory-efficient attention on AMD GPUs](https://github.com/ROCm/aotriton/issues/16#issuecomment-2346675491)
- Targeted LoRA adapters cannot use Dropout
- This may limit model expressivity / cause overfitting
- Targeted LoRA adapters cannot have bias terms
- This may limit model expressivity
Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to
be re-finetuned without these features in order to be useful.
## Implementation details
### Custom autograd functions
The LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the
LoRA and base weight computations together and provides a single, efficient backward
pass for the entire MLP block.
For attention components, similar optimizations are provided through a function that
handles the query, key, and value projections, and a function that handles the output
projection. They are designed to work with the existing `transformers` attention
implementation via some monkey-patching logic.
### Triton kernels
Two activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for
improved speed and memory performance. These kernels handle both the forward and
backward passes.
### Integration
The custom autograd functions and Triton kernels are designed to work together. The
autograd function manages the high-level computation flow and gradient tracking, while
calling the Triton kernels for the activation function computation. During the backward
pass, the kernel computes both the activation output and the required gradients, which
the autograd function then uses to compute the final gradients for the entire
computation path.
## Future Work
- Support for additional model architectures
- Support for the FSDP setting
- Support for dropout and bias
- Additional operator fusions

View File

@@ -3,18 +3,6 @@ title: Multi Node
description: How to use Axolotl on multiple machines
---
The below are three ways to train multi-node in Axolotl.
::: {.callout-important}
Each machine needs a copy of Axolotl, we suggest using the same commit to ensure compatibility.
You will also need to have the same configuration file for your model on each machine.
Make sure the main machine is reachable by other machines.
:::
# Accelerate
You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below:
~/.cache/huggingface/accelerate/default_config.yaml
@@ -38,7 +26,7 @@ tpu_use_sudo: false
use_cpu: false
```
Configure your model to use FSDP in the Axolotl yaml. For example:
Configure your model to use FSDP with for example:
```yaml
fsdp:
- full_shard
@@ -49,40 +37,12 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
## Machine configuration
On each machine you need a copy of Axolotl, we suggest using the same commit to ensure compatibility.
You will also need to have the same configuration file for your model on each machine.
On the main machine only, make sure the port you set as `main_process_port` is open in TCP and reachable by other machines.
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.
# Raytrain
Please see ray train doc [here](ray-integration.qmd).
# Torchrun
If you are using Infiniband, we recommend torchrun to utilize the full bandwidth.
Set the following env (change buffersize/socketname depending on your system):
```yaml
export NCCL_IB_DISABLE=0
export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond"
export NCCL_BUFFSIZE=2097152
```
Run the following on each node:
```bash
torchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" -m axolotl.cli.train config.yaml
```
Please make sure to substitute the placeholder variables.
- `num_nodes`: Number of nodes (containing GPUs)
- `gpu_per_node`: Number of gpus per node
- `head_node_ip`: IP of the head node (make sure other machines can connect to this)
- `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400)
- `rdzv_id`: A unique job ID that is used by the job across nodes.
::: {.callout-note}
You need to call `axolotl.cli.train` instead of `axolotl train` as the latter calls accelerate under the hood
:::
More info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html)

View File

@@ -1,39 +1,26 @@
---
title: "RLHF (Beta)"
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
back-to-top-navigation: true
toc: true
toc-depth: 3
---
# Overview
### Overview
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
feedback. Various methods include, but not limited to:
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
- [Direct Preference Optimization (DPO)](#dpo)
- [Identity Preference Optimization (IPO)](#ipo)
- [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
- Direct Preference Optimization (DPO)
- Identity Preference Optimization (IPO)
# RLHF using Axolotl
### RLHF using Axolotl
::: {.callout-important}
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
:::
>[!IMPORTANT]
>This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
We rely on the [TRL](https://github.com/huggingface/trl) library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.
::: {.callout-tip}
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
:::
## DPO
Example config:
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
#### DPO
```yaml
rl: dpo
datasets:
@@ -45,265 +32,12 @@ datasets:
type: chatml
```
DPO supports the following types with the following dataset format:
### chatml.argilla
```json
{
"system": "...", // optional
"instruction": "...",
"chosen_response": "...",
"rejected_response": "..."
}
```
### chatml.argilla_chat
```json
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
### chatml.icr
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
### chatml.intel
```json
{
"system": "...", // optional
"question": "...",
"chosen": "...",
"rejected": "..."
}
```
### chatml.prompt_pairs
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}
```
### chatml.ultra
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
### llama3.argilla
```json
{
"system": "...", // optional
"instruction": "...",
"chosen_response": "...",
"rejected_response": "..."
}
```
### llama3.argilla_chat
```json
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
### llama3.icr
```json
{
"system": "...", // optional
"input": "...",
"chosen": "...",
"rejected": "..."
}
```
### llama3.intel
```json
{
"system": "...", // optional
"question": "...",
"chosen": "...",
"rejected": "..."
}
```
### llama3.prompt_pairs
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}
```
### llama3.ultra
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
### zephyr.nectar
```json
{
"prompt": "...",
"answers": [
{
"answer": "...",
"rank": 1
},
{
"answer": "...",
"rank": 2
}
// ... more answers with ranks
]
}
```
### chat_template.default
```yaml
rl: dpo
datasets:
- path: ...
split: train
type: chat_template.default
field_messages: "messages"
field_chosen: "chosen"
field_rejected: "rejected"
message_property_mappings:
role: role
content: content
roles:
user: ["user"]
assistant: ["assistant"]
system: ["system"]
```
Sample input format:
```json
{
"messages": [
{
"role": "system",
"content": "..."
},
{
"role": "user",
"content": "..."
},
// ... more messages
],
"chosen": {
"role": "assistant",
"content": "..."
},
"rejected": {
"role": "assistant",
"content": "..."
}
}
```
### user_defined.default
For custom behaviors,
```yaml
rl: dpo
datasets:
- path: ...
split: train
type: user_defined.default
field_prompt: "prompt"
field_system: "system"
field_chosen: "chosen"
field_rejected: "rejected"
prompt_format: "{prompt}"
chosen_format: "{chosen}"
rejected_format: "{rejected}"
```
The input format is a simple JSON input with customizable fields based on the above config.
```json
{
"system": "...", // optional
"prompt": "...",
"chosen": "...",
"rejected": "..."
}
```
## IPO
As IPO is just DPO with a different loss function, all supported options for DPO works here.
#### IPO
```yaml
rl: ipo
```
## ORPO
#### ORPO
Paper: https://arxiv.org/abs/2403.07691
@@ -318,28 +52,8 @@ datasets:
type: chat_template.argilla
```
ORPO supports the following types with the following dataset format:
### chat_template.argilla
```json
{
"system": "...", // optional
"prompt": "...", // if available, will be taken as user message for single-turn instead of from list below
// chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
## KTO
#### KTO
```yaml
rl: kto
@@ -358,144 +72,7 @@ gradient_checkpointing_kwargs:
use_reentrant: true
```
KTO supports the following types with the following dataset format:
### chatml.argilla
```json
{
"system": "...", // optional
"instruction": "...",
"completion": "..."
}
```
### chatml.argilla_chat
```json
{
"chosen": [
{"role": "user", "content": "..."}
],
"completion": [
{"role": "assistant", "content": "..."}
]
}
```
### chatml.intel
```json
{
"system": "...", // optional
"question": "...",
"completion": "..."
}
```
### chatml.prompt_pairs
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
```
### chatml.ultra
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
```
### llama3.argilla
```json
{
"system": "...", // optional
"instruction": "...",
"completion": "..."
}
```
### llama3.argilla_chat
```json
{
"completion": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
### llama3.intel
```json
{
"system": "...", // optional
"question": "...",
"completion": "..."
}
```
### llama3.prompt_pairs
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
```
### llama3.ultra
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "..."
}
```
### user_defined.default
For custom behaviors,
```yaml
rl: kto
datasets:
- path: ...
split: train
type: user_defined.default
field_prompt: "prompt"
field_system: "system"
field_completion: "completion"
field_label: "label"
prompt_format: "{prompt}"
completion_format: "{completion}"
```
The input format is a simple JSON input with customizable fields based on the above config.
```json
{
"system": "...", // optional
"prompt": "...",
"completion": "...",
"label": "..."
}
```
## Using local dataset files
#### Using local dataset files
```yaml
datasets:
- ds_type: json
@@ -505,9 +82,9 @@ datasets:
type: chatml.intel
```
## TRL auto-unwrapping for PEFT
#### Trl autounwrap for peft
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
```yaml
# load ref model when adapter training.

View File

@@ -21,9 +21,8 @@ datasets:
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0

View File

@@ -16,9 +16,8 @@ datasets:
type: chat_template
drop_system_message: true
field_messages: conversations
message_property_mappings:
role: from
content: value
message_field_role: from
message_field_content: value
val_set_size: 0.0
output_dir: ./outputs/out

View File

@@ -13,9 +13,8 @@ datasets:
type: chat_template
drop_system_message: true
field_messages: conversations
message_property_mappings:
role: from
content: value
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0

View File

@@ -17,9 +17,8 @@ datasets:
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.02

View File

@@ -17,9 +17,8 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_property_mappings:
role: role
content: content
message_field_role: role
message_field_content: content
roles:
system:
- system

View File

@@ -14,9 +14,8 @@ datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_property_mappings:
role: role
content: content
message_field_role: role
message_field_content: content
roles:
user:
- user

View File

@@ -17,9 +17,8 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_property_mappings:
role: role
content: content
message_field_role: role
message_field_content: content
roles:
system:
- system
@@ -32,9 +31,8 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_property_mappings:
role: role
content: content
message_field_role: role
message_field_content: content
roles:
system:
- system

View File

@@ -1,82 +0,0 @@
base_model: NousResearch/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
# Currently, we don't support dropout with our custom Triton kernels
# lora_dropout: 0.05
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
# These options enable our custom Triton kernels / autograd
# functions for MLP and attention calculations
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -22,9 +22,8 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_property_mappings:
role: role
content: content
message_field_role: role
message_field_content: content
dataset_prepared_path:
val_set_size: 0.05

View File

@@ -14,9 +14,8 @@ datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_property_mappings:
role: role
content: content
message_field_role: role
message_field_content: content
roles:
user:
- user

View File

@@ -12,9 +12,8 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_property_mappings:
role: role
content: content
message_field_role: role
message_field_content: content
roles:
system:
- system

View File

@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.2
bitsandbytes==0.45.1
triton>=3.0.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.4.post1
flash-attn==2.7.0.post2
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.2
@@ -13,12 +13,12 @@ liger-kernel==0.5.2
packaging==23.2
peft==0.14.0
transformers==4.48.3
transformers==4.48.2
tokenizers>=0.21.0
accelerate==1.3.0
datasets==3.2.0
deepspeed==0.16.1
trl==0.15.0
trl==0.14.0
optimum==1.16.2
hf_transfer

View File

@@ -31,26 +31,27 @@ def parse_dataset(dataset=None, split="train"):
ds_cfg["field_messages"] = field_messages
message_fields = features[field_messages][0].keys()
message_property_mappings = {"role": None, "content": None}
message_field_role = None
for key in ["from", "role"]:
if key in message_fields:
message_property_mappings["role"] = key
message_field_role = key
break
if not message_property_mappings["role"]:
if not message_field_role:
raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_role"] = message_field_role
message_field_content = None
for key in ["content", "text", "value"]:
if key in message_fields:
message_property_mappings["content"] = key
message_field_content = key
break
if not message_property_mappings["content"]:
if not message_field_content:
raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_property_mappings"] = message_property_mappings
ds_cfg["message_field_content"] = message_field_content
print(yaml.dump({"datasets": [ds_cfg]}))

View File

@@ -71,15 +71,12 @@ def parse_requirements():
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post2")
elif (major, minor) >= (2, 5):
if (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers>=0.0.28.post3")
_install_requires.append("xformers==0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0:
@@ -125,7 +122,7 @@ setup(
},
extras_require={
"flash-attn": [
"flash-attn==2.7.4.post1",
"flash-attn==2.7.0.post2",
],
"deepspeed": [
"deepspeed==0.16.1",
@@ -157,7 +154,7 @@ setup(
"ray[train]",
],
"vllm": [
"vllm==0.7.2",
"vllm>=0.7.1",
],
},
)

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.8.0.dev0"
__version__ = "0.6.0"

View File

@@ -35,18 +35,13 @@ def do_cli_train(
cloud_config: Union[Path, str],
config: Union[Path, str],
accelerate: bool = True,
cwd=None,
**kwargs,
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
local_dirs = {}
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
local_dirs = {"/workspace/mounts": cwd}
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
cloud.train(config_yaml, accelerate=accelerate)
def do_cli_lm_eval(

View File

@@ -7,7 +7,6 @@ import os
import subprocess # nosec B404
from pathlib import Path
from random import randint
from typing import Optional
import modal
@@ -23,18 +22,11 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
# modal workaround so it doesn't use the automounted axolotl
new_env = copy.deepcopy(os.environ)
if "PYTHONPATH" in new_env:
paths = ["/workspace/mounts"]
for sub_python_path_str in new_env["PYTHONPATH"].split(":"):
sub_python_path = Path(sub_python_path_str)
if not sub_python_path.joinpath("src", "axolotl").exists():
# we don't want to use the automounted axolotl or unexpected behavior happens
paths.append(str(sub_python_path))
if paths:
new_env["PYTHONPATH"] = ":".join(paths)
else:
del new_env["PYTHONPATH"]
# if "PYTHONPATH" in new_env:
# python_path = Path(new_env["PYTHONPATH"].split(":")[0])
# if python_path.joinpath("src", "axolotl").exists():
# # we don't want to use the automounted axolotl or unexpected behavior happens
# del new_env["PYTHONPATH"]
# Propagate errors from subprocess.
if exit_code := subprocess.call( # nosec B603
@@ -214,12 +206,9 @@ class ModalCloud(Cloud):
memory = int(self.config.memory)
return 1024 * memory
def get_train_env(self, local_dirs=None):
image = self.get_image()
for mount, local_dir in (local_dirs or {}).items():
image = image.add_local_dir(local_dir, mount)
def get_train_env(self):
return self.app.function(
image=image,
image=self.get_image(),
volumes={k: v[0] for k, v in self.volumes.items()},
cpu=16.0,
gpu=self.get_train_gpu(),
@@ -228,21 +217,14 @@ class ModalCloud(Cloud):
secrets=self.get_secrets(),
)
def train(
self,
config_yaml: str,
accelerate: bool = True,
local_dirs: Optional[dict[str, str]] = None,
**kwargs,
):
modal_fn = self.get_train_env(local_dirs)(_train)
def train(self, config_yaml: str, accelerate: bool = True):
modal_fn = self.get_train_env()(_train)
with modal.enable_output():
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
accelerate=accelerate,
volumes={k: v[0] for k, v in self.volumes.items()},
**kwargs,
)
def lm_eval(self, config_yaml: str):
@@ -273,7 +255,7 @@ def _preprocess(config_yaml: str, volumes=None):
)
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
@@ -283,11 +265,8 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
accelerate_args = "--accelerate"
else:
accelerate_args = "--no-accelerate"
num_processes_args = ""
if num_processes := kwargs.pop("num_processes", None):
num_processes_args = f"--num-processes {num_processes}"
run_cmd(
f"axolotl train {accelerate_args} {num_processes_args} /workspace/artifacts/axolotl/config.yaml",
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
run_folder,
volumes,
)

View File

@@ -2,9 +2,11 @@
# pylint: disable=redefined-outer-name
import logging
import os
import random
import subprocess # nosec B404
import tempfile
from copy import deepcopy
from itertools import product
from pathlib import Path
from typing import Optional
@@ -14,7 +16,6 @@ from dotenv import load_dotenv
import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.cli.sweeps import generate_sweep_configs
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
@@ -27,6 +28,76 @@ from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
def generate_sweep_configs(base_config, sweeps_config):
"""
Recursively generates all possible configurations by applying sweeps to the base config.
Args:
base_config (dict): The original configuration dictionary
sweeps_config (dict): Dictionary where keys are parameters and values are either:
- lists of values to sweep independently
- or for paired values, a list of dicts under the '_' key
Returns:
list: List of all possible configuration dictionaries
Example:
sweeps_config = {
'learning_rate': [0.1, 0.01],
'_': [
{'load_in_8bit': True, 'adapter': 'lora'},
{'load_in_4bit': True, 'adapter': 'qlora'}
]
}
"""
# Separate paired values from regular sweeps
paired_values = sweeps_config.get("_", [])
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
# Process regular sweeps
param_names = list(regular_sweeps.keys())
param_values = list(regular_sweeps.values())
# Generate combinations for regular sweeps
regular_combinations = list(product(*param_values)) if param_values else [()]
# Combine regular sweeps with paired values
all_combinations = []
for reg_combo in regular_combinations:
if paired_values:
for paired_set in paired_values:
new_config = {}
# new_config = deepcopy(base_config)
# Combine regular parameters with paired parameters
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
for param_name, param_value in full_combo.items():
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)
else:
# If no paired values, just use regular combinations
# new_config = deepcopy(base_config)
new_config = {}
for param_name, param_value in zip(param_names, reg_combo):
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)
# randomize the order of trials
random.seed(42)
random.shuffle(all_combinations)
# Generate a new config for each combination
result_configs = []
for combination in all_combinations:
new_config = deepcopy(base_config)
for param_name, param_value in combination.items():
new_config[param_name] = param_value
result_configs.append(new_config)
return result_configs
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli():
@@ -95,6 +166,7 @@ def train(
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
from axolotl.cli.cloud import do_cli_train
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
@@ -128,16 +200,7 @@ def train(
try:
if accelerate:
if cloud:
from axolotl.cli.cloud import do_cli_train
cwd = os.getcwd()
do_cli_train(
cloud_config=cloud,
config=config,
accelerate=True,
cwd=cwd,
**kwargs,
)
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
else:
accelerate_args = []
if "main_process_port" in kwargs:
@@ -158,11 +221,7 @@ def train(
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
from axolotl.cli.cloud import do_cli_train
do_cli_train(
cloud_config=cloud, config=config, accelerate=False, **kwargs
)
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
else:
from axolotl.cli.train import do_cli

View File

@@ -1,77 +0,0 @@
"""Utilities for handling sweeps over configs for axolotl train CLI command"""
import random
from copy import deepcopy
from itertools import product
def generate_sweep_configs(
base_config: dict[str, list], sweeps_config: dict[str, list]
) -> list[dict[str, list]]:
"""
Recursively generates all possible configurations by applying sweeps to the base config.
Args:
base_config (dict): The original configuration dictionary
sweeps_config (dict): Dictionary where keys are parameters and values are either:
- lists of values to sweep independently
- or for paired values, a list of dicts under the '_' key
Returns:
list: List of all possible configuration dictionaries
Example:
sweeps_config = {
'learning_rate': [0.1, 0.01],
'_': [
{'load_in_8bit': True, 'adapter': 'lora'},
{'load_in_4bit': True, 'adapter': 'qlora'}
]
}
"""
# Separate paired values from regular sweeps
paired_values = sweeps_config.get("_", [])
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
# Process regular sweeps
param_names = list(regular_sweeps.keys())
param_values = list(regular_sweeps.values())
# Generate combinations for regular sweeps
regular_combinations = list(product(*param_values)) if param_values else [()]
# Combine regular sweeps with paired values
all_combinations = []
for reg_combo in regular_combinations:
if paired_values:
for paired_set in paired_values:
new_config = {}
# new_config = deepcopy(base_config)
# Combine regular parameters with paired parameters
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
for param_name, param_value in full_combo.items():
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)
else:
# If no paired values, just use regular combinations
# new_config = deepcopy(base_config)
new_config = {}
for param_name, param_value in zip(param_names, reg_combo):
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)
# randomize the order of trials
random.seed(42)
random.shuffle(all_combinations)
# Generate a new config for each combination
result_configs = []
for combination in all_combinations:
new_config = deepcopy(base_config)
for param_name, param_value in combination.items():
new_config[param_name] = param_value
result_configs.append(new_config)
return result_configs

View File

@@ -122,11 +122,9 @@ def load_preference_datasets(
`total_num_steps`.
"""
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps: Optional[int] = int(
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cfg.rl == "grpo":
total_num_steps = None
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")

View File

@@ -330,12 +330,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
training_arguments_kwargs = {}
if self.cfg.include_tokens_per_second is not None:
training_arguments_kwargs[
"include_tokens_per_second"
] = self.cfg.include_tokens_per_second
if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
else:
@@ -648,6 +642,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
tokenizer=self.tokenizer,
)
if self.cfg.rl == "orpo":
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[
"neftune_noise_alpha"
@@ -1035,12 +1032,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
max_steps = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=max_steps,
max_steps=self.cfg.max_steps or total_num_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
warmup_steps=self.cfg.warmup_steps,
@@ -1070,7 +1065,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model]
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = DPOStrategy.get_trainer_class()
@@ -1088,10 +1082,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters.keys():
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
else:
if "processing_class" in sig.parameters.keys():
dpo_trainer_kwargs["processing_class"] = self.tokenizer
else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
if self.cfg.datasets is not None and (
trainer_cls is DPOStrategy.get_trainer_class()

View File

@@ -54,22 +54,16 @@ class GRPOStrategy:
if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
return grpo_args_kwargs
@classmethod
def set_trainer_args(cls, cfg):
trainer_args = []
if cfg.trl and cfg.trl.reward_funcs:
reward_funcs = []
for reward_func_fqn in cfg.trl.reward_funcs:
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
trainer_args.append(reward_funcs)
return trainer_args
@classmethod
def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_funcs:
reward_funcs = []
for reward_func_fqn in cfg.trl.reward_funcs:
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
trainer_kwargs["reward_funcs"] = reward_funcs
if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[
"reward_processing_classes"
@@ -114,6 +108,6 @@ class GRPOStrategy:
except ModuleNotFoundError:
# the user has passed a string (ideally indicating the path of a reward model)
LOG.info(
f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
f"Reward function {reward_func} is a pre-trained model path - if this is unexpected, please check the reward function path."
)
return reward_func

View File

@@ -1,107 +1,14 @@
"""
Axolotl GRPO trainer
"""
from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module
from transformers import PreTrainedModel
from trl import GRPOConfig, GRPOTrainer
from trl.models import unwrap_model_for_generation
from trl import GRPOTrainer
from axolotl.core.trainers.base import SchedulerMixin
# mypy: ignore-errors
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
"""
Extend the base GRPOTrainer for axolotl helpers
"""
_tag_names = ["trl", "grpo", "axolotl"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# pylint: disable=access-member-before-definition
# Enable gradient checkpointing if requested
if kwargs["args"].gradient_checkpointing:
# Ensure use_cache is disabled
if hasattr(self.model, "config"):
self.model.config.use_cache = False
# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(self.model) and hasattr(
self.model.base_model, "gradient_checkpointing_enable"
):
self.model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
elif hasattr(self.model, "gradient_checkpointing_enable"):
self.model.gradient_checkpointing_enable()
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
# pylint: enable=access-member-before-definition
def _enable_gradient_checkpointing(
self, model: PreTrainedModel, args: GRPOConfig
) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# pylint: disable=unused-argument,redefined-builtin
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs
or gradient_checkpointing_kwargs["use_reentrant"]
)
if use_reentrant:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(
make_inputs_require_grad
)
return model
# pylint: enable=unused-argument,redefined-builtin
def _move_model_to_vllm(self):
with unwrap_model_for_generation(
self.model,
self.accelerator,
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
unwrapped_model = (
unwrapped_model._orig_mod # pylint: disable=protected-access
)
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
unwrapped_model.unmerge_adapter()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", ""): v
for k, v in state_dict.items()
}
# Remove values with adapter prefix (example: "_lora")
state_dict = {
k: v
for k, v in state_dict.items()
if unwrapped_model.prefix not in k
}
# When module to save, remove its prefix and discard the original module
state_dict = {
k.replace("modules_to_save.default.", ""): v
for k, v in state_dict.items()
if "original_module" not in k
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
)
llm_model.load_weights(state_dict.items())

View File

@@ -1,590 +0,0 @@
{
"model.layers.0.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.1.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.2.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.3.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.4.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.5.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.6.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.7.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.8.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.9.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.10.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.11.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.12.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.13.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.14.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.15.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"lm_head": {
"snr": Infinity,
"type": "lm_head"
},
"model.layers.0.mlp.down_proj": {
"snr": 70.0594253540039,
"type": "mlp.down_proj"
},
"model.layers.1.mlp.down_proj": {
"snr": 11.135851860046387,
"type": "mlp.down_proj"
},
"model.layers.2.mlp.down_proj": {
"snr": 7.035482883453369,
"type": "mlp.down_proj"
},
"model.layers.3.mlp.down_proj": {
"snr": 6.422532081604004,
"type": "mlp.down_proj"
},
"model.layers.4.mlp.down_proj": {
"snr": 5.748020172119141,
"type": "mlp.down_proj"
},
"model.layers.5.mlp.down_proj": {
"snr": 3.885556697845459,
"type": "mlp.down_proj"
},
"model.layers.6.mlp.down_proj": {
"snr": 3.4336745738983154,
"type": "mlp.down_proj"
},
"model.layers.7.mlp.down_proj": {
"snr": 2.791595935821533,
"type": "mlp.down_proj"
},
"model.layers.8.mlp.down_proj": {
"snr": 5.36277961730957,
"type": "mlp.down_proj"
},
"model.layers.9.mlp.down_proj": {
"snr": 4.459208011627197,
"type": "mlp.down_proj"
},
"model.layers.10.mlp.down_proj": {
"snr": 6.272170066833496,
"type": "mlp.down_proj"
},
"model.layers.11.mlp.down_proj": {
"snr": 5.264761447906494,
"type": "mlp.down_proj"
},
"model.layers.12.mlp.down_proj": {
"snr": 4.324735641479492,
"type": "mlp.down_proj"
},
"model.layers.13.mlp.down_proj": {
"snr": 3.878648042678833,
"type": "mlp.down_proj"
},
"model.layers.14.mlp.down_proj": {
"snr": 2.9773054122924805,
"type": "mlp.down_proj"
},
"model.layers.15.mlp.down_proj": {
"snr": 4.471445560455322,
"type": "mlp.down_proj"
},
"model.layers.0.mlp.gate_proj": {
"snr": 25.227100372314453,
"type": "mlp.gate_proj"
},
"model.layers.1.mlp.gate_proj": {
"snr": 6.58299446105957,
"type": "mlp.gate_proj"
},
"model.layers.2.mlp.gate_proj": {
"snr": 3.4688243865966797,
"type": "mlp.gate_proj"
},
"model.layers.3.mlp.gate_proj": {
"snr": 1.555246114730835,
"type": "mlp.gate_proj"
},
"model.layers.4.mlp.gate_proj": {
"snr": 0.7770601511001587,
"type": "mlp.gate_proj"
},
"model.layers.5.mlp.gate_proj": {
"snr": 0.6239906549453735,
"type": "mlp.gate_proj"
},
"model.layers.6.mlp.gate_proj": {
"snr": 0.6440379023551941,
"type": "mlp.gate_proj"
},
"model.layers.7.mlp.gate_proj": {
"snr": 0.5120116472244263,
"type": "mlp.gate_proj"
},
"model.layers.8.mlp.gate_proj": {
"snr": 0.6544050574302673,
"type": "mlp.gate_proj"
},
"model.layers.9.mlp.gate_proj": {
"snr": 0.5381016731262207,
"type": "mlp.gate_proj"
},
"model.layers.10.mlp.gate_proj": {
"snr": 0.622873842716217,
"type": "mlp.gate_proj"
},
"model.layers.11.mlp.gate_proj": {
"snr": 0.9361700415611267,
"type": "mlp.gate_proj"
},
"model.layers.12.mlp.gate_proj": {
"snr": 1.475605845451355,
"type": "mlp.gate_proj"
},
"model.layers.13.mlp.gate_proj": {
"snr": 1.608325719833374,
"type": "mlp.gate_proj"
},
"model.layers.14.mlp.gate_proj": {
"snr": 1.0720024108886719,
"type": "mlp.gate_proj"
},
"model.layers.15.mlp.gate_proj": {
"snr": 0.7111338973045349,
"type": "mlp.gate_proj"
},
"model.layers.0.mlp.up_proj": {
"snr": 28.431896209716797,
"type": "mlp.up_proj"
},
"model.layers.1.mlp.up_proj": {
"snr": 15.546019554138184,
"type": "mlp.up_proj"
},
"model.layers.2.mlp.up_proj": {
"snr": 23.048023223876953,
"type": "mlp.up_proj"
},
"model.layers.3.mlp.up_proj": {
"snr": 25.790977478027344,
"type": "mlp.up_proj"
},
"model.layers.4.mlp.up_proj": {
"snr": 18.552549362182617,
"type": "mlp.up_proj"
},
"model.layers.5.mlp.up_proj": {
"snr": 8.85106372833252,
"type": "mlp.up_proj"
},
"model.layers.6.mlp.up_proj": {
"snr": 10.653799057006836,
"type": "mlp.up_proj"
},
"model.layers.7.mlp.up_proj": {
"snr": 7.365357875823975,
"type": "mlp.up_proj"
},
"model.layers.8.mlp.up_proj": {
"snr": 11.98373794555664,
"type": "mlp.up_proj"
},
"model.layers.9.mlp.up_proj": {
"snr": 8.04493236541748,
"type": "mlp.up_proj"
},
"model.layers.10.mlp.up_proj": {
"snr": 8.523039817810059,
"type": "mlp.up_proj"
},
"model.layers.11.mlp.up_proj": {
"snr": 5.381742477416992,
"type": "mlp.up_proj"
},
"model.layers.12.mlp.up_proj": {
"snr": 3.9845118522644043,
"type": "mlp.up_proj"
},
"model.layers.13.mlp.up_proj": {
"snr": 3.4893221855163574,
"type": "mlp.up_proj"
},
"model.layers.14.mlp.up_proj": {
"snr": 1.764201045036316,
"type": "mlp.up_proj"
},
"model.layers.15.mlp.up_proj": {
"snr": 0.9730708599090576,
"type": "mlp.up_proj"
},
"model.embed_tokens": {
"snr": Infinity,
"type": "model.embed_tokens"
},
"model.norm": {
"snr": Infinity,
"type": "model.norm"
},
"model.layers.0.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.1.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.2.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.3.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.4.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.5.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.6.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.7.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.8.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.9.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.10.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.11.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.12.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.13.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.14.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.15.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.0.self_attn.k_proj": {
"snr": 0.11727584153413773,
"type": "self_attn.k_proj"
},
"model.layers.1.self_attn.k_proj": {
"snr": 0.24786807596683502,
"type": "self_attn.k_proj"
},
"model.layers.2.self_attn.k_proj": {
"snr": 0.36378130316734314,
"type": "self_attn.k_proj"
},
"model.layers.3.self_attn.k_proj": {
"snr": 0.2983120381832123,
"type": "self_attn.k_proj"
},
"model.layers.4.self_attn.k_proj": {
"snr": 0.33789733052253723,
"type": "self_attn.k_proj"
},
"model.layers.5.self_attn.k_proj": {
"snr": 0.29155924916267395,
"type": "self_attn.k_proj"
},
"model.layers.6.self_attn.k_proj": {
"snr": 0.2537297010421753,
"type": "self_attn.k_proj"
},
"model.layers.7.self_attn.k_proj": {
"snr": 0.28204113245010376,
"type": "self_attn.k_proj"
},
"model.layers.8.self_attn.k_proj": {
"snr": 0.2776711583137512,
"type": "self_attn.k_proj"
},
"model.layers.9.self_attn.k_proj": {
"snr": 0.2927376627922058,
"type": "self_attn.k_proj"
},
"model.layers.10.self_attn.k_proj": {
"snr": 0.31486213207244873,
"type": "self_attn.k_proj"
},
"model.layers.11.self_attn.k_proj": {
"snr": 0.32363659143447876,
"type": "self_attn.k_proj"
},
"model.layers.12.self_attn.k_proj": {
"snr": 0.31382912397384644,
"type": "self_attn.k_proj"
},
"model.layers.13.self_attn.k_proj": {
"snr": 0.4635234773159027,
"type": "self_attn.k_proj"
},
"model.layers.14.self_attn.k_proj": {
"snr": 0.25379249453544617,
"type": "self_attn.k_proj"
},
"model.layers.15.self_attn.k_proj": {
"snr": 0.2628238797187805,
"type": "self_attn.k_proj"
},
"model.layers.0.self_attn.o_proj": {
"snr": 0.27602291107177734,
"type": "self_attn.o_proj"
},
"model.layers.1.self_attn.o_proj": {
"snr": 0.2149604707956314,
"type": "self_attn.o_proj"
},
"model.layers.2.self_attn.o_proj": {
"snr": 0.2540294826030731,
"type": "self_attn.o_proj"
},
"model.layers.3.self_attn.o_proj": {
"snr": 0.27978822588920593,
"type": "self_attn.o_proj"
},
"model.layers.4.self_attn.o_proj": {
"snr": 0.3121289908885956,
"type": "self_attn.o_proj"
},
"model.layers.5.self_attn.o_proj": {
"snr": 0.35037684440612793,
"type": "self_attn.o_proj"
},
"model.layers.6.self_attn.o_proj": {
"snr": 0.366205096244812,
"type": "self_attn.o_proj"
},
"model.layers.7.self_attn.o_proj": {
"snr": 0.3692712187767029,
"type": "self_attn.o_proj"
},
"model.layers.8.self_attn.o_proj": {
"snr": 0.3301038146018982,
"type": "self_attn.o_proj"
},
"model.layers.9.self_attn.o_proj": {
"snr": 0.3003396987915039,
"type": "self_attn.o_proj"
},
"model.layers.10.self_attn.o_proj": {
"snr": 0.30804169178009033,
"type": "self_attn.o_proj"
},
"model.layers.11.self_attn.o_proj": {
"snr": 0.28501132130622864,
"type": "self_attn.o_proj"
},
"model.layers.12.self_attn.o_proj": {
"snr": 0.2171541005373001,
"type": "self_attn.o_proj"
},
"model.layers.13.self_attn.o_proj": {
"snr": 0.19183959066867828,
"type": "self_attn.o_proj"
},
"model.layers.14.self_attn.o_proj": {
"snr": 0.19215913116931915,
"type": "self_attn.o_proj"
},
"model.layers.15.self_attn.o_proj": {
"snr": 0.25486502051353455,
"type": "self_attn.o_proj"
},
"model.layers.0.self_attn.q_proj": {
"snr": 0.03850084915757179,
"type": "self_attn.q_proj"
},
"model.layers.1.self_attn.q_proj": {
"snr": 0.0713055431842804,
"type": "self_attn.q_proj"
},
"model.layers.2.self_attn.q_proj": {
"snr": 0.07948919385671616,
"type": "self_attn.q_proj"
},
"model.layers.3.self_attn.q_proj": {
"snr": 0.08047746121883392,
"type": "self_attn.q_proj"
},
"model.layers.4.self_attn.q_proj": {
"snr": 0.0852593332529068,
"type": "self_attn.q_proj"
},
"model.layers.5.self_attn.q_proj": {
"snr": 0.09794823825359344,
"type": "self_attn.q_proj"
},
"model.layers.6.self_attn.q_proj": {
"snr": 0.09627152234315872,
"type": "self_attn.q_proj"
},
"model.layers.7.self_attn.q_proj": {
"snr": 0.11065381020307541,
"type": "self_attn.q_proj"
},
"model.layers.8.self_attn.q_proj": {
"snr": 0.12031875550746918,
"type": "self_attn.q_proj"
},
"model.layers.9.self_attn.q_proj": {
"snr": 0.09804573655128479,
"type": "self_attn.q_proj"
},
"model.layers.10.self_attn.q_proj": {
"snr": 0.10897502303123474,
"type": "self_attn.q_proj"
},
"model.layers.11.self_attn.q_proj": {
"snr": 0.09267337620258331,
"type": "self_attn.q_proj"
},
"model.layers.12.self_attn.q_proj": {
"snr": 0.08803492039442062,
"type": "self_attn.q_proj"
},
"model.layers.13.self_attn.q_proj": {
"snr": 0.0902542844414711,
"type": "self_attn.q_proj"
},
"model.layers.14.self_attn.q_proj": {
"snr": 0.10154066979885101,
"type": "self_attn.q_proj"
},
"model.layers.15.self_attn.q_proj": {
"snr": 0.09083802253007889,
"type": "self_attn.q_proj"
},
"model.layers.0.self_attn.v_proj": {
"snr": 2.842210054397583,
"type": "self_attn.v_proj"
},
"model.layers.1.self_attn.v_proj": {
"snr": 10.59461498260498,
"type": "self_attn.v_proj"
},
"model.layers.2.self_attn.v_proj": {
"snr": 8.993025779724121,
"type": "self_attn.v_proj"
},
"model.layers.3.self_attn.v_proj": {
"snr": 62.567787170410156,
"type": "self_attn.v_proj"
},
"model.layers.4.self_attn.v_proj": {
"snr": 23.80082893371582,
"type": "self_attn.v_proj"
},
"model.layers.5.self_attn.v_proj": {
"snr": 7.957369804382324,
"type": "self_attn.v_proj"
},
"model.layers.6.self_attn.v_proj": {
"snr": 12.01815414428711,
"type": "self_attn.v_proj"
},
"model.layers.7.self_attn.v_proj": {
"snr": 5.095500469207764,
"type": "self_attn.v_proj"
},
"model.layers.8.self_attn.v_proj": {
"snr": 11.719332695007324,
"type": "self_attn.v_proj"
},
"model.layers.9.self_attn.v_proj": {
"snr": 555.0869750976562,
"type": "self_attn.v_proj"
},
"model.layers.10.self_attn.v_proj": {
"snr": 22.95538330078125,
"type": "self_attn.v_proj"
},
"model.layers.11.self_attn.v_proj": {
"snr": 30.042158126831055,
"type": "self_attn.v_proj"
},
"model.layers.12.self_attn.v_proj": {
"snr": 9.577271461486816,
"type": "self_attn.v_proj"
},
"model.layers.13.self_attn.v_proj": {
"snr": 18.176361083984375,
"type": "self_attn.v_proj"
},
"model.layers.14.self_attn.v_proj": {
"snr": 1.5695856809616089,
"type": "self_attn.v_proj"
},
"model.layers.15.self_attn.v_proj": {
"snr": 2.7235565185546875,
"type": "self_attn.v_proj"
}
}

View File

@@ -1,590 +0,0 @@
{
"model.layers.0.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.1.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.2.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.3.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.4.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.5.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.6.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.7.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.8.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.9.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.10.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.11.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.12.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.13.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.14.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.15.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"lm_head": {
"snr": Infinity,
"type": "lm_head"
},
"model.layers.0.mlp.down_proj": {
"snr": 57.09797286987305,
"type": "mlp.down_proj"
},
"model.layers.1.mlp.down_proj": {
"snr": 9.538983345031738,
"type": "mlp.down_proj"
},
"model.layers.2.mlp.down_proj": {
"snr": 6.227016925811768,
"type": "mlp.down_proj"
},
"model.layers.3.mlp.down_proj": {
"snr": 5.660686492919922,
"type": "mlp.down_proj"
},
"model.layers.4.mlp.down_proj": {
"snr": 5.178432464599609,
"type": "mlp.down_proj"
},
"model.layers.5.mlp.down_proj": {
"snr": 3.5638349056243896,
"type": "mlp.down_proj"
},
"model.layers.6.mlp.down_proj": {
"snr": 3.0918056964874268,
"type": "mlp.down_proj"
},
"model.layers.7.mlp.down_proj": {
"snr": 2.456392288208008,
"type": "mlp.down_proj"
},
"model.layers.8.mlp.down_proj": {
"snr": 4.525328636169434,
"type": "mlp.down_proj"
},
"model.layers.9.mlp.down_proj": {
"snr": 3.9409055709838867,
"type": "mlp.down_proj"
},
"model.layers.10.mlp.down_proj": {
"snr": 5.447249412536621,
"type": "mlp.down_proj"
},
"model.layers.11.mlp.down_proj": {
"snr": 4.807600975036621,
"type": "mlp.down_proj"
},
"model.layers.12.mlp.down_proj": {
"snr": 3.915374517440796,
"type": "mlp.down_proj"
},
"model.layers.13.mlp.down_proj": {
"snr": 3.4820363521575928,
"type": "mlp.down_proj"
},
"model.layers.14.mlp.down_proj": {
"snr": 2.6045074462890625,
"type": "mlp.down_proj"
},
"model.layers.15.mlp.down_proj": {
"snr": 3.7237701416015625,
"type": "mlp.down_proj"
},
"model.layers.0.mlp.gate_proj": {
"snr": 22.160131454467773,
"type": "mlp.gate_proj"
},
"model.layers.1.mlp.gate_proj": {
"snr": 6.072206020355225,
"type": "mlp.gate_proj"
},
"model.layers.2.mlp.gate_proj": {
"snr": 3.2467362880706787,
"type": "mlp.gate_proj"
},
"model.layers.3.mlp.gate_proj": {
"snr": 1.4111896753311157,
"type": "mlp.gate_proj"
},
"model.layers.4.mlp.gate_proj": {
"snr": 0.7405938506126404,
"type": "mlp.gate_proj"
},
"model.layers.5.mlp.gate_proj": {
"snr": 0.5916463136672974,
"type": "mlp.gate_proj"
},
"model.layers.6.mlp.gate_proj": {
"snr": 0.6149423718452454,
"type": "mlp.gate_proj"
},
"model.layers.7.mlp.gate_proj": {
"snr": 0.48369669914245605,
"type": "mlp.gate_proj"
},
"model.layers.8.mlp.gate_proj": {
"snr": 0.6047574877738953,
"type": "mlp.gate_proj"
},
"model.layers.9.mlp.gate_proj": {
"snr": 0.5092479586601257,
"type": "mlp.gate_proj"
},
"model.layers.10.mlp.gate_proj": {
"snr": 0.5999670624732971,
"type": "mlp.gate_proj"
},
"model.layers.11.mlp.gate_proj": {
"snr": 0.8980127573013306,
"type": "mlp.gate_proj"
},
"model.layers.12.mlp.gate_proj": {
"snr": 1.4252448081970215,
"type": "mlp.gate_proj"
},
"model.layers.13.mlp.gate_proj": {
"snr": 1.509937047958374,
"type": "mlp.gate_proj"
},
"model.layers.14.mlp.gate_proj": {
"snr": 1.0066585540771484,
"type": "mlp.gate_proj"
},
"model.layers.15.mlp.gate_proj": {
"snr": 0.6413647532463074,
"type": "mlp.gate_proj"
},
"model.layers.0.mlp.up_proj": {
"snr": 26.08852195739746,
"type": "mlp.up_proj"
},
"model.layers.1.mlp.up_proj": {
"snr": 13.382951736450195,
"type": "mlp.up_proj"
},
"model.layers.2.mlp.up_proj": {
"snr": 20.088768005371094,
"type": "mlp.up_proj"
},
"model.layers.3.mlp.up_proj": {
"snr": 23.0632381439209,
"type": "mlp.up_proj"
},
"model.layers.4.mlp.up_proj": {
"snr": 16.07433319091797,
"type": "mlp.up_proj"
},
"model.layers.5.mlp.up_proj": {
"snr": 8.00507640838623,
"type": "mlp.up_proj"
},
"model.layers.6.mlp.up_proj": {
"snr": 9.538354873657227,
"type": "mlp.up_proj"
},
"model.layers.7.mlp.up_proj": {
"snr": 6.286602973937988,
"type": "mlp.up_proj"
},
"model.layers.8.mlp.up_proj": {
"snr": 10.092820167541504,
"type": "mlp.up_proj"
},
"model.layers.9.mlp.up_proj": {
"snr": 7.193963527679443,
"type": "mlp.up_proj"
},
"model.layers.10.mlp.up_proj": {
"snr": 7.320116996765137,
"type": "mlp.up_proj"
},
"model.layers.11.mlp.up_proj": {
"snr": 4.8728532791137695,
"type": "mlp.up_proj"
},
"model.layers.12.mlp.up_proj": {
"snr": 3.596583366394043,
"type": "mlp.up_proj"
},
"model.layers.13.mlp.up_proj": {
"snr": 3.166161298751831,
"type": "mlp.up_proj"
},
"model.layers.14.mlp.up_proj": {
"snr": 1.5600818395614624,
"type": "mlp.up_proj"
},
"model.layers.15.mlp.up_proj": {
"snr": 0.8726214170455933,
"type": "mlp.up_proj"
},
"model.embed_tokens": {
"snr": Infinity,
"type": "model.embed_tokens"
},
"model.norm": {
"snr": Infinity,
"type": "model.norm"
},
"model.layers.0.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.1.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.2.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.3.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.4.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.5.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.6.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.7.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.8.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.9.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.10.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.11.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.12.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.13.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.14.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.15.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.0.self_attn.k_proj": {
"snr": 0.1154392883181572,
"type": "self_attn.k_proj"
},
"model.layers.1.self_attn.k_proj": {
"snr": 0.24299409985542297,
"type": "self_attn.k_proj"
},
"model.layers.2.self_attn.k_proj": {
"snr": 0.3624322712421417,
"type": "self_attn.k_proj"
},
"model.layers.3.self_attn.k_proj": {
"snr": 0.29509487748146057,
"type": "self_attn.k_proj"
},
"model.layers.4.self_attn.k_proj": {
"snr": 0.32953736186027527,
"type": "self_attn.k_proj"
},
"model.layers.5.self_attn.k_proj": {
"snr": 0.2908833622932434,
"type": "self_attn.k_proj"
},
"model.layers.6.self_attn.k_proj": {
"snr": 0.2488437294960022,
"type": "self_attn.k_proj"
},
"model.layers.7.self_attn.k_proj": {
"snr": 0.27847856283187866,
"type": "self_attn.k_proj"
},
"model.layers.8.self_attn.k_proj": {
"snr": 0.27143892645835876,
"type": "self_attn.k_proj"
},
"model.layers.9.self_attn.k_proj": {
"snr": 0.28804272413253784,
"type": "self_attn.k_proj"
},
"model.layers.10.self_attn.k_proj": {
"snr": 0.31197959184646606,
"type": "self_attn.k_proj"
},
"model.layers.11.self_attn.k_proj": {
"snr": 0.3203586935997009,
"type": "self_attn.k_proj"
},
"model.layers.12.self_attn.k_proj": {
"snr": 0.30905747413635254,
"type": "self_attn.k_proj"
},
"model.layers.13.self_attn.k_proj": {
"snr": 0.46828722953796387,
"type": "self_attn.k_proj"
},
"model.layers.14.self_attn.k_proj": {
"snr": 0.24205778539180756,
"type": "self_attn.k_proj"
},
"model.layers.15.self_attn.k_proj": {
"snr": 0.2559327781200409,
"type": "self_attn.k_proj"
},
"model.layers.0.self_attn.o_proj": {
"snr": 0.2638678550720215,
"type": "self_attn.o_proj"
},
"model.layers.1.self_attn.o_proj": {
"snr": 0.21109595894813538,
"type": "self_attn.o_proj"
},
"model.layers.2.self_attn.o_proj": {
"snr": 0.24751724302768707,
"type": "self_attn.o_proj"
},
"model.layers.3.self_attn.o_proj": {
"snr": 0.2728094160556793,
"type": "self_attn.o_proj"
},
"model.layers.4.self_attn.o_proj": {
"snr": 0.3001374304294586,
"type": "self_attn.o_proj"
},
"model.layers.5.self_attn.o_proj": {
"snr": 0.33903488516807556,
"type": "self_attn.o_proj"
},
"model.layers.6.self_attn.o_proj": {
"snr": 0.3530929982662201,
"type": "self_attn.o_proj"
},
"model.layers.7.self_attn.o_proj": {
"snr": 0.36753255128860474,
"type": "self_attn.o_proj"
},
"model.layers.8.self_attn.o_proj": {
"snr": 0.3373180329799652,
"type": "self_attn.o_proj"
},
"model.layers.9.self_attn.o_proj": {
"snr": 0.2970578670501709,
"type": "self_attn.o_proj"
},
"model.layers.10.self_attn.o_proj": {
"snr": 0.3076324760913849,
"type": "self_attn.o_proj"
},
"model.layers.11.self_attn.o_proj": {
"snr": 0.2766900658607483,
"type": "self_attn.o_proj"
},
"model.layers.12.self_attn.o_proj": {
"snr": 0.20973259210586548,
"type": "self_attn.o_proj"
},
"model.layers.13.self_attn.o_proj": {
"snr": 0.18185566365718842,
"type": "self_attn.o_proj"
},
"model.layers.14.self_attn.o_proj": {
"snr": 0.18329747021198273,
"type": "self_attn.o_proj"
},
"model.layers.15.self_attn.o_proj": {
"snr": 0.2437991499900818,
"type": "self_attn.o_proj"
},
"model.layers.0.self_attn.q_proj": {
"snr": 0.038040731102228165,
"type": "self_attn.q_proj"
},
"model.layers.1.self_attn.q_proj": {
"snr": 0.0707998052239418,
"type": "self_attn.q_proj"
},
"model.layers.2.self_attn.q_proj": {
"snr": 0.0787411704659462,
"type": "self_attn.q_proj"
},
"model.layers.3.self_attn.q_proj": {
"snr": 0.08089710026979446,
"type": "self_attn.q_proj"
},
"model.layers.4.self_attn.q_proj": {
"snr": 0.08591937273740768,
"type": "self_attn.q_proj"
},
"model.layers.5.self_attn.q_proj": {
"snr": 0.09852176159620285,
"type": "self_attn.q_proj"
},
"model.layers.6.self_attn.q_proj": {
"snr": 0.09690654277801514,
"type": "self_attn.q_proj"
},
"model.layers.7.self_attn.q_proj": {
"snr": 0.11181341856718063,
"type": "self_attn.q_proj"
},
"model.layers.8.self_attn.q_proj": {
"snr": 0.12042108923196793,
"type": "self_attn.q_proj"
},
"model.layers.9.self_attn.q_proj": {
"snr": 0.09799323976039886,
"type": "self_attn.q_proj"
},
"model.layers.10.self_attn.q_proj": {
"snr": 0.10901063680648804,
"type": "self_attn.q_proj"
},
"model.layers.11.self_attn.q_proj": {
"snr": 0.09307146072387695,
"type": "self_attn.q_proj"
},
"model.layers.12.self_attn.q_proj": {
"snr": 0.0880950540304184,
"type": "self_attn.q_proj"
},
"model.layers.13.self_attn.q_proj": {
"snr": 0.08886399120092392,
"type": "self_attn.q_proj"
},
"model.layers.14.self_attn.q_proj": {
"snr": 0.09955056011676788,
"type": "self_attn.q_proj"
},
"model.layers.15.self_attn.q_proj": {
"snr": 0.08929339051246643,
"type": "self_attn.q_proj"
},
"model.layers.0.self_attn.v_proj": {
"snr": 2.5501928329467773,
"type": "self_attn.v_proj"
},
"model.layers.1.self_attn.v_proj": {
"snr": 9.449499130249023,
"type": "self_attn.v_proj"
},
"model.layers.2.self_attn.v_proj": {
"snr": 7.9920830726623535,
"type": "self_attn.v_proj"
},
"model.layers.3.self_attn.v_proj": {
"snr": 50.69462585449219,
"type": "self_attn.v_proj"
},
"model.layers.4.self_attn.v_proj": {
"snr": 19.083511352539062,
"type": "self_attn.v_proj"
},
"model.layers.5.self_attn.v_proj": {
"snr": 7.21597146987915,
"type": "self_attn.v_proj"
},
"model.layers.6.self_attn.v_proj": {
"snr": 11.27744197845459,
"type": "self_attn.v_proj"
},
"model.layers.7.self_attn.v_proj": {
"snr": 4.579711437225342,
"type": "self_attn.v_proj"
},
"model.layers.8.self_attn.v_proj": {
"snr": 10.940719604492188,
"type": "self_attn.v_proj"
},
"model.layers.9.self_attn.v_proj": {
"snr": 553.4417724609375,
"type": "self_attn.v_proj"
},
"model.layers.10.self_attn.v_proj": {
"snr": 20.59434700012207,
"type": "self_attn.v_proj"
},
"model.layers.11.self_attn.v_proj": {
"snr": 26.636865615844727,
"type": "self_attn.v_proj"
},
"model.layers.12.self_attn.v_proj": {
"snr": 8.614749908447266,
"type": "self_attn.v_proj"
},
"model.layers.13.self_attn.v_proj": {
"snr": 17.722007751464844,
"type": "self_attn.v_proj"
},
"model.layers.14.self_attn.v_proj": {
"snr": 1.48500657081604,
"type": "self_attn.v_proj"
},
"model.layers.15.self_attn.v_proj": {
"snr": 2.5776851177215576,
"type": "self_attn.v_proj"
}
}

View File

@@ -1,159 +0,0 @@
"""
Module for definition of GEGLU Triton kernels.
See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
import torch
import triton
import triton.language as tl
SQRT_2_PI: tl.constexpr = 0.7978845608028654 # sqrt(2/π)
@triton.jit
def _geglu_fwd_kernel(
gate_ptr,
up_ptr,
out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
"""GEGLU forward kernel.
Args:
gate_ptr: Pointer to gate tensor [*, hidden_dim].
up_ptr: Pointer to up-projection tensor [*, hidden_dim].
out_ptr: Pointer to output tensor [*, hidden_dim].
n_elements: Total number of elements in the input tensors.
BLOCK_SIZE: Size of thread blocks for parallel computation.
"""
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
up = tl.load(up_ptr + offsets, mask=mask, other=0)
# Compute activation in fp32 then convert back
gelu_gate = 0.5 * gate * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0)
gelu_gate = gelu_gate.to(up.dtype)
result = gelu_gate * up
tl.store(out_ptr + offsets, result, mask=mask)
def geglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
"""GEGLU forward pass.
Args:
gate: Input gate tensor of shape [batch, seq_len, hidden_dim].
up: Up-projection tensor of shape [batch, seq_len, hidden_dim].
Returns:
torch.Tensor: Output tensor of shape [batch, seq_len, hidden_dim].
"""
batch, seq_len, hidden_dim = gate.shape
n_elements = gate.numel()
out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device="cuda")
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
_geglu_fwd_kernel[grid](
gate_ptr=gate,
up_ptr=up,
out_ptr=out,
n_elements=n_elements,
BLOCK_SIZE=1024,
)
return out
@triton.jit
def _geglu_bwd_kernel(
grad_out_ptr,
gate_ptr,
up_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
"""GEGLU backward kernel. Stores gradient results in-place.
Args:
grad_out_ptr: Pointer to gradient output tensor [*, hidden_dim].
gate_ptr: Pointer to gate tensor [*, hidden_dim].
up_ptr: Pointer to up-projection tensor [*, hidden_dim].
n_elements: Total number of elements in the input tensors.
BLOCK_SIZE: Size of thread blocks for parallel computation.
Note:
After kernel execution, tensors are modified in-place:
- `grad_out_ptr` contains GEGLU activation output (`h`)
- `gate_ptr` contains gradient w.r.t gate (`grad_gate`)
- `up_ptr` contains gradient w.r.t up (`grad_up`)
"""
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0)
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
up = tl.load(up_ptr + offsets, mask=mask, other=0)
# Forward pass
gelu_partial = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0)
gelu_gate = gelu_partial * gate
gelu_gate = gelu_gate.to(grad_out.dtype)
# Forward output
h = gelu_gate * up
# Compute gradients
grad_up = grad_out * gelu_gate
# Compute gate gradient using GELU derivative
temp = grad_out * up
t = 0.3989422804014327 # 1/sqrt(2*pi)
dgelu_dgate = gelu_partial + t * gate * tl.exp(-0.5 * gate * gate)
grad_gate = temp.to(tl.float32) * dgelu_dgate
grad_gate = grad_gate.to(grad_out.dtype)
# Store results
tl.store(grad_out_ptr + offsets, h, mask=mask)
tl.store(gate_ptr + offsets, grad_gate, mask=mask)
tl.store(up_ptr + offsets, grad_up, mask=mask)
def geglu_backward(
grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""GEGLU backward pass using in-place operations.
Args:
grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`.
gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
Returns:
Tuple containing:
- GEGLU activation output (`h`)
- Gradient with respect to gate (`grad_gate`)
- Gradient with respect to up (`grad_up`)
Note:
This function modifies its input tensors in-place to store results.
"""
n_elements = grad_output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
_geglu_bwd_kernel[grid](
grad_out_ptr=grad_output,
gate_ptr=gate,
up_ptr=up,
n_elements=n_elements,
BLOCK_SIZE=1024,
)
return grad_output, gate, up

View File

@@ -1,779 +0,0 @@
"""
Module for definition of Low-Rank Adaptation (LoRA) Triton kernels.
See "LoRA: Low-Rank Adaptation of Large Language Models"
(https://arxiv.org/abs/2106.09685).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
# pylint: disable=invalid-name
from typing import Callable
import torch
from bitsandbytes.functional import QuantState
from torch import nn
from .geglu import geglu_backward, geglu_forward
from .quantize import dequantize
from .swiglu import swiglu_backward, swiglu_forward
from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd
def get_lora_parameters(
proj: nn.Module,
) -> tuple[
torch.Tensor,
QuantState | None,
torch.Tensor | None,
torch.Tensor | None,
float | None,
]:
"""
Gets LoRA parameters from a projection module.
Args:
proj: The projection module to extract parameters from.
Returns:
A tuple containing the base weight matrix, quantization state, LoRA A matrix,
LoRA B matrix, and scaling factor. States and matrices may be None if not
available.
"""
# For DPO or disabled adapters
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
W = base_layer.weight
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
quant_state = getattr(W, "quant_state", None)
return W, quant_state, None, None, None
active_adapter = (
proj.active_adapters[0]
if hasattr(proj, "active_adapters")
else proj.active_adapter
)
A = proj.lora_A[active_adapter].weight
B = proj.lora_B[active_adapter].weight
s = proj.scaling[active_adapter]
quant_state = getattr(W, "quant_state", None)
return W, quant_state, A, B, s
def matmul_lora(
X: torch.Tensor,
W: torch.Tensor,
W_quant: QuantState,
A: torch.Tensor,
B: torch.Tensor,
s: float,
out: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Efficient fused matmul + LoRA computation.
Args:
X: Input tensor [*, in_features]
W: Base weight matrix [out_features, in_features]
W_quant: Quantization state for W
A: LoRA A matrix [rank, in_features]
B: LoRA B matrix [out_features, rank]
s: LoRA scaling factor
out: Optional output tensor for inplace operations
Returns:
Result of X @ W + X @ A @ B
"""
dtype = X.dtype
W = dequantize(W.t(), W_quant)
if X.dim() == 3:
batch, seq_len, _ = X.shape
X = X.view(-1, X.shape[-1])
reshape = True
else:
reshape = False
out = torch.matmul(X, W, out=out)
if W_quant is not None:
del W
if A is not None:
A, B = A.t(), B.t()
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
return out.view(batch, seq_len, -1) if reshape else out
class LoRA_MLP(torch.autograd.Function):
"""Optimized LoRA MLP implementation."""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx,
X: torch.Tensor,
gate_weight: torch.Tensor,
gate_quant: object | None,
gate_A: torch.Tensor | None,
gate_B: torch.Tensor | None,
gate_scale: float,
up_weight: torch.Tensor,
up_quant: object | None,
up_A: torch.Tensor | None,
up_B: torch.Tensor | None,
up_scale: float,
down_weight: torch.Tensor,
down_quant: object | None,
down_A: torch.Tensor | None,
down_B: torch.Tensor | None,
down_scale: float,
activation_fn: Callable,
activation_fn_backward: Callable,
inplace: bool | None = True,
) -> torch.Tensor:
"""
Forward pass for LoRA MLP.
Args:
ctx: Autograd context
X: Input features
gate_weight: Gate projection weight
gate_quant: Gate quantization state
gate_A: Gate LoRA A matrix
gate_B: Gate LoRA B matrix
gate_scale: Gate LoRA scale
up_weight: Up-projection weight
up_quant: Up-projection quantization state
up_A: Up-projection LoRA A matrix
up_B: Up-projection LoRA B matrix
up_scale: Up-projection LoRA scale
down_weight: Down-projection weight
down_quant: Down-projection quantization state
down_A: Down-projection LoRA A matrix
down_B: Down-projection LoRA B matrix
down_scale: Down-projection LoRA scale
activation_fn: Forward activation function
activation_fn_backward: Backward activation function
inplace: Whether to perform operations in-place
Returns:
Output transformed by multi-layer perceptron and activation function
"""
# Compute projections
gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale)
up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale)
# Activation
hidden = activation_fn(gate, up)
# Down projection
output = matmul_lora(
hidden, down_weight, down_quant, down_A, down_B, down_scale
)
# Save for backward
ctx.save_for_backward(X, gate, up, gate_A, gate_B, up_A, up_B, down_A, down_B)
ctx.scales = (gate_scale, up_scale, down_scale)
ctx.quants = (gate_quant, up_quant, down_quant)
ctx.weights = (gate_weight, up_weight, down_weight)
ctx.activation_fn = activation_fn
ctx.activation_fn_backward = activation_fn_backward
ctx.inplace = inplace
return output
@staticmethod
@torch_amp_custom_bwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_output: torch.Tensor,
) -> tuple[
torch.Tensor | None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
]:
"""
Performs backward pass computation for LoRA MLP.
Args:
ctx: Context object storing tensors saved during forward pass
grad_output: Gradient of loss with respect to layer output
Returns:
Tuple containing gradients for all inputs from forward pass:
- Input gradient tensor (or `None`)
- `None` for weights/quantization states
- LoRA A/B matrix gradients (or `None`)
- `None` for scaling factors
- `None` for activation functions and flags
"""
(
X,
gate,
up,
gate_A,
gate_B,
up_A,
up_B,
down_A,
down_B,
) = ctx.saved_tensors
gate_scale, up_scale, down_scale = ctx.scales
gate_quant, up_quant, down_quant = ctx.quants
gate_weight, up_weight, down_weight = ctx.weights
# Transpose all LoRA matrices
gate_A, gate_B = (
gate_A.t() if gate_A is not None else None,
gate_B.t() if gate_B is not None else None,
)
up_A, up_B = (
up_A.t() if up_A is not None else None,
up_B.t() if up_B is not None else None,
)
down_A, down_B = (
down_A.t() if down_A is not None else None,
down_B.t() if down_B is not None else None,
)
# Reshape inputs
batch, seq_len, hd = X.shape
grad_output = grad_output.view(-1, grad_output.shape[-1])
X = X.view(-1, X.shape[-1])
gate = gate.view(-1, gate.shape[-1])
up = up.view(-1, up.shape[-1])
dtype = X.dtype
# Down projection
DW = matmul_lora(
grad_output,
down_weight.t(),
down_quant,
down_B,
down_A,
down_scale,
)
# Activation backward
h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up)
# Initialize and compute LoRA gradients
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
if down_A is not None:
d_down_A = h.t() @ (grad_output @ down_B.t())
d_down_B = (down_A.t() @ h.t()) @ grad_output
d_down_A *= down_scale
d_down_B *= down_scale
if up_A is not None:
d_up_A = X.t() @ (grad_up @ up_B.t())
d_up_B = (up_A.t() @ X.t()) @ grad_up
d_up_A *= up_scale
d_up_B *= up_scale
if gate_A is not None:
d_gate_A = X.t() @ (grad_gate @ gate_B.t())
d_gate_B = (gate_A.t() @ X.t()) @ grad_gate
d_gate_A *= gate_scale
d_gate_B *= gate_scale
# Compute input gradients
dX = torch.zeros_like(X) if ctx.needs_input_grad[0] else None
if dX is not None:
# Up projection gradients
up_weight = dequantize(up_weight.t(), up_quant)
if ctx.inplace:
dX = torch.matmul(grad_up, up_weight.t(), out=X)
else:
dX = torch.matmul(grad_up, up_weight.t())
del up_weight
# Note the .to(dtype) only where mixing LoRA with base weights
if up_A is not None:
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
# Gate projection gradients
gate_weight = dequantize(gate_weight.t(), gate_quant)
dX += grad_gate @ gate_weight.t()
del gate_weight
if gate_A is not None:
dX += (
grad_gate
@ gate_B.to(dtype).t()
@ (gate_scale * gate_A.to(dtype).t())
)
# Reshape back
dX = dX.view(batch, seq_len, hd)
# Return gradients in correct order matching forward inputs
return (
dX,
None,
None,
d_gate_A.t() if d_gate_A is not None else None,
d_gate_B.t() if d_gate_B is not None else None,
None,
None,
None,
d_up_A.t() if d_up_A is not None else None,
d_up_B.t() if d_up_B is not None else None,
None,
None,
None,
d_down_A.t() if d_down_A is not None else None,
d_down_B.t() if d_down_B is not None else None,
None,
None,
None,
None,
)
def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:
"""
Applies LoRA to MLP layer with SwiGLU activation.
Args:
X: Input tensor for the MLP layer
inplace: Whether to perform operations in-place to save memory
Returns:
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
"""
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(
X,
gateW,
gateW_quant,
gateA,
gateB,
gateS,
upW,
upW_quant,
upA,
upB,
upS,
downW,
downW_quant,
downA,
downB,
downS,
swiglu_forward,
swiglu_backward,
inplace,
)
return out
def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:
"""
Applies LoRA to MLP layer with GEGLU activation.
Args:
X: Input tensor for the MLP layer
inplace: Whether to perform operations in-place to save memory
Returns:
Output tensor after applying LoRA-adapted MLP with GEGLU activation
"""
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(
X,
gateW,
gateW_quant,
gateA,
gateB,
gateS,
upW,
upW_quant,
upA,
upB,
upS,
downW,
downW_quant,
downA,
downB,
downS,
geglu_forward,
geglu_backward,
inplace,
)
return out
class LoRA_QKV(torch.autograd.Function):
"""
Optimized LoRA QKV implementation with quantization support.
Implements efficient computation of query, key, value projections with LoRA,
supporting quantization and memory optimization.
"""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
q_weight: torch.Tensor,
q_quant: QuantState | None,
q_A: torch.Tensor | None,
q_B: torch.Tensor | None,
q_scale: float,
k_weight: torch.Tensor,
k_quant: QuantState | None,
k_A: torch.Tensor | None,
k_B: torch.Tensor | None,
k_scale: float,
v_weight: torch.Tensor,
v_quant: QuantState | None,
v_A: torch.Tensor | None,
v_B: torch.Tensor | None,
v_scale: float,
inplace: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass computing Q, K, V projections with LoRA.
Args:
ctx: Autograd context
X: Input tensor
q_weight: Query projection weight
q_quant: Query quantization state
q_A: Query LoRA A matrix
q_B: Query LoRA B matrix
q_scale: Query LoRA scale
k_weight: Key projection weight
k_quant: Key quantization state
k_A: Key LoRA A matrix
k_B: Key LoRA B matrix
k_scale: Key LoRA scale
v_weight: Value projection weight
v_quant: Value quantization state
v_A: Value LoRA A matrix
v_B: Value LoRA B matrix
v_scale: Value LoRA scale
inplace: Whether to perform operations in-place
Returns:
Tuple of (Query, Key, Value) projection tensors
"""
Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale)
K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale)
V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale)
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
ctx.scales = (q_scale, k_scale, v_scale)
ctx.quants = (q_quant, k_quant, v_quant)
ctx.weights = (q_weight, k_weight, v_weight)
ctx.inplace = inplace
return Q, K, V
@staticmethod
@torch_amp_custom_fwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
q_grad: torch.Tensor,
k_grad: torch.Tensor,
v_grad: torch.Tensor,
) -> tuple[
torch.Tensor,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
]:
"""
Backward pass computing gradients for LoRA QKV.
Args:
ctx: Autograd context
q_grad: Gradient for query projection
k_grad: Gradient for key projection
v_grad: Gradient for value projection
Returns:
Tuple containing gradients for all forward inputs
"""
X, A_q, B_q, A_k, B_k, A_v, B_v = ctx.saved_tensors
q_weight, k_weight, v_weight = ctx.weights
q_quant, k_quant, v_quant = ctx.quants
q_scale, k_scale, v_scale = ctx.scales
dtype = X.dtype
# Reshape gradients
batch, seq_len = X.shape[:2]
q_grad = q_grad.view(-1, q_grad.shape[-1])
k_grad = k_grad.reshape(-1, k_grad.shape[-1])
v_grad = v_grad.view(-1, v_grad.shape[-1])
X = X.view(-1, X.shape[-1])
# Pre-transpose X once
X_t = X.t()
# Initialize LoRA gradients as None
d_A_q = d_B_q = d_A_k = d_B_k = d_A_v = d_B_v = None
# Compute q path LoRA gradients if adapters exist
if A_q is not None and B_q is not None:
A_q_scaled = (q_scale * A_q).to(dtype)
B_q_scaled = B_q.to(dtype)
d_A_q = torch.mm(X_t, torch.mm(q_grad, B_q_scaled))
d_B_q = torch.mm(torch.mm(A_q_scaled, X_t), q_grad)
# Compute k path LoRA gradients if adapters exist
if A_k is not None and B_k is not None:
A_k_scaled = (k_scale * A_k).to(dtype)
B_k_scaled = B_k.to(dtype)
d_A_k = torch.mm(X_t, torch.mm(k_grad, B_k_scaled))
d_B_k = torch.mm(torch.mm(A_k_scaled, X_t), k_grad)
# Compute v path LoRA gradients if adapters exist
if A_v is not None and B_v is not None:
A_v_scaled = (v_scale * A_v).to(dtype)
B_v_scaled = B_v.to(dtype)
d_A_v = torch.mm(X_t, torch.mm(v_grad, B_v_scaled))
d_B_v = torch.mm(torch.mm(A_v_scaled, X_t), v_grad)
# Compute input gradient, reusing X memory if possible
out_buffer = X if ctx.inplace else None
# Q path
q_weight_t = dequantize(q_weight, q_quant)
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
del q_weight
del q_weight_t
if A_q is not None and B_q is not None:
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
# K path
k_weight_t = dequantize(k_weight, k_quant)
grad_X.addmm_(k_grad, k_weight_t)
del k_weight
del k_weight_t
if A_k is not None and B_k is not None:
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
# V path
v_weight_t = dequantize(v_weight, v_quant)
grad_X.addmm_(v_grad, v_weight_t)
del v_weight
del v_weight_t
if A_v is not None and B_v is not None:
grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled))
# Transpose gradients if needed
if d_A_q is not None:
d_A_q = d_A_q.t()
if d_B_q is not None:
d_B_q = d_B_q.t()
if d_A_k is not None:
d_A_k = d_A_k.t()
if d_B_k is not None:
d_B_k = d_B_k.t()
if d_A_v is not None:
d_A_v = d_A_v.t()
if d_B_v is not None:
d_B_v = d_B_v.t()
return (
grad_X.view(batch, seq_len, -1),
None,
None,
d_A_q,
d_B_q,
None,
None,
None,
d_A_k,
d_B_k,
None,
None,
None,
d_A_v,
d_B_v,
None,
None,
)
def apply_lora_qkv(
self, X: torch.Tensor, inplace: bool = True
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Applies LoRA to compute Query, Key, Value projections.
Args:
X: Input tensor
inplace: Whether to perform operations in-place
Returns:
Tuple of (Query, Key, Value) projection tensors
"""
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
Q, K, V = LoRA_QKV.apply(
X,
QW,
QW_quant,
QA,
QB,
QS,
KW,
KW_quant,
KA,
KB,
KS,
VW,
VW_quant,
VA,
VB,
VS,
inplace,
)
return Q, K, V
class LoRA_O(torch.autograd.Function):
"""Optimized LoRA implementation for output projection."""
@staticmethod
@torch_amp_custom_fwd
def forward(
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
W: torch.Tensor,
W_quant: QuantState | None,
A: torch.Tensor | None,
B: torch.Tensor | None,
S: float,
) -> torch.Tensor:
"""
Forward pass for output projection with LoRA.
Args:
ctx: Autograd context
X: Input tensor
W: Output projection weight
W_quant: Weight quantization state
A: LoRA A matrix
B: LoRA B matrix
S: LoRA scaling factor
Returns:
Output projection tensor
"""
XW = matmul_lora(X, W, W_quant, A, B, S)
ctx.custom_saved_tensors = (
W,
W_quant,
S,
)
ctx.save_for_backward(A, B, X)
return XW
@staticmethod
@torch_amp_custom_bwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
dY: torch.Tensor,
) -> tuple[
torch.Tensor,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
]:
"""
Backward pass computing gradients for LoRA output projection.
Args:
ctx: Autograd context
dY: Gradient of loss with respect to output
Returns:
Tuple containing gradients for all forward inputs
"""
W, W_quant, S = ctx.custom_saved_tensors
A, B, X = ctx.saved_tensors
batch, seq_len, hd = X.shape
dY = dY.reshape(-1, dY.shape[-1])
X = X.reshape(-1, X.shape[-1])
dtype = X.dtype
# Weight projection
dY_X = X.t() @ dY
d_A = S * dY_X @ B
d_B = S * A @ dY_X
# Get derivative for dX
W = dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
dX += dY @ B.to(dtype) @ (S * A.to(dtype))
# W, W_quant, A, B, S
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
"""
Applies LoRA to output projection layer.
Args:
X: Input tensor
Returns:
Transformed output tensor
"""
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS)
return output

View File

@@ -1,149 +0,0 @@
"""Dequantization utilities for `bitsandbytes` integration."""
# pylint: disable=invalid-name,global-statement
import ctypes
import bitsandbytes as bnb
import torch
from bitsandbytes.functional import QuantState, get_ptr
from packaging.version import Version
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
CUDA_STREAM: torch.cuda.Stream | None = None
HAS_CUDA_STREAM: bool = Version(bnb.__version__) > Version("0.43.3")
def dequantize(
W: torch.Tensor,
quant_state: QuantState | list | None = None,
out: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Fast NF4 dequantization using `bitsandbytes` CUDA kernels.
Performs efficient dequantization of weights from NF4 format using `bitsandbytes`'
optimized CUDA implementations. Supports both legacy list and new `QuantState`
formats.
Args:
W: Quantized weight tensor to dequantize
quant_state: Quantization state containing metadata needed for
dequantization. Can be either a `QuantState` object or legacy list format.
If None, returns `W` unchanged.
out: Optional output tensor for storing dequantized results. Must match
expected shape and dtype if provided.
Returns:
Dequantized tensor in the specified dtype (fp16 or bf16). Will be transposed if
input `W` was transposed.
Raises:
AssertionError: If provided output tensor doesn't match expected shape / dtype.
Note:
Uses CUDA streams for better performance when available in newer `bitsandbytes`
versions (>0.43.3).
"""
if quant_state is None:
return W
# Get the target device from input tensor W
target_device = W.device
# Extract quantization state
if not isinstance(quant_state, list):
# New style quant_state class
absmax = quant_state.absmax.to(target_device)
shape = quant_state.shape
dtype = quant_state.dtype
blocksize = quant_state.blocksize
offset = quant_state.offset.to(target_device)
state2 = quant_state.state2
absmax2 = state2.absmax.to(target_device)
code2 = state2.code.to(target_device)
blocksize2 = state2.blocksize
else:
# Legacy list format
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
absmax = absmax.to(target_device)
offset, state2 = compressed_stats
offset = offset.to(target_device)
absmax2, code2, blocksize2, _, _, _, _ = state2
absmax2 = absmax2.to(target_device)
code2 = code2.to(target_device)
# Setup output tensor on the same device as input
if out is None:
out = torch.empty(shape, dtype=dtype, device=target_device)
else:
assert out.shape == shape and out.dtype == dtype
out = out.to(target_device)
# Dequantize statistics on the target device
n_elements_absmax: int = absmax.numel()
out_absmax: torch.Tensor = torch.empty(
n_elements_absmax, dtype=torch.float32, device=target_device
)
ptr_out_absmax: int = get_ptr(out_absmax)
# Use CUDA stream if available
if HAS_CUDA_STREAM:
global CUDA_STREAM
if CUDA_STREAM is None:
CUDA_STREAM = torch.cuda.current_stream(target_device)
cdequantize_blockwise_fp32(
get_ptr(code2),
get_ptr(absmax),
get_ptr(absmax2),
ptr_out_absmax,
ctypes.c_int(blocksize2),
ctypes.c_int(n_elements_absmax),
CUDA_STREAM,
)
else:
cdequantize_blockwise_fp32(
get_ptr(code2),
get_ptr(absmax),
get_ptr(absmax2),
ptr_out_absmax,
ctypes.c_int(blocksize2),
ctypes.c_int(n_elements_absmax),
)
out_absmax += offset
# Choose appropriate dequantization function
fx = (
cdequantize_blockwise_fp16_nf4
if dtype == torch.float16
else cdequantize_blockwise_bf16_nf4
)
# Dequantize weights
if HAS_CUDA_STREAM:
fx(
get_ptr(None),
get_ptr(W),
ptr_out_absmax,
get_ptr(out),
ctypes.c_int(blocksize),
ctypes.c_int(out.numel()),
CUDA_STREAM,
)
else:
fx(
get_ptr(None),
get_ptr(W),
ptr_out_absmax,
get_ptr(out),
ctypes.c_int(blocksize),
ctypes.c_int(out.numel()),
)
# Handle transposed data
is_transposed: bool = W.shape[0] == 1
return out.t() if is_transposed else out

View File

@@ -1,163 +0,0 @@
"""
Module for definition of SwiGLU Triton kernels.
See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
import torch
import triton
import triton.language as tl
@triton.jit
def _swiglu_fwd_kernel(
gate_ptr,
up_ptr,
out_ptr,
n_elements,
block_size: tl.constexpr,
):
"""
SwiGLU forward kernel. The kernel computes activation in fp32 precision for better
numerical stability, then converts back to original dtype for the final result.
Args:
gate_ptr: Pointer to gate tensor `[*, hidden_dim]`.
up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`.
out_ptr: Pointer to output tensor `[*, hidden_dim]`.
n_elements: Total number of elements in the input tensors.
block_size: Size of thread blocks for parallel computation.
"""
block_idx = tl.program_id(0)
offsets = block_idx * block_size + tl.arange(0, block_size)
mask = offsets < n_elements
# Load gate in fp32, keep up in original dtype
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
up = tl.load(up_ptr + offsets, mask=mask, other=0)
# Compute activation in fp32 then convert back
f = gate * tl.sigmoid(gate)
f = f.to(up.dtype)
result = f * up
tl.store(out_ptr + offsets, result, mask=mask)
@triton.jit
def _swiglu_bwd_kernel(
grad_out_ptr,
gate_ptr,
up_ptr,
n_elements,
block_size: tl.constexpr,
):
"""
SwiGLU backward kernel. Stores gradient results in-place.
Args:
grad_out_ptr: Pointer to gradient output tensor `[*, hidden_dim]`.
gate_ptr: Pointer to gate tensor `[*, hidden_dim]`.
up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`.
n_elements: Total number of elements in the input tensors.
block_size: Size of thread blocks for parallel computation.
Note:
After kernel execution, tensors are modified in-place:
- `grad_out_ptr` contains forward output (`h`)
- `gate_ptr` contains gradient w.r.t gate (`grad_gate`)
- `up_ptr` contains gradient w.r.t up (`grad_up`)
"""
block_idx = tl.program_id(0)
offsets = block_idx * block_size + tl.arange(0, block_size)
mask = offsets < n_elements
# Load values - only convert gate to fp32
grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0)
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
up = tl.load(up_ptr + offsets, mask=mask, other=0)
# Compute SiLU and forward output
sigmoid_gate = tl.sigmoid(gate)
silu_gate = sigmoid_gate * gate
silu_gate = silu_gate.to(grad_out.dtype)
h = silu_gate * up
# Compute gradients
grad_up = grad_out * silu_gate # gradient for up is grad_out * SiLU(gate)
# Compute gate gradient
temp = grad_out * up
grad_gate = temp.to(tl.float32) * sigmoid_gate * (1.0 + gate * (1.0 - sigmoid_gate))
grad_gate = grad_gate.to(grad_out.dtype)
# Store results with correct gradient ordering
tl.store(grad_out_ptr + offsets, h, mask=mask)
tl.store(gate_ptr + offsets, grad_gate, mask=mask) # grad wrt gate
tl.store(up_ptr + offsets, grad_up, mask=mask) # grad wrt up
# pylint: disable=unnecessary-lambda-assignment
def swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
"""
SwiGLU forward pass. Computes SwiGLU activation: `x * sigmoid(x) * up`, where
`x` is the gate tensor.
Args:
gate: Input gate tensor of shape `[batch, seq_len, hidden_dim]`.
up: Up-projection tensor of shape `[batch, seq_len, hidden_dim]`.
Returns:
Output tensor of shape `[batch, seq_len, hidden_dim]`.
"""
batch, seq_len, hidden_dim = gate.shape
n_elements = gate.numel()
out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device="cuda")
grid = lambda meta: (triton.cdiv(n_elements, meta["block_size"]),) # noqa: E731
_swiglu_fwd_kernel[grid](
gate_ptr=gate,
up_ptr=up,
out_ptr=out,
n_elements=n_elements,
block_size=1024,
)
return out
# pylint: disable=unnecessary-lambda-assignment
def swiglu_backward(
grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
SwiGLU backward pass using in-place operations.
Args:
grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`.
gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
Returns:
Tuple containing:
- Forward pass output (`h`)
- Gradient with respect to gate (`df`)
- Gradient with respect to up-projection (`de`)
"""
n_elements = grad_output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["block_size"]),) # noqa: E731
_swiglu_bwd_kernel[grid](
grad_out_ptr=grad_output,
gate_ptr=gate,
up_ptr=up,
n_elements=n_elements,
block_size=1024,
)
# After kernel execution, tensors contain:
# grad_output: h (forward output)
# gate: grad_gate (grad wrt gate)
# up: grad_up (grad wrt up)
return grad_output, gate, up

View File

@@ -1,11 +0,0 @@
"""Utilities for `axolotl.kernels` submodules."""
import torch
from packaging.version import Version
if Version(torch.__version__) < Version("2.4.0"):
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")

View File

@@ -1,333 +0,0 @@
"""Module for patching custom LoRA Triton kernels and `torch.autograd` functions."""
import importlib
import inspect
import logging
import types
from typing import Type
import torch
from accelerate.logging import get_logger
from peft import PeftModelForCausalLM
from torch import nn
from transformers import AutoConfig
from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
apply_lora_qkv,
)
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.dict import DictDefault
LOG = get_logger(__name__)
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
PATCHED_QKV_CODE = """
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
ORIGINAL_O_CODE = """
attn_output = self.o_proj(attn_output)
""".lstrip(
"\n"
)
PATCHED_O_CODE = """
attn_output = self.apply_o(attn_output)
""".lstrip(
"\n"
)
SUPPORTED_ACTIVATIONS = ["silu", "gelu"]
APPLY_FN_MAPPING = {
"silu": apply_lora_mlp_swiglu,
"gelu": apply_lora_mlp_geglu,
}
def original_apply_qkv(
self: nn.Module, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Original implementation of QKV projection without optimizations.
Args:
self: The attention module instance.
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim].
Returns:
A tuple `(query_states, key_states, value_states)` containing the projected
states for query, key, and value.
"""
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
return query_states, key_states, value_states
def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Original implementation of output projection without optimizations.
Args:
self: The attention module instance.
hidden_states: Input tensor of shape `[`batch_size, seq_len, hidden_dim]`.
Returns:
The output projection result.
"""
attn_output = self.o_proj(hidden_states)
return attn_output
def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
"""
Get the appropriate attention class by inspecting the model config.
Uses dynamic import to support any model architecture that follows
the standard transformers naming convention.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
The appropriate attention class for the model.
Raises:
ValueError: If `base_model` not specified or attention class cannot be imported
ImportError: If the model module or attention class doesn't exist
"""
if "base_model" not in cfg:
raise ValueError("base_model must be specified in config")
# Get model config without loading the model
model_config = AutoConfig.from_pretrained(cfg["base_model"])
model_type = model_config.model_type
# Special case for model_type = "qwen2"
if model_type == "qwen2":
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
return Qwen2Attention
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = __import__(
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
)
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
return attention_cls
except (ImportError, AttributeError) as e:
raise ValueError(
f"Could not import attention class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e
# pylint: disable=protected-access
def patch_self_attn_lora(cfg: DictDefault):
"""
Given an `axolotl` config, this method patches the inferred attention class forward
pass with optimized LoRA implementations.
It modifies the attention class to use optimized QKV and output projections. The
original implementation is preserved and can be restored if needed.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Raises:
AssertionError: If the required code blocks are not found in the attention
implementation.
"""
attention_cls = get_attention_cls_from_config(cfg)
# Check if already patched
if hasattr(attention_cls, "_original_forward"):
LOG.info(f"{attention_cls.__name__} already patched")
return
self_attn_forward = inspect.getsource(attention_cls.forward)
attention_cls._original_forward = self_attn_forward
self_attn_forward, _ = detab_code(self_attn_forward)
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",
"def axolotl_attn_forward(",
1,
)
# Load necessary imports
module_name = attention_cls.__module__
module = importlib.import_module(module_name)
items_to_import = []
for item in dir(module):
if item in self_attn_forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
attention_cls.forward = (
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
def apply_lora_kernel_patches(
model: PeftModelForCausalLM, cfg: DictDefault
) -> PeftModelForCausalLM:
"""
Applies optimized Triton kernel patches to a PEFT model.
Patches a PEFT model with optimized implementations for MLP and attention
computations. The optimizations include custom Triton kernels for activation
functions and specialized autograd functions for LoRA computations.
Args:
model: A PEFT model to be patched with optimized kernels.
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
PeftModelForCausalLM: The patched model with optimized kernels.
Raises:
TypeError: If the provided model is not a `PeftModelForCausalLM`.
NotImplementedError: If the model type is not supported.
AssertionError: If multiple adapters are active (currently unsupported).
Note:
The optimizations require LoRA adapters with no dropout and no bias terms. The
function will skip patching if these conditions aren't met.
"""
if not isinstance(model, PeftModelForCausalLM):
raise TypeError("Model must be a PeftModelForCausalLM")
# Get active LoRA adapter config
if hasattr(model, "active_adapters"):
assert (
len(model.active_adapters) == 1
), "Axolotl currently does not support LoRA Triton kernels for multiple adapters"
active_adapter = model.active_adapters[0]
else:
active_adapter = model.active_adapter
lora_config = model.model.peft_config[active_adapter]
# Only patch if conditions are met
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
if not can_patch:
LOG.warning("Cannot patch layers - requires no dropout and no bias")
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
return model
# This needs to be reset after patching
original_level = LOG.getEffectiveLevel()
LOG.setLevel(logging.INFO)
# Choose activation based on model type
activation = model.config.hidden_act
if activation not in SUPPORTED_ACTIVATIONS:
raise NotImplementedError(f"Activation {activation} is not supported")
# Patch each layer
for layer in model.model.model.layers:
# Add QKV, O fallback implementations to start
# These will be overwritten later (if some conditions apply)
layer.self_attn.apply_qkv = types.MethodType(
original_apply_qkv, layer.self_attn
)
layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn)
if cfg.lora_mlp_kernel:
# MLP patching
gate_proj = layer.mlp.gate_proj
up_proj = layer.mlp.up_proj
down_proj = layer.mlp.down_proj
can_patch_mlp = all(
hasattr(proj, "lora_A")
and getattr(proj, "base_layer", proj).bias is None
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
)
if can_patch_mlp:
apply_fn = APPLY_FN_MAPPING[activation]
layer.mlp.forward = types.MethodType(apply_fn, layer.mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
)
if cfg.lora_qkv_kernel:
# Query, key, value patching
layer_modules = [
getattr(layer.self_attn, linear_proj)
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_qkv:
# Add optimized implementation
layer.self_attn.apply_qkv = types.MethodType(
apply_lora_qkv, layer.self_attn
)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
]
can_patch_o = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_o:
layer.self_attn.apply_o = types.MethodType(
apply_lora_o, layer.self_attn
)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
)
LOG.setLevel(original_level)
return model

View File

@@ -41,10 +41,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc
return None

View File

@@ -34,12 +34,15 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
max_length = self.prompter.max_length
self.messages = "chosen_messages"
# pylint: disable=duplicate-code
prompt["messages"] = []
prompt[self.messages] = []
if prompt["system"]:
prompt["messages"].append({"role": "system", "content": prompt["system"]})
prompt["messages"].append({"role": "user", "content": prompt["input"]})
prompt["messages"].append({"role": "assistant", "content": prompt["chosen"]})
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super()._tokenize_single_prompt(prompt)
if len(chosen_tokenized["input_ids"]) > max_length:
@@ -52,12 +55,17 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
:max_length
]
self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt["messages"] = []
prompt[self.messages] = []
if prompt["system"]:
prompt["messages"].append({"role": "system", "content": prompt["system"]})
prompt["messages"].append({"role": "user", "content": prompt["input"]})
prompt["messages"].append({"role": "assistant", "content": prompt["rejected"]})
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
rejected_tokenized = super()._tokenize_single_prompt(prompt)
if len(rejected_tokenized["input_ids"]) > max_length:
@@ -91,13 +99,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_property_mappings": ds_cfg.get(
"message_property_mappings",
{
"role": "role",
"content": "content",
},
),
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", None
@@ -121,4 +124,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy

View File

@@ -4,16 +4,13 @@ HF Chat Templates prompt strategy
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from transformers import ProcessorMixin
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
# Configure the logger
LOG = logging.getLogger("axolotl")
@@ -26,23 +23,16 @@ class ChatTemplatePrompter(Prompter):
def __init__(
self,
tokenizer,
chat_template: str,
processor=None,
chat_template=None,
max_length=2048,
message_property_mappings: Optional[Dict[str, str]] = None,
message_field_role: str = "role",
message_field_content: str = "content",
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
field_messages: str = "messages",
roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False,
):
# check if message_property_mappings is None or empty dict
if message_property_mappings is None or (not message_property_mappings):
message_property_mappings = {
"role": "role",
"content": "content",
}
if roles:
self.roles = {s: t for t, sources in roles.items() for s in sources}
else:
@@ -55,28 +45,18 @@ class ChatTemplatePrompter(Prompter):
"tool": "tool",
}
self._chat_template_msg_variables = self.get_chat_template_msg_variables(
chat_template, field_messages
)
self.message_property_mappings = message_property_mappings
self.message_field_role = message_field_role
self.message_field_content = message_field_content
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.field_messages = field_messages
self.tokenizer = tokenizer
self.processor: Optional[ProcessorMixin] = processor
self.processor: ProcessorMixin = processor
self.chat_template = chat_template
self.max_length = max_length
self.drop_system_message = drop_system_message
@property
def chat_template_msg_variables(self) -> Set[str]:
return self._chat_template_msg_variables
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
if self.processor:
if not callable(self.processor):
raise TypeError("Processor must be callable")
text = self.processor.apply_chat_template(
conversation,
chat_template=self.chat_template,
@@ -204,21 +184,17 @@ class ChatTemplatePrompter(Prompter):
return adjusted_details
def get_chat_template_msg_variables(
self, chat_template: str, field_messages: str
) -> Set[str]:
template_analyzer = JinjaTemplateAnalyzer(chat_template)
return template_analyzer.get_message_vars(field_messages)
class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
_messages = "messages"
def __init__(
self,
prompter: "ChatTemplatePrompter",
prompter: ChatTemplatePrompter,
tokenizer,
train_on_inputs,
sequence_len,
@@ -226,7 +202,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
train_on_eos=None,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.prompter: ChatTemplatePrompter = prompter
self.roles_to_train = []
if roles_to_train:
@@ -238,9 +213,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self.train_on_eos = train_on_eos
self.images = "images"
LOG.debug(
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
)
@property
def messages(self):
return self._messages
@messages.setter
def messages(self, messages):
self._messages = messages
@property
def supports_batched(self) -> bool:
@@ -250,7 +229,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
try:
return all(isinstance(v, list) for v in prompt.values()) and all(
isinstance(v, list) for v in prompt[self.prompter.field_messages]
isinstance(v, list) for v in prompt[self.messages]
)
except KeyError:
return False
@@ -485,17 +464,30 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt):
turns = []
for message in prompt[self.prompter.field_messages]:
transformed_message = self.transform_message(message)
optional_keys = [
"tool_calls", # tool that 'assistant' calls
"name", # name of tool given by 'tool'
"tool_call_id", # mistral/mixtral requires this
]
for message in prompt[self.messages]:
turn = {
**transformed_message,
"role": self.prompter.roles[message[self.prompter.message_field_role]],
"training": message.get(self.prompter.message_field_training),
"training_detail": message.get(
self.prompter.message_field_training_detail
),
}
# do not add content if None as it may conflict with some templates due to tools
content = message.get(self.prompter.message_field_content, None)
if content is not None:
turn["content"] = content
for key in optional_keys:
value = message.get(key, None)
if value is not None:
turn[key] = value
turns.append(turn)
if self.prompter.drop_system_message and turns[0]["role"] == "system":
@@ -503,37 +495,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return turns
def transform_message(self, message):
# Build the initial transformed message from the mappings
transformed_message = {}
for key, value in self.prompter.message_property_mappings.items():
if message.get(value) is not None:
transformed_message[key] = message[value]
else:
LOG.debug(
f"Could not find value for property {value} in message: {message}"
)
# Map the role if necessary
if "role" in transformed_message:
transformed_message["role"] = self.prompter.roles.get(
transformed_message["role"], transformed_message["role"]
)
# Determine which keys in the original message were not mapped
mapped_values = set(self.prompter.message_property_mappings.values())
remaining_keys = set(message) - mapped_values
# Keep only the properties defined in the chat template
# and not already mapped
for key in self.prompter.chat_template_msg_variables:
if key in remaining_keys:
val = message.get(key)
if val is not None:
transformed_message[key] = val
return transformed_message
def get_images(self, prompt):
return prompt.get(self.images, None)
@@ -555,46 +516,33 @@ class StrategyLoader:
}
def __call__(
self,
tokenizer,
cfg,
ds_cfg: Optional[Union[Dict[str, Any], DatasetConfig]] = None,
processor=None,
self, tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
):
if ds_cfg is None:
dataset_config = {}
elif isinstance(ds_cfg, BaseModel):
dataset_config = ds_cfg.model_dump()
else:
dataset_config = ds_cfg
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_property_mappings": dataset_config.get(
"message_property_mappings", {}
),
"message_field_training": dataset_config.get(
"message_field_training", None
),
"message_field_training_detail": dataset_config.get(
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail",
None,
),
"field_messages": dataset_config.get("field_messages", "messages"),
"roles": dataset_config.get("roles"),
"drop_system_message": dataset_config.get("drop_system_message", False),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
}
strategy_params = self._get_strategy_params(cfg, dataset_config)
strategy_params = self._get_strategy_params(cfg, ds_cfg)
strategy_cls = self._get_strategy_cls()
strategy = strategy_cls(
@@ -603,6 +551,9 @@ class StrategyLoader:
**strategy_params,
)
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy

View File

@@ -3,28 +3,20 @@ DPO prompt strategies for using tokenizer chat templates.
"""
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
def default(
cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx]
ds_cfg = handle_legacy_message_fields_logic(ds_cfg)
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
message_property_mappings = ds_cfg.get(
"message_property_mappings",
{
"role": "role",
"content": "content",
},
)
field_message_role = ds_cfg.get("message_field_role", "role")
field_message_content = ds_cfg.get("message_field_content", "content")
role_map_inv = ds_cfg.get(
"roles",
{
@@ -48,18 +40,18 @@ def default(
messages = sample[field_messages]
messages = [
{
"role": role_map[m[message_property_mappings["role"]]],
"content": m[message_property_mappings["content"]],
"role": role_map[m[field_message_role]],
"content": m[field_message_content],
}
for m in messages
]
chosen = {
"role": role_map[sample[field_chosen][message_property_mappings["role"]]],
"content": sample[field_chosen][message_property_mappings["content"]],
"role": role_map[sample[field_chosen][field_message_role]],
"content": sample[field_chosen][field_message_content],
}
rejected = {
"role": role_map[sample[field_rejected][message_property_mappings["role"]]],
"content": sample[field_rejected][message_property_mappings["content"]],
"role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content],
}
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}

View File

@@ -1,318 +0,0 @@
"""Module for inspect jinja templates for the variables they use"""
from typing import Dict, Optional, Set, TypedDict, Union
from jinja2 import Environment, meta, nodes
class JinjaTemplateAnalysis(TypedDict):
"""
Represents the detailed analysis of a Jinja template variable.
Attributes:
accessed_properties (Set[str]): A set of properties accessed from the variable
(e.g., `foo.bar` results in 'bar' being accessed for 'foo').
accessed_indices (Set[Union[int, float]]): A set of indices accessed from the variable.
is_iterated (bool): Indicates if the variable is used as an iteration source in a `for` loop.
is_conditional (bool): Indicates if the variable is referenced within a conditional statement (e.g., an `if` block).
iteration_source (Optional[str]): The name of the variable being iterated over, if applicable.
iteration_target (Optional[Union[str, list[str]]]): The loop target(s) assigned in the iteration.
"""
accessed_properties: Set[str]
accessed_indices: Set[Union[int, float]]
is_iterated: bool
is_conditional: bool
iteration_source: Optional[str]
iteration_target: Optional[Union[str, list[str]]]
class JinjaTemplateAnalyzer:
"""
Analyzes Jinja templates to extract information about variable usage,
including accessed properties, iteration, and conditional references.
Attributes:
env (jinja2.Environment): The Jinja2 environment used for parsing templates.
property_access (Dict[str, Set[str]]): Tracks accessed properties for variables.
iteration_targets (Dict[str, str]): Maps iteration target variables to their sources.
Methods:
get_template_variables(template: str) -> Dict[str, Set[str]]:
Parse a Jinja template and return a mapping of variables to their accessed properties.
analyze_template(template: str) -> Dict[str, JinjaTemplateAnalysis]:
Perform a detailed analysis of the template, including variable usage,
iteration, and conditional references.
Private Methods:
_visit_node(node) -> None:
Recursively visit AST nodes to detect attribute access and iteration targets.
_get_base_name(node) -> Optional[str]:
Extract the base variable name from a node.
_get_target_name(node) -> Optional[Union[str, list[str]]]:
Extract the target name(s) from a `For` node.
"""
def __init__(self, template: str):
self.env: Environment = Environment(autoescape=True)
self.property_access: Dict[str, Set[str]] = {}
self.iteration_targets: Dict[str, Union[str, list[str]]] = {}
self.index_access: Dict[str, Set[Union[int, float]]] = {}
self.ast: nodes.Node = self.env.parse(template)
self.template: str = template
self.variable_assignments: Dict[str, str] = {}
def _visit_node(self, node) -> None:
"""Recursively visit AST nodes to find attribute access."""
# Handle attribute access (dot notation)
if isinstance(node, nodes.Getattr):
base_name = self._get_base_name(node.node)
if base_name:
self.property_access.setdefault(base_name, set()).add(node.attr)
# Handle dictionary access (subscript notation)
elif isinstance(node, nodes.Getitem):
base_name = self._get_base_name(node.node)
if base_name and isinstance(node.arg, nodes.Const):
value = node.arg.value
if isinstance(value, (int, float)):
self.index_access.setdefault(base_name, set()).add(value)
else:
self.property_access.setdefault(base_name, set()).add(value)
elif isinstance(node, nodes.Test) and node.name == "defined":
base_name = self._get_base_name(node.node)
if base_name:
if isinstance(node.node, nodes.Getattr):
self.property_access.setdefault(base_name, set()).add(
node.node.attr
)
# Handle loop variables
elif isinstance(node, nodes.For):
iter_name = self._get_base_name(node.iter)
target_name = self._get_target_name(node.target)
if iter_name and target_name:
self.iteration_targets[target_name] = iter_name
self.property_access.setdefault(iter_name, set())
elif isinstance(node, nodes.Assign):
target_name = self._get_target_name(node.target)
source_name = self._get_base_name(node.node)
if target_name and source_name:
self.variable_assignments[target_name] = source_name
elif isinstance(node, nodes.Filter):
if node.name == "selectattr":
target = self._get_base_name(node.node)
if target:
self.variable_assignments[f"filtered_{target}"] = target
for child in node.iter_child_nodes():
self._visit_node(child)
def _get_target_name(self, node) -> Optional[str]:
"""Get the target variable name from a For node.
Args:
node: A Jinja AST node representing either a Name or Tuple node
Returns:
- str: For simple variable targets (e.g., "item" in "for item in items")
- None: If the node type is not recognized or is a tuple
"""
if isinstance(node, nodes.Name):
return node.name
return None
def _get_target_names(self, node) -> list[str]:
"""Get all target variable names from a For node, including tuple unpacking.
Args:
node: A Jinja AST node representing either a Name or Tuple node
Returns:
List of target variable names
"""
if isinstance(node, nodes.Name):
return [node.name]
if isinstance(node, nodes.Tuple):
names = []
for n in node.items:
if isinstance(n, nodes.Name):
names.append(n.name)
return names
return []
def _get_base_name(self, node) -> Optional[str]:
"""Get the base variable name from a node."""
if isinstance(node, nodes.Name):
return node.name
if isinstance(node, nodes.Getattr):
return self._get_base_name(node.node)
if isinstance(node, nodes.Getitem):
return self._get_base_name(node.node)
return None
def get_template_variables(self) -> Dict[str, Set[str]]:
"""
Parse a Jinja template and return both variables and their accessed properties.
Args:
template (str): The Jinja template string
Returns:
Dict[str, Set[str]]: Dictionary mapping variable names to sets of accessed properties
"""
# Parse the template
ast = self.env.parse(self.template)
# Get all undeclared variables
variables = meta.find_undeclared_variables(ast)
# Reset property access tracking
self.property_access = {}
# Visit all nodes to find property access
self._visit_node(ast)
# Create result dictionary
result: Dict[str, Set[str]] = {var: set() for var in variables}
# Merge in any discovered sub-properties
for var, props in self.property_access.items():
if var not in result:
result[var] = set()
result[var].update(props)
return result
def analyze_template(self) -> Dict[str, JinjaTemplateAnalysis]:
"""
Provide a detailed analysis of template variables and their usage.
"""
variables = self.get_template_variables()
self.iteration_targets = {}
analysis: Dict[str, JinjaTemplateAnalysis] = {
var: JinjaTemplateAnalysis(
accessed_properties=props,
accessed_indices=set(),
is_iterated=False,
is_conditional=False,
iteration_source=None,
iteration_target=None,
)
for var, props in variables.items()
}
for var, indices in self.index_access.items():
if var in analysis:
analysis[var]["accessed_indices"] = indices
def visit_node(node):
if isinstance(node, nodes.If):
def find_test_vars(test_node):
if isinstance(test_node, nodes.Name):
if test_node.name in analysis:
analysis[test_node.name]["is_conditional"] = True
for child in test_node.iter_child_nodes():
find_test_vars(child)
find_test_vars(node.test)
if isinstance(node, nodes.For):
iter_target = self._get_base_name(node.iter)
target_name = self._get_target_name(node.target)
if iter_target in analysis:
analysis[iter_target]["is_iterated"] = True
if target_name:
analysis[iter_target]["iteration_target"] = target_name
if isinstance(target_name, str) and target_name not in analysis:
analysis[target_name] = {
"accessed_properties": set(),
"is_iterated": False,
"is_conditional": False,
"iteration_source": iter_target,
"iteration_target": None,
}
for child in node.iter_child_nodes():
visit_node(child)
visit_node(self.ast)
return analysis
def get_downstream_properties(self, start_var: str) -> Dict[str, Set[str]]:
"""
Get all properties accessed on a variable and its downstream assignments.
Args:
start_var: The starting variable to trace
Returns:
Dict mapping variable names to their accessed properties
"""
visited = set()
properties = {}
def trace_variable(var_name: str):
if var_name in visited:
return
visited.add(var_name)
# Get direct properties
if var_name in self.property_access:
properties[var_name] = self.property_access[var_name]
# Get properties from iteration targets
if var_name in self.iteration_targets:
target = self.iteration_targets[var_name]
if isinstance(target, str):
trace_variable(target)
elif isinstance(target, list):
for t in target:
trace_variable(t)
# Follow assignments
for target, source in self.variable_assignments.items():
if source == var_name:
trace_variable(target)
# Check for array slicing
analysis = self.analyze_template()
if var_name in analysis:
var_info = analysis[var_name]
if var_info["accessed_indices"]:
# If this variable is sliced, follow the resulting assignment
slice_result = f"{var_name}_slice"
if slice_result in self.property_access:
trace_variable(slice_result)
trace_variable(start_var)
return properties
def get_message_vars(self, field_messages: str = "messages") -> Set[str]:
"""
Get all properties accessed on messages and derived variables.
"""
all_properties = self.get_downstream_properties(field_messages)
# Combine all properties from all related variables
combined_properties = set()
for properties in all_properties.values():
combined_properties.update(properties)
# Also include properties from the message iteration variable
analysis = self.analyze_template()
if "message" in analysis:
combined_properties.update(analysis["message"]["accessed_properties"])
return combined_properties

View File

@@ -51,13 +51,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
field_messages = ds_cfg.get("field_messages")
message_property_mappings = ds_cfg.get("message_property_mappings")
message_field_role = (
message_property_mappings.get("role") if message_property_mappings else None
)
message_field_content = (
message_property_mappings.get("content") if message_property_mappings else None
)
message_field_role = ds_cfg.get("message_field_role")
message_field_content = ds_cfg.get("message_field_content")
message_field_training = ds_cfg.get("message_field_training")
builder_kwargs = {}

View File

@@ -175,7 +175,6 @@ def train(
LOG.info("hang tight... sorting dataset for group_by_length")
pretrain_hooks(cfg, trainer)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
@@ -186,7 +185,6 @@ def train(
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
post_train_hooks(cfg, trainer)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")

View File

@@ -15,7 +15,7 @@ _DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"
_CHAT_TEMPLATES = {
"alpaca": "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'system' and loop.first %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '### Instruction:\n' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '### Response:\n' + message['content'] + eos_token }}{% endif %}{% if not loop.last %}{{ '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\n\n### Response:\n' }}{% endif %}",
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
"mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1...
"mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large...
"mistral_v3_tekken": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3-Tekken: Nemo, Pixtral...
@@ -38,7 +38,7 @@ def get_chat_template(
user_choice: str,
jinja_template: Optional[str] = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
) -> str:
):
"""
Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.
@@ -70,7 +70,7 @@ def get_chat_template(
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
f"Please add a chat_template in tokenizer config"
)
return tokenizer.chat_template # type: ignore
return tokenizer.chat_template
if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
if not tokenizer:
@@ -78,7 +78,7 @@ def get_chat_template(
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
)
if tokenizer.chat_template:
return tokenizer.chat_template # type: ignore
return tokenizer.chat_template
user_choice = user_choice[
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :

View File

@@ -18,7 +18,6 @@ from axolotl.utils.config.models.input.v0_4_1 import (
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.config.models.input.v0_4_1 import DPODataset, KTODataset, SFTDataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
@@ -259,7 +258,7 @@ def validate_config(
cfg: DictDefault,
capabilities: Optional[dict] = None,
env_capabilities: Optional[dict] = None,
) -> DictDefault:
):
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
AxolotlInputConfig = AxolotlInputConfigBase
@@ -269,16 +268,6 @@ def validate_config(
AxolotlInputConfig, # pylint: disable=invalid-name
) = merge_input_args()
# Convert datasets to proper format if needed
if cfg.get("datasets"):
for idx, ds_cfg in enumerate(cfg["datasets"]):
if cfg.get("rl") == "dpo" and not isinstance(ds_cfg, DPODataset):
cfg["datasets"][idx] = DPODataset(**ds_cfg)
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
elif not isinstance(ds_cfg, SFTDataset):
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))
if capabilities or env_capabilities:
if (capabilities and env_capabilities is None) or (
env_capabilities and capabilities is None

View File

@@ -1,4 +1,7 @@
"""Module with Pydantic models for configuration."""
"""
Module for pydantic models for configuration
"""
# pylint: disable=too-many-lines
import logging
@@ -6,13 +9,12 @@ import os
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
from annotated_types import MinLen
from packaging import version
from pydantic import (
BaseModel,
Field,
StringConstraints,
field_serializer,
conlist,
field_validator,
model_validator,
)
@@ -22,7 +24,7 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
from .trl import TRLConfig
from .trl import TrlConfig
LOG = logging.getLogger("axolotl.utils.config.models.input")
@@ -116,9 +118,6 @@ class RemappedParameters(BaseModel):
overrides_of_model_config: Optional[Dict[str, Any]] = Field(
default=None, alias="model_config"
)
overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field(
default=None, alias="model_kwargs"
)
type_of_model: Optional[str] = Field(default=None, alias="model_type")
revision_of_model: Optional[str] = Field(default=None, alias="model_revision")
@@ -167,7 +166,6 @@ class SFTDataset(BaseModel):
type: Optional[Union[str, UserDefinedPrompterType]] = None
input_transform: Optional[str] = None
shards: Optional[int] = None
shards_idx: Optional[int] = None
preprocess_shards: Optional[int] = None
conversation: Optional[str] = None
# Do not make this too strict or it will break the validator to choose different dataset class
@@ -187,13 +185,8 @@ class SFTDataset(BaseModel):
field_human: Optional[str] = None
field_model: Optional[str] = None
field_messages: Optional[str] = None
message_field_role: Optional[
str
] = None # deprecated, use message_property_mappings
message_field_content: Optional[
str
] = None # deprecated, use message_property_mappings
message_property_mappings: Optional[Dict[str, str]] = None
message_field_role: Optional[str] = None
message_field_content: Optional[str] = None
message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None
logprobs_field: Optional[str] = None
@@ -205,18 +198,9 @@ class SFTDataset(BaseModel):
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
@model_validator(mode="before")
@classmethod
def handle_legacy_message_fields(cls, data):
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
return handle_legacy_message_fields_logic(data)
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
if isinstance(data, BaseModel):
data = data.model_dump()
# Set chat_template to tokenizer_default if not set
if data.get("type") == "chat_template" and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.tokenizer_default
@@ -256,7 +240,6 @@ class DPODataset(BaseModel):
type: Optional[Union[UserDefinedDPOType, str]] = None
data_files: Optional[List[str]] = None
revision: Optional[str] = None
field_messages: Optional[str] = None
class StepwiseSupervisedDataset(BaseModel):
@@ -293,9 +276,6 @@ class KTODataset(BaseModel):
revision: Optional[str] = None
DatasetConfig = Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset]
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
@@ -435,8 +415,6 @@ class ReLoRAConfig(BaseModel):
class ModelInputConfig(BaseModel):
"""model to train on configuration subset"""
model_config = {"protected_namespaces": ()}
base_model: str
base_model_config: Optional[str] = None
cls_model_config: Optional[str] = None
@@ -451,6 +429,8 @@ class ModelInputConfig(BaseModel):
)
trust_remote_code: Optional[bool] = None
model_kwargs: Optional[Dict[str, Any]] = None
@field_validator("trust_remote_code")
@classmethod
def hint_trust_remote_code(cls, trust_remote_code):
@@ -503,7 +483,7 @@ class HyperparametersConfig(BaseModel):
"adopt_adamw",
],
]
] = OptimizerNames.ADAMW_HF
] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None,
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
@@ -515,9 +495,7 @@ class HyperparametersConfig(BaseModel):
},
)
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[
Union[SchedulerType, Literal["one_cycle"]]
] = SchedulerType.COSINE
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None
cosine_min_lr_ratio: Optional[float] = None
@@ -641,19 +619,19 @@ class RayConfig(BaseModel):
use_ray: bool = Field(default=False)
ray_run_name: Optional[str] = Field(
default=None,
json_schema_extra={
metadata={
"help": "The training results will be saved at `saves/ray_run_name`."
},
)
ray_num_workers: int = Field(
default=1,
json_schema_extra={
metadata={
"help": "The number of workers for Ray training. Default is 1 worker."
},
)
resources_per_worker: dict = Field(
default_factory=lambda: {"GPU": 1},
json_schema_extra={
metadata={
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
},
)
@@ -678,7 +656,10 @@ class AxolotlInputConfig(
):
"""wrapper of all config options"""
model_config = {"populate_by_name": True}
class Config:
"""Config for alias"""
populate_by_name = True
strict: Optional[bool] = Field(default=False)
resume_from_checkpoint: Optional[str] = None
@@ -689,8 +670,8 @@ class AxolotlInputConfig(
shrink_embeddings: Optional[bool] = None
rl: Optional[RLType] = None
trl: Optional[TRLConfig] = Field(
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
trl: Optional[TrlConfig] = Field(
default_factory=lambda: TrlConfig(), # pylint: disable=unnecessary-lambda
)
reward_model: Optional[bool] = None
process_reward_model: Optional[bool] = None
@@ -700,27 +681,16 @@ class AxolotlInputConfig(
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
dpo_use_logits_to_keep: Optional[bool] = None
datasets: Optional[
Annotated[
list[Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset]],
MinLen(1),
]
] = None
test_datasets: Optional[
Annotated[
list[Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset]],
MinLen(1),
]
] = None
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
shuffle_merged_datasets: Optional[bool] = True
dataset_prepared_path: Optional[str] = None
dataset_shard_num: Optional[int] = None
dataset_shard_idx: Optional[int] = None
skip_prepare_dataset: Optional[bool] = False
pretraining_dataset: Optional[
Annotated[list[Union[PretrainingDataset, SFTDataset]], MinLen(1)]
pretraining_dataset: Optional[ # type: ignore
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
] = Field(
default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"},
@@ -838,10 +808,6 @@ class AxolotlInputConfig(
unsloth_rms_norm: Optional[bool] = None
unsloth_rope: Optional[bool] = None
lora_mlp_kernel: Optional[bool] = None
lora_qkv_kernel: Optional[bool] = None
lora_o_kernel: Optional[bool] = None
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
fsdp: Optional[List[str]] = None
fsdp_config: Optional[Dict[str, Any]] = None
@@ -864,7 +830,7 @@ class AxolotlInputConfig(
warmup_steps: Optional[int] = None
warmup_ratio: Optional[float] = None
eval_steps: Optional[Union[int, float]] = None
evals_per_epoch: Optional[int] = None
evals_per_epoch: Optional[Union[int]] = None
eval_strategy: Optional[str] = None
save_steps: Optional[Union[int, float]] = None
saves_per_epoch: Optional[int] = None
@@ -876,7 +842,6 @@ class AxolotlInputConfig(
save_only_model: Optional[bool] = False
use_tensorboard: Optional[bool] = None
profiler_steps: Optional[int] = None
include_tokens_per_second: Optional[bool] = None
neftune_noise_alpha: Optional[float] = None
@@ -926,15 +891,10 @@ class AxolotlInputConfig(
@classmethod
def deprecate_sharegpt_datasets(cls, datasets):
for _, ds_cfg in enumerate(datasets):
# Handle both dict and pydantic model cases
ds_type = (
ds_cfg.get("type")
if isinstance(ds_cfg, dict)
else getattr(ds_cfg, "type", None)
)
if not ds_type:
if not ds_cfg.get("type"):
continue
ds_type = ds_cfg["type"]
# skip if it's a dict (for custom user instruction prompt)
if isinstance(ds_type, dict):
continue
@@ -946,14 +906,6 @@ class AxolotlInputConfig(
return datasets
@field_serializer("datasets")
def datasets_serializer(
self, ds_configs: Optional[List[DatasetConfig]]
) -> Optional[List[Dict[str, Any]]]:
if ds_configs:
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
return None
@model_validator(mode="before")
@classmethod
def check_batch_size_fields(cls, data):
@@ -1579,42 +1531,12 @@ class AxolotlInputConfig(
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_8bit(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
):
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError(
"lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA"
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_axolotl_unsloth(cls, data):
is_lora_kernel = any(
data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
)
is_unsloth_lora = any(
data.get(k)
for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
)
if is_lora_kernel and is_unsloth_lora:
raise ValueError(
"both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):
@@ -1747,29 +1669,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_multigpu_lora_kernels(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
):
capabilities = data.get("capabilities")
is_fsdp = data.get("fsdp") is not None
is_deepspeed = data.get("deepspeed") is not None
if capabilities and capabilities.get("n_gpu", 0) > 1:
if is_fsdp:
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
)
if is_deepspeed:
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with DeepSpeed."
)
return data
@model_validator(mode="before")
@classmethod
def check_adopt_torch_version(cls, data):
@@ -1806,77 +1705,3 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
else:
data["torch_compile"] = False
return data
def handle_legacy_message_fields_logic(data: dict) -> dict:
"""
Handle backwards compatibility between legacy message field mapping and new property mapping system.
Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options:
- message_field_role: Mapped to the role field
- message_field_content: Mapped to the content field
The new system uses message_property_mappings to support arbitrary field mappings:
message_property_mappings:
role: source_role_field
content: source_content_field
additional_field: source_field
Args:
data: Dictionary containing configuration data
Returns:
Updated dictionary with message field mappings consolidated
Raises:
ValueError: If there are conflicts between legacy and new mappings
"""
data = data.copy() # Create a copy to avoid modifying the original
if data.get("message_property_mappings") is None:
data["message_property_mappings"] = {}
# Check for conflicts and handle role
if "message_field_role" in data:
LOG.warning(
"message_field_role is deprecated, use message_property_mappings instead. "
f"Example: message_property_mappings: {{role: {data['message_field_role']}}}"
)
if (
"role" in data["message_property_mappings"]
and data["message_property_mappings"]["role"] != data["message_field_role"]
):
raise ValueError(
f"Conflicting message role fields: message_field_role='{data['message_field_role']}' "
f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'"
)
data["message_property_mappings"]["role"] = data["message_field_role"] or "role"
del data["message_field_role"]
elif "role" not in data["message_property_mappings"]:
data["message_property_mappings"]["role"] = "role"
# Check for conflicts and handle content
if "message_field_content" in data:
LOG.warning(
"message_field_content is deprecated, use message_property_mappings instead. "
f"Example: message_property_mappings: {{content: {data['message_field_content']}}}"
)
if (
"content" in data["message_property_mappings"]
and data["message_property_mappings"]["content"]
!= data["message_field_content"]
):
raise ValueError(
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
)
data["message_property_mappings"]["content"] = (
data["message_field_content"] or "content"
)
del data["message_field_content"]
elif "content" not in data["message_property_mappings"]:
data["message_property_mappings"]["content"] = "content"
return data

View File

@@ -6,7 +6,7 @@ from typing import List, Optional
from pydantic import BaseModel, Field
class TRLConfig(BaseModel):
class TrlConfig(BaseModel):
"""
Input args for TRL.
"""
@@ -25,11 +25,8 @@ class TRLConfig(BaseModel):
vllm_gpu_memory_utilization: Optional[float] = 0.9
vllm_max_model_len: Optional[int] = None
vllm_dtype: Optional[str] = "auto"
reward_funcs: Optional[List[str]] = None
num_generations: Optional[int] = None
log_completions: Optional[bool] = False
sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64

View File

@@ -4,16 +4,15 @@ import inspect
import logging
from functools import partial
from pathlib import Path
from typing import Any, List, Union
from typing import Any, List
import yaml
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.kto import load as load_kto
from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
@@ -114,21 +113,29 @@ def drop_long_rl_seq(
return (len_prompt + len_completion) <= sequence_len
if rl == "grpo":
return True
raise ValueError("Unknown RL type")
def load_prepare_preference_datasets(cfg):
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
use_auth_token = _cfg.hf_use_auth_token
for config_dataset in datasets_w_name_generator(dataset_cfgs):
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token, streaming=False
)
split_datasets.append(ds)
for i, ds_cfg in enumerate(dataset_cfgs):
if ds_cfg["ds_type"] == "json":
for data_file in ds_cfg["data_files"]:
data_files = {ds_cfg["split"]: data_file}
ds = load_dataset( # pylint: disable=invalid-name
"json",
data_files=data_files,
split=ds_cfg["split"],
)
split_datasets.insert(i, ds)
else:
ds = load_dataset( # pylint: disable=invalid-name
ds_cfg["path"],
split=ds_cfg["split"],
revision=ds_cfg.get("revision", None),
)
split_datasets.insert(i, ds)
tokenizer = load_tokenizer(cfg)

View File

@@ -43,7 +43,7 @@ from axolotl.prompters import (
UnsupportedPrompter,
)
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
from axolotl.utils.data.shared import load_dataset_w_config
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
drop_long_seq_in_dataset,
@@ -180,7 +180,6 @@ def load_tokenized_prepared_datasets(
) -> Tuple[DatasetDict, List[Prompter]]:
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
tokenizer_name = cfg.tokenizer_config
ds_hash = str(
md5(
(
@@ -264,11 +263,30 @@ def load_tokenized_prepared_datasets(
datasets = []
def for_d_in_datasets(dataset_configs):
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
# load_dataset doesn't properly handle multiple named configurations
# at the same time for a given dataset
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
elif dataset.preprocess_shards and not dataset.shards:
for shard in range(dataset.preprocess_shards):
yield DictDefault(
{
**dataset,
"shards": dataset.preprocess_shards,
"shards_idx": shard,
}
)
else:
yield dataset
streaming_ds = False
if preprocess_iterable:
streaming_ds = True
# pylint: disable=invalid-name
for config_dataset in datasets_w_name_generator(cfg_datasets):
for config_dataset in for_d_in_datasets(cfg_datasets):
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token, streaming=streaming_ds
)

View File

@@ -1,7 +1,6 @@
"""
dataset loading shared utils
"""
from pathlib import Path
from typing import Optional, Union
@@ -30,43 +29,9 @@ def get_ds_type(config_dataset: DictDefault):
return ds_type
def datasets_w_name_generator(dataset_configs: list[DictDefault]):
"""
Yields dataset configs handling multiple names or preprocess_shards
Args:
dataset_configs: list of dataset configs (equivalent to cfg.datasets)
"""
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
# load_dataset doesn't properly handle multiple named configurations
# at the same time for a given dataset
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
elif dataset.preprocess_shards and not dataset.shards:
for shard in range(dataset.preprocess_shards):
yield DictDefault(
{
**dataset,
"shards": dataset.preprocess_shards,
"shards_idx": shard,
}
)
else:
yield dataset
def load_dataset_w_config(
config_dataset: DictDefault, use_auth_token: bool, streaming=False
config_dataset, auth_token, streaming=False
) -> Union[Dataset, DatasetDict]:
"""
Load a dataset from a config
Args:
config_dataset: single dataset config
use_auth_token: whether to use HF auth token
streaming: whether to stream the dataset
"""
# pylint: disable=invalid-name
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
ds_from_hub = False
@@ -78,7 +43,7 @@ def load_dataset_w_config(
config_dataset.path,
name=config_dataset.name,
streaming=True,
token=use_auth_token,
token=auth_token,
revision=config_dataset.revision,
trust_remote_code=ds_trust_remote_code,
)
@@ -196,7 +161,7 @@ def load_dataset_w_config(
name=config_dataset.name,
streaming=streaming,
data_files=config_dataset.data_files,
token=use_auth_token,
token=auth_token,
revision=config_dataset.revision,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,

View File

@@ -13,26 +13,3 @@ class DictDefault(Dict):
def __or__(self, other):
return DictDefault(super().__ror__(other))
def __setitem__(self, name, value):
# workaround for pickle/unpickle issues and __frozen not being available
try:
isFrozen = hasattr( # pylint: disable=invalid-name
self, "__frozen"
) and object.__getattribute__(self, "__frozen")
except AttributeError:
isFrozen = False # pylint: disable=invalid-name
if isFrozen and name not in super().keys():
raise KeyError(name)
super(Dict, self).__setitem__(name, value) # pylint: disable=bad-super-call
try:
p = object.__getattribute__(self, "__parent")
key = object.__getattribute__(self, "__key")
except AttributeError:
p = None
key = None
if p is not None:
p[key] = self
object.__delattr__(self, "__parent")
object.__delattr__(self, "__key")

View File

@@ -1,75 +0,0 @@
# Copyright 2025 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
module to get the state dict of a merged lora model
"""
import torch
from peft.tuners.tuners_utils import onload_layer
from peft.utils import ModulesToSaveWrapper, _get_submodules
def get_lora_merged_state_dict(
model: torch.nn.Module,
) -> dict:
r"""
Create and return a state_dict that has the LoRA deltas
merged into the base models weights, without modifying `model` in place.
Arguments:
model (torch.nn.Module): A model that has LoRA/PEFT adapters attached.
Returns:
dict: A state_dict of the merged parameters.
"""
base_model_prefix = "base_model.model."
state_dict = {}
key_list = [key for key, _ in model.named_modules() if model.prefix not in key]
for key in key_list:
try:
_, target, _ = _get_submodules(model, key)
except AttributeError:
continue
with onload_layer(target):
weight_key = key.replace(base_model_prefix, "") + ".weight"
bias_key = key.replace(base_model_prefix, "") + ".bias"
if hasattr(target, "base_layer"):
target.merge(safe_merge=True, adapter_names=None)
# get the state_dict of target.base_layer
layer_state_dict = target.base_layer.state_dict()
state_dict[weight_key] = layer_state_dict["weight"]
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
new_module = target.modules_to_save[target.active_adapter]
if hasattr(new_module, "base_layer"):
# check if the module is itself a tuner layer
new_module.merge(safe_merge=True, adapter_names=None)
layer_state_dict = new_module.state_dict()
state_dict[weight_key] = layer_state_dict["weight"]
elif hasattr(target, "weight"):
if any(
skip in key
for skip in [
".original_module",
".modules_to_save",
".base_layer",
]
):
continue
layer_state_dict = target.state_dict()
state_dict[weight_key] = layer_state_dict["weight"]
if hasattr(target, "bias") and "bias" in layer_state_dict.keys():
state_dict[bias_key] = layer_state_dict["bias"]
return state_dict

View File

@@ -357,8 +357,8 @@ class ModelLoader:
# init model kwargs
self.model_kwargs: Dict[str, Any] = {}
if cfg.overrides_of_model_kwargs:
for key, val in cfg.overrides_of_model_kwargs.items():
if cfg.model_kwargs:
for key, val in cfg.model_kwargs.items():
self.model_kwargs[key] = val
# init model
@@ -414,7 +414,6 @@ class ModelLoader:
has_remote_code = "AutoModelForCausalLM" in auto_map_config
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is False:
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
has_remote_code = self.cfg.trust_remote_code
@@ -426,6 +425,10 @@ class ModelLoader:
if self.cfg.is_llama_derived_model:
self.patch_loss_llama()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
elif self.cfg.is_llama_derived_model:
self.patch_llama_derived_model()
@@ -439,11 +442,6 @@ class ModelLoader:
patch_mistral_cross_entropy()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.cfg)
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
@@ -474,7 +472,9 @@ class ModelLoader:
return importlib.util.find_spec("flash_attn") is not None
def patch_loss_llama(self) -> None:
"""Patch loss functions and other optimizations"""
"""
Patch loss functions
"""
if self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_fa_llama_cross_entropy,
@@ -494,14 +494,15 @@ class ModelLoader:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
def patch_llama_derived_model(self) -> None:
"""Modify all llama derived models in one block"""
"""
Modify all llama derived models in one block
"""
self.patch_loss_llama()
if self.cfg.flash_attention:
@@ -1012,8 +1013,7 @@ class ModelLoader:
if hasattr(module, "weight"):
module.to(dist_dtype)
# TODO: Deprecate this.
def apply_unsloth_lora_patch(self) -> None:
def apply_lora_patch(self) -> None:
if self.cfg.unsloth_lora_mlp:
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
@@ -1027,16 +1027,6 @@ class ModelLoader:
integrate_rope_embeddings()
def apply_lora_patch(self) -> None:
if (
self.cfg.lora_mlp_kernel
or self.cfg.lora_qkv_kernel
or self.cfg.lora_o_kernel
):
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
apply_lora_kernel_patches(self.model, self.cfg)
def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
self.apply_patches()
self.set_auto_model_loader()
@@ -1181,7 +1171,6 @@ class ModelLoader:
if self.cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
self.apply_unsloth_lora_patch()
self.apply_lora_patch()
for _ in range(3):
@@ -1323,7 +1312,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
LOG.info("Initializing LoRA weights using dora. This might take longer.")
if cfg.peft_use_rslora:
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
if cfg.peft_layer_replication:

View File

@@ -5,12 +5,12 @@ import numpy as np
def get_dataset_lengths(dataset):
if "length" in dataset.column_names:
lengths = np.array(dataset["length"])
elif "position_ids" in dataset.column_names:
position_ids = dataset["position_ids"]
if "length" in dataset.data.column_names:
lengths = np.array(dataset.data.column("length"))
elif "position_ids" in dataset.data.column_names:
position_ids = dataset.data.column("position_ids")
lengths = np.array([x[-1] + 1 for x in position_ids])
else:
input_ids = dataset["input_ids"]
lengths = np.array([len(seq) for seq in input_ids])
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths

View File

@@ -396,8 +396,8 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
):
total_num_tokens = np.sum(
train_dataset.select_columns("input_ids")
.to_pandas()["input_ids"]
.apply(len)
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True)
@@ -576,7 +576,7 @@ def prepare_opinionated_env(cfg):
def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
if cfg.rl:
if cfg.rl in ("dpo", "grpo", "ipo", "orpo", "kto", "simpo"):
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]

View File

@@ -9,7 +9,7 @@ from e2e.utils import check_tensorboard, require_torch_2_5_1
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
@@ -79,7 +79,6 @@ class TestKnowledgeDistillation:
def test_llama_kd(self, temp_dir, kd_min_cfg):
cfg = DictDefault(kd_min_cfg)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
@@ -110,7 +109,6 @@ class TestKnowledgeDistillation:
| kd_min_cfg
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()

View File

@@ -1,76 +0,0 @@
"""Tests for GEGLU activation function Triton kernels."""
# pylint: disable=duplicate-code
import torch
import torch.nn.functional as F
from axolotl.kernels.geglu import geglu_backward, geglu_forward
def test_geglu_forward_shape():
"""Test that GEGLU forward pass preserves expected shapes."""
batch, seq_len, hidden_dim = 2, 3, 64
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
out = geglu_forward(gate, up)
assert out.shape == (batch, seq_len, hidden_dim)
assert out.dtype == gate.dtype
assert out.device == gate.device
def test_geglu_forward_values():
"""Test GEGLU forward pass matches PyTorch reference implementation."""
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
# Custom implementation
triton_out = geglu_forward(gate.clone(), up.clone())
# PyTorch reference
torch_out = F.gelu(gate) * up
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
def test_geglu_backward():
"""Test GEGLU backward pass matches PyTorch autograd."""
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
grad_output = torch.randn(2, 3, 64, device="cuda")
# PyTorch reference - compute intermediates
gelu_gate = F.gelu(gate)
torch_out = gelu_gate * up
torch_out.backward(grad_output)
# Custom backward pass
gate_clone = gate.clone().detach()
up_clone = up.clone().detach()
grad_output_clone = grad_output.clone()
h, grad_gate, grad_up = geglu_backward(grad_output_clone, gate_clone, up_clone)
# Compare outputs and gradients
assert torch.allclose(h, torch_out, rtol=1e-3)
assert torch.allclose(grad_gate, gate.grad, rtol=1e-3)
assert torch.allclose(grad_up, up.grad, rtol=1e-3)
def test_geglu_inplace_preservation():
"""Test that GEGLU backward doesn't modify original tensors unexpectedly."""
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
grad_output = torch.randn(2, 3, 64, device="cuda")
gate_copy = gate.clone()
up_copy = up.clone()
grad_copy = grad_output.clone()
geglu_backward(grad_output, gate, up)
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
assert not torch.equal(up, up_copy), "Up should be modified in-place"
assert not torch.equal(
grad_output, grad_copy
), "Grad output should be modified in-place"

View File

@@ -1,531 +0,0 @@
"""Tests for LoRA custom autograd."""
# pylint: disable=invalid-name,redefined-outer-name
import pytest
import torch
from bitsandbytes.functional import QuantState
from torch import nn
from axolotl.kernels.geglu import geglu_backward, geglu_forward
from axolotl.kernels.lora import (
LoRA_MLP,
LoRA_O,
LoRA_QKV,
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
get_lora_parameters,
matmul_lora,
)
from axolotl.kernels.swiglu import swiglu_backward, swiglu_forward
@pytest.fixture
def mock_quantstate():
"""Creates a mock QuantState for testing"""
shape = (64, 64)
n_blocks = shape[0] # Assuming blockwise quantization along first dimension
# Create nested state first
nested_state = QuantState(
absmax=torch.ones(n_blocks, device="cuda"), # One value per block
shape=shape,
code=torch.randint(0, 15, shape, device="cuda"), # NF4 range is 0-15
dtype=torch.float16,
blocksize=64,
quant_type="nf4",
offset=None,
state2=None,
)
# Create main state with nested state
return QuantState(
absmax=torch.ones(n_blocks, device="cuda"),
shape=shape,
code=torch.randint(0, 15, shape, device="cuda"),
dtype=torch.float16,
blocksize=64,
quant_type="nf4",
offset=torch.zeros(n_blocks, dtype=torch.int32, device="cuda"),
state2=nested_state,
)
@pytest.fixture
def sample_tensors():
"""Creates sample tensors for testing"""
torch.manual_seed(42)
batch_size, seq_len, hidden_dim = 2, 3, 64
rank = 8
out_dim = hidden_dim
return {
"X": torch.randn(
batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16
),
"W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16),
"scale": 0.5,
"shapes": {
"batch": batch_size,
"seq": seq_len,
"hidden": hidden_dim,
"out": out_dim,
"rank": rank,
},
}
@pytest.fixture
def mock_proj():
"""Creates a mock projection module for testing."""
class MockProj(nn.Module):
"""Mock projection class."""
def __init__(self, in_features=64, out_features=128, rank=8):
super().__init__()
self.base_layer = nn.Linear(in_features, out_features)
self.base_layer.to("cuda")
self.lora_A = nn.ModuleDict(
{"default": nn.Linear(in_features, rank, bias=False).to("cuda")}
)
self.lora_B = nn.ModuleDict(
{"default": nn.Linear(rank, out_features, bias=False).to("cuda")}
)
self.scaling = {"default": 0.5}
self.active_adapter = "default"
self.disable_adapters = False
self.merged = False
return MockProj()
def test_get_lora_parameters(mock_proj):
"""Tests get_lora_parameters function"""
# Test with LoRA enabled
W, _, A, B, s = get_lora_parameters(mock_proj)
assert isinstance(W, torch.Tensor)
assert W.shape == (128, 64)
assert A.shape == (8, 64)
assert B.shape == (128, 8)
assert s == 0.5
# Test with LoRA disabled
mock_proj.disable_adapters = True
W, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
# Test with merged state
mock_proj.disable_adapters = False
mock_proj.merged = True
W, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
def test_matmul_lora(sample_tensors):
"""Tests matmul_lora function"""
X = sample_tensors["X"]
W = sample_tensors["W"]
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
rank = shapes["rank"]
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test base matmul
out1 = matmul_lora(X, W, None, None, None, None)
expected1 = torch.matmul(X, W.t())
assert torch.allclose(out1, expected1, rtol=1e-3)
# Test with LoRA
out2 = matmul_lora(X, W, None, A, B, scale)
lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())
expected2 = expected1 + lora_term
assert torch.allclose(out2, expected2, rtol=1e-3)
# Test 3D input reshaping
X_3d = X.clone()
out3 = matmul_lora(X_3d, W, None, A, B, scale)
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])
@pytest.mark.parametrize(
"activation_forward,activation_backward",
[(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)],
)
def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward):
"""Tests LoRA_MLP directly with different activation functions"""
X = sample_tensors["X"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
# Create linear layers
gate_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
up_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
down_proj = nn.Linear(out_dim, hidden_dim).to(device="cuda", dtype=torch.float16)
# Test SwiGLU path
X.requires_grad = True
output = LoRA_MLP.apply(
X,
gate_proj.weight,
None, # gate_quant
None, # gate_A
None, # gate_B
None, # gate_scale
up_proj.weight,
None, # up_quant
None, # up_A
None, # up_B
None, # up_scale
down_proj.weight,
None, # down_quant
None, # down_A
None, # down_B
None, # down_scale
activation_forward,
activation_backward,
True, # inplace
)
assert output.shape == X.shape
assert not torch.isnan(output).any()
# Test backward pass
loss = output.sum()
loss.backward()
assert X.grad is not None
assert not torch.isnan(X.grad).any()
@pytest.mark.parametrize(
"activation_forward,activation_backward",
[(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)],
)
def test_lora_mlp_with_adapters(
sample_tensors, activation_forward, activation_backward
):
"""Tests LoRA_MLP with LoRA adapters"""
X = sample_tensors["X"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
rank = shapes["rank"]
# Create LoRA components
gate_A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
gate_B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
up_A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
up_B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
down_A = torch.randn(rank, out_dim, device="cuda", dtype=torch.float16)
down_B = torch.randn(hidden_dim, rank, device="cuda", dtype=torch.float16)
scale = 0.5
gate_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
up_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
down_proj = nn.Linear(out_dim, hidden_dim).to(device="cuda", dtype=torch.float16)
X.requires_grad = True
gate_A.requires_grad = True
gate_B.requires_grad = True
up_A.requires_grad = True
up_B.requires_grad = True
down_A.requires_grad = True
down_B.requires_grad = True
# Forward pass with adapters
output = LoRA_MLP.apply(
X,
gate_proj.weight,
None,
gate_A,
gate_B,
scale,
up_proj.weight,
None,
up_A,
up_B,
scale,
down_proj.weight,
None,
down_A,
down_B,
scale,
activation_forward,
activation_backward,
True,
)
assert output.shape == X.shape
assert not torch.isnan(output).any()
# Test backward pass
loss = output.sum()
loss.backward()
# Check all gradients
assert X.grad is not None
assert gate_A.grad is not None
assert gate_B.grad is not None
assert up_A.grad is not None
assert up_B.grad is not None
assert down_A.grad is not None
assert down_B.grad is not None
assert not torch.isnan(X.grad).any()
assert not torch.isnan(gate_A.grad).any()
assert not torch.isnan(gate_B.grad).any()
assert not torch.isnan(up_A.grad).any()
assert not torch.isnan(up_B.grad).any()
assert not torch.isnan(down_A.grad).any()
assert not torch.isnan(down_B.grad).any()
def test_lora_qkv(sample_tensors):
"""Tests LoRA QKV implementation with and without adapters"""
X = sample_tensors["X"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
rank = shapes["rank"]
# Create base weights
q_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
k_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
v_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
# Create LoRA matrices
q_A = torch.randn(
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
)
q_B = torch.randn(
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
)
k_A = torch.randn(
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
)
k_B = torch.randn(
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
)
v_A = torch.randn(
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
)
v_B = torch.randn(
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
)
scale = 0.5
X.requires_grad = True
# Test without LoRA adapters
Q1, K1, V1 = LoRA_QKV.apply(
X,
q_weight,
None,
None,
None,
None,
k_weight,
None,
None,
None,
None,
v_weight,
None,
None,
None,
None,
True,
)
assert Q1.shape == K1.shape == V1.shape == X.shape
loss1 = (Q1 + K1 + V1).sum()
loss1.backward()
assert X.grad is not None
# Clear gradients
X.grad = None
# Test with LoRA adapters
Q2, K2, V2 = LoRA_QKV.apply(
X,
q_weight,
None,
q_A,
q_B,
scale,
k_weight,
None,
k_A,
k_B,
scale,
v_weight,
None,
v_A,
v_B,
scale,
True,
)
assert Q2.shape == K2.shape == V2.shape == X.shape
loss2 = (Q2 + K2 + V2).sum()
loss2.backward()
# Check gradients
assert X.grad is not None
assert q_A.grad is not None
assert q_B.grad is not None
assert k_A.grad is not None
assert k_B.grad is not None
assert v_A.grad is not None
assert v_B.grad is not None
# Check for NaN values
assert not torch.isnan(X.grad).any()
assert not torch.isnan(q_A.grad).any()
assert not torch.isnan(q_B.grad).any()
assert not torch.isnan(k_A.grad).any()
assert not torch.isnan(k_B.grad).any()
assert not torch.isnan(v_A.grad).any()
assert not torch.isnan(v_B.grad).any()
def test_lora_o(sample_tensors):
"""Tests LoRA output projection"""
X = sample_tensors["X"]
W = sample_tensors["W"]
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
rank = shapes["rank"]
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test forward pass
X.requires_grad = True
output = LoRA_O.apply(X, W, None, A, B, scale)
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
# Test backward pass
loss = output.sum()
loss.backward()
assert X.grad is not None
def test_with_quantization(sample_tensors, mock_quantstate):
"""Tests LoRA with quantized weights"""
X = sample_tensors["X"] # [batch, seq, hidden]
W = sample_tensors["W"] # [out, hidden]
scale = 0.5
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
rank = shapes["rank"]
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test matmul with quantization
out = matmul_lora(X, W, mock_quantstate, A, B, scale)
assert out.shape == (X.shape[0], X.shape[1], W.shape[0])
assert not torch.isnan(out).any()
# Test with different batch sizes
X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16)
out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale)
assert out2.shape == (4, 6, W.shape[0])
assert not torch.isnan(out2).any()
@pytest.mark.parametrize(
"batch,seq,hidden,rank,out",
[
(1, 1, 32, 4, 64),
(2, 3, 64, 8, 128),
(4, 5, 128, 16, 256),
],
)
def test_shapes_and_dimensions(batch, seq, hidden, rank, out):
"""Tests various input shapes and dimensions"""
X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16)
W = torch.randn(out, hidden, device="cuda", dtype=torch.float16)
A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16)
B = torch.randn(out, rank, device="cuda", dtype=torch.float16)
scale = 0.5
result = matmul_lora(X, W, None, A, B, scale)
assert result.shape == (batch, seq, out)
def test_gradient_flow(sample_tensors):
"""Tests gradient flow through LoRA layers"""
X = sample_tensors["X"].clone()
W = sample_tensors["W"].clone()
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
hidden_dim = shapes["hidden"]
out_dim = shapes["out"]
rank = shapes["rank"]
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
X.requires_grad = True
A.requires_grad = True
B.requires_grad = True
# Forward pass
out = matmul_lora(X, W, None, A, B, scale)
loss = out.sum()
# Backward pass
loss.backward()
assert X.grad is not None
assert A.grad is not None
assert B.grad is not None
assert not torch.isnan(X.grad).any()
assert not torch.isnan(A.grad).any()
assert not torch.isnan(B.grad).any()
@pytest.mark.parametrize(
"apply_function",
[apply_lora_mlp_swiglu, apply_lora_mlp_geglu],
)
def test_inplace_operations(sample_tensors, apply_function):
"""Tests inplace operation behavior"""
X = sample_tensors["X"]
shapes = sample_tensors["shapes"]
# Create MLP with both inplace=True and inplace=False
mlp = type(
"MLPModule",
(),
{
"gate_proj": nn.Linear(shapes["hidden"], shapes["out"]).to(
device="cuda", dtype=torch.float16
),
"up_proj": nn.Linear(shapes["hidden"], shapes["out"]).to(
device="cuda", dtype=torch.float16
),
"down_proj": nn.Linear(shapes["out"], shapes["hidden"]).to(
device="cuda", dtype=torch.float16
),
},
)
out1 = apply_function(mlp, X.clone(), inplace=True)
out2 = apply_function(mlp, X.clone(), inplace=False)
assert torch.allclose(out1, out2, rtol=1e-3)

View File

@@ -1,103 +0,0 @@
"""Tests for quantization utility functions."""
# pylint: disable=invalid-name
import torch
from bitsandbytes.functional import QuantState
from axolotl.kernels.quantize import dequantize
def test_dequantize_null_state():
"""Test that dequantize returns input unchanged when quant_state is None"""
W = torch.randn(32, 32)
assert torch.equal(dequantize(W, None), W)
def test_dequantize_shape_preservation():
"""Test that dequantization preserves expected shapes"""
shape = (32, 32)
W = torch.randn(shape, device="cuda")
quant_state = QuantState(
absmax=torch.ones(shape[0], device="cuda"),
shape=shape,
code=torch.randint(0, 15, shape, device="cuda"),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=torch.zeros(shape[0], dtype=torch.int32, device="cuda"),
state2=QuantState(
absmax=torch.ones(shape[0], device="cuda"),
shape=shape,
code=torch.randint(0, 15, shape, device="cuda"),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=None,
state2=None,
),
)
result = dequantize(W, quant_state)
assert result.shape == shape
assert result.dtype == torch.float16
assert result.device == W.device
def test_dequantize_transposed():
"""Test that transposed input produces transposed output"""
shape = (32, 32)
W = torch.randn(1, shape[1], device="cuda") # Transposed input
quant_state = QuantState(
absmax=torch.ones(1),
shape=shape,
code=torch.randint(0, 15, shape),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=torch.zeros(1, dtype=torch.int32),
state2=QuantState(
absmax=torch.ones(1),
shape=shape,
code=torch.randint(0, 15, shape),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=None,
state2=None,
),
)
result = dequantize(W, quant_state)
assert result.shape[0] == shape[0]
def test_dequantize_output_tensor():
"""Test dequantization with provided output tensor"""
shape = (32, 32)
W = torch.randn(shape, device="cuda")
out = torch.empty(shape, dtype=torch.float16, device="cuda")
quant_state = QuantState(
absmax=torch.ones(shape[0]),
shape=shape,
code=torch.randint(0, 15, shape),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=torch.zeros(shape[0], dtype=torch.int32),
state2=QuantState(
absmax=torch.ones(shape[0]),
shape=shape,
code=torch.randint(0, 15, shape),
dtype=torch.float16,
blocksize=32,
quant_type="nf4",
offset=None,
state2=None,
),
)
result = dequantize(W, quant_state, out=out)
assert result is out

View File

@@ -1,78 +0,0 @@
"""Tests for SwiGLU activation function Triton kernels."""
# pylint: disable=duplicate-code
import torch
import torch.nn.functional as F
from axolotl.kernels.swiglu import swiglu_backward, swiglu_forward
def test_swiglu_forward_shape():
"""Test that SwiGLU forward pass preserves expected shapes"""
batch, seq_len, hidden_dim = 2, 3, 64
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
out = swiglu_forward(gate, up)
assert out.shape == (batch, seq_len, hidden_dim)
assert out.dtype == gate.dtype
assert out.device == gate.device
def test_swiglu_forward_values():
"""Test SwiGLU forward pass matches PyTorch reference implementation"""
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
# Custom implementation
triton_out = swiglu_forward(gate.clone(), up.clone())
# PyTorch reference
torch_out = F.silu(gate) * up
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
def test_swiglu_backward():
"""Test SwiGLU backward pass matches PyTorch autograd"""
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
grad_output = torch.randn(2, 3, 64, device="cuda")
# PyTorch reference - compute intermediates
silu_gate = F.silu(gate)
torch_out = silu_gate * up
torch_out.backward(grad_output)
# Custom backward pass
gate_clone = gate.clone().detach()
up_clone = up.clone().detach()
grad_output_clone = grad_output.clone()
h, our_grad_gate, our_grad_up = swiglu_backward(
grad_output_clone, gate_clone, up_clone
)
# Compare outputs and gradients
assert torch.allclose(h, torch_out, rtol=1e-3)
assert torch.allclose(our_grad_gate, gate.grad, rtol=1e-3)
assert torch.allclose(our_grad_up, up.grad, rtol=1e-3)
def test_swiglu_inplace_preservation():
"""Test that SwiGLU backward doesn't modify original tensors unexpectedly"""
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
grad_output = torch.randn(2, 3, 64, device="cuda")
gate_copy = gate.clone()
up_copy = up.clone()
grad_copy = grad_output.clone()
swiglu_backward(grad_output, gate, up)
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
assert not torch.equal(up, up_copy), "Up should be modified in-place"
assert not torch.equal(
grad_output, grad_copy
), "Grad output should be modified in-place"

View File

@@ -1,173 +0,0 @@
"""
GRPO test suite
"""
import random
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from e2e.utils import require_vllm
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
class TestGRPO:
"""
Test case for GRPO training using multilpe GPUs
"""
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
fout.write(
"""import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
"""
)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_dora(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
"vllm_gpu_memory_utilization": 0.15,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"peft_use_dora": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_fft(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
"vllm_gpu_memory_utilization": 0.15,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)

View File

@@ -9,7 +9,7 @@ from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard, require_torch_lt_2_6_0
from e2e.utils import check_tensorboard
from axolotl.utils.dict import DictDefault
@@ -24,7 +24,6 @@ class TestMultiGPURay:
Test cases for AnyScale Ray post training
"""
@require_torch_lt_2_6_0
def test_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -81,7 +80,6 @@ class TestMultiGPURay:
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@require_torch_lt_2_6_0
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],

View File

@@ -1,414 +0,0 @@
"""Integration tests for LoRA activation and attention kernels."""
# pylint: disable=redefined-outer-name
import pytest
import torch
from accelerate.state import PartialState
from peft import PeftModelForCausalLM, get_peft_config
from transformers import AutoModelForCausalLM, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
apply_lora_qkv,
)
from axolotl.monkeypatch.lora_kernels import (
apply_lora_kernel_patches,
patch_self_attn_lora,
)
from axolotl.utils.dict import DictDefault
MODEL_CONFIGS = [
{
"name": "openaccess-ai-collective/tiny-mistral",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
{
"name": "Qwen/Qwen2-7B",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
{
"name": "HuggingFaceTB/SmolLM2-135M",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float32,
},
{
"name": "mhenrichsen/gemma-2b",
"expected_activation": apply_lora_mlp_geglu,
"dtype": torch.float16,
},
]
@pytest.fixture(autouse=True)
def init_accelerate():
"""Initialize Accelerate state before tests."""
_ = PartialState()
@pytest.fixture
def small_llama_model():
"""Create a small LLaMA model for testing."""
config = {
"vocab_size": 100,
"hidden_size": 128,
"intermediate_size": 256,
"num_hidden_layers": 2,
"num_attention_heads": 4,
}
return LlamaForCausalLM(LlamaConfig(**config))
def test_attention_patching_integration():
"""Test attention patching in integration context."""
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
# Store the original implementation
original_forward = getattr(LlamaAttention, "forward")
# Apply patch
patch_self_attn_lora(cfg)
# Get the new forward method
patched_forward = LlamaAttention.forward
# Check the forward method was replaced
assert original_forward is not patched_forward
assert patched_forward.__name__ == "axolotl_attn_forward"
# Check original implementation was stored
assert hasattr(LlamaAttention, "_original_forward")
# Clean up
setattr(LlamaAttention, "forward", original_forward)
delattr(LlamaAttention, "_original_forward")
def test_swiglu_mlp_integration(small_llama_model):
"""Test SwiGLU activation in LoRA MLP context."""
peft_config = get_peft_config(
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "none",
}
)
model = PeftModelForCausalLM(small_llama_model, peft_config).to("cuda")
cfg = DictDefault({"lora_mlp_kernel": True})
# Apply patches
patched_model = apply_lora_kernel_patches(model, cfg)
# Verify patches
layer = patched_model.model.model.layers[0]
assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu
# Test forward pass
batch_size, seq_len = 2, 10
hidden_states = torch.randn(
batch_size, seq_len, model.config.hidden_size, device=model.device
)
position_ids = (
torch.arange(seq_len, device=model.device).unsqueeze(0).expand(batch_size, -1)
)
cos, sin = model.model.model.rotary_emb(hidden_states, position_ids)
inputs = {
"hidden_states": hidden_states,
"attention_mask": None,
"position_embeddings": (cos, sin),
"output_attentions": False,
"use_cache": False,
"past_key_value": None,
}
# Compare outputs
with torch.no_grad():
original_output = model.model.model.layers[0](**inputs)[0]
patched_output = layer(**inputs)[0]
assert torch.allclose(original_output, patched_output, rtol=1e-4)
def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda"
)
peft_config = get_peft_config(
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "none",
}
)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
patched_model = apply_lora_kernel_patches(model, cfg)
# Verify patches
layer = patched_model.model.model.layers[0]
assert layer.mlp.forward.__func__ is apply_lora_mlp_geglu
# Test end-to-end
inputs = torch.randint(0, 100, (1, 20), device=model.device, dtype=torch.long)
with torch.no_grad():
original_output = model(inputs).logits
patched_output = patched_model(inputs).logits
assert torch.allclose(original_output, patched_output, rtol=1e-4)
@pytest.mark.parametrize(
"model_name,expected_activation",
[
("HuggingFaceTB/SmolLM2-135M", apply_lora_mlp_swiglu),
("mhenrichsen/gemma-2b", apply_lora_mlp_geglu),
],
)
def test_model_specific_activation(model_name, expected_activation):
"""Test that each model type gets the correct activation function."""
model = AutoModelForCausalLM.from_pretrained(model_name)
peft_config = get_peft_config(
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "none",
}
)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0]
assert layer.mlp.forward.__func__ is expected_activation
def test_kernel_patch_conditions():
"""Test various conditions that should prevent kernel patching."""
test_configs = [
# Dropout prevents patching
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0.1,
"bias": "none",
},
# Bias prevents patching
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "lora_only",
},
]
for config in test_configs:
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
# Should not patch
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0].mlp
# Verify no patches applied
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
assert layer.forward.__func__ is not apply_lora_mlp_geglu
def test_kernel_config_options():
"""Test that kernel configuration options are respected."""
# Test different configurations
test_configs = [
(
{"lora_mlp_kernel": True, "lora_qkv_kernel": False, "lora_o_kernel": False},
lambda layer: (
layer.mlp.forward.__func__ is apply_lora_mlp_swiglu
and layer.self_attn.apply_qkv.__func__ is not apply_lora_qkv
and layer.self_attn.apply_o.__func__ is not apply_lora_o
),
),
(
{"lora_mlp_kernel": False, "lora_qkv_kernel": True, "lora_o_kernel": False},
lambda layer: (
layer.mlp.forward.__func__ is not apply_lora_mlp_swiglu
and layer.self_attn.apply_qkv.__func__ is apply_lora_qkv
and layer.self_attn.apply_o.__func__ is not apply_lora_o
),
),
(
{"lora_mlp_kernel": False, "lora_qkv_kernel": False, "lora_o_kernel": True},
lambda layer: (
layer.mlp.forward.__func__ is not apply_lora_mlp_swiglu
and layer.self_attn.apply_qkv.__func__ is not apply_lora_qkv
and layer.self_attn.apply_o.__func__ is apply_lora_o
),
),
]
for config_dict, check_fn in test_configs:
# Create fresh model for each test
config = {
"vocab_size": 100,
"hidden_size": 128,
"intermediate_size": 256,
"num_hidden_layers": 2,
"num_attention_heads": 4,
}
small_llama_model = LlamaForCausalLM(LlamaConfig(**config))
peft_config = get_peft_config(
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": [
"gate_proj",
"up_proj",
"down_proj",
"q_proj",
"k_proj",
"v_proj",
"o_proj",
],
"lora_dropout": 0,
"bias": "none",
}
)
model = PeftModelForCausalLM(small_llama_model, peft_config).to("cuda")
cfg = DictDefault(config_dict)
patched_model = apply_lora_kernel_patches(model, cfg)
# Verify only requested optimizations were applied
for layer in patched_model.model.model.layers:
assert check_fn(layer), f"Failed for config: {config_dict}"
# Clean up
del model
del small_llama_model
del patched_model
def get_lora_config():
"""Get standard LoRA configuration for testing."""
return {
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "none",
}
def get_test_inputs(model, seq_length=20):
"""Generate test inputs for model evaluation."""
return torch.randint(
0,
model.config.vocab_size,
(1, seq_length),
device=model.device,
dtype=torch.long,
)
@pytest.mark.parametrize("model_config", MODEL_CONFIGS)
def test_model_architecture(model_config):
"""Test LoRA kernel patches across different model architectures."""
# Load model with appropriate dtype
model = AutoModelForCausalLM.from_pretrained(
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda"
)
# Apply LoRA configuration
peft_config = get_peft_config(get_lora_config())
model = PeftModelForCausalLM(model, peft_config)
# Apply kernel patches
cfg = DictDefault({"lora_mlp_kernel": True})
patched_model = apply_lora_kernel_patches(model, cfg)
# Verify correct activation function
layer = patched_model.model.model.layers[0]
assert (
layer.mlp.forward.__func__ is model_config["expected_activation"]
), f"Wrong activation for {model_config['name']}"
# Test forward pass
inputs = get_test_inputs(model)
with torch.no_grad():
original_output = model(inputs).logits
patched_output = patched_model(inputs).logits
# Check outputs match
assert torch.allclose(
original_output, patched_output, rtol=1e-4
), f"Outputs don't match for {model_config['name']}"
# pylint: disable=duplicate-code
def test_kernel_training_integration():
"""Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer
# Create minimal config
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"sequence_len": 1024,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
# Load model
model, _ = load_model_and_tokenizer(cfg=cfg)
# Verify correct activation function
layer = model.model.model.layers[0]
assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu

View File

@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard
@@ -76,9 +76,7 @@ class TestFAXentropyLlama:
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -10,7 +10,7 @@ from pathlib import Path
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
@@ -73,8 +73,6 @@ class TestReLoraLlama(unittest.TestCase):
"use_tensorboard": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_preference_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -63,8 +63,6 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -110,8 +108,6 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -157,8 +153,6 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -204,8 +198,6 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -250,8 +242,6 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -299,8 +289,6 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -365,8 +353,6 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
@@ -56,8 +56,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"use_tensorboard": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -65,8 +65,6 @@ class TestFalcon(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -120,8 +118,6 @@ class TestFalcon(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -161,8 +157,6 @@ class TestFalcon(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -10,7 +10,7 @@ from e2e.utils import check_model_output_exists
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
@@ -56,8 +56,6 @@ class TestLlama:
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -101,8 +99,6 @@ class TestLlama:
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -142,8 +138,6 @@ class TestLlama:
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -10,7 +10,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard
@@ -69,8 +69,6 @@ class TestPretrainLlama:
"use_tensorboard": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -62,8 +62,6 @@ class TestLlamaVision(unittest.TestCase):
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -59,8 +59,6 @@ class TestLoraLlama(unittest.TestCase):
"max_steps": 20,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -11,7 +11,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -59,8 +59,6 @@ class TestMamba(unittest.TestCase):
"save_safetensors": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

Some files were not shown because too many files have changed in this diff Show More