Compare commits
1 Commits
patch_lora
...
docs-lint-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f24efd77a1 |
2
.github/workflows/main.yml
vendored
2
.github/workflows/main.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras: vllm
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
|
||||
5
.github/workflows/multi-gpu-e2e.yml
vendored
5
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -24,21 +24,20 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras: # no vllm support for 2.4.1
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras: vllm
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
# awaiting vllm#12721
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
|
||||
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -204,7 +204,7 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
num_gpus: 1
|
||||
axolotl_extras: vllm
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -4,8 +4,8 @@ set -e
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
||||
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
"""Modal app to run axolotl GPU tests"""
|
||||
"""
|
||||
modal application to run axolotl gpu tests in Modal
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
|
||||
@@ -91,12 +91,7 @@ datasets:
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
|
||||
data_files: # Optional[str] path to source data files
|
||||
|
||||
shards: # Optional[int] split dataset into N pieces (use with shards_idx)
|
||||
shards_idx: # Optional[int] = 0 the index of sharded dataset to use
|
||||
|
||||
preprocess_shards: # Optional[int] process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)
|
||||
|
||||
shards: # Optional[int] number of shards to split data into
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
train_on_split: train # Optional[str] name of dataset split to load from
|
||||
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||
@@ -142,19 +137,10 @@ datasets:
|
||||
|
||||
# Key containing the messages (default: "messages")
|
||||
field_messages: messages
|
||||
|
||||
# Mapping of properties from the input dataset to the chat template.
|
||||
# (default: message_property_mappings={'role':'role', 'content':'content'})
|
||||
# If a property exists in the template but not in this mapping, the system will attempt
|
||||
# to load it directly from the message using the property name as the key.
|
||||
# Example: In the mapping below, 'from' is loaded from input dataset and used as 'role',
|
||||
# while 'value' is loaded and used as 'content' in the chat template.
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
# ...
|
||||
|
||||
message_property_mappings:
|
||||
# Key for role in each message (default: "role")
|
||||
message_field_role: role
|
||||
# Key for content in each message (default: "content")
|
||||
message_field_content: content
|
||||
|
||||
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
||||
roles:
|
||||
@@ -314,13 +300,6 @@ lora_modules_to_save:
|
||||
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# Apply custom LoRA autograd functions and activation function Triton kernels for
|
||||
# speed and memory savings
|
||||
# See: https://axolotl-ai-cloud.github.io/axolotl/docs/lora_optims.html
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
|
||||
# LoRA+ hyperparameters
|
||||
# For more details about the following options, see:
|
||||
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
|
||||
@@ -369,9 +348,6 @@ comet_mode: # Create a new experiment ("create") or log to an existing one ("get
|
||||
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
|
||||
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
|
||||
|
||||
# Tensorboard
|
||||
use_tensorboard: # Optional[bool]
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
output_dir: ./completed-model
|
||||
|
||||
@@ -406,9 +382,6 @@ save_total_limit: # Checkpoints saved at a time
|
||||
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
||||
max_steps:
|
||||
|
||||
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
|
||||
include_tokens_per_second:
|
||||
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
||||
|
||||
@@ -6,7 +6,7 @@ order: 3
|
||||
|
||||
## sharegpt
|
||||
|
||||
IMPORTANT: ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
|
||||
IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
|
||||
|
||||
## pygmalion
|
||||
|
||||
@@ -22,7 +22,7 @@ Chat Template strategy uses a jinja2 template that converts a list of messages i
|
||||
{"conversations": [{"role": "...", "content": "..."}]}
|
||||
```
|
||||
|
||||
See [configs](../config.qmd) for full configs and supported templates.
|
||||
See `config.qmd` for full configs and supported templates.
|
||||
|
||||
### Migrating from sharegpt
|
||||
|
||||
@@ -42,9 +42,8 @@ datasets:
|
||||
type: chat_template
|
||||
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
# new (if setting a new chat_template like chatml, gemma, etc)
|
||||
chat_template: chatml
|
||||
@@ -53,9 +52,8 @@ datasets:
|
||||
type: chat_template
|
||||
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
```
|
||||
|
||||
We recommend checking the below examples for other usecases.
|
||||
@@ -140,9 +138,8 @@ datasets:
|
||||
type: chat_template
|
||||
chat_template: tokenizer_default
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
roles_to_train: []
|
||||
train_on_eos: turn
|
||||
message_field_training: train
|
||||
|
||||
@@ -1,458 +1,14 @@
|
||||
---
|
||||
title: Dataset Formats
|
||||
description: Guide to Dataset Formats in Axolotl
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-depth: 5
|
||||
description: Supported dataset formats.
|
||||
listing:
|
||||
fields: [title, description]
|
||||
type: table
|
||||
sort-ui: false
|
||||
filter-ui: false
|
||||
max-description-length: 250
|
||||
---
|
||||
|
||||
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL format. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
||||
|
||||
Axolotl is a training framework that aims to make the process convenient yet flexible to users by simply passing a config yaml file.
|
||||
|
||||
As there are a lot of available options in Axolotl, this guide aims to provide an simplify the user experience to choosing the proper choice.
|
||||
|
||||
Axolotl supports 3 kinds of training methods: pre-training, supervised fine-tuning, and preference-based post-training (e.g. DPO, ORPO, PRMs). Each method has their own dataset format which are described below.
|
||||
|
||||
## [Pre-training](pretraining.qmd)
|
||||
|
||||
When aiming to train on large corpora of text datasets, pre-training is your go-to choice. Due to the size of these datasets, downloading the entire-datasets before beginning training would be prohibitively time-consuming. Axolotl supports [streaming](https://huggingface.co/docs/datasets/en/stream) to only load batches into memory at a time.
|
||||
|
||||
A sample format for a pre-training dataset is as follows:
|
||||
|
||||
```json
|
||||
{"text": "first row"}
|
||||
{"text": "second row"}
|
||||
...
|
||||
```
|
||||
|
||||
It is typically recommended to save your dataset as `.jsonl` due to its flexibility and simplicity.
|
||||
|
||||
Axolotl supports loading from a Hugging Face hub repo or from local files.
|
||||
|
||||
::: {.callout-important}
|
||||
For pre-training only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts.
|
||||
:::
|
||||
|
||||
### Pre-training from Hugging Face hub datasets
|
||||
|
||||
As an example, to train using a Hugging Face dataset `hf_org/name`, you can pass the following config:
|
||||
|
||||
```yaml
|
||||
pretraining_dataset: hf_org/name
|
||||
```
|
||||
|
||||
### Pre-training from local dataset files
|
||||
|
||||
Given a few corpus files: `A.jsonl`, `B.jsonl`, and `C.jsonl`, your config will look like the below:
|
||||
|
||||
```yaml
|
||||
pretraining_dataset:
|
||||
- path: json
|
||||
data_files:
|
||||
- A.jsonl
|
||||
- B.jsonl
|
||||
- C.jsonl
|
||||
```
|
||||
|
||||
While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet`, `arrow`, `SQL`, `Webdataset`) that are supported by [`Dataset.load_dataset`](https://huggingface.co/docs/datasets/loading#local-and-remote-files)
|
||||
|
||||
### Pre-training without streaming
|
||||
|
||||
On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
|
||||
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.
|
||||
|
||||
From Hugging Face:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: hf_org/name
|
||||
type: completion
|
||||
```
|
||||
|
||||
From local files (either example works):
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type: completion
|
||||
|
||||
- path: json
|
||||
data_files: ["A.jsonl", "B.jsonl", "C.jsonl"]
|
||||
type: completion
|
||||
```
|
||||
|
||||
### Pre-training dataset configuration tips
|
||||
|
||||
#### Setting max_steps
|
||||
|
||||
When using streaming for large datasets, Axolotl does not know in advance how large the dataset is and does not know when to stop.
|
||||
|
||||
Therefore, it is necessary to set `max_steps: int` in your config for pre-training to run, so that Axolotl knows when to stop training.
|
||||
|
||||
One step is equal to `sequence_len * micro_batch_size * gradient_accumulation_steps * total_num_gpus` tokens.
|
||||
|
||||
#### Group_by_length
|
||||
|
||||
It is recommended to leave this off if downloading from Hugging Face hub as it would download the entire dataset which can be very large.
|
||||
|
||||
## Supervised fine-tuning (SFT)
|
||||
|
||||
Supervised fine-tuning is the process of training models to respond to an instruction or chat input.
|
||||
|
||||
As there are a wide variety of dataset formats, Axolotl tries to support a majority of the formats available in public datasets.
|
||||
|
||||
Axolotl provides four approaches for loading datasets, however, it's easier to work backwards from the dataset you have available to figure out which approach to use.
|
||||
|
||||
A flow chart is as follows:
|
||||
|
||||
1. Do you already have the dataset tokenized? If yes, check [Pre-Tokenized Dataset](#pre-tokenized-dataset).
|
||||
|
||||
2. Do you want to format the dataset yourself and manually choose each section to mask? If yes, check [Template Free Dataset](#template-free-dataset)
|
||||
|
||||
3. Is your dataset in a "conversation" format, containing a `list[messages]`? If yes, check [Conversation Dataset](#conversation-dataset)
|
||||
|
||||
4. Is your dataset in an "instruct" format, containing `{ instruction, response }`? If yes, check [Instruction Dataset](#instruction-dataset)
|
||||
|
||||
If you went through the flow chart and did not find one that matches, it is recommended to preprocess your dataset into one of the above or create a thread on Github Discussion.
|
||||
|
||||
::: {.callout-tip}
|
||||
You can mix and match within each approach or across approaches to train a model on a variety of datasets.
|
||||
:::
|
||||
|
||||
### [Pre-Tokenized Dataset](tokenized.qmd)
|
||||
|
||||
We suggest this approach when you want to bring your own tokenized dataset.
|
||||
|
||||
Axolotl expects the dataset to have three keys:
|
||||
- `input_ids`: from tokenizing formatted prompt
|
||||
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
|
||||
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.
|
||||
|
||||
::: {.callout-tip}
|
||||
Make sure to add BOS/EOS tokens to your prompt and mask it appropriately.
|
||||
:::
|
||||
|
||||
A config for this would look like:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type:
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
`type: ` is empty!
|
||||
:::
|
||||
|
||||
### [Template Free Dataset](template_free.qmd)
|
||||
|
||||
We reccomend this approach when you want granular control over the prompt formatting, special tokens, and masking, whilst letting Axolotl handle the tokenization. This is very useful if your dataset has unique prompts that differ across samples and where one single general template wouldn't suffice.
|
||||
|
||||
In the example below, you could see that there is no proper structure. At the same time, it's very flexible as there are no constraints on how your prompt can look.
|
||||
|
||||
```json
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"label": true,
|
||||
"text": "<s>Hello\n"
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "hi there!. "
|
||||
},
|
||||
{
|
||||
"label": false,
|
||||
"text": "goodbye "
|
||||
},
|
||||
{
|
||||
"label": true,
|
||||
"text": "farewell</s>"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Each prompt must be have a key called `segments` which is a list of `{ text, label }`.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type: input_output
|
||||
```
|
||||
|
||||
### [Conversation Dataset](conversation.qmd)
|
||||
|
||||
`conversation` messages are a list of messages which usually contain a `role` and `content` key.
|
||||
|
||||
::: {.callout-tip}
|
||||
Fun fact: Axolotl synonymously refers to "chat" messages as `conversation` messages due to how FastChat initially used this term to build a widely used [fastchat conversation](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) method for formatting chat messages prior to the creation of `chat_templates`.
|
||||
:::
|
||||
|
||||
#### What are `chat_templates`?
|
||||
|
||||
The current most popular and convenient method for inference is to use `chat_templates` for formatting prompts. Axolotl supports using `chat_templates` for training to ensure that the model performs in the same environment as in inference.
|
||||
|
||||
Here's a quick rundown on `chat_template`: A `chat_template` is a Jinja2 template which formats a list of messages into a prompt.
|
||||
|
||||
An example of a prompt formatted into a popular template called ChatML can be seen below:
|
||||
|
||||
Single prompt (pretty-printed):
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "How can I help you?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Can you add 3+5?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The answer is 8."
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The ChatML template is as follows:
|
||||
```jinja2
|
||||
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}
|
||||
```
|
||||
|
||||
The above prompt formatted into this template will result in:
|
||||
|
||||
```
|
||||
<|im_start|>user
|
||||
Hi<|im_end|>
|
||||
<|im_start|>assistant
|
||||
How can I help you?<|im_end|>
|
||||
<|im_start|>user
|
||||
Can you add 3+5?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
The answer is 8.<|im_end|>
|
||||
```
|
||||
|
||||
By using delimiters (`<|im_start|>` and `<|im_end|>`), a prompt separates different speakers which helps the model identify which portion belongs to whom.
|
||||
|
||||
#### Common Conversation Dataset formats
|
||||
|
||||
Older conversation datasets with the following format are colloquially called `sharegpt` datasets.
|
||||
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
Newer conversation datasets usually follow the OpenAI format.
|
||||
|
||||
```json
|
||||
{"messages": [{"role": "...", "content": "..."}]}
|
||||
```
|
||||
|
||||
Axolotl supports both as well as allowing customization of any kind of key.
|
||||
|
||||
#### [Chat Template Usage](conversation.qmd#chat_template)
|
||||
|
||||
To properly use this method, it is important to identify three things:
|
||||
|
||||
1. Which `chat_template` would you use?
|
||||
|
||||
2. What are the keys in your dataset, and what are the possible roles? For example, in OpenAI format, the keys would be `messages`, `role`, and `content`, respectively, whereas the possible roles are `system`, `user`, and `assistant`.
|
||||
|
||||
3. What do you want to mask? For instance, only assistant messages, only last message, or nothing.
|
||||
|
||||
##### Choosing a `chat_template`
|
||||
|
||||
There are a lot of `chat_templates` out there. Axolotl supports the common ones: [supported chat templates](https://github.com/axolotl-ai-cloud/axolotl/blob/860609392184cf62a7e0ca676658b170e059ce6c/src/axolotl/utils/chat_templates.py#L17). For example, to use ChatML, it would be `chat_template: chatml`.
|
||||
|
||||
However, it is also possible to use the already configured template within the tokenizer by specifying `chat_template: tokenizer_default`. If you want a fallback (in case some tokenizer does not have it pre-configured), you can do `chat_template: tokenizer_default_fallback_chatml` to fallback to the ChatML template if a tokenizer template was not found.
|
||||
|
||||
One last but powerful approach is to bring your own template. This can be set via:
|
||||
|
||||
```yaml
|
||||
chat_template_jinja: # your template
|
||||
```
|
||||
|
||||
##### Setting `chat_template` dataset keys
|
||||
|
||||
We currently default to OpenAI format for dataset keys, so if that's your current dataset format, there's nothing to do here.
|
||||
|
||||
If your dataset format is different, here are the keys you should check (with their defaults):
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
...
|
||||
field_messages: messages # this should point to the key containing the list of conversations
|
||||
message_property_mappings: # this is a mapping from keys in your dataset to keys in chat_template
|
||||
role: role
|
||||
content: content
|
||||
```
|
||||
|
||||
In some `chat_templates` (e.g. [Gemma](https://huggingface.co/google/gemma-2b-it/blob/main/tokenizer_config.json#L1507)), the roles are hardcoded to `user` and `assistant`. Consequently, you may find it necessary to map the roles in your dataset to these above. We currently have some defaults that should work for common datasets, but if you get a `KeyError`, it would be necessary to add mapping for your roles. Here is an example of how it would look like:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
...
|
||||
roles:
|
||||
assistant:
|
||||
- gpt
|
||||
- model
|
||||
user:
|
||||
- human
|
||||
```
|
||||
|
||||
In the example above, all `gpt` and `model` values are converted to `assistant`. All `human` values are converted to `user.`
|
||||
|
||||
##### Handling masking
|
||||
|
||||
The common use case for `chat_template` is for chat messages, therefore, it is common to mask all non-assistant messages. Assistant messages refer to the bot messages that you want the model to learn on.
|
||||
|
||||
To train on all `assistant` messages, you would set the following configs.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
...
|
||||
roles_to_train: ["assistant"]
|
||||
train_on_eos: "turn"
|
||||
```
|
||||
|
||||
The `train_on_eos` config means that it would mask all EOS tokens for turns that aren't assistant-turns. The other options are: `all` and `last` to choose which EOS to train on.
|
||||
|
||||
Perhaps, you want to train on `assistant` and `narrator` roles, you can simply add `narrator` to the list of `roles_to_train`. You would also need to add it to the mapping of `roles` above.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
...
|
||||
roles_to_train: ["assistant", "narrator"]
|
||||
roles:
|
||||
assistant:
|
||||
- gpt
|
||||
- model
|
||||
user:
|
||||
- human
|
||||
narrator: ["narrator"]
|
||||
```
|
||||
|
||||
#### Applying `chat_template`
|
||||
|
||||
Once all the above steps are completed, you could combine all these configs together to form a bespoke configuration for your custom dataset. The final step would be to correctly set the EOS token in your config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type: chat_template
|
||||
|
||||
# step 1
|
||||
chat_template: chatml
|
||||
|
||||
# step 2
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
roles:
|
||||
assistant:
|
||||
- gpt
|
||||
- model
|
||||
- assistant
|
||||
user:
|
||||
- human
|
||||
- user
|
||||
|
||||
# step 3
|
||||
roles_to_train: ["assistant"]
|
||||
train_on_eos: "turn"
|
||||
|
||||
special_tokens:
|
||||
eos_token: <|im_end|>
|
||||
```
|
||||
|
||||
If this config were to be applied to the sample dataset above, the output would look as such (which can be retrieved via `axolotl preprocess config.yaml --debug`):
|
||||
|
||||
```
|
||||
<|im_start|>(-100, 128256) user(-100, 882)
|
||||
(-100, 198) Hi(-100, 13347) <|im_end|>(-100, 128257)
|
||||
(-100, 198) <|im_start|>(-100, 128256) assistant(-100, 78191)
|
||||
(-100, 198) How(4438, 4438) can(649, 649) I(358, 358) help(1520, 1520) you(499, 499) ?(30, 30) <|im_end|>(128257, 128257)
|
||||
(-100, 198) <|im_start|>(-100, 128256) user(-100, 882)
|
||||
(-100, 198) Can(-100, 6854) you(-100, 499) add(-100, 923) (-100, 220) 3(-100, 18) +(-100, 10) 5(-100, 20) ?(-100, 30) <|im_end|>(-100, 128257)
|
||||
(-100, 198) <|im_start|>(-100, 128256) assistant(-100, 78191)
|
||||
(-100, 198) The(791, 791) answer(4320, 4320) is(374, 374) (220, 220) 8(23, 23) .(13, 13) <|im_end|>(128257, 128257)
|
||||
(-100, 198)
|
||||
```
|
||||
|
||||
The first number refers to the label, the second refers to the `token_id`. For example, `-100` labels appear on non-assistant portions, meaning that they are masked during. For assistant portions, the label is the same as the `token_id`.
|
||||
|
||||
### [Instruction Dataset](inst_tune.qmd)
|
||||
|
||||
Instruction datasets are used to train instruction-following models and comprise a prompt, containing an instruction, and a single response. In contrast to chat datasets which may be multi-turn, instruct datasets are typically single-turn.
|
||||
|
||||
An example is of a common format called Alpaca:
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "output": "..."}
|
||||
```
|
||||
|
||||
Using those keys, a prompt can be built based on it.
|
||||
```
|
||||
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction:
|
||||
{instruction}
|
||||
|
||||
### Input:
|
||||
{input}
|
||||
|
||||
### Response:
|
||||
{output}
|
||||
```
|
||||
|
||||
This can be configured as such:
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type: alpaca
|
||||
```
|
||||
|
||||
Axolotl supports many kinds of instruction dataset. All of them can be found here (https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/inst_tune.html) with their respective type and sample row format.
|
||||
|
||||
#### Custom Instruct Prompt Format
|
||||
|
||||
Due to the myriad possibilities of instruction formats, Axolotl allows customizing your own instruction format without having to dive into the code directly.
|
||||
|
||||
In the example below, a sample row is used to output in `mistral_v1` format.
|
||||
```json
|
||||
{"input": "...", "output": "..."}
|
||||
```
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: repo
|
||||
type:
|
||||
system_prompt: ""
|
||||
|
||||
field_system:
|
||||
field_instruction: input
|
||||
field_input:
|
||||
field_output: output
|
||||
|
||||
# multi-line example with input
|
||||
format: |-
|
||||
[INST] {instruction} {input} [/INST]
|
||||
|
||||
# single-line example without input
|
||||
no_input_format: "[INST] {instruction} [/INST]"
|
||||
```
|
||||
|
||||
The config sets that the `field_instruction` is actually named `input`, and the `field_input` is empty as we don't have an `input` in this sample. Generally, `instruction` can be thought as the question to the model, and `input` as the additional information with `output` being the response. It is not necessary to have an `input` nor `system`. In the end, the most important part is to understand what format you want it to look like and how you can customize this to your use case.
|
||||
|
||||
## Reinforcement Learning from Human Feedback (RLHF)
|
||||
|
||||
As there are multiple RLHF methods with their own dataset requirements. Please see [RLHF datasets](../rlhf.qmd) documentation for more detail.
|
||||
Below are these various formats organized by task:
|
||||
|
||||
@@ -23,7 +23,3 @@ description: Frequently asked questions
|
||||
**Q: The codes is stuck on saving preprocessed datasets.**
|
||||
|
||||
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
|
||||
> A: This means that the property mapping for the stated attribute does not exist when building `chat_template` prompt. For example, if `no attribute 'content'`, please check you have added the correct mapping for `content` under `message_property_mappings`.
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
---
|
||||
title: "LoRA Optimizations"
|
||||
description: "Custom autograd functions and Triton kernels in Axolotl for optimized
|
||||
LoRA fine-tuning"
|
||||
---
|
||||
|
||||
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
|
||||
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
|
||||
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
|
||||
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
|
||||
to leverage operator fusion and tensor re-use in order to improve speed and reduce
|
||||
memory usage during the forward and backward passes of these calculations.
|
||||
|
||||
We currently support several common model architectures, including (but not limited to):
|
||||
- `llama`
|
||||
- `mistral`
|
||||
- `qwen2`
|
||||
- `gemma`
|
||||
- `gemma2`
|
||||
|
||||
<details>
|
||||
|
||||
The set of models we support is currently limited by our attention patching strategy,
|
||||
which assumes (and replaces) specific code blocks for query / key / value and output
|
||||
projections:
|
||||
|
||||
```python
|
||||
ORIGINAL_QKV_CODE = """
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
ORIGINAL_O_CODE = """
|
||||
attn_output = self.o_proj(attn_output)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
```
|
||||
|
||||
Is replaced with:
|
||||
|
||||
```python
|
||||
PATCHED_QKV_CODE = """
|
||||
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
||||
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
||||
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
PATCHED_O_CODE = """
|
||||
attn_output = self.apply_o(attn_output)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
```
|
||||
|
||||
Where `apply_qkv` and `apply_o` are defined in the `axolotl.kernels.lora` module.
|
||||
|
||||
We welcome testing of other model architectures and / or PRs to expand our patching
|
||||
logic to be compatible with more of them.
|
||||
|
||||
</details>
|
||||
|
||||
## Usage
|
||||
|
||||
These optimizations can be enabled in your Axolotl config YAML file. The
|
||||
`lora_mlp_kernel` option enables the optimized MLP path, while `lora_qkv_kernel` and
|
||||
`lora_o_kernel` enable the fused query-key-value projection and optimized output
|
||||
projection, respectively.
|
||||
|
||||
```yaml
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
||||
- Note: Set `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` to enable [memory-efficient attention on AMD GPUs](https://github.com/ROCm/aotriton/issues/16#issuecomment-2346675491)
|
||||
- Targeted LoRA adapters cannot use Dropout
|
||||
- This may limit model expressivity / cause overfitting
|
||||
- Targeted LoRA adapters cannot have bias terms
|
||||
- This may limit model expressivity
|
||||
|
||||
Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to
|
||||
be re-finetuned without these features in order to be useful.
|
||||
|
||||
## Implementation details
|
||||
|
||||
### Custom autograd functions
|
||||
|
||||
The LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the
|
||||
LoRA and base weight computations together and provides a single, efficient backward
|
||||
pass for the entire MLP block.
|
||||
|
||||
For attention components, similar optimizations are provided through a function that
|
||||
handles the query, key, and value projections, and a function that handles the output
|
||||
projection. They are designed to work with the existing `transformers` attention
|
||||
implementation via some monkey-patching logic.
|
||||
|
||||
### Triton kernels
|
||||
|
||||
Two activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for
|
||||
improved speed and memory performance. These kernels handle both the forward and
|
||||
backward passes.
|
||||
|
||||
### Integration
|
||||
|
||||
The custom autograd functions and Triton kernels are designed to work together. The
|
||||
autograd function manages the high-level computation flow and gradient tracking, while
|
||||
calling the Triton kernels for the activation function computation. During the backward
|
||||
pass, the kernel computes both the activation output and the required gradients, which
|
||||
the autograd function then uses to compute the final gradients for the entire
|
||||
computation path.
|
||||
|
||||
## Future Work
|
||||
|
||||
- Support for additional model architectures
|
||||
- Support for the FSDP setting
|
||||
- Support for dropout and bias
|
||||
- Additional operator fusions
|
||||
451
docs/rlhf.qmd
451
docs/rlhf.qmd
@@ -1,39 +1,26 @@
|
||||
---
|
||||
title: "RLHF (Beta)"
|
||||
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
---
|
||||
|
||||
# Overview
|
||||
### Overview
|
||||
|
||||
Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
|
||||
feedback. Various methods include, but not limited to:
|
||||
|
||||
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
||||
- [Direct Preference Optimization (DPO)](#dpo)
|
||||
- [Identity Preference Optimization (IPO)](#ipo)
|
||||
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
||||
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
||||
- Direct Preference Optimization (DPO)
|
||||
- Identity Preference Optimization (IPO)
|
||||
|
||||
|
||||
# RLHF using Axolotl
|
||||
### RLHF using Axolotl
|
||||
|
||||
::: {.callout-important}
|
||||
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
||||
:::
|
||||
>[!IMPORTANT]
|
||||
>This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
||||
|
||||
We rely on the [TRL](https://github.com/huggingface/trl) library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats.
|
||||
|
||||
::: {.callout-tip}
|
||||
You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`.
|
||||
:::
|
||||
|
||||
## DPO
|
||||
|
||||
Example config:
|
||||
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
|
||||
|
||||
#### DPO
|
||||
```yaml
|
||||
rl: dpo
|
||||
datasets:
|
||||
@@ -45,265 +32,12 @@ datasets:
|
||||
type: chatml
|
||||
```
|
||||
|
||||
DPO supports the following types with the following dataset format:
|
||||
|
||||
### chatml.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"instruction": "...",
|
||||
"chosen_response": "...",
|
||||
"rejected_response": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.icr
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"input": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.intel
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"question": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"instruction": "...",
|
||||
"chosen_response": "...",
|
||||
"rejected_response": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.icr
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"input": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.intel
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"question": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### zephyr.nectar
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "...",
|
||||
"answers": [
|
||||
{
|
||||
"answer": "...",
|
||||
"rank": 1
|
||||
},
|
||||
{
|
||||
"answer": "...",
|
||||
"rank": 2
|
||||
}
|
||||
// ... more answers with ranks
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### chat_template.default
|
||||
|
||||
```yaml
|
||||
rl: dpo
|
||||
datasets:
|
||||
- path: ...
|
||||
split: train
|
||||
type: chat_template.default
|
||||
field_messages: "messages"
|
||||
field_chosen: "chosen"
|
||||
field_rejected: "rejected"
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user: ["user"]
|
||||
assistant: ["assistant"]
|
||||
system: ["system"]
|
||||
```
|
||||
|
||||
Sample input format:
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "..."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "..."
|
||||
},
|
||||
// ... more messages
|
||||
],
|
||||
"chosen": {
|
||||
"role": "assistant",
|
||||
"content": "..."
|
||||
},
|
||||
"rejected": {
|
||||
"role": "assistant",
|
||||
"content": "..."
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### user_defined.default
|
||||
|
||||
For custom behaviors,
|
||||
|
||||
```yaml
|
||||
rl: dpo
|
||||
datasets:
|
||||
- path: ...
|
||||
split: train
|
||||
type: user_defined.default
|
||||
|
||||
field_prompt: "prompt"
|
||||
field_system: "system"
|
||||
field_chosen: "chosen"
|
||||
field_rejected: "rejected"
|
||||
prompt_format: "{prompt}"
|
||||
chosen_format: "{chosen}"
|
||||
rejected_format: "{rejected}"
|
||||
```
|
||||
|
||||
The input format is a simple JSON input with customizable fields based on the above config.
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
## IPO
|
||||
|
||||
As IPO is just DPO with a different loss function, all supported options for DPO works here.
|
||||
|
||||
#### IPO
|
||||
```yaml
|
||||
rl: ipo
|
||||
```
|
||||
|
||||
## ORPO
|
||||
#### ORPO
|
||||
|
||||
Paper: https://arxiv.org/abs/2403.07691
|
||||
|
||||
@@ -318,28 +52,8 @@ datasets:
|
||||
type: chat_template.argilla
|
||||
```
|
||||
|
||||
ORPO supports the following types with the following dataset format:
|
||||
|
||||
### chat_template.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...", // if available, will be taken as user message for single-turn instead of from list below
|
||||
|
||||
// chosen/rejected should be same till last content and only even-number of alternating user/assistant turns
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## KTO
|
||||
#### KTO
|
||||
|
||||
```yaml
|
||||
rl: kto
|
||||
@@ -358,144 +72,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: true
|
||||
```
|
||||
|
||||
KTO supports the following types with the following dataset format:
|
||||
|
||||
### chatml.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"instruction": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "..."}
|
||||
],
|
||||
"completion": [
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.intel
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"question": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### chatml.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"instruction": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.argilla_chat
|
||||
|
||||
```json
|
||||
{
|
||||
"completion": [
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.intel
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"question": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.prompt_pairs
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### llama3.ultra
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"completion": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### user_defined.default
|
||||
|
||||
For custom behaviors,
|
||||
|
||||
```yaml
|
||||
rl: kto
|
||||
datasets:
|
||||
- path: ...
|
||||
split: train
|
||||
type: user_defined.default
|
||||
|
||||
field_prompt: "prompt"
|
||||
field_system: "system"
|
||||
field_completion: "completion"
|
||||
field_label: "label"
|
||||
prompt_format: "{prompt}"
|
||||
completion_format: "{completion}"
|
||||
```
|
||||
|
||||
The input format is a simple JSON input with customizable fields based on the above config.
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"prompt": "...",
|
||||
"completion": "...",
|
||||
"label": "..."
|
||||
}
|
||||
```
|
||||
|
||||
## Using local dataset files
|
||||
|
||||
#### Using local dataset files
|
||||
```yaml
|
||||
datasets:
|
||||
- ds_type: json
|
||||
@@ -505,9 +82,9 @@ datasets:
|
||||
type: chatml.intel
|
||||
```
|
||||
|
||||
## TRL auto-unwrapping for PEFT
|
||||
#### Trl autounwrap for peft
|
||||
|
||||
TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config:
|
||||
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
|
||||
|
||||
```yaml
|
||||
# load ref model when adapter training.
|
||||
|
||||
@@ -21,9 +21,8 @@ datasets:
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -16,9 +16,8 @@ datasets:
|
||||
type: chat_template
|
||||
drop_system_message: true
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
@@ -13,9 +13,8 @@ datasets:
|
||||
type: chat_template
|
||||
drop_system_message: true
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -17,9 +17,8 @@ datasets:
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
|
||||
@@ -17,9 +17,8 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
|
||||
@@ -14,9 +14,8 @@ datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
|
||||
@@ -17,9 +17,8 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
@@ -32,9 +31,8 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
# Currently, we don't support dropout with our custom Triton kernels
|
||||
# lora_dropout: 0.05
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# These options enable our custom Triton kernels / autograd
|
||||
# functions for MLP and attention calculations
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
@@ -22,9 +22,8 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -14,9 +14,8 @@ datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
|
||||
@@ -12,9 +12,8 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
|
||||
@@ -18,7 +18,7 @@ tokenizers>=0.21.0
|
||||
accelerate==1.3.0
|
||||
datasets==3.2.0
|
||||
deepspeed==0.16.1
|
||||
trl==0.15.0
|
||||
trl==0.13.0
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
@@ -26,7 +26,7 @@ sentencepiece
|
||||
gradio==3.50.2
|
||||
|
||||
modal==0.70.5
|
||||
pydantic==2.10.6
|
||||
pydantic==2.6.3
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
|
||||
@@ -31,26 +31,27 @@ def parse_dataset(dataset=None, split="train"):
|
||||
ds_cfg["field_messages"] = field_messages
|
||||
|
||||
message_fields = features[field_messages][0].keys()
|
||||
|
||||
message_property_mappings = {"role": None, "content": None}
|
||||
message_field_role = None
|
||||
for key in ["from", "role"]:
|
||||
if key in message_fields:
|
||||
message_property_mappings["role"] = key
|
||||
message_field_role = key
|
||||
break
|
||||
if not message_property_mappings["role"]:
|
||||
if not message_field_role:
|
||||
raise ValueError(
|
||||
f'No role field found in messages: {", ".join(message_fields)}'
|
||||
)
|
||||
ds_cfg["message_field_role"] = message_field_role
|
||||
|
||||
message_field_content = None
|
||||
for key in ["content", "text", "value"]:
|
||||
if key in message_fields:
|
||||
message_property_mappings["content"] = key
|
||||
message_field_content = key
|
||||
break
|
||||
if not message_property_mappings["content"]:
|
||||
if not message_field_content:
|
||||
raise ValueError(
|
||||
f'No content field found in messages: {", ".join(message_fields)}'
|
||||
)
|
||||
ds_cfg["message_property_mappings"] = message_property_mappings
|
||||
ds_cfg["message_field_content"] = message_field_content
|
||||
|
||||
print(yaml.dump({"datasets": [ds_cfg]}))
|
||||
|
||||
|
||||
7
setup.py
7
setup.py
@@ -79,7 +79,7 @@ def parse_requirements():
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers>=0.0.28.post3")
|
||||
_install_requires.append("xformers==0.0.29")
|
||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||
elif (major, minor) >= (2, 4):
|
||||
if patch == 0:
|
||||
@@ -125,7 +125,7 @@ setup(
|
||||
},
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn==2.7.4.post1",
|
||||
"flash-attn==2.7.0.post2",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.16.1",
|
||||
@@ -156,8 +156,5 @@ setup(
|
||||
"ray": [
|
||||
"ray[train]",
|
||||
],
|
||||
"vllm": [
|
||||
"vllm==0.7.2",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -4,4 +4,4 @@ import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.8.0.dev0"
|
||||
__version__ = "0.6.0"
|
||||
|
||||
@@ -35,18 +35,13 @@ def do_cli_train(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
accelerate: bool = True,
|
||||
cwd=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
config_yaml = file.read()
|
||||
local_dirs = {}
|
||||
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
|
||||
local_dirs = {"/workspace/mounts": cwd}
|
||||
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
|
||||
cloud.train(config_yaml, accelerate=accelerate)
|
||||
|
||||
|
||||
def do_cli_lm_eval(
|
||||
|
||||
@@ -7,7 +7,6 @@ import os
|
||||
import subprocess # nosec B404
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
from typing import Optional
|
||||
|
||||
import modal
|
||||
|
||||
@@ -23,18 +22,8 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
|
||||
|
||||
# modal workaround so it doesn't use the automounted axolotl
|
||||
new_env = copy.deepcopy(os.environ)
|
||||
|
||||
if "PYTHONPATH" in new_env:
|
||||
paths = ["/workspace/mounts"]
|
||||
for sub_python_path_str in new_env["PYTHONPATH"].split(":"):
|
||||
sub_python_path = Path(sub_python_path_str)
|
||||
if not sub_python_path.joinpath("src", "axolotl").exists():
|
||||
# we don't want to use the automounted axolotl or unexpected behavior happens
|
||||
paths.append(str(sub_python_path))
|
||||
if paths:
|
||||
new_env["PYTHONPATH"] = ":".join(paths)
|
||||
else:
|
||||
del new_env["PYTHONPATH"]
|
||||
del new_env["PYTHONPATH"]
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call( # nosec B603
|
||||
@@ -214,12 +203,9 @@ class ModalCloud(Cloud):
|
||||
memory = int(self.config.memory)
|
||||
return 1024 * memory
|
||||
|
||||
def get_train_env(self, local_dirs=None):
|
||||
image = self.get_image()
|
||||
for mount, local_dir in (local_dirs or {}).items():
|
||||
image = image.add_local_dir(local_dir, mount)
|
||||
def get_train_env(self):
|
||||
return self.app.function(
|
||||
image=image,
|
||||
image=self.get_image(),
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
cpu=16.0,
|
||||
gpu=self.get_train_gpu(),
|
||||
@@ -228,21 +214,14 @@ class ModalCloud(Cloud):
|
||||
secrets=self.get_secrets(),
|
||||
)
|
||||
|
||||
def train(
|
||||
self,
|
||||
config_yaml: str,
|
||||
accelerate: bool = True,
|
||||
local_dirs: Optional[dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
modal_fn = self.get_train_env(local_dirs)(_train)
|
||||
def train(self, config_yaml: str, accelerate: bool = True):
|
||||
modal_fn = self.get_train_env()(_train)
|
||||
with modal.enable_output():
|
||||
with self.app.run(detach=True):
|
||||
modal_fn.remote(
|
||||
config_yaml,
|
||||
accelerate=accelerate,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def lm_eval(self, config_yaml: str):
|
||||
@@ -273,7 +252,7 @@ def _preprocess(config_yaml: str, volumes=None):
|
||||
)
|
||||
|
||||
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
@@ -283,11 +262,8 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
accelerate_args = "--accelerate"
|
||||
else:
|
||||
accelerate_args = "--no-accelerate"
|
||||
num_processes_args = ""
|
||||
if num_processes := kwargs.pop("num_processes", None):
|
||||
num_processes_args = f"--num-processes {num_processes}"
|
||||
run_cmd(
|
||||
f"axolotl train {accelerate_args} {num_processes_args} /workspace/artifacts/axolotl/config.yaml",
|
||||
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
@@ -2,19 +2,19 @@
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import subprocess # nosec B404
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import axolotl
|
||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.cli.sweeps import generate_sweep_configs
|
||||
from axolotl.cli.utils import (
|
||||
add_options_from_config,
|
||||
add_options_from_dataclass,
|
||||
@@ -27,6 +27,76 @@ from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
|
||||
def generate_sweep_configs(base_config, sweeps_config):
|
||||
"""
|
||||
Recursively generates all possible configurations by applying sweeps to the base config.
|
||||
|
||||
Args:
|
||||
base_config (dict): The original configuration dictionary
|
||||
sweeps_config (dict): Dictionary where keys are parameters and values are either:
|
||||
- lists of values to sweep independently
|
||||
- or for paired values, a list of dicts under the '_' key
|
||||
|
||||
Returns:
|
||||
list: List of all possible configuration dictionaries
|
||||
|
||||
Example:
|
||||
sweeps_config = {
|
||||
'learning_rate': [0.1, 0.01],
|
||||
'_': [
|
||||
{'load_in_8bit': True, 'adapter': 'lora'},
|
||||
{'load_in_4bit': True, 'adapter': 'qlora'}
|
||||
]
|
||||
}
|
||||
"""
|
||||
# Separate paired values from regular sweeps
|
||||
paired_values = sweeps_config.get("_", [])
|
||||
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
|
||||
|
||||
# Process regular sweeps
|
||||
param_names = list(regular_sweeps.keys())
|
||||
param_values = list(regular_sweeps.values())
|
||||
|
||||
# Generate combinations for regular sweeps
|
||||
regular_combinations = list(product(*param_values)) if param_values else [()]
|
||||
|
||||
# Combine regular sweeps with paired values
|
||||
all_combinations = []
|
||||
for reg_combo in regular_combinations:
|
||||
if paired_values:
|
||||
for paired_set in paired_values:
|
||||
new_config = {}
|
||||
# new_config = deepcopy(base_config)
|
||||
# Combine regular parameters with paired parameters
|
||||
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
|
||||
for param_name, param_value in full_combo.items():
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
else:
|
||||
# If no paired values, just use regular combinations
|
||||
# new_config = deepcopy(base_config)
|
||||
new_config = {}
|
||||
for param_name, param_value in zip(param_names, reg_combo):
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
|
||||
# randomize the order of trials
|
||||
random.seed(42)
|
||||
random.shuffle(all_combinations)
|
||||
|
||||
# Generate a new config for each combination
|
||||
result_configs = []
|
||||
for combination in all_combinations:
|
||||
new_config = deepcopy(base_config)
|
||||
for param_name, param_value in combination.items():
|
||||
new_config[param_name] = param_value
|
||||
result_configs.append(new_config)
|
||||
|
||||
return result_configs
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||
def cli():
|
||||
@@ -95,6 +165,7 @@ def train(
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
@@ -128,16 +199,7 @@ def train(
|
||||
try:
|
||||
if accelerate:
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
cwd = os.getcwd()
|
||||
do_cli_train(
|
||||
cloud_config=cloud,
|
||||
config=config,
|
||||
accelerate=True,
|
||||
cwd=cwd,
|
||||
**kwargs,
|
||||
)
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
||||
else:
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
@@ -146,7 +208,7 @@ def train(
|
||||
accelerate_args.append(str(main_process_port))
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes", None)
|
||||
accelerate_args.append("--num_processes")
|
||||
accelerate_args.append("--num-processes")
|
||||
accelerate_args.append(str(num_processes))
|
||||
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
@@ -158,11 +220,7 @@ def train(
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
do_cli_train(
|
||||
cloud_config=cloud, config=config, accelerate=False, **kwargs
|
||||
)
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
@@ -323,5 +381,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
main()
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
"""Utilities for handling sweeps over configs for axolotl train CLI command"""
|
||||
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
|
||||
|
||||
def generate_sweep_configs(
|
||||
base_config: dict[str, list], sweeps_config: dict[str, list]
|
||||
) -> list[dict[str, list]]:
|
||||
"""
|
||||
Recursively generates all possible configurations by applying sweeps to the base config.
|
||||
|
||||
Args:
|
||||
base_config (dict): The original configuration dictionary
|
||||
sweeps_config (dict): Dictionary where keys are parameters and values are either:
|
||||
- lists of values to sweep independently
|
||||
- or for paired values, a list of dicts under the '_' key
|
||||
|
||||
Returns:
|
||||
list: List of all possible configuration dictionaries
|
||||
|
||||
Example:
|
||||
sweeps_config = {
|
||||
'learning_rate': [0.1, 0.01],
|
||||
'_': [
|
||||
{'load_in_8bit': True, 'adapter': 'lora'},
|
||||
{'load_in_4bit': True, 'adapter': 'qlora'}
|
||||
]
|
||||
}
|
||||
"""
|
||||
# Separate paired values from regular sweeps
|
||||
paired_values = sweeps_config.get("_", [])
|
||||
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
|
||||
|
||||
# Process regular sweeps
|
||||
param_names = list(regular_sweeps.keys())
|
||||
param_values = list(regular_sweeps.values())
|
||||
|
||||
# Generate combinations for regular sweeps
|
||||
regular_combinations = list(product(*param_values)) if param_values else [()]
|
||||
|
||||
# Combine regular sweeps with paired values
|
||||
all_combinations = []
|
||||
for reg_combo in regular_combinations:
|
||||
if paired_values:
|
||||
for paired_set in paired_values:
|
||||
new_config = {}
|
||||
# new_config = deepcopy(base_config)
|
||||
# Combine regular parameters with paired parameters
|
||||
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
|
||||
for param_name, param_value in full_combo.items():
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
else:
|
||||
# If no paired values, just use regular combinations
|
||||
# new_config = deepcopy(base_config)
|
||||
new_config = {}
|
||||
for param_name, param_value in zip(param_names, reg_combo):
|
||||
new_config[param_name] = param_value
|
||||
print(new_config)
|
||||
all_combinations.append(new_config)
|
||||
|
||||
# randomize the order of trials
|
||||
random.seed(42)
|
||||
random.shuffle(all_combinations)
|
||||
|
||||
# Generate a new config for each combination
|
||||
result_configs = []
|
||||
for combination in all_combinations:
|
||||
new_config = deepcopy(base_config)
|
||||
for param_name, param_value in combination.items():
|
||||
new_config[param_name] = param_value
|
||||
result_configs.append(new_config)
|
||||
|
||||
return result_configs
|
||||
@@ -122,11 +122,9 @@ def load_preference_datasets(
|
||||
`total_num_steps`.
|
||||
"""
|
||||
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||
total_num_steps: Optional[int] = int(
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
if cfg.rl == "grpo":
|
||||
total_num_steps = None
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
|
||||
@@ -39,6 +39,7 @@ from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlDPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlMambaTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
@@ -47,11 +48,9 @@ from axolotl.core.trainers.base import (
|
||||
AxolotlTrainer,
|
||||
ReLoRATrainer,
|
||||
)
|
||||
from axolotl.core.trainers.dpo import DPOStrategy
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlCPOConfig,
|
||||
AxolotlDPOConfig,
|
||||
AxolotlKTOConfig,
|
||||
AxolotlORPOConfig,
|
||||
AxolotlPRMConfig,
|
||||
@@ -330,12 +329,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
|
||||
training_arguments_kwargs = {}
|
||||
|
||||
if self.cfg.include_tokens_per_second is not None:
|
||||
training_arguments_kwargs[
|
||||
"include_tokens_per_second"
|
||||
] = self.cfg.include_tokens_per_second
|
||||
|
||||
if self.cfg.bf16 == "full":
|
||||
training_arguments_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
@@ -648,6 +641,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
|
||||
if self.cfg.rl == "orpo":
|
||||
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
|
||||
|
||||
if self.cfg.neftune_noise_alpha is not None:
|
||||
training_arguments_kwargs[
|
||||
"neftune_noise_alpha"
|
||||
@@ -656,7 +652,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_kwargs = {}
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
if self.cfg.optimizer in [
|
||||
@@ -969,11 +965,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
|
||||
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
|
||||
if self.cfg.rl_beta:
|
||||
training_args_kwargs["beta"] = self.cfg.rl_beta
|
||||
if self.cfg.orpo_alpha:
|
||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||
@@ -982,7 +977,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl == "simpo":
|
||||
training_args_cls = AxolotlCPOConfig
|
||||
training_args_kwargs["loss_type"] = "simpo"
|
||||
@@ -1007,15 +1001,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.kto_undesirable_weight or 1.0
|
||||
)
|
||||
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl == "grpo":
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||
|
||||
else:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
if self.cfg.rl == "ipo":
|
||||
@@ -1026,21 +1016,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
if self.cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs[
|
||||
"use_logits_to_keep"
|
||||
] = self.cfg.dpo_use_logits_to_keep
|
||||
|
||||
for blocklist_key in blocklist_args_kwargs:
|
||||
if blocklist_key in training_args_kwargs:
|
||||
del training_args_kwargs[blocklist_key]
|
||||
|
||||
max_steps = self.cfg.max_steps or total_num_steps or -1
|
||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
self.cfg.output_dir,
|
||||
output_dir=self.cfg.output_dir,
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
max_steps=max_steps,
|
||||
max_steps=self.cfg.max_steps or total_num_steps,
|
||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||
learning_rate=self.cfg.learning_rate,
|
||||
warmup_steps=self.cfg.warmup_steps,
|
||||
@@ -1067,13 +1047,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
dpo_trainer_kwargs[
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
if self.cfg.rl == "grpo":
|
||||
trainer_cls = GRPOStrategy.get_trainer_class()
|
||||
trainer_cls_args = [self.model]
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
elif self.cfg.rl in ["dpo", "ipo"]:
|
||||
trainer_cls = DPOStrategy.get_trainer_class()
|
||||
if self.cfg.rl in ["dpo", "ipo"]:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
@@ -1088,14 +1063,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "tokenizer" in sig.parameters.keys():
|
||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
else:
|
||||
if "processing_class" in sig.parameters.keys():
|
||||
dpo_trainer_kwargs["processing_class"] = self.tokenizer
|
||||
else:
|
||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
|
||||
if self.cfg.datasets is not None and (
|
||||
trainer_cls is DPOStrategy.get_trainer_class()
|
||||
):
|
||||
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
||||
dpo_trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
|
||||
@@ -5,21 +5,30 @@ module for customized trainers
|
||||
from __future__ import annotations
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
from typing import Dict, Literal, Optional
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||
from trl import (
|
||||
CPOTrainer,
|
||||
DPOTrainer,
|
||||
KTOTrainer,
|
||||
ORPOTrainer,
|
||||
PRMTrainer,
|
||||
RewardTrainer,
|
||||
)
|
||||
from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
@@ -838,6 +847,107 @@ class ReLoRATrainer(AxolotlTrainer):
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "dpo"]
|
||||
|
||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
def create_optimizer(self):
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
opt_model,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
if loraplus_lr_ratio:
|
||||
print("Using lora+")
|
||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
@wraps(DPOTrainer.push_to_hub)
|
||||
def push_to_hub(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> Dict:
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
)
|
||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||
for key in res.keys():
|
||||
res[key] = res[key][1:]
|
||||
|
||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||
|
||||
return res
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
num_items_in_batch=None,
|
||||
) -> torch.Tensor:
|
||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
Extend the base ORPOTrainer for axolotl helpers
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
"""
|
||||
DPO Specific Strategy for training
|
||||
"""
|
||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||
|
||||
|
||||
class DPOStrategy:
|
||||
"""
|
||||
Strategy for DPO training
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls):
|
||||
return AxolotlDPOTrainer
|
||||
|
||||
@classmethod
|
||||
def get_training_args_class(cls):
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
|
||||
return AxolotlDPOConfig
|
||||
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg):
|
||||
training_args_kwargs = {}
|
||||
if cfg.rl == "ipo":
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
||||
if cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||
return training_args_kwargs
|
||||
@@ -1,15 +0,0 @@
|
||||
"""
|
||||
Axolotl specific DPO args
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trl import DPOConfig
|
||||
|
||||
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
@@ -1,125 +0,0 @@
|
||||
"""
|
||||
DPO trainer for axolotl
|
||||
"""
|
||||
import gc
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from transformers import Trainer
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
SchedulerMixin,
|
||||
_sanitize_kwargs_for_ds_tagging,
|
||||
_sanitize_kwargs_for_tagging,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "dpo"]
|
||||
|
||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
def create_optimizer(self):
|
||||
# pylint: disable=duplicate-code
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
return super().create_optimizer()
|
||||
|
||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||
self.args,
|
||||
opt_model,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||
if loraplus_lr_ratio:
|
||||
print("Using lora+")
|
||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
||||
# pylint: disable=duplicate-code
|
||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||
opt_model,
|
||||
optimizer_cls,
|
||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||
**optimizer_kwargs,
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
self.optimizer
|
||||
)
|
||||
|
||||
return self.optimizer
|
||||
|
||||
@wraps(DPOTrainer.push_to_hub)
|
||||
def push_to_hub(self, *args, **kwargs) -> str:
|
||||
"""
|
||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||
)
|
||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> Dict:
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
)
|
||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||
for key in res.keys():
|
||||
res[key] = res[key][1:]
|
||||
|
||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||
|
||||
return res
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
num_items_in_batch=None,
|
||||
) -> torch.Tensor:
|
||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
@@ -1,119 +0,0 @@
|
||||
"""
|
||||
GRPO Specific Strategy for training
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
class GRPOStrategy:
|
||||
"""
|
||||
Strategy for GRPO training
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(cls):
|
||||
return AxolotlGRPOTrainer
|
||||
|
||||
@classmethod
|
||||
def get_training_args_class(cls):
|
||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||
|
||||
return AxolotlGRPOConfig
|
||||
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg):
|
||||
grpo_args_kwargs = {}
|
||||
if cfg.trl and cfg.trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
|
||||
if cfg.trl and cfg.trl.vllm_device:
|
||||
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
|
||||
else:
|
||||
grpo_args_kwargs["vllm_device"] = "auto"
|
||||
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
|
||||
grpo_args_kwargs[
|
||||
"vllm_gpu_memory_utilization"
|
||||
] = cfg.trl.vllm_gpu_memory_utilization
|
||||
if cfg.trl and cfg.trl.vllm_max_model_len:
|
||||
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
|
||||
if cfg.trl and cfg.trl.num_generations:
|
||||
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
|
||||
if cfg.trl and cfg.trl.sync_ref_model:
|
||||
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
|
||||
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
|
||||
grpo_args_kwargs[
|
||||
"ref_model_mixup_alpha"
|
||||
] = cfg.trl.ref_model_mixup_alpha
|
||||
if cfg.trl and cfg.trl.ref_model_sync_steps:
|
||||
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
|
||||
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
|
||||
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
def set_trainer_args(cls, cfg):
|
||||
trainer_args = []
|
||||
if cfg.trl and cfg.trl.reward_funcs:
|
||||
reward_funcs = []
|
||||
for reward_func_fqn in cfg.trl.reward_funcs:
|
||||
reward_funcs.append(cls.get_reward_func(reward_func_fqn))
|
||||
trainer_args.append(reward_funcs)
|
||||
return trainer_args
|
||||
|
||||
@classmethod
|
||||
def set_trainer_kwargs(cls, cfg):
|
||||
trainer_kwargs = {}
|
||||
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||
trainer_kwargs[
|
||||
"reward_processing_classes"
|
||||
] = cfg.trl.reward_processing_classes
|
||||
return trainer_kwargs
|
||||
|
||||
@classmethod
|
||||
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||
# No data collation is needed in GRPO, handled by trl's trainer __init__
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_blocklist_args_kwargs(cls):
|
||||
return ["dataset_num_proc"]
|
||||
|
||||
@classmethod
|
||||
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:
|
||||
"""
|
||||
Returns the reward function from the given fully qualified name, or the path to the reward function model.
|
||||
|
||||
Args:
|
||||
reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform),
|
||||
or a HF hub path to the reward model.
|
||||
Raises:
|
||||
ValueError: If the reward function does not accept at least two arguments.
|
||||
|
||||
Returns:
|
||||
RewardFunc: A callable that accepts prompts and completions and returns rewards,
|
||||
or a path to a reward model.
|
||||
|
||||
"""
|
||||
try:
|
||||
# use importlib to dynamically load the reward function from the module
|
||||
reward_func_module_name = reward_func_fqn.split(".")[-1]
|
||||
reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
|
||||
reward_func = getattr(reward_func_module, reward_func_module_name)
|
||||
if not len(inspect.signature(reward_func).parameters) >= 2:
|
||||
raise ValueError(
|
||||
"Reward function must accept at least two arguments: prompts: list and completions: list"
|
||||
)
|
||||
return reward_func
|
||||
except ModuleNotFoundError:
|
||||
# the user has passed a string (ideally indicating the path of a reward model)
|
||||
LOG.info(
|
||||
f"Reward function {reward_func_fqn} is a pre-trained model path - if this is unexpected, please check the reward function path."
|
||||
)
|
||||
return reward_func
|
||||
@@ -1,15 +0,0 @@
|
||||
"""
|
||||
Axolotl Specific Training Args
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trl import GRPOConfig
|
||||
|
||||
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""
|
||||
Axolotl GRPO Config for GRPO training
|
||||
"""
|
||||
@@ -1,107 +0,0 @@
|
||||
"""
|
||||
Axolotl GRPO trainer
|
||||
"""
|
||||
from accelerate.utils import is_peft_model
|
||||
from accelerate.utils.other import is_compiled_module
|
||||
from transformers import PreTrainedModel
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
from trl.models import unwrap_model_for_generation
|
||||
|
||||
from axolotl.core.trainers.base import SchedulerMixin
|
||||
|
||||
|
||||
# mypy: ignore-errors
|
||||
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
"""
|
||||
Extend the base GRPOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# pylint: disable=access-member-before-definition
|
||||
# Enable gradient checkpointing if requested
|
||||
if kwargs["args"].gradient_checkpointing:
|
||||
# Ensure use_cache is disabled
|
||||
if hasattr(self.model, "config"):
|
||||
self.model.config.use_cache = False
|
||||
|
||||
# Enable gradient checkpointing on the base model for PEFT
|
||||
if is_peft_model(self.model) and hasattr(
|
||||
self.model.base_model, "gradient_checkpointing_enable"
|
||||
):
|
||||
self.model.base_model.gradient_checkpointing_enable()
|
||||
# Enable gradient checkpointing for non-PEFT models
|
||||
elif hasattr(self.model, "gradient_checkpointing_enable"):
|
||||
self.model.gradient_checkpointing_enable()
|
||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
||||
# pylint: enable=access-member-before-definition
|
||||
|
||||
def _enable_gradient_checkpointing(
|
||||
self, model: PreTrainedModel, args: GRPOConfig
|
||||
) -> PreTrainedModel:
|
||||
"""Enables gradient checkpointing for the model."""
|
||||
# pylint: disable=unused-argument,redefined-builtin
|
||||
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
||||
use_reentrant = (
|
||||
"use_reentrant" not in gradient_checkpointing_kwargs
|
||||
or gradient_checkpointing_kwargs["use_reentrant"]
|
||||
)
|
||||
|
||||
if use_reentrant:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(
|
||||
make_inputs_require_grad
|
||||
)
|
||||
|
||||
return model
|
||||
# pylint: enable=unused-argument,redefined-builtin
|
||||
|
||||
def _move_model_to_vllm(self):
|
||||
with unwrap_model_for_generation(
|
||||
self.model,
|
||||
self.accelerator,
|
||||
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
|
||||
) as unwrapped_model:
|
||||
if is_compiled_module(unwrapped_model):
|
||||
unwrapped_model = (
|
||||
unwrapped_model._orig_mod # pylint: disable=protected-access
|
||||
)
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.merge_adapter()
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
unwrapped_model.unmerge_adapter()
|
||||
# Remove base_model and base_layer prefixes
|
||||
state_dict = {
|
||||
k.removeprefix("base_model.model.")
|
||||
.removeprefix("base_model.model.")
|
||||
.replace(".base_layer", ""): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
# Remove values with adapter prefix (example: "_lora")
|
||||
state_dict = {
|
||||
k: v
|
||||
for k, v in state_dict.items()
|
||||
if unwrapped_model.prefix not in k
|
||||
}
|
||||
# When module to save, remove its prefix and discard the original module
|
||||
state_dict = {
|
||||
k.replace("modules_to_save.default.", ""): v
|
||||
for k, v in state_dict.items()
|
||||
if "original_module" not in k
|
||||
}
|
||||
else:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = (
|
||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
)
|
||||
llm_model.load_weights(state_dict.items())
|
||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -217,6 +217,13 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
DPO config for DPO training
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
|
||||
"""
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
"""
|
||||
Module for definition of GEGLU Triton kernels.
|
||||
|
||||
See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
|
||||
|
||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||
"""
|
||||
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
SQRT_2_PI: tl.constexpr = 0.7978845608028654 # sqrt(2/π)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _geglu_fwd_kernel(
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""GEGLU forward kernel.
|
||||
|
||||
Args:
|
||||
gate_ptr: Pointer to gate tensor [*, hidden_dim].
|
||||
up_ptr: Pointer to up-projection tensor [*, hidden_dim].
|
||||
out_ptr: Pointer to output tensor [*, hidden_dim].
|
||||
n_elements: Total number of elements in the input tensors.
|
||||
BLOCK_SIZE: Size of thread blocks for parallel computation.
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
|
||||
up = tl.load(up_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
# Compute activation in fp32 then convert back
|
||||
gelu_gate = 0.5 * gate * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0)
|
||||
gelu_gate = gelu_gate.to(up.dtype)
|
||||
result = gelu_gate * up
|
||||
|
||||
tl.store(out_ptr + offsets, result, mask=mask)
|
||||
|
||||
|
||||
def geglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
|
||||
"""GEGLU forward pass.
|
||||
|
||||
Args:
|
||||
gate: Input gate tensor of shape [batch, seq_len, hidden_dim].
|
||||
up: Up-projection tensor of shape [batch, seq_len, hidden_dim].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape [batch, seq_len, hidden_dim].
|
||||
"""
|
||||
batch, seq_len, hidden_dim = gate.shape
|
||||
n_elements = gate.numel()
|
||||
out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device="cuda")
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
|
||||
_geglu_fwd_kernel[grid](
|
||||
gate_ptr=gate,
|
||||
up_ptr=up,
|
||||
out_ptr=out,
|
||||
n_elements=n_elements,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _geglu_bwd_kernel(
|
||||
grad_out_ptr,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""GEGLU backward kernel. Stores gradient results in-place.
|
||||
|
||||
Args:
|
||||
grad_out_ptr: Pointer to gradient output tensor [*, hidden_dim].
|
||||
gate_ptr: Pointer to gate tensor [*, hidden_dim].
|
||||
up_ptr: Pointer to up-projection tensor [*, hidden_dim].
|
||||
n_elements: Total number of elements in the input tensors.
|
||||
BLOCK_SIZE: Size of thread blocks for parallel computation.
|
||||
|
||||
Note:
|
||||
After kernel execution, tensors are modified in-place:
|
||||
- `grad_out_ptr` contains GEGLU activation output (`h`)
|
||||
- `gate_ptr` contains gradient w.r.t gate (`grad_gate`)
|
||||
- `up_ptr` contains gradient w.r.t up (`grad_up`)
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0)
|
||||
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
|
||||
up = tl.load(up_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
# Forward pass
|
||||
gelu_partial = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0)
|
||||
gelu_gate = gelu_partial * gate
|
||||
gelu_gate = gelu_gate.to(grad_out.dtype)
|
||||
|
||||
# Forward output
|
||||
h = gelu_gate * up
|
||||
|
||||
# Compute gradients
|
||||
grad_up = grad_out * gelu_gate
|
||||
|
||||
# Compute gate gradient using GELU derivative
|
||||
temp = grad_out * up
|
||||
t = 0.3989422804014327 # 1/sqrt(2*pi)
|
||||
dgelu_dgate = gelu_partial + t * gate * tl.exp(-0.5 * gate * gate)
|
||||
grad_gate = temp.to(tl.float32) * dgelu_dgate
|
||||
grad_gate = grad_gate.to(grad_out.dtype)
|
||||
|
||||
# Store results
|
||||
tl.store(grad_out_ptr + offsets, h, mask=mask)
|
||||
tl.store(gate_ptr + offsets, grad_gate, mask=mask)
|
||||
tl.store(up_ptr + offsets, grad_up, mask=mask)
|
||||
|
||||
|
||||
def geglu_backward(
|
||||
grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""GEGLU backward pass using in-place operations.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`.
|
||||
gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
|
||||
up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- GEGLU activation output (`h`)
|
||||
- Gradient with respect to gate (`grad_gate`)
|
||||
- Gradient with respect to up (`grad_up`)
|
||||
|
||||
Note:
|
||||
This function modifies its input tensors in-place to store results.
|
||||
"""
|
||||
n_elements = grad_output.numel()
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731
|
||||
_geglu_bwd_kernel[grid](
|
||||
grad_out_ptr=grad_output,
|
||||
gate_ptr=gate,
|
||||
up_ptr=up,
|
||||
n_elements=n_elements,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
return grad_output, gate, up
|
||||
@@ -1,779 +0,0 @@
|
||||
"""
|
||||
Module for definition of Low-Rank Adaptation (LoRA) Triton kernels.
|
||||
|
||||
See "LoRA: Low-Rank Adaptation of Large Language Models"
|
||||
(https://arxiv.org/abs/2106.09685).
|
||||
|
||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||
"""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
from torch import nn
|
||||
|
||||
from .geglu import geglu_backward, geglu_forward
|
||||
from .quantize import dequantize
|
||||
from .swiglu import swiglu_backward, swiglu_forward
|
||||
from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd
|
||||
|
||||
|
||||
def get_lora_parameters(
|
||||
proj: nn.Module,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
QuantState | None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
float | None,
|
||||
]:
|
||||
"""
|
||||
Gets LoRA parameters from a projection module.
|
||||
|
||||
Args:
|
||||
proj: The projection module to extract parameters from.
|
||||
|
||||
Returns:
|
||||
A tuple containing the base weight matrix, quantization state, LoRA A matrix,
|
||||
LoRA B matrix, and scaling factor. States and matrices may be None if not
|
||||
available.
|
||||
"""
|
||||
# For DPO or disabled adapters
|
||||
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
|
||||
W = base_layer.weight
|
||||
|
||||
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
return W, quant_state, None, None, None
|
||||
|
||||
active_adapter = (
|
||||
proj.active_adapters[0]
|
||||
if hasattr(proj, "active_adapters")
|
||||
else proj.active_adapter
|
||||
)
|
||||
A = proj.lora_A[active_adapter].weight
|
||||
B = proj.lora_B[active_adapter].weight
|
||||
s = proj.scaling[active_adapter]
|
||||
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
|
||||
return W, quant_state, A, B, s
|
||||
|
||||
|
||||
def matmul_lora(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
W_quant: QuantState,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
s: float,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Efficient fused matmul + LoRA computation.
|
||||
|
||||
Args:
|
||||
X: Input tensor [*, in_features]
|
||||
W: Base weight matrix [out_features, in_features]
|
||||
W_quant: Quantization state for W
|
||||
A: LoRA A matrix [rank, in_features]
|
||||
B: LoRA B matrix [out_features, rank]
|
||||
s: LoRA scaling factor
|
||||
out: Optional output tensor for inplace operations
|
||||
|
||||
Returns:
|
||||
Result of X @ W + X @ A @ B
|
||||
"""
|
||||
dtype = X.dtype
|
||||
W = dequantize(W.t(), W_quant)
|
||||
|
||||
if X.dim() == 3:
|
||||
batch, seq_len, _ = X.shape
|
||||
X = X.view(-1, X.shape[-1])
|
||||
reshape = True
|
||||
else:
|
||||
reshape = False
|
||||
|
||||
out = torch.matmul(X, W, out=out)
|
||||
if W_quant is not None:
|
||||
del W
|
||||
|
||||
if A is not None:
|
||||
A, B = A.t(), B.t()
|
||||
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
||||
|
||||
return out.view(batch, seq_len, -1) if reshape else out
|
||||
|
||||
|
||||
class LoRA_MLP(torch.autograd.Function):
|
||||
"""Optimized LoRA MLP implementation."""
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def forward(
|
||||
ctx,
|
||||
X: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
gate_quant: object | None,
|
||||
gate_A: torch.Tensor | None,
|
||||
gate_B: torch.Tensor | None,
|
||||
gate_scale: float,
|
||||
up_weight: torch.Tensor,
|
||||
up_quant: object | None,
|
||||
up_A: torch.Tensor | None,
|
||||
up_B: torch.Tensor | None,
|
||||
up_scale: float,
|
||||
down_weight: torch.Tensor,
|
||||
down_quant: object | None,
|
||||
down_A: torch.Tensor | None,
|
||||
down_B: torch.Tensor | None,
|
||||
down_scale: float,
|
||||
activation_fn: Callable,
|
||||
activation_fn_backward: Callable,
|
||||
inplace: bool | None = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for LoRA MLP.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
X: Input features
|
||||
gate_weight: Gate projection weight
|
||||
gate_quant: Gate quantization state
|
||||
gate_A: Gate LoRA A matrix
|
||||
gate_B: Gate LoRA B matrix
|
||||
gate_scale: Gate LoRA scale
|
||||
up_weight: Up-projection weight
|
||||
up_quant: Up-projection quantization state
|
||||
up_A: Up-projection LoRA A matrix
|
||||
up_B: Up-projection LoRA B matrix
|
||||
up_scale: Up-projection LoRA scale
|
||||
down_weight: Down-projection weight
|
||||
down_quant: Down-projection quantization state
|
||||
down_A: Down-projection LoRA A matrix
|
||||
down_B: Down-projection LoRA B matrix
|
||||
down_scale: Down-projection LoRA scale
|
||||
activation_fn: Forward activation function
|
||||
activation_fn_backward: Backward activation function
|
||||
inplace: Whether to perform operations in-place
|
||||
|
||||
Returns:
|
||||
Output transformed by multi-layer perceptron and activation function
|
||||
"""
|
||||
# Compute projections
|
||||
gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale)
|
||||
up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale)
|
||||
|
||||
# Activation
|
||||
hidden = activation_fn(gate, up)
|
||||
|
||||
# Down projection
|
||||
output = matmul_lora(
|
||||
hidden, down_weight, down_quant, down_A, down_B, down_scale
|
||||
)
|
||||
|
||||
# Save for backward
|
||||
ctx.save_for_backward(X, gate, up, gate_A, gate_B, up_A, up_B, down_A, down_B)
|
||||
ctx.scales = (gate_scale, up_scale, down_scale)
|
||||
ctx.quants = (gate_quant, up_quant, down_quant)
|
||||
ctx.weights = (gate_weight, up_weight, down_weight)
|
||||
ctx.activation_fn = activation_fn
|
||||
ctx.activation_fn_backward = activation_fn_backward
|
||||
ctx.inplace = inplace
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_bwd
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_output: torch.Tensor,
|
||||
) -> tuple[
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Performs backward pass computation for LoRA MLP.
|
||||
|
||||
Args:
|
||||
ctx: Context object storing tensors saved during forward pass
|
||||
grad_output: Gradient of loss with respect to layer output
|
||||
|
||||
Returns:
|
||||
Tuple containing gradients for all inputs from forward pass:
|
||||
- Input gradient tensor (or `None`)
|
||||
- `None` for weights/quantization states
|
||||
- LoRA A/B matrix gradients (or `None`)
|
||||
- `None` for scaling factors
|
||||
- `None` for activation functions and flags
|
||||
"""
|
||||
(
|
||||
X,
|
||||
gate,
|
||||
up,
|
||||
gate_A,
|
||||
gate_B,
|
||||
up_A,
|
||||
up_B,
|
||||
down_A,
|
||||
down_B,
|
||||
) = ctx.saved_tensors
|
||||
gate_scale, up_scale, down_scale = ctx.scales
|
||||
gate_quant, up_quant, down_quant = ctx.quants
|
||||
gate_weight, up_weight, down_weight = ctx.weights
|
||||
|
||||
# Transpose all LoRA matrices
|
||||
gate_A, gate_B = (
|
||||
gate_A.t() if gate_A is not None else None,
|
||||
gate_B.t() if gate_B is not None else None,
|
||||
)
|
||||
up_A, up_B = (
|
||||
up_A.t() if up_A is not None else None,
|
||||
up_B.t() if up_B is not None else None,
|
||||
)
|
||||
down_A, down_B = (
|
||||
down_A.t() if down_A is not None else None,
|
||||
down_B.t() if down_B is not None else None,
|
||||
)
|
||||
|
||||
# Reshape inputs
|
||||
batch, seq_len, hd = X.shape
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1])
|
||||
X = X.view(-1, X.shape[-1])
|
||||
gate = gate.view(-1, gate.shape[-1])
|
||||
up = up.view(-1, up.shape[-1])
|
||||
dtype = X.dtype
|
||||
|
||||
# Down projection
|
||||
DW = matmul_lora(
|
||||
grad_output,
|
||||
down_weight.t(),
|
||||
down_quant,
|
||||
down_B,
|
||||
down_A,
|
||||
down_scale,
|
||||
)
|
||||
|
||||
# Activation backward
|
||||
h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up)
|
||||
|
||||
# Initialize and compute LoRA gradients
|
||||
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
|
||||
|
||||
if down_A is not None:
|
||||
d_down_A = h.t() @ (grad_output @ down_B.t())
|
||||
d_down_B = (down_A.t() @ h.t()) @ grad_output
|
||||
d_down_A *= down_scale
|
||||
d_down_B *= down_scale
|
||||
|
||||
if up_A is not None:
|
||||
d_up_A = X.t() @ (grad_up @ up_B.t())
|
||||
d_up_B = (up_A.t() @ X.t()) @ grad_up
|
||||
d_up_A *= up_scale
|
||||
d_up_B *= up_scale
|
||||
|
||||
if gate_A is not None:
|
||||
d_gate_A = X.t() @ (grad_gate @ gate_B.t())
|
||||
d_gate_B = (gate_A.t() @ X.t()) @ grad_gate
|
||||
d_gate_A *= gate_scale
|
||||
d_gate_B *= gate_scale
|
||||
|
||||
# Compute input gradients
|
||||
dX = torch.zeros_like(X) if ctx.needs_input_grad[0] else None
|
||||
|
||||
if dX is not None:
|
||||
# Up projection gradients
|
||||
up_weight = dequantize(up_weight.t(), up_quant)
|
||||
if ctx.inplace:
|
||||
dX = torch.matmul(grad_up, up_weight.t(), out=X)
|
||||
else:
|
||||
dX = torch.matmul(grad_up, up_weight.t())
|
||||
del up_weight
|
||||
|
||||
# Note the .to(dtype) only where mixing LoRA with base weights
|
||||
if up_A is not None:
|
||||
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
|
||||
|
||||
# Gate projection gradients
|
||||
gate_weight = dequantize(gate_weight.t(), gate_quant)
|
||||
dX += grad_gate @ gate_weight.t()
|
||||
del gate_weight
|
||||
|
||||
if gate_A is not None:
|
||||
dX += (
|
||||
grad_gate
|
||||
@ gate_B.to(dtype).t()
|
||||
@ (gate_scale * gate_A.to(dtype).t())
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
dX = dX.view(batch, seq_len, hd)
|
||||
|
||||
# Return gradients in correct order matching forward inputs
|
||||
return (
|
||||
dX,
|
||||
None,
|
||||
None,
|
||||
d_gate_A.t() if d_gate_A is not None else None,
|
||||
d_gate_B.t() if d_gate_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_up_A.t() if d_up_A is not None else None,
|
||||
d_up_B.t() if d_up_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_down_A.t() if d_down_A is not None else None,
|
||||
d_down_B.t() if d_down_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Applies LoRA to MLP layer with SwiGLU activation.
|
||||
|
||||
Args:
|
||||
X: Input tensor for the MLP layer
|
||||
inplace: Whether to perform operations in-place to save memory
|
||||
|
||||
Returns:
|
||||
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
|
||||
"""
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
|
||||
out = LoRA_MLP.apply(
|
||||
X,
|
||||
gateW,
|
||||
gateW_quant,
|
||||
gateA,
|
||||
gateB,
|
||||
gateS,
|
||||
upW,
|
||||
upW_quant,
|
||||
upA,
|
||||
upB,
|
||||
upS,
|
||||
downW,
|
||||
downW_quant,
|
||||
downA,
|
||||
downB,
|
||||
downS,
|
||||
swiglu_forward,
|
||||
swiglu_backward,
|
||||
inplace,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Applies LoRA to MLP layer with GEGLU activation.
|
||||
|
||||
Args:
|
||||
X: Input tensor for the MLP layer
|
||||
inplace: Whether to perform operations in-place to save memory
|
||||
|
||||
Returns:
|
||||
Output tensor after applying LoRA-adapted MLP with GEGLU activation
|
||||
"""
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
out = LoRA_MLP.apply(
|
||||
X,
|
||||
gateW,
|
||||
gateW_quant,
|
||||
gateA,
|
||||
gateB,
|
||||
gateS,
|
||||
upW,
|
||||
upW_quant,
|
||||
upA,
|
||||
upB,
|
||||
upS,
|
||||
downW,
|
||||
downW_quant,
|
||||
downA,
|
||||
downB,
|
||||
downS,
|
||||
geglu_forward,
|
||||
geglu_backward,
|
||||
inplace,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class LoRA_QKV(torch.autograd.Function):
|
||||
"""
|
||||
Optimized LoRA QKV implementation with quantization support.
|
||||
|
||||
Implements efficient computation of query, key, value projections with LoRA,
|
||||
supporting quantization and memory optimization.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
X: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
q_quant: QuantState | None,
|
||||
q_A: torch.Tensor | None,
|
||||
q_B: torch.Tensor | None,
|
||||
q_scale: float,
|
||||
k_weight: torch.Tensor,
|
||||
k_quant: QuantState | None,
|
||||
k_A: torch.Tensor | None,
|
||||
k_B: torch.Tensor | None,
|
||||
k_scale: float,
|
||||
v_weight: torch.Tensor,
|
||||
v_quant: QuantState | None,
|
||||
v_A: torch.Tensor | None,
|
||||
v_B: torch.Tensor | None,
|
||||
v_scale: float,
|
||||
inplace: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass computing Q, K, V projections with LoRA.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
X: Input tensor
|
||||
q_weight: Query projection weight
|
||||
q_quant: Query quantization state
|
||||
q_A: Query LoRA A matrix
|
||||
q_B: Query LoRA B matrix
|
||||
q_scale: Query LoRA scale
|
||||
k_weight: Key projection weight
|
||||
k_quant: Key quantization state
|
||||
k_A: Key LoRA A matrix
|
||||
k_B: Key LoRA B matrix
|
||||
k_scale: Key LoRA scale
|
||||
v_weight: Value projection weight
|
||||
v_quant: Value quantization state
|
||||
v_A: Value LoRA A matrix
|
||||
v_B: Value LoRA B matrix
|
||||
v_scale: Value LoRA scale
|
||||
inplace: Whether to perform operations in-place
|
||||
|
||||
Returns:
|
||||
Tuple of (Query, Key, Value) projection tensors
|
||||
"""
|
||||
Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale)
|
||||
K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale)
|
||||
V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale)
|
||||
|
||||
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
|
||||
ctx.scales = (q_scale, k_scale, v_scale)
|
||||
ctx.quants = (q_quant, k_quant, v_quant)
|
||||
ctx.weights = (q_weight, k_weight, v_weight)
|
||||
ctx.inplace = inplace
|
||||
|
||||
return Q, K, V
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
q_grad: torch.Tensor,
|
||||
k_grad: torch.Tensor,
|
||||
v_grad: torch.Tensor,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Backward pass computing gradients for LoRA QKV.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
q_grad: Gradient for query projection
|
||||
k_grad: Gradient for key projection
|
||||
v_grad: Gradient for value projection
|
||||
|
||||
Returns:
|
||||
Tuple containing gradients for all forward inputs
|
||||
"""
|
||||
X, A_q, B_q, A_k, B_k, A_v, B_v = ctx.saved_tensors
|
||||
q_weight, k_weight, v_weight = ctx.weights
|
||||
q_quant, k_quant, v_quant = ctx.quants
|
||||
q_scale, k_scale, v_scale = ctx.scales
|
||||
dtype = X.dtype
|
||||
|
||||
# Reshape gradients
|
||||
batch, seq_len = X.shape[:2]
|
||||
q_grad = q_grad.view(-1, q_grad.shape[-1])
|
||||
k_grad = k_grad.reshape(-1, k_grad.shape[-1])
|
||||
v_grad = v_grad.view(-1, v_grad.shape[-1])
|
||||
X = X.view(-1, X.shape[-1])
|
||||
|
||||
# Pre-transpose X once
|
||||
X_t = X.t()
|
||||
|
||||
# Initialize LoRA gradients as None
|
||||
d_A_q = d_B_q = d_A_k = d_B_k = d_A_v = d_B_v = None
|
||||
|
||||
# Compute q path LoRA gradients if adapters exist
|
||||
if A_q is not None and B_q is not None:
|
||||
A_q_scaled = (q_scale * A_q).to(dtype)
|
||||
B_q_scaled = B_q.to(dtype)
|
||||
d_A_q = torch.mm(X_t, torch.mm(q_grad, B_q_scaled))
|
||||
d_B_q = torch.mm(torch.mm(A_q_scaled, X_t), q_grad)
|
||||
|
||||
# Compute k path LoRA gradients if adapters exist
|
||||
if A_k is not None and B_k is not None:
|
||||
A_k_scaled = (k_scale * A_k).to(dtype)
|
||||
B_k_scaled = B_k.to(dtype)
|
||||
d_A_k = torch.mm(X_t, torch.mm(k_grad, B_k_scaled))
|
||||
d_B_k = torch.mm(torch.mm(A_k_scaled, X_t), k_grad)
|
||||
|
||||
# Compute v path LoRA gradients if adapters exist
|
||||
if A_v is not None and B_v is not None:
|
||||
A_v_scaled = (v_scale * A_v).to(dtype)
|
||||
B_v_scaled = B_v.to(dtype)
|
||||
d_A_v = torch.mm(X_t, torch.mm(v_grad, B_v_scaled))
|
||||
d_B_v = torch.mm(torch.mm(A_v_scaled, X_t), v_grad)
|
||||
|
||||
# Compute input gradient, reusing X memory if possible
|
||||
out_buffer = X if ctx.inplace else None
|
||||
|
||||
# Q path
|
||||
q_weight_t = dequantize(q_weight, q_quant)
|
||||
grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer)
|
||||
del q_weight
|
||||
del q_weight_t
|
||||
if A_q is not None and B_q is not None:
|
||||
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
|
||||
|
||||
# K path
|
||||
k_weight_t = dequantize(k_weight, k_quant)
|
||||
grad_X.addmm_(k_grad, k_weight_t)
|
||||
del k_weight
|
||||
del k_weight_t
|
||||
if A_k is not None and B_k is not None:
|
||||
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
|
||||
|
||||
# V path
|
||||
v_weight_t = dequantize(v_weight, v_quant)
|
||||
grad_X.addmm_(v_grad, v_weight_t)
|
||||
del v_weight
|
||||
del v_weight_t
|
||||
if A_v is not None and B_v is not None:
|
||||
grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled))
|
||||
|
||||
# Transpose gradients if needed
|
||||
if d_A_q is not None:
|
||||
d_A_q = d_A_q.t()
|
||||
if d_B_q is not None:
|
||||
d_B_q = d_B_q.t()
|
||||
if d_A_k is not None:
|
||||
d_A_k = d_A_k.t()
|
||||
if d_B_k is not None:
|
||||
d_B_k = d_B_k.t()
|
||||
if d_A_v is not None:
|
||||
d_A_v = d_A_v.t()
|
||||
if d_B_v is not None:
|
||||
d_B_v = d_B_v.t()
|
||||
|
||||
return (
|
||||
grad_X.view(batch, seq_len, -1),
|
||||
None,
|
||||
None,
|
||||
d_A_q,
|
||||
d_B_q,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_k,
|
||||
d_B_k,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_v,
|
||||
d_B_v,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_qkv(
|
||||
self, X: torch.Tensor, inplace: bool = True
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Applies LoRA to compute Query, Key, Value projections.
|
||||
|
||||
Args:
|
||||
X: Input tensor
|
||||
inplace: Whether to perform operations in-place
|
||||
|
||||
Returns:
|
||||
Tuple of (Query, Key, Value) projection tensors
|
||||
"""
|
||||
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
||||
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
||||
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
||||
Q, K, V = LoRA_QKV.apply(
|
||||
X,
|
||||
QW,
|
||||
QW_quant,
|
||||
QA,
|
||||
QB,
|
||||
QS,
|
||||
KW,
|
||||
KW_quant,
|
||||
KA,
|
||||
KB,
|
||||
KS,
|
||||
VW,
|
||||
VW_quant,
|
||||
VA,
|
||||
VB,
|
||||
VS,
|
||||
inplace,
|
||||
)
|
||||
|
||||
return Q, K, V
|
||||
|
||||
|
||||
class LoRA_O(torch.autograd.Function):
|
||||
"""Optimized LoRA implementation for output projection."""
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
W_quant: QuantState | None,
|
||||
A: torch.Tensor | None,
|
||||
B: torch.Tensor | None,
|
||||
S: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for output projection with LoRA.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
X: Input tensor
|
||||
W: Output projection weight
|
||||
W_quant: Weight quantization state
|
||||
A: LoRA A matrix
|
||||
B: LoRA B matrix
|
||||
S: LoRA scaling factor
|
||||
|
||||
Returns:
|
||||
Output projection tensor
|
||||
"""
|
||||
XW = matmul_lora(X, W, W_quant, A, B, S)
|
||||
ctx.custom_saved_tensors = (
|
||||
W,
|
||||
W_quant,
|
||||
S,
|
||||
)
|
||||
ctx.save_for_backward(A, B, X)
|
||||
|
||||
return XW
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_bwd
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
dY: torch.Tensor,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Backward pass computing gradients for LoRA output projection.
|
||||
|
||||
Args:
|
||||
ctx: Autograd context
|
||||
dY: Gradient of loss with respect to output
|
||||
|
||||
Returns:
|
||||
Tuple containing gradients for all forward inputs
|
||||
"""
|
||||
W, W_quant, S = ctx.custom_saved_tensors
|
||||
A, B, X = ctx.saved_tensors
|
||||
|
||||
batch, seq_len, hd = X.shape
|
||||
dY = dY.reshape(-1, dY.shape[-1])
|
||||
X = X.reshape(-1, X.shape[-1])
|
||||
dtype = X.dtype
|
||||
|
||||
# Weight projection
|
||||
dY_X = X.t() @ dY
|
||||
d_A = S * dY_X @ B
|
||||
d_B = S * A @ dY_X
|
||||
|
||||
# Get derivative for dX
|
||||
W = dequantize(W.t(), W_quant)
|
||||
dX = dY @ W.t()
|
||||
del W
|
||||
dX += dY @ B.to(dtype) @ (S * A.to(dtype))
|
||||
|
||||
# W, W_quant, A, B, S
|
||||
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
|
||||
|
||||
|
||||
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies LoRA to output projection layer.
|
||||
|
||||
Args:
|
||||
X: Input tensor
|
||||
|
||||
Returns:
|
||||
Transformed output tensor
|
||||
"""
|
||||
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
||||
output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS)
|
||||
|
||||
return output
|
||||
@@ -1,149 +0,0 @@
|
||||
"""Dequantization utilities for `bitsandbytes` integration."""
|
||||
# pylint: disable=invalid-name,global-statement
|
||||
|
||||
import ctypes
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState, get_ptr
|
||||
from packaging.version import Version
|
||||
|
||||
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
||||
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
|
||||
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
|
||||
|
||||
CUDA_STREAM: torch.cuda.Stream | None = None
|
||||
HAS_CUDA_STREAM: bool = Version(bnb.__version__) > Version("0.43.3")
|
||||
|
||||
|
||||
def dequantize(
|
||||
W: torch.Tensor,
|
||||
quant_state: QuantState | list | None = None,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fast NF4 dequantization using `bitsandbytes` CUDA kernels.
|
||||
|
||||
Performs efficient dequantization of weights from NF4 format using `bitsandbytes`'
|
||||
optimized CUDA implementations. Supports both legacy list and new `QuantState`
|
||||
formats.
|
||||
|
||||
Args:
|
||||
W: Quantized weight tensor to dequantize
|
||||
quant_state: Quantization state containing metadata needed for
|
||||
dequantization. Can be either a `QuantState` object or legacy list format.
|
||||
If None, returns `W` unchanged.
|
||||
out: Optional output tensor for storing dequantized results. Must match
|
||||
expected shape and dtype if provided.
|
||||
|
||||
Returns:
|
||||
Dequantized tensor in the specified dtype (fp16 or bf16). Will be transposed if
|
||||
input `W` was transposed.
|
||||
|
||||
Raises:
|
||||
AssertionError: If provided output tensor doesn't match expected shape / dtype.
|
||||
|
||||
Note:
|
||||
Uses CUDA streams for better performance when available in newer `bitsandbytes`
|
||||
versions (>0.43.3).
|
||||
"""
|
||||
if quant_state is None:
|
||||
return W
|
||||
|
||||
# Get the target device from input tensor W
|
||||
target_device = W.device
|
||||
|
||||
# Extract quantization state
|
||||
if not isinstance(quant_state, list):
|
||||
# New style quant_state class
|
||||
absmax = quant_state.absmax.to(target_device)
|
||||
shape = quant_state.shape
|
||||
dtype = quant_state.dtype
|
||||
blocksize = quant_state.blocksize
|
||||
offset = quant_state.offset.to(target_device)
|
||||
state2 = quant_state.state2
|
||||
absmax2 = state2.absmax.to(target_device)
|
||||
code2 = state2.code.to(target_device)
|
||||
blocksize2 = state2.blocksize
|
||||
else:
|
||||
# Legacy list format
|
||||
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
|
||||
absmax = absmax.to(target_device)
|
||||
offset, state2 = compressed_stats
|
||||
offset = offset.to(target_device)
|
||||
absmax2, code2, blocksize2, _, _, _, _ = state2
|
||||
absmax2 = absmax2.to(target_device)
|
||||
code2 = code2.to(target_device)
|
||||
|
||||
# Setup output tensor on the same device as input
|
||||
if out is None:
|
||||
out = torch.empty(shape, dtype=dtype, device=target_device)
|
||||
else:
|
||||
assert out.shape == shape and out.dtype == dtype
|
||||
out = out.to(target_device)
|
||||
|
||||
# Dequantize statistics on the target device
|
||||
n_elements_absmax: int = absmax.numel()
|
||||
out_absmax: torch.Tensor = torch.empty(
|
||||
n_elements_absmax, dtype=torch.float32, device=target_device
|
||||
)
|
||||
ptr_out_absmax: int = get_ptr(out_absmax)
|
||||
|
||||
# Use CUDA stream if available
|
||||
if HAS_CUDA_STREAM:
|
||||
global CUDA_STREAM
|
||||
if CUDA_STREAM is None:
|
||||
CUDA_STREAM = torch.cuda.current_stream(target_device)
|
||||
|
||||
cdequantize_blockwise_fp32(
|
||||
get_ptr(code2),
|
||||
get_ptr(absmax),
|
||||
get_ptr(absmax2),
|
||||
ptr_out_absmax,
|
||||
ctypes.c_int(blocksize2),
|
||||
ctypes.c_int(n_elements_absmax),
|
||||
CUDA_STREAM,
|
||||
)
|
||||
else:
|
||||
cdequantize_blockwise_fp32(
|
||||
get_ptr(code2),
|
||||
get_ptr(absmax),
|
||||
get_ptr(absmax2),
|
||||
ptr_out_absmax,
|
||||
ctypes.c_int(blocksize2),
|
||||
ctypes.c_int(n_elements_absmax),
|
||||
)
|
||||
|
||||
out_absmax += offset
|
||||
|
||||
# Choose appropriate dequantization function
|
||||
fx = (
|
||||
cdequantize_blockwise_fp16_nf4
|
||||
if dtype == torch.float16
|
||||
else cdequantize_blockwise_bf16_nf4
|
||||
)
|
||||
|
||||
# Dequantize weights
|
||||
if HAS_CUDA_STREAM:
|
||||
fx(
|
||||
get_ptr(None),
|
||||
get_ptr(W),
|
||||
ptr_out_absmax,
|
||||
get_ptr(out),
|
||||
ctypes.c_int(blocksize),
|
||||
ctypes.c_int(out.numel()),
|
||||
CUDA_STREAM,
|
||||
)
|
||||
else:
|
||||
fx(
|
||||
get_ptr(None),
|
||||
get_ptr(W),
|
||||
ptr_out_absmax,
|
||||
get_ptr(out),
|
||||
ctypes.c_int(blocksize),
|
||||
ctypes.c_int(out.numel()),
|
||||
)
|
||||
|
||||
# Handle transposed data
|
||||
is_transposed: bool = W.shape[0] == 1
|
||||
return out.t() if is_transposed else out
|
||||
@@ -1,163 +0,0 @@
|
||||
"""
|
||||
Module for definition of SwiGLU Triton kernels.
|
||||
|
||||
See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
|
||||
|
||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||
"""
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _swiglu_fwd_kernel(
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
SwiGLU forward kernel. The kernel computes activation in fp32 precision for better
|
||||
numerical stability, then converts back to original dtype for the final result.
|
||||
|
||||
Args:
|
||||
gate_ptr: Pointer to gate tensor `[*, hidden_dim]`.
|
||||
up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`.
|
||||
out_ptr: Pointer to output tensor `[*, hidden_dim]`.
|
||||
n_elements: Total number of elements in the input tensors.
|
||||
block_size: Size of thread blocks for parallel computation.
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * block_size + tl.arange(0, block_size)
|
||||
mask = offsets < n_elements
|
||||
|
||||
# Load gate in fp32, keep up in original dtype
|
||||
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
|
||||
up = tl.load(up_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
# Compute activation in fp32 then convert back
|
||||
f = gate * tl.sigmoid(gate)
|
||||
f = f.to(up.dtype)
|
||||
result = f * up
|
||||
|
||||
tl.store(out_ptr + offsets, result, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _swiglu_bwd_kernel(
|
||||
grad_out_ptr,
|
||||
gate_ptr,
|
||||
up_ptr,
|
||||
n_elements,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
SwiGLU backward kernel. Stores gradient results in-place.
|
||||
|
||||
Args:
|
||||
grad_out_ptr: Pointer to gradient output tensor `[*, hidden_dim]`.
|
||||
gate_ptr: Pointer to gate tensor `[*, hidden_dim]`.
|
||||
up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`.
|
||||
n_elements: Total number of elements in the input tensors.
|
||||
block_size: Size of thread blocks for parallel computation.
|
||||
|
||||
Note:
|
||||
After kernel execution, tensors are modified in-place:
|
||||
- `grad_out_ptr` contains forward output (`h`)
|
||||
- `gate_ptr` contains gradient w.r.t gate (`grad_gate`)
|
||||
- `up_ptr` contains gradient w.r.t up (`grad_up`)
|
||||
"""
|
||||
block_idx = tl.program_id(0)
|
||||
offsets = block_idx * block_size + tl.arange(0, block_size)
|
||||
mask = offsets < n_elements
|
||||
|
||||
# Load values - only convert gate to fp32
|
||||
grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0)
|
||||
gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32)
|
||||
up = tl.load(up_ptr + offsets, mask=mask, other=0)
|
||||
|
||||
# Compute SiLU and forward output
|
||||
sigmoid_gate = tl.sigmoid(gate)
|
||||
silu_gate = sigmoid_gate * gate
|
||||
silu_gate = silu_gate.to(grad_out.dtype)
|
||||
h = silu_gate * up
|
||||
|
||||
# Compute gradients
|
||||
grad_up = grad_out * silu_gate # gradient for up is grad_out * SiLU(gate)
|
||||
|
||||
# Compute gate gradient
|
||||
temp = grad_out * up
|
||||
grad_gate = temp.to(tl.float32) * sigmoid_gate * (1.0 + gate * (1.0 - sigmoid_gate))
|
||||
grad_gate = grad_gate.to(grad_out.dtype)
|
||||
|
||||
# Store results with correct gradient ordering
|
||||
tl.store(grad_out_ptr + offsets, h, mask=mask)
|
||||
tl.store(gate_ptr + offsets, grad_gate, mask=mask) # grad wrt gate
|
||||
tl.store(up_ptr + offsets, grad_up, mask=mask) # grad wrt up
|
||||
|
||||
|
||||
# pylint: disable=unnecessary-lambda-assignment
|
||||
def swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
SwiGLU forward pass. Computes SwiGLU activation: `x * sigmoid(x) * up`, where
|
||||
`x` is the gate tensor.
|
||||
|
||||
Args:
|
||||
gate: Input gate tensor of shape `[batch, seq_len, hidden_dim]`.
|
||||
up: Up-projection tensor of shape `[batch, seq_len, hidden_dim]`.
|
||||
|
||||
Returns:
|
||||
Output tensor of shape `[batch, seq_len, hidden_dim]`.
|
||||
"""
|
||||
batch, seq_len, hidden_dim = gate.shape
|
||||
n_elements = gate.numel()
|
||||
out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device="cuda")
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["block_size"]),) # noqa: E731
|
||||
_swiglu_fwd_kernel[grid](
|
||||
gate_ptr=gate,
|
||||
up_ptr=up,
|
||||
out_ptr=out,
|
||||
n_elements=n_elements,
|
||||
block_size=1024,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# pylint: disable=unnecessary-lambda-assignment
|
||||
def swiglu_backward(
|
||||
grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
SwiGLU backward pass using in-place operations.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`.
|
||||
gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
|
||||
up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Forward pass output (`h`)
|
||||
- Gradient with respect to gate (`df`)
|
||||
- Gradient with respect to up-projection (`de`)
|
||||
"""
|
||||
n_elements = grad_output.numel()
|
||||
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["block_size"]),) # noqa: E731
|
||||
_swiglu_bwd_kernel[grid](
|
||||
grad_out_ptr=grad_output,
|
||||
gate_ptr=gate,
|
||||
up_ptr=up,
|
||||
n_elements=n_elements,
|
||||
block_size=1024,
|
||||
)
|
||||
|
||||
# After kernel execution, tensors contain:
|
||||
# grad_output: h (forward output)
|
||||
# gate: grad_gate (grad wrt gate)
|
||||
# up: grad_up (grad wrt up)
|
||||
return grad_output, gate, up
|
||||
@@ -1,11 +0,0 @@
|
||||
"""Utilities for `axolotl.kernels` submodules."""
|
||||
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
|
||||
if Version(torch.__version__) < Version("2.4.0"):
|
||||
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
||||
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
||||
else:
|
||||
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
@@ -1,314 +0,0 @@
|
||||
"""Module for patching custom LoRA Triton kernels and `torch.autograd` functions."""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import types
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from peft import PeftModelForCausalLM
|
||||
from torch import nn
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from axolotl.kernels.lora import (
|
||||
apply_lora_mlp_geglu,
|
||||
apply_lora_mlp_swiglu,
|
||||
apply_lora_o,
|
||||
apply_lora_qkv,
|
||||
)
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_QKV_CODE = """
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
PATCHED_QKV_CODE = """
|
||||
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||
query_states = query_states.view(hidden_shape).transpose(1, 2)
|
||||
key_states = key_states.view(hidden_shape).transpose(1, 2)
|
||||
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
ORIGINAL_O_CODE = """
|
||||
attn_output = self.o_proj(attn_output)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
PATCHED_O_CODE = """
|
||||
attn_output = self.apply_o(attn_output)
|
||||
""".lstrip(
|
||||
"\n"
|
||||
)
|
||||
|
||||
SUPPORTED_ACTIVATIONS = ["silu", "gelu"]
|
||||
APPLY_FN_MAPPING = {
|
||||
"silu": apply_lora_mlp_swiglu,
|
||||
"gelu": apply_lora_mlp_geglu,
|
||||
}
|
||||
|
||||
|
||||
def original_apply_qkv(
|
||||
self: nn.Module, hidden_states: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Original implementation of QKV projection without optimizations.
|
||||
|
||||
Args:
|
||||
self: The attention module instance.
|
||||
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim].
|
||||
|
||||
Returns:
|
||||
A tuple `(query_states, key_states, value_states)` containing the projected
|
||||
states for query, key, and value.
|
||||
"""
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
return query_states, key_states, value_states
|
||||
|
||||
|
||||
def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Original implementation of output projection without optimizations.
|
||||
|
||||
Args:
|
||||
self: The attention module instance.
|
||||
hidden_states: Input tensor of shape `[`batch_size, seq_len, hidden_dim]`.
|
||||
|
||||
Returns:
|
||||
The output projection result.
|
||||
"""
|
||||
attn_output = self.o_proj(hidden_states)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def patch_self_attn_lora(model: PreTrainedModel):
|
||||
"""
|
||||
Patches the attention classes in a transformer model with optimized LoRA implementations.
|
||||
|
||||
It modifies the attention class to use optimized QKV and output projections. The
|
||||
original implementation is preserved and can be restored if needed.
|
||||
|
||||
Args:
|
||||
model: A HuggingFace transformers model.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the required code blocks are not found in the attention
|
||||
implementation.
|
||||
"""
|
||||
# Find all attention modules in the model
|
||||
attention_modules = [
|
||||
module
|
||||
for module in model.modules()
|
||||
if "attention" in module.__class__.__name__.lower()
|
||||
and hasattr(module, "forward")
|
||||
]
|
||||
|
||||
if not attention_modules:
|
||||
LOG.warning("No attention modules found in model")
|
||||
return
|
||||
|
||||
attention_classes = {type(module) for module in attention_modules}
|
||||
LOG.info(f"Found attention classes: {[cls.__name__ for cls in attention_classes]}")
|
||||
|
||||
for attention_cls in attention_classes:
|
||||
# Skip if already patched
|
||||
if hasattr(attention_cls, "_original_forward"):
|
||||
LOG.info(f"{attention_cls.__name__} already patched")
|
||||
continue
|
||||
|
||||
# Get and store original forward implementation
|
||||
self_attn_forward = inspect.getsource(attention_cls.forward)
|
||||
attention_cls._original_forward = self_attn_forward
|
||||
|
||||
# Remove indentation
|
||||
self_attn_forward, _ = detab_code(self_attn_forward)
|
||||
|
||||
# Verify required code blocks exist
|
||||
assert (
|
||||
ORIGINAL_QKV_CODE in self_attn_forward
|
||||
), f"Original QKV code not found in {attention_cls.__name__}"
|
||||
assert (
|
||||
ORIGINAL_O_CODE in self_attn_forward
|
||||
), f"Original O code not found in {attention_cls.__name__}"
|
||||
|
||||
# Replace code blocks
|
||||
self_attn_forward = self_attn_forward.replace(
|
||||
ORIGINAL_QKV_CODE, PATCHED_QKV_CODE
|
||||
)
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(
|
||||
"def forward(",
|
||||
"def axolotl_attn_forward(",
|
||||
1,
|
||||
)
|
||||
|
||||
# Import necessary symbols from the attention module
|
||||
module_name = attention_cls.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in self_attn_forward:
|
||||
items_to_import.append(item)
|
||||
|
||||
if items_to_import:
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
|
||||
# Execute the new implementation
|
||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
|
||||
attention_cls.forward = (
|
||||
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_kernel_patches(
|
||||
model: PeftModelForCausalLM, cfg: DictDefault
|
||||
) -> PeftModelForCausalLM:
|
||||
"""
|
||||
Applies optimized Triton kernel patches to a PEFT model.
|
||||
|
||||
Patches a PEFT model with optimized implementations for MLP and attention
|
||||
computations. The optimizations include custom Triton kernels for activation
|
||||
functions and specialized autograd functions for LoRA computations.
|
||||
|
||||
Args:
|
||||
model: A PEFT model to be patched with optimized kernels.
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
PeftModelForCausalLM: The patched model with optimized kernels.
|
||||
|
||||
Raises:
|
||||
TypeError: If the provided model is not a `PeftModelForCausalLM`.
|
||||
NotImplementedError: If the model type is not supported.
|
||||
AssertionError: If multiple adapters are active (currently unsupported).
|
||||
|
||||
Note:
|
||||
The optimizations require LoRA adapters with no dropout and no bias terms. The
|
||||
function will skip patching if these conditions aren't met.
|
||||
"""
|
||||
if not isinstance(model, PeftModelForCausalLM):
|
||||
raise TypeError("Model must be a PeftModelForCausalLM")
|
||||
|
||||
# Get active LoRA adapter config
|
||||
if hasattr(model, "active_adapters"):
|
||||
assert (
|
||||
len(model.active_adapters) == 1
|
||||
), "Axolotl currently does not support LoRA Triton kernels for multiple adapters"
|
||||
active_adapter = model.active_adapters[0]
|
||||
else:
|
||||
active_adapter = model.active_adapter
|
||||
lora_config = model.model.peft_config[active_adapter]
|
||||
|
||||
# Only patch if conditions are met
|
||||
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
|
||||
|
||||
if not can_patch:
|
||||
LOG.warning("Cannot patch layers - requires no dropout and no bias")
|
||||
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
|
||||
return model
|
||||
|
||||
# This needs to be reset after patching
|
||||
original_level = LOG.getEffectiveLevel()
|
||||
LOG.setLevel(logging.INFO)
|
||||
|
||||
# Choose activation based on model type
|
||||
activation = model.config.hidden_act
|
||||
if activation not in SUPPORTED_ACTIVATIONS:
|
||||
raise NotImplementedError(f"Activation {activation} is not supported")
|
||||
|
||||
# Patch each layer
|
||||
for layer in model.model.model.layers:
|
||||
# Add QKV, O fallback implementations to start
|
||||
# These will be overwritten later (if some conditions apply)
|
||||
layer.self_attn.apply_qkv = types.MethodType(
|
||||
original_apply_qkv, layer.self_attn
|
||||
)
|
||||
layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn)
|
||||
|
||||
if cfg.lora_mlp_kernel:
|
||||
# MLP patching
|
||||
gate_proj = layer.mlp.gate_proj
|
||||
up_proj = layer.mlp.up_proj
|
||||
down_proj = layer.mlp.down_proj
|
||||
|
||||
can_patch_mlp = all(
|
||||
hasattr(proj, "lora_A")
|
||||
and getattr(proj, "base_layer", proj).bias is None
|
||||
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
||||
for proj in (gate_proj, up_proj, down_proj)
|
||||
)
|
||||
|
||||
if can_patch_mlp:
|
||||
apply_fn = APPLY_FN_MAPPING[activation]
|
||||
layer.mlp.forward = types.MethodType(apply_fn, layer.mlp)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
|
||||
)
|
||||
if cfg.lora_qkv_kernel:
|
||||
# Query, key, value patching
|
||||
layer_modules = [
|
||||
getattr(layer.self_attn, linear_proj)
|
||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||
]
|
||||
can_patch_qkv = all(
|
||||
hasattr(module, "lora_A")
|
||||
and getattr(module, "base_layer", module).bias is None
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if can_patch_qkv:
|
||||
# Add optimized implementation
|
||||
layer.self_attn.apply_qkv = types.MethodType(
|
||||
apply_lora_qkv, layer.self_attn
|
||||
)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
|
||||
)
|
||||
if cfg.lora_o_kernel:
|
||||
# Output patching
|
||||
layer_modules = [
|
||||
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||
]
|
||||
can_patch_o = all(
|
||||
hasattr(module, "lora_A")
|
||||
and getattr(module, "base_layer", module).bias is None
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
if can_patch_o:
|
||||
layer.self_attn.apply_o = types.MethodType(
|
||||
apply_lora_o, layer.self_attn
|
||||
)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
|
||||
)
|
||||
|
||||
LOG.setLevel(original_level)
|
||||
|
||||
return model
|
||||
@@ -41,10 +41,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
||||
load_kwargs["ds_cfg"] = ds_cfg
|
||||
if "processor" in sig.parameters:
|
||||
load_kwargs["processor"] = processor
|
||||
|
||||
return func(tokenizer, cfg, **load_kwargs)
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
||||
raise exc
|
||||
return None
|
||||
|
||||
@@ -13,19 +13,8 @@ def load(strategy, cfg, module_base=None, **kwargs):
|
||||
if len(strategy.split(".")) == 1:
|
||||
strategy = strategy + ".default"
|
||||
load_fn = strategy.split(".")[-1]
|
||||
if len(strategy.split(".")) > 1:
|
||||
try:
|
||||
importlib.import_module(
|
||||
strategy.split(".")[-2],
|
||||
".".join(strategy.split(".")[:-2]),
|
||||
)
|
||||
module_base = ".".join(strategy.split(".")[:-2])
|
||||
strategy = strategy.split(".")[-2]
|
||||
except ModuleNotFoundError:
|
||||
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||
else:
|
||||
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(strategy, module_base)
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{strategy}", module_base)
|
||||
func = getattr(mod, load_fn)
|
||||
return func(cfg, **kwargs)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
|
||||
@@ -34,12 +34,15 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
|
||||
max_length = self.prompter.max_length
|
||||
|
||||
self.messages = "chosen_messages"
|
||||
# pylint: disable=duplicate-code
|
||||
prompt["messages"] = []
|
||||
prompt[self.messages] = []
|
||||
if prompt["system"]:
|
||||
prompt["messages"].append({"role": "system", "content": prompt["system"]})
|
||||
prompt["messages"].append({"role": "user", "content": prompt["input"]})
|
||||
prompt["messages"].append({"role": "assistant", "content": prompt["chosen"]})
|
||||
prompt[self.messages].append(
|
||||
{"role": "system", "content": prompt["system"]}
|
||||
)
|
||||
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
||||
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
|
||||
chosen_tokenized = super()._tokenize_single_prompt(prompt)
|
||||
|
||||
if len(chosen_tokenized["input_ids"]) > max_length:
|
||||
@@ -52,12 +55,17 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
:max_length
|
||||
]
|
||||
|
||||
self.messages = "rejected_messages"
|
||||
# pylint: disable=duplicate-code
|
||||
prompt["messages"] = []
|
||||
prompt[self.messages] = []
|
||||
if prompt["system"]:
|
||||
prompt["messages"].append({"role": "system", "content": prompt["system"]})
|
||||
prompt["messages"].append({"role": "user", "content": prompt["input"]})
|
||||
prompt["messages"].append({"role": "assistant", "content": prompt["rejected"]})
|
||||
prompt[self.messages].append(
|
||||
{"role": "system", "content": prompt["system"]}
|
||||
)
|
||||
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
||||
prompt[self.messages].append(
|
||||
{"role": "assistant", "content": prompt["rejected"]}
|
||||
)
|
||||
rejected_tokenized = super()._tokenize_single_prompt(prompt)
|
||||
|
||||
if len(rejected_tokenized["input_ids"]) > max_length:
|
||||
@@ -91,13 +99,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_template_string,
|
||||
"message_property_mappings": ds_cfg.get(
|
||||
"message_property_mappings",
|
||||
{
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
),
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_field_training_detail", None
|
||||
@@ -121,4 +124,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
||||
)
|
||||
|
||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
|
||||
return strategy
|
||||
|
||||
@@ -4,16 +4,13 @@ HF Chat Templates prompt strategy
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import ProcessorMixin
|
||||
|
||||
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
|
||||
|
||||
# Configure the logger
|
||||
LOG = logging.getLogger("axolotl")
|
||||
@@ -26,23 +23,16 @@ class ChatTemplatePrompter(Prompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template: str,
|
||||
processor=None,
|
||||
chat_template=None,
|
||||
max_length=2048,
|
||||
message_property_mappings: Optional[Dict[str, str]] = None,
|
||||
message_field_role: str = "role",
|
||||
message_field_content: str = "content",
|
||||
message_field_training: Optional[str] = None,
|
||||
message_field_training_detail: Optional[str] = None,
|
||||
field_messages: str = "messages",
|
||||
roles: Optional[Dict[str, List[str]]] = None,
|
||||
drop_system_message: bool = False,
|
||||
):
|
||||
# check if message_property_mappings is None or empty dict
|
||||
if message_property_mappings is None or (not message_property_mappings):
|
||||
message_property_mappings = {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
}
|
||||
|
||||
if roles:
|
||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||
else:
|
||||
@@ -55,28 +45,18 @@ class ChatTemplatePrompter(Prompter):
|
||||
"tool": "tool",
|
||||
}
|
||||
|
||||
self._chat_template_msg_variables = self.get_chat_template_msg_variables(
|
||||
chat_template, field_messages
|
||||
)
|
||||
self.message_property_mappings = message_property_mappings
|
||||
self.message_field_role = message_field_role
|
||||
self.message_field_content = message_field_content
|
||||
self.message_field_training = message_field_training
|
||||
self.message_field_training_detail = message_field_training_detail
|
||||
self.field_messages = field_messages
|
||||
self.tokenizer = tokenizer
|
||||
self.processor: Optional[ProcessorMixin] = processor
|
||||
self.processor: ProcessorMixin = processor
|
||||
self.chat_template = chat_template
|
||||
self.max_length = max_length
|
||||
self.drop_system_message = drop_system_message
|
||||
|
||||
@property
|
||||
def chat_template_msg_variables(self) -> Set[str]:
|
||||
return self._chat_template_msg_variables
|
||||
|
||||
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
||||
if self.processor:
|
||||
if not callable(self.processor):
|
||||
raise TypeError("Processor must be callable")
|
||||
|
||||
text = self.processor.apply_chat_template(
|
||||
conversation,
|
||||
chat_template=self.chat_template,
|
||||
@@ -204,21 +184,17 @@ class ChatTemplatePrompter(Prompter):
|
||||
|
||||
return adjusted_details
|
||||
|
||||
def get_chat_template_msg_variables(
|
||||
self, chat_template: str, field_messages: str
|
||||
) -> Set[str]:
|
||||
template_analyzer = JinjaTemplateAnalyzer(chat_template)
|
||||
return template_analyzer.get_message_vars(field_messages)
|
||||
|
||||
|
||||
class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for instruction-based prompts.
|
||||
"""
|
||||
|
||||
_messages = "messages"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter: "ChatTemplatePrompter",
|
||||
prompter: ChatTemplatePrompter,
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
@@ -226,7 +202,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
train_on_eos=None,
|
||||
):
|
||||
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||
self.prompter: ChatTemplatePrompter = prompter
|
||||
|
||||
self.roles_to_train = []
|
||||
if roles_to_train:
|
||||
@@ -238,9 +213,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
self.train_on_eos = train_on_eos
|
||||
self.images = "images"
|
||||
|
||||
LOG.debug(
|
||||
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
|
||||
)
|
||||
@property
|
||||
def messages(self):
|
||||
return self._messages
|
||||
|
||||
@messages.setter
|
||||
def messages(self, messages):
|
||||
self._messages = messages
|
||||
|
||||
@property
|
||||
def supports_batched(self) -> bool:
|
||||
@@ -250,7 +229,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
||||
try:
|
||||
return all(isinstance(v, list) for v in prompt.values()) and all(
|
||||
isinstance(v, list) for v in prompt[self.prompter.field_messages]
|
||||
isinstance(v, list) for v in prompt[self.messages]
|
||||
)
|
||||
except KeyError:
|
||||
return False
|
||||
@@ -485,17 +464,30 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
turns = []
|
||||
for message in prompt[self.prompter.field_messages]:
|
||||
transformed_message = self.transform_message(message)
|
||||
|
||||
optional_keys = [
|
||||
"tool_calls", # tool that 'assistant' calls
|
||||
"name", # name of tool given by 'tool'
|
||||
"tool_call_id", # mistral/mixtral requires this
|
||||
]
|
||||
for message in prompt[self.messages]:
|
||||
turn = {
|
||||
**transformed_message,
|
||||
"role": self.prompter.roles[message[self.prompter.message_field_role]],
|
||||
"training": message.get(self.prompter.message_field_training),
|
||||
"training_detail": message.get(
|
||||
self.prompter.message_field_training_detail
|
||||
),
|
||||
}
|
||||
|
||||
# do not add content if None as it may conflict with some templates due to tools
|
||||
content = message.get(self.prompter.message_field_content, None)
|
||||
if content is not None:
|
||||
turn["content"] = content
|
||||
|
||||
for key in optional_keys:
|
||||
value = message.get(key, None)
|
||||
if value is not None:
|
||||
turn[key] = value
|
||||
|
||||
turns.append(turn)
|
||||
|
||||
if self.prompter.drop_system_message and turns[0]["role"] == "system":
|
||||
@@ -503,37 +495,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
return turns
|
||||
|
||||
def transform_message(self, message):
|
||||
# Build the initial transformed message from the mappings
|
||||
transformed_message = {}
|
||||
for key, value in self.prompter.message_property_mappings.items():
|
||||
if message.get(value) is not None:
|
||||
transformed_message[key] = message[value]
|
||||
else:
|
||||
LOG.debug(
|
||||
f"Could not find value for property {value} in message: {message}"
|
||||
)
|
||||
|
||||
# Map the role if necessary
|
||||
if "role" in transformed_message:
|
||||
transformed_message["role"] = self.prompter.roles.get(
|
||||
transformed_message["role"], transformed_message["role"]
|
||||
)
|
||||
|
||||
# Determine which keys in the original message were not mapped
|
||||
mapped_values = set(self.prompter.message_property_mappings.values())
|
||||
remaining_keys = set(message) - mapped_values
|
||||
|
||||
# Keep only the properties defined in the chat template
|
||||
# and not already mapped
|
||||
for key in self.prompter.chat_template_msg_variables:
|
||||
if key in remaining_keys:
|
||||
val = message.get(key)
|
||||
if val is not None:
|
||||
transformed_message[key] = val
|
||||
|
||||
return transformed_message
|
||||
|
||||
def get_images(self, prompt):
|
||||
return prompt.get(self.images, None)
|
||||
|
||||
@@ -555,46 +516,33 @@ class StrategyLoader:
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
tokenizer,
|
||||
cfg,
|
||||
ds_cfg: Optional[Union[Dict[str, Any], DatasetConfig]] = None,
|
||||
processor=None,
|
||||
self, tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
|
||||
):
|
||||
if ds_cfg is None:
|
||||
dataset_config = {}
|
||||
elif isinstance(ds_cfg, BaseModel):
|
||||
dataset_config = ds_cfg.model_dump()
|
||||
else:
|
||||
dataset_config = ds_cfg
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
ds_cfg = ds_cfg or {}
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
|
||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||
)
|
||||
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
||||
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_template_string,
|
||||
"message_property_mappings": dataset_config.get(
|
||||
"message_property_mappings", {}
|
||||
),
|
||||
"message_field_training": dataset_config.get(
|
||||
"message_field_training", None
|
||||
),
|
||||
"message_field_training_detail": dataset_config.get(
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_field_training_detail",
|
||||
None,
|
||||
),
|
||||
"field_messages": dataset_config.get("field_messages", "messages"),
|
||||
"roles": dataset_config.get("roles"),
|
||||
"drop_system_message": dataset_config.get("drop_system_message", False),
|
||||
"roles": ds_cfg.get("roles"),
|
||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||
"max_length": cfg.sequence_len + 1,
|
||||
"processor": processor,
|
||||
}
|
||||
|
||||
strategy_params = self._get_strategy_params(cfg, dataset_config)
|
||||
strategy_params = self._get_strategy_params(cfg, ds_cfg)
|
||||
strategy_cls = self._get_strategy_cls()
|
||||
|
||||
strategy = strategy_cls(
|
||||
@@ -603,6 +551,9 @@ class StrategyLoader:
|
||||
**strategy_params,
|
||||
)
|
||||
|
||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
|
||||
@@ -3,28 +3,20 @@ DPO prompt strategies for using tokenizer chat templates.
|
||||
"""
|
||||
|
||||
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
|
||||
|
||||
|
||||
def default(
|
||||
cfg, dataset_idx=0, **kwargs
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
ds_cfg = cfg["datasets"][dataset_idx]
|
||||
ds_cfg = handle_legacy_message_fields_logic(ds_cfg)
|
||||
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg=cfg, ds_cfg=ds_cfg
|
||||
)
|
||||
field_messages = ds_cfg.get("field_messages", "messages")
|
||||
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||
message_property_mappings = ds_cfg.get(
|
||||
"message_property_mappings",
|
||||
{
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
)
|
||||
field_message_role = ds_cfg.get("message_field_role", "role")
|
||||
field_message_content = ds_cfg.get("message_field_content", "content")
|
||||
role_map_inv = ds_cfg.get(
|
||||
"roles",
|
||||
{
|
||||
@@ -48,18 +40,18 @@ def default(
|
||||
messages = sample[field_messages]
|
||||
messages = [
|
||||
{
|
||||
"role": role_map[m[message_property_mappings["role"]]],
|
||||
"content": m[message_property_mappings["content"]],
|
||||
"role": role_map[m[field_message_role]],
|
||||
"content": m[field_message_content],
|
||||
}
|
||||
for m in messages
|
||||
]
|
||||
chosen = {
|
||||
"role": role_map[sample[field_chosen][message_property_mappings["role"]]],
|
||||
"content": sample[field_chosen][message_property_mappings["content"]],
|
||||
"role": role_map[sample[field_chosen][field_message_role]],
|
||||
"content": sample[field_chosen][field_message_content],
|
||||
}
|
||||
rejected = {
|
||||
"role": role_map[sample[field_rejected][message_property_mappings["role"]]],
|
||||
"content": sample[field_rejected][message_property_mappings["content"]],
|
||||
"role": role_map[sample[field_rejected][field_message_role]],
|
||||
"content": sample[field_rejected][field_message_content],
|
||||
}
|
||||
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
|
||||
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
"""
|
||||
DPO prompt strategies passthrough/zero-processing strategy
|
||||
"""
|
||||
|
||||
|
||||
def default(
|
||||
cfg, dataset_idx=0, **kwargs
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
def transform_fn(
|
||||
sample, tokenizer=None
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
@@ -1,318 +0,0 @@
|
||||
"""Module for inspect jinja templates for the variables they use"""
|
||||
from typing import Dict, Optional, Set, TypedDict, Union
|
||||
|
||||
from jinja2 import Environment, meta, nodes
|
||||
|
||||
|
||||
class JinjaTemplateAnalysis(TypedDict):
|
||||
"""
|
||||
Represents the detailed analysis of a Jinja template variable.
|
||||
|
||||
Attributes:
|
||||
accessed_properties (Set[str]): A set of properties accessed from the variable
|
||||
(e.g., `foo.bar` results in 'bar' being accessed for 'foo').
|
||||
accessed_indices (Set[Union[int, float]]): A set of indices accessed from the variable.
|
||||
is_iterated (bool): Indicates if the variable is used as an iteration source in a `for` loop.
|
||||
is_conditional (bool): Indicates if the variable is referenced within a conditional statement (e.g., an `if` block).
|
||||
iteration_source (Optional[str]): The name of the variable being iterated over, if applicable.
|
||||
iteration_target (Optional[Union[str, list[str]]]): The loop target(s) assigned in the iteration.
|
||||
"""
|
||||
|
||||
accessed_properties: Set[str]
|
||||
accessed_indices: Set[Union[int, float]]
|
||||
is_iterated: bool
|
||||
is_conditional: bool
|
||||
iteration_source: Optional[str]
|
||||
iteration_target: Optional[Union[str, list[str]]]
|
||||
|
||||
|
||||
class JinjaTemplateAnalyzer:
|
||||
"""
|
||||
Analyzes Jinja templates to extract information about variable usage,
|
||||
including accessed properties, iteration, and conditional references.
|
||||
|
||||
Attributes:
|
||||
env (jinja2.Environment): The Jinja2 environment used for parsing templates.
|
||||
property_access (Dict[str, Set[str]]): Tracks accessed properties for variables.
|
||||
iteration_targets (Dict[str, str]): Maps iteration target variables to their sources.
|
||||
|
||||
Methods:
|
||||
get_template_variables(template: str) -> Dict[str, Set[str]]:
|
||||
Parse a Jinja template and return a mapping of variables to their accessed properties.
|
||||
|
||||
analyze_template(template: str) -> Dict[str, JinjaTemplateAnalysis]:
|
||||
Perform a detailed analysis of the template, including variable usage,
|
||||
iteration, and conditional references.
|
||||
|
||||
Private Methods:
|
||||
_visit_node(node) -> None:
|
||||
Recursively visit AST nodes to detect attribute access and iteration targets.
|
||||
|
||||
_get_base_name(node) -> Optional[str]:
|
||||
Extract the base variable name from a node.
|
||||
|
||||
_get_target_name(node) -> Optional[Union[str, list[str]]]:
|
||||
Extract the target name(s) from a `For` node.
|
||||
"""
|
||||
|
||||
def __init__(self, template: str):
|
||||
self.env: Environment = Environment(autoescape=True)
|
||||
self.property_access: Dict[str, Set[str]] = {}
|
||||
self.iteration_targets: Dict[str, Union[str, list[str]]] = {}
|
||||
self.index_access: Dict[str, Set[Union[int, float]]] = {}
|
||||
self.ast: nodes.Node = self.env.parse(template)
|
||||
self.template: str = template
|
||||
self.variable_assignments: Dict[str, str] = {}
|
||||
|
||||
def _visit_node(self, node) -> None:
|
||||
"""Recursively visit AST nodes to find attribute access."""
|
||||
# Handle attribute access (dot notation)
|
||||
if isinstance(node, nodes.Getattr):
|
||||
base_name = self._get_base_name(node.node)
|
||||
if base_name:
|
||||
self.property_access.setdefault(base_name, set()).add(node.attr)
|
||||
|
||||
# Handle dictionary access (subscript notation)
|
||||
elif isinstance(node, nodes.Getitem):
|
||||
base_name = self._get_base_name(node.node)
|
||||
if base_name and isinstance(node.arg, nodes.Const):
|
||||
value = node.arg.value
|
||||
if isinstance(value, (int, float)):
|
||||
self.index_access.setdefault(base_name, set()).add(value)
|
||||
else:
|
||||
self.property_access.setdefault(base_name, set()).add(value)
|
||||
|
||||
elif isinstance(node, nodes.Test) and node.name == "defined":
|
||||
base_name = self._get_base_name(node.node)
|
||||
if base_name:
|
||||
if isinstance(node.node, nodes.Getattr):
|
||||
self.property_access.setdefault(base_name, set()).add(
|
||||
node.node.attr
|
||||
)
|
||||
|
||||
# Handle loop variables
|
||||
elif isinstance(node, nodes.For):
|
||||
iter_name = self._get_base_name(node.iter)
|
||||
target_name = self._get_target_name(node.target)
|
||||
if iter_name and target_name:
|
||||
self.iteration_targets[target_name] = iter_name
|
||||
self.property_access.setdefault(iter_name, set())
|
||||
|
||||
elif isinstance(node, nodes.Assign):
|
||||
target_name = self._get_target_name(node.target)
|
||||
source_name = self._get_base_name(node.node)
|
||||
if target_name and source_name:
|
||||
self.variable_assignments[target_name] = source_name
|
||||
|
||||
elif isinstance(node, nodes.Filter):
|
||||
if node.name == "selectattr":
|
||||
target = self._get_base_name(node.node)
|
||||
if target:
|
||||
self.variable_assignments[f"filtered_{target}"] = target
|
||||
|
||||
for child in node.iter_child_nodes():
|
||||
self._visit_node(child)
|
||||
|
||||
def _get_target_name(self, node) -> Optional[str]:
|
||||
"""Get the target variable name from a For node.
|
||||
|
||||
Args:
|
||||
node: A Jinja AST node representing either a Name or Tuple node
|
||||
|
||||
Returns:
|
||||
- str: For simple variable targets (e.g., "item" in "for item in items")
|
||||
- None: If the node type is not recognized or is a tuple
|
||||
"""
|
||||
if isinstance(node, nodes.Name):
|
||||
return node.name
|
||||
return None
|
||||
|
||||
def _get_target_names(self, node) -> list[str]:
|
||||
"""Get all target variable names from a For node, including tuple unpacking.
|
||||
|
||||
Args:
|
||||
node: A Jinja AST node representing either a Name or Tuple node
|
||||
|
||||
Returns:
|
||||
List of target variable names
|
||||
"""
|
||||
if isinstance(node, nodes.Name):
|
||||
return [node.name]
|
||||
|
||||
if isinstance(node, nodes.Tuple):
|
||||
names = []
|
||||
for n in node.items:
|
||||
if isinstance(n, nodes.Name):
|
||||
names.append(n.name)
|
||||
return names
|
||||
|
||||
return []
|
||||
|
||||
def _get_base_name(self, node) -> Optional[str]:
|
||||
"""Get the base variable name from a node."""
|
||||
if isinstance(node, nodes.Name):
|
||||
return node.name
|
||||
|
||||
if isinstance(node, nodes.Getattr):
|
||||
return self._get_base_name(node.node)
|
||||
|
||||
if isinstance(node, nodes.Getitem):
|
||||
return self._get_base_name(node.node)
|
||||
|
||||
return None
|
||||
|
||||
def get_template_variables(self) -> Dict[str, Set[str]]:
|
||||
"""
|
||||
Parse a Jinja template and return both variables and their accessed properties.
|
||||
|
||||
Args:
|
||||
template (str): The Jinja template string
|
||||
|
||||
Returns:
|
||||
Dict[str, Set[str]]: Dictionary mapping variable names to sets of accessed properties
|
||||
"""
|
||||
# Parse the template
|
||||
ast = self.env.parse(self.template)
|
||||
|
||||
# Get all undeclared variables
|
||||
variables = meta.find_undeclared_variables(ast)
|
||||
|
||||
# Reset property access tracking
|
||||
self.property_access = {}
|
||||
|
||||
# Visit all nodes to find property access
|
||||
self._visit_node(ast)
|
||||
|
||||
# Create result dictionary
|
||||
result: Dict[str, Set[str]] = {var: set() for var in variables}
|
||||
# Merge in any discovered sub-properties
|
||||
for var, props in self.property_access.items():
|
||||
if var not in result:
|
||||
result[var] = set()
|
||||
result[var].update(props)
|
||||
|
||||
return result
|
||||
|
||||
def analyze_template(self) -> Dict[str, JinjaTemplateAnalysis]:
|
||||
"""
|
||||
Provide a detailed analysis of template variables and their usage.
|
||||
"""
|
||||
variables = self.get_template_variables()
|
||||
self.iteration_targets = {}
|
||||
|
||||
analysis: Dict[str, JinjaTemplateAnalysis] = {
|
||||
var: JinjaTemplateAnalysis(
|
||||
accessed_properties=props,
|
||||
accessed_indices=set(),
|
||||
is_iterated=False,
|
||||
is_conditional=False,
|
||||
iteration_source=None,
|
||||
iteration_target=None,
|
||||
)
|
||||
for var, props in variables.items()
|
||||
}
|
||||
|
||||
for var, indices in self.index_access.items():
|
||||
if var in analysis:
|
||||
analysis[var]["accessed_indices"] = indices
|
||||
|
||||
def visit_node(node):
|
||||
if isinstance(node, nodes.If):
|
||||
|
||||
def find_test_vars(test_node):
|
||||
if isinstance(test_node, nodes.Name):
|
||||
if test_node.name in analysis:
|
||||
analysis[test_node.name]["is_conditional"] = True
|
||||
for child in test_node.iter_child_nodes():
|
||||
find_test_vars(child)
|
||||
|
||||
find_test_vars(node.test)
|
||||
|
||||
if isinstance(node, nodes.For):
|
||||
iter_target = self._get_base_name(node.iter)
|
||||
target_name = self._get_target_name(node.target)
|
||||
if iter_target in analysis:
|
||||
analysis[iter_target]["is_iterated"] = True
|
||||
if target_name:
|
||||
analysis[iter_target]["iteration_target"] = target_name
|
||||
if isinstance(target_name, str) and target_name not in analysis:
|
||||
analysis[target_name] = {
|
||||
"accessed_properties": set(),
|
||||
"is_iterated": False,
|
||||
"is_conditional": False,
|
||||
"iteration_source": iter_target,
|
||||
"iteration_target": None,
|
||||
}
|
||||
|
||||
for child in node.iter_child_nodes():
|
||||
visit_node(child)
|
||||
|
||||
visit_node(self.ast)
|
||||
return analysis
|
||||
|
||||
def get_downstream_properties(self, start_var: str) -> Dict[str, Set[str]]:
|
||||
"""
|
||||
Get all properties accessed on a variable and its downstream assignments.
|
||||
|
||||
Args:
|
||||
start_var: The starting variable to trace
|
||||
|
||||
Returns:
|
||||
Dict mapping variable names to their accessed properties
|
||||
"""
|
||||
visited = set()
|
||||
properties = {}
|
||||
|
||||
def trace_variable(var_name: str):
|
||||
if var_name in visited:
|
||||
return
|
||||
visited.add(var_name)
|
||||
|
||||
# Get direct properties
|
||||
if var_name in self.property_access:
|
||||
properties[var_name] = self.property_access[var_name]
|
||||
|
||||
# Get properties from iteration targets
|
||||
if var_name in self.iteration_targets:
|
||||
target = self.iteration_targets[var_name]
|
||||
if isinstance(target, str):
|
||||
trace_variable(target)
|
||||
elif isinstance(target, list):
|
||||
for t in target:
|
||||
trace_variable(t)
|
||||
|
||||
# Follow assignments
|
||||
for target, source in self.variable_assignments.items():
|
||||
if source == var_name:
|
||||
trace_variable(target)
|
||||
|
||||
# Check for array slicing
|
||||
analysis = self.analyze_template()
|
||||
if var_name in analysis:
|
||||
var_info = analysis[var_name]
|
||||
if var_info["accessed_indices"]:
|
||||
# If this variable is sliced, follow the resulting assignment
|
||||
slice_result = f"{var_name}_slice"
|
||||
if slice_result in self.property_access:
|
||||
trace_variable(slice_result)
|
||||
|
||||
trace_variable(start_var)
|
||||
return properties
|
||||
|
||||
def get_message_vars(self, field_messages: str = "messages") -> Set[str]:
|
||||
"""
|
||||
Get all properties accessed on messages and derived variables.
|
||||
"""
|
||||
all_properties = self.get_downstream_properties(field_messages)
|
||||
|
||||
# Combine all properties from all related variables
|
||||
combined_properties = set()
|
||||
for properties in all_properties.values():
|
||||
combined_properties.update(properties)
|
||||
|
||||
# Also include properties from the message iteration variable
|
||||
analysis = self.analyze_template()
|
||||
if "message" in analysis:
|
||||
combined_properties.update(analysis["message"]["accessed_properties"])
|
||||
|
||||
return combined_properties
|
||||
@@ -51,13 +51,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ds_cfg = ds_cfg or {}
|
||||
|
||||
field_messages = ds_cfg.get("field_messages")
|
||||
message_property_mappings = ds_cfg.get("message_property_mappings")
|
||||
message_field_role = (
|
||||
message_property_mappings.get("role") if message_property_mappings else None
|
||||
)
|
||||
message_field_content = (
|
||||
message_property_mappings.get("content") if message_property_mappings else None
|
||||
)
|
||||
message_field_role = ds_cfg.get("message_field_role")
|
||||
message_field_content = ds_cfg.get("message_field_content")
|
||||
message_field_training = ds_cfg.get("message_field_training")
|
||||
|
||||
builder_kwargs = {}
|
||||
|
||||
@@ -175,7 +175,6 @@ def train(
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
|
||||
pretrain_hooks(cfg, trainer)
|
||||
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
@@ -186,7 +185,6 @@ def train(
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
post_train_hooks(cfg, trainer)
|
||||
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
@@ -15,7 +15,7 @@ _DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
|
||||
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"
|
||||
|
||||
_CHAT_TEMPLATES = {
|
||||
"alpaca": "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'system' and loop.first %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '### Instruction:\n' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '### Response:\n' + message['content'] + eos_token }}{% endif %}{% if not loop.last %}{{ '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '\n\n### Response:\n' }}{% endif %}",
|
||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||
"mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1...
|
||||
"mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large...
|
||||
"mistral_v3_tekken": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3-Tekken: Nemo, Pixtral...
|
||||
@@ -38,7 +38,7 @@ def get_chat_template(
|
||||
user_choice: str,
|
||||
jinja_template: Optional[str] = None,
|
||||
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||
) -> str:
|
||||
):
|
||||
"""
|
||||
Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.
|
||||
|
||||
@@ -70,7 +70,7 @@ def get_chat_template(
|
||||
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
|
||||
f"Please add a chat_template in tokenizer config"
|
||||
)
|
||||
return tokenizer.chat_template # type: ignore
|
||||
return tokenizer.chat_template
|
||||
|
||||
if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
|
||||
if not tokenizer:
|
||||
@@ -78,7 +78,7 @@ def get_chat_template(
|
||||
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
|
||||
)
|
||||
if tokenizer.chat_template:
|
||||
return tokenizer.chat_template # type: ignore
|
||||
return tokenizer.chat_template
|
||||
|
||||
user_choice = user_choice[
|
||||
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
|
||||
|
||||
@@ -18,7 +18,6 @@ from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||
)
|
||||
from axolotl.utils.config.models.input.v0_4_1 import DPODataset, KTODataset, SFTDataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model_config
|
||||
|
||||
@@ -259,7 +258,7 @@ def validate_config(
|
||||
cfg: DictDefault,
|
||||
capabilities: Optional[dict] = None,
|
||||
env_capabilities: Optional[dict] = None,
|
||||
) -> DictDefault:
|
||||
):
|
||||
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
|
||||
AxolotlInputConfig = AxolotlInputConfigBase
|
||||
|
||||
@@ -269,16 +268,6 @@ def validate_config(
|
||||
AxolotlInputConfig, # pylint: disable=invalid-name
|
||||
) = merge_input_args()
|
||||
|
||||
# Convert datasets to proper format if needed
|
||||
if cfg.get("datasets"):
|
||||
for idx, ds_cfg in enumerate(cfg["datasets"]):
|
||||
if cfg.get("rl") == "dpo" and not isinstance(ds_cfg, DPODataset):
|
||||
cfg["datasets"][idx] = DPODataset(**ds_cfg)
|
||||
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
|
||||
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
|
||||
elif not isinstance(ds_cfg, SFTDataset):
|
||||
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))
|
||||
|
||||
if capabilities or env_capabilities:
|
||||
if (capabilities and env_capabilities is None) or (
|
||||
env_capabilities and capabilities is None
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
"""Module with Pydantic models for configuration."""
|
||||
"""
|
||||
Module for pydantic models for configuration
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
@@ -6,13 +9,12 @@ import os
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from annotated_types import MinLen
|
||||
from packaging import version
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
StringConstraints,
|
||||
field_serializer,
|
||||
conlist,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
@@ -22,8 +24,6 @@ from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
||||
|
||||
from .trl import TRLConfig
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||
|
||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||
@@ -33,7 +33,6 @@ class RLType(str, Enum):
|
||||
"""RL trainer type configuration subset"""
|
||||
|
||||
dpo = "dpo" # pylint: disable=invalid-name
|
||||
grpo = "grpo" # pylint: disable=invalid-name
|
||||
ipo = "ipo" # pylint: disable=invalid-name
|
||||
orpo = "orpo" # pylint: disable=invalid-name
|
||||
kto = "kto" # pylint: disable=invalid-name
|
||||
@@ -167,7 +166,6 @@ class SFTDataset(BaseModel):
|
||||
type: Optional[Union[str, UserDefinedPrompterType]] = None
|
||||
input_transform: Optional[str] = None
|
||||
shards: Optional[int] = None
|
||||
shards_idx: Optional[int] = None
|
||||
preprocess_shards: Optional[int] = None
|
||||
conversation: Optional[str] = None
|
||||
# Do not make this too strict or it will break the validator to choose different dataset class
|
||||
@@ -187,13 +185,8 @@ class SFTDataset(BaseModel):
|
||||
field_human: Optional[str] = None
|
||||
field_model: Optional[str] = None
|
||||
field_messages: Optional[str] = None
|
||||
message_field_role: Optional[
|
||||
str
|
||||
] = None # deprecated, use message_property_mappings
|
||||
message_field_content: Optional[
|
||||
str
|
||||
] = None # deprecated, use message_property_mappings
|
||||
message_property_mappings: Optional[Dict[str, str]] = None
|
||||
message_field_role: Optional[str] = None
|
||||
message_field_content: Optional[str] = None
|
||||
message_field_training: Optional[str] = None
|
||||
message_field_training_detail: Optional[str] = None
|
||||
logprobs_field: Optional[str] = None
|
||||
@@ -205,18 +198,9 @@ class SFTDataset(BaseModel):
|
||||
trust_remote_code: Optional[bool] = False
|
||||
revision: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def handle_legacy_message_fields(cls, data):
|
||||
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
|
||||
return handle_legacy_message_fields_logic(data)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_chat_template_config(cls, data):
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump()
|
||||
|
||||
# Set chat_template to tokenizer_default if not set
|
||||
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
||||
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||
@@ -256,7 +240,6 @@ class DPODataset(BaseModel):
|
||||
type: Optional[Union[UserDefinedDPOType, str]] = None
|
||||
data_files: Optional[List[str]] = None
|
||||
revision: Optional[str] = None
|
||||
field_messages: Optional[str] = None
|
||||
|
||||
|
||||
class StepwiseSupervisedDataset(BaseModel):
|
||||
@@ -293,9 +276,6 @@ class KTODataset(BaseModel):
|
||||
revision: Optional[str] = None
|
||||
|
||||
|
||||
DatasetConfig = Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset]
|
||||
|
||||
|
||||
class LoftQConfig(BaseModel):
|
||||
"""LoftQ configuration subset"""
|
||||
|
||||
@@ -435,8 +415,6 @@ class ReLoRAConfig(BaseModel):
|
||||
class ModelInputConfig(BaseModel):
|
||||
"""model to train on configuration subset"""
|
||||
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
base_model: str
|
||||
base_model_config: Optional[str] = None
|
||||
cls_model_config: Optional[str] = None
|
||||
@@ -503,7 +481,7 @@ class HyperparametersConfig(BaseModel):
|
||||
"adopt_adamw",
|
||||
],
|
||||
]
|
||||
] = OptimizerNames.ADAMW_HF
|
||||
] = OptimizerNames.ADAMW_HF.value
|
||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
||||
@@ -515,9 +493,7 @@ class HyperparametersConfig(BaseModel):
|
||||
},
|
||||
)
|
||||
torchdistx_path: Optional[str] = None
|
||||
lr_scheduler: Optional[
|
||||
Union[SchedulerType, Literal["one_cycle"]]
|
||||
] = SchedulerType.COSINE
|
||||
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
|
||||
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
||||
lr_quadratic_warmup: Optional[bool] = None
|
||||
cosine_min_lr_ratio: Optional[float] = None
|
||||
@@ -641,19 +617,19 @@ class RayConfig(BaseModel):
|
||||
use_ray: bool = Field(default=False)
|
||||
ray_run_name: Optional[str] = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
metadata={
|
||||
"help": "The training results will be saved at `saves/ray_run_name`."
|
||||
},
|
||||
)
|
||||
ray_num_workers: int = Field(
|
||||
default=1,
|
||||
json_schema_extra={
|
||||
metadata={
|
||||
"help": "The number of workers for Ray training. Default is 1 worker."
|
||||
},
|
||||
)
|
||||
resources_per_worker: dict = Field(
|
||||
default_factory=lambda: {"GPU": 1},
|
||||
json_schema_extra={
|
||||
metadata={
|
||||
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
||||
},
|
||||
)
|
||||
@@ -678,49 +654,35 @@ class AxolotlInputConfig(
|
||||
):
|
||||
"""wrapper of all config options"""
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
class Config:
|
||||
"""Config for alias"""
|
||||
|
||||
populate_by_name = True
|
||||
|
||||
strict: Optional[bool] = Field(default=False)
|
||||
resume_from_checkpoint: Optional[str] = None
|
||||
auto_resume_from_checkpoints: Optional[bool] = None
|
||||
resize_token_embeddings_to_32x: Optional[bool] = None
|
||||
mean_resizing_embeddings: Optional[bool] = False
|
||||
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
||||
shrink_embeddings: Optional[bool] = None
|
||||
|
||||
rl: Optional[RLType] = None
|
||||
trl: Optional[TRLConfig] = Field(
|
||||
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
|
||||
)
|
||||
reward_model: Optional[bool] = None
|
||||
process_reward_model: Optional[bool] = None
|
||||
num_labels: Optional[int] = None
|
||||
dpo_use_weighting: Optional[
|
||||
bool
|
||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||
dpo_use_logits_to_keep: Optional[bool] = None
|
||||
|
||||
datasets: Optional[
|
||||
Annotated[
|
||||
list[Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset]],
|
||||
MinLen(1),
|
||||
]
|
||||
] = None
|
||||
|
||||
test_datasets: Optional[
|
||||
Annotated[
|
||||
list[Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset]],
|
||||
MinLen(1),
|
||||
]
|
||||
] = None
|
||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
||||
shuffle_merged_datasets: Optional[bool] = True
|
||||
dataset_prepared_path: Optional[str] = None
|
||||
dataset_shard_num: Optional[int] = None
|
||||
dataset_shard_idx: Optional[int] = None
|
||||
skip_prepare_dataset: Optional[bool] = False
|
||||
|
||||
pretraining_dataset: Optional[
|
||||
Annotated[list[Union[PretrainingDataset, SFTDataset]], MinLen(1)]
|
||||
pretraining_dataset: Optional[ # type: ignore
|
||||
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
|
||||
] = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||
@@ -838,10 +800,6 @@ class AxolotlInputConfig(
|
||||
unsloth_rms_norm: Optional[bool] = None
|
||||
unsloth_rope: Optional[bool] = None
|
||||
|
||||
lora_mlp_kernel: Optional[bool] = None
|
||||
lora_qkv_kernel: Optional[bool] = None
|
||||
lora_o_kernel: Optional[bool] = None
|
||||
|
||||
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
||||
fsdp: Optional[List[str]] = None
|
||||
fsdp_config: Optional[Dict[str, Any]] = None
|
||||
@@ -864,7 +822,7 @@ class AxolotlInputConfig(
|
||||
warmup_steps: Optional[int] = None
|
||||
warmup_ratio: Optional[float] = None
|
||||
eval_steps: Optional[Union[int, float]] = None
|
||||
evals_per_epoch: Optional[int] = None
|
||||
evals_per_epoch: Optional[Union[int]] = None
|
||||
eval_strategy: Optional[str] = None
|
||||
save_steps: Optional[Union[int, float]] = None
|
||||
saves_per_epoch: Optional[int] = None
|
||||
@@ -876,7 +834,6 @@ class AxolotlInputConfig(
|
||||
save_only_model: Optional[bool] = False
|
||||
use_tensorboard: Optional[bool] = None
|
||||
profiler_steps: Optional[int] = None
|
||||
include_tokens_per_second: Optional[bool] = None
|
||||
|
||||
neftune_noise_alpha: Optional[float] = None
|
||||
|
||||
@@ -926,15 +883,10 @@ class AxolotlInputConfig(
|
||||
@classmethod
|
||||
def deprecate_sharegpt_datasets(cls, datasets):
|
||||
for _, ds_cfg in enumerate(datasets):
|
||||
# Handle both dict and pydantic model cases
|
||||
ds_type = (
|
||||
ds_cfg.get("type")
|
||||
if isinstance(ds_cfg, dict)
|
||||
else getattr(ds_cfg, "type", None)
|
||||
)
|
||||
if not ds_type:
|
||||
if not ds_cfg.get("type"):
|
||||
continue
|
||||
|
||||
ds_type = ds_cfg["type"]
|
||||
# skip if it's a dict (for custom user instruction prompt)
|
||||
if isinstance(ds_type, dict):
|
||||
continue
|
||||
@@ -946,14 +898,6 @@ class AxolotlInputConfig(
|
||||
|
||||
return datasets
|
||||
|
||||
@field_serializer("datasets")
|
||||
def datasets_serializer(
|
||||
self, ds_configs: Optional[List[DatasetConfig]]
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
if ds_configs:
|
||||
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
|
||||
return None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_batch_size_fields(cls, data):
|
||||
@@ -1579,42 +1523,12 @@ class AxolotlInputConfig(
|
||||
or data.get("unsloth_lora_qkv")
|
||||
or data.get("unsloth_lora_o")
|
||||
):
|
||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
|
||||
raise ValueError(
|
||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_8bit(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
):
|
||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_axolotl_unsloth(cls, data):
|
||||
is_lora_kernel = any(
|
||||
data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
|
||||
)
|
||||
is_unsloth_lora = any(
|
||||
data.get(k)
|
||||
for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
||||
)
|
||||
if is_lora_kernel and is_unsloth_lora:
|
||||
raise ValueError(
|
||||
"both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_torch_compile_deepspeed(cls, data):
|
||||
@@ -1747,29 +1661,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_multigpu_lora_kernels(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
):
|
||||
capabilities = data.get("capabilities")
|
||||
is_fsdp = data.get("fsdp") is not None
|
||||
is_deepspeed = data.get("deepspeed") is not None
|
||||
|
||||
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
||||
if is_fsdp:
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP."
|
||||
)
|
||||
if is_deepspeed:
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with DeepSpeed."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_adopt_torch_version(cls, data):
|
||||
@@ -1806,77 +1697,3 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
else:
|
||||
data["torch_compile"] = False
|
||||
return data
|
||||
|
||||
|
||||
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
||||
"""
|
||||
Handle backwards compatibility between legacy message field mapping and new property mapping system.
|
||||
|
||||
Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options:
|
||||
- message_field_role: Mapped to the role field
|
||||
- message_field_content: Mapped to the content field
|
||||
|
||||
The new system uses message_property_mappings to support arbitrary field mappings:
|
||||
message_property_mappings:
|
||||
role: source_role_field
|
||||
content: source_content_field
|
||||
additional_field: source_field
|
||||
|
||||
Args:
|
||||
data: Dictionary containing configuration data
|
||||
|
||||
Returns:
|
||||
Updated dictionary with message field mappings consolidated
|
||||
|
||||
Raises:
|
||||
ValueError: If there are conflicts between legacy and new mappings
|
||||
"""
|
||||
data = data.copy() # Create a copy to avoid modifying the original
|
||||
|
||||
if data.get("message_property_mappings") is None:
|
||||
data["message_property_mappings"] = {}
|
||||
|
||||
# Check for conflicts and handle role
|
||||
if "message_field_role" in data:
|
||||
LOG.warning(
|
||||
"message_field_role is deprecated, use message_property_mappings instead. "
|
||||
f"Example: message_property_mappings: {{role: {data['message_field_role']}}}"
|
||||
)
|
||||
if (
|
||||
"role" in data["message_property_mappings"]
|
||||
and data["message_property_mappings"]["role"] != data["message_field_role"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Conflicting message role fields: message_field_role='{data['message_field_role']}' "
|
||||
f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'"
|
||||
)
|
||||
data["message_property_mappings"]["role"] = data["message_field_role"] or "role"
|
||||
|
||||
del data["message_field_role"]
|
||||
elif "role" not in data["message_property_mappings"]:
|
||||
data["message_property_mappings"]["role"] = "role"
|
||||
|
||||
# Check for conflicts and handle content
|
||||
if "message_field_content" in data:
|
||||
LOG.warning(
|
||||
"message_field_content is deprecated, use message_property_mappings instead. "
|
||||
f"Example: message_property_mappings: {{content: {data['message_field_content']}}}"
|
||||
)
|
||||
if (
|
||||
"content" in data["message_property_mappings"]
|
||||
and data["message_property_mappings"]["content"]
|
||||
!= data["message_field_content"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
|
||||
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
|
||||
)
|
||||
data["message_property_mappings"]["content"] = (
|
||||
data["message_field_content"] or "content"
|
||||
)
|
||||
|
||||
del data["message_field_content"]
|
||||
elif "content" not in data["message_property_mappings"]:
|
||||
data["message_property_mappings"]["content"] = "content"
|
||||
|
||||
return data
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
"""
|
||||
GRPO specific configuration args
|
||||
"""
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TRLConfig(BaseModel):
|
||||
"""
|
||||
Input args for TRL.
|
||||
"""
|
||||
|
||||
beta: Optional[float] = None
|
||||
max_completion_length: Optional[int] = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Maximum length of the completion for RL training"
|
||||
},
|
||||
)
|
||||
|
||||
# GRPO specific args
|
||||
use_vllm: Optional[bool] = False
|
||||
vllm_device: Optional[str] = "auto"
|
||||
vllm_gpu_memory_utilization: Optional[float] = 0.9
|
||||
vllm_max_model_len: Optional[int] = None
|
||||
vllm_dtype: Optional[str] = "auto"
|
||||
|
||||
reward_funcs: Optional[List[str]] = None
|
||||
num_generations: Optional[int] = None
|
||||
log_completions: Optional[bool] = False
|
||||
|
||||
sync_ref_model: Optional[bool] = False
|
||||
ref_model_mixup_alpha: Optional[float] = 0.9
|
||||
ref_model_sync_steps: Optional[int] = 64
|
||||
@@ -4,16 +4,15 @@ import inspect
|
||||
import logging
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Union
|
||||
from typing import Any, List
|
||||
|
||||
import yaml
|
||||
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
|
||||
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.prompt_strategies.kto import load as load_kto
|
||||
from axolotl.prompt_strategies.orpo import load as load_orpo
|
||||
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
@@ -58,7 +57,7 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
||||
dataset.save_to_disk(str(prepared_ds_path))
|
||||
|
||||
|
||||
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
||||
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
||||
sig = inspect.signature(ds_transform_fn)
|
||||
if "tokenizer" in sig.parameters:
|
||||
if not tokenizer:
|
||||
@@ -71,7 +70,6 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
||||
data_set = data_set.map(
|
||||
ds_transform_fn,
|
||||
desc="Mapping RL Dataset",
|
||||
**map_kwargs,
|
||||
)
|
||||
|
||||
return data_set
|
||||
@@ -114,21 +112,29 @@ def drop_long_rl_seq(
|
||||
|
||||
return (len_prompt + len_completion) <= sequence_len
|
||||
|
||||
if rl == "grpo":
|
||||
return True
|
||||
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
|
||||
def load_prepare_preference_datasets(cfg):
|
||||
def load_split(dataset_cfgs, _cfg):
|
||||
split_datasets: List[Any] = []
|
||||
use_auth_token = _cfg.hf_use_auth_token
|
||||
for config_dataset in datasets_w_name_generator(dataset_cfgs):
|
||||
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
||||
config_dataset, use_auth_token, streaming=False
|
||||
)
|
||||
split_datasets.append(ds)
|
||||
for i, ds_cfg in enumerate(dataset_cfgs):
|
||||
if ds_cfg["ds_type"] == "json":
|
||||
for data_file in ds_cfg["data_files"]:
|
||||
data_files = {ds_cfg["split"]: data_file}
|
||||
ds = load_dataset( # pylint: disable=invalid-name
|
||||
"json",
|
||||
data_files=data_files,
|
||||
split=ds_cfg["split"],
|
||||
)
|
||||
split_datasets.insert(i, ds)
|
||||
else:
|
||||
ds = load_dataset( # pylint: disable=invalid-name
|
||||
ds_cfg["path"],
|
||||
split=ds_cfg["split"],
|
||||
revision=ds_cfg.get("revision", None),
|
||||
)
|
||||
split_datasets.insert(i, ds)
|
||||
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
@@ -144,45 +150,36 @@ def load_prepare_preference_datasets(cfg):
|
||||
else:
|
||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||
|
||||
map_kwargs = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
split_datasets[i] = map_dataset(
|
||||
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||
cfg, data_set, ds_transform_fn, tokenizer
|
||||
)
|
||||
elif _cfg.rl == "kto":
|
||||
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
||||
map_kwargs = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
split_datasets[i] = map_dataset(
|
||||
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||
cfg, data_set, ds_transform_fn, tokenizer
|
||||
)
|
||||
else:
|
||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||
# "prompt", "chosen" and "rejected" already preprocessed
|
||||
split_datasets[i] = data_set
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
drop_long = partial(
|
||||
drop_long_rl_seq,
|
||||
rl=_cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
drop_long = partial(
|
||||
drop_long_rl_seq,
|
||||
rl=_cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} long samples from dataset index {i}"
|
||||
)
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
|
||||
|
||||
combined_datasets = concatenate_datasets(split_datasets)
|
||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
||||
|
||||
@@ -43,7 +43,7 @@ from axolotl.prompters import (
|
||||
UnsupportedPrompter,
|
||||
)
|
||||
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||
from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_config
|
||||
from axolotl.utils.data.shared import load_dataset_w_config
|
||||
from axolotl.utils.data.utils import (
|
||||
deduplicate_and_log_datasets,
|
||||
drop_long_seq_in_dataset,
|
||||
@@ -180,7 +180,6 @@ def load_tokenized_prepared_datasets(
|
||||
) -> Tuple[DatasetDict, List[Prompter]]:
|
||||
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
||||
tokenizer_name = cfg.tokenizer_config
|
||||
|
||||
ds_hash = str(
|
||||
md5(
|
||||
(
|
||||
@@ -264,11 +263,30 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
datasets = []
|
||||
|
||||
def for_d_in_datasets(dataset_configs):
|
||||
for dataset in dataset_configs:
|
||||
if dataset.name and isinstance(dataset.name, list):
|
||||
# load_dataset doesn't properly handle multiple named configurations
|
||||
# at the same time for a given dataset
|
||||
for name in dataset.name:
|
||||
yield DictDefault({**dataset, "name": name})
|
||||
elif dataset.preprocess_shards and not dataset.shards:
|
||||
for shard in range(dataset.preprocess_shards):
|
||||
yield DictDefault(
|
||||
{
|
||||
**dataset,
|
||||
"shards": dataset.preprocess_shards,
|
||||
"shards_idx": shard,
|
||||
}
|
||||
)
|
||||
else:
|
||||
yield dataset
|
||||
|
||||
streaming_ds = False
|
||||
if preprocess_iterable:
|
||||
streaming_ds = True
|
||||
# pylint: disable=invalid-name
|
||||
for config_dataset in datasets_w_name_generator(cfg_datasets):
|
||||
for config_dataset in for_d_in_datasets(cfg_datasets):
|
||||
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
||||
config_dataset, use_auth_token, streaming=streaming_ds
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
dataset loading shared utils
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -30,43 +29,9 @@ def get_ds_type(config_dataset: DictDefault):
|
||||
return ds_type
|
||||
|
||||
|
||||
def datasets_w_name_generator(dataset_configs: list[DictDefault]):
|
||||
"""
|
||||
Yields dataset configs handling multiple names or preprocess_shards
|
||||
|
||||
Args:
|
||||
dataset_configs: list of dataset configs (equivalent to cfg.datasets)
|
||||
"""
|
||||
for dataset in dataset_configs:
|
||||
if dataset.name and isinstance(dataset.name, list):
|
||||
# load_dataset doesn't properly handle multiple named configurations
|
||||
# at the same time for a given dataset
|
||||
for name in dataset.name:
|
||||
yield DictDefault({**dataset, "name": name})
|
||||
elif dataset.preprocess_shards and not dataset.shards:
|
||||
for shard in range(dataset.preprocess_shards):
|
||||
yield DictDefault(
|
||||
{
|
||||
**dataset,
|
||||
"shards": dataset.preprocess_shards,
|
||||
"shards_idx": shard,
|
||||
}
|
||||
)
|
||||
else:
|
||||
yield dataset
|
||||
|
||||
|
||||
def load_dataset_w_config(
|
||||
config_dataset: DictDefault, use_auth_token: bool, streaming=False
|
||||
config_dataset, auth_token, streaming=False
|
||||
) -> Union[Dataset, DatasetDict]:
|
||||
"""
|
||||
Load a dataset from a config
|
||||
|
||||
Args:
|
||||
config_dataset: single dataset config
|
||||
use_auth_token: whether to use HF auth token
|
||||
streaming: whether to stream the dataset
|
||||
"""
|
||||
# pylint: disable=invalid-name
|
||||
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
||||
ds_from_hub = False
|
||||
@@ -78,7 +43,7 @@ def load_dataset_w_config(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=True,
|
||||
token=use_auth_token,
|
||||
token=auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=ds_trust_remote_code,
|
||||
)
|
||||
@@ -196,7 +161,7 @@ def load_dataset_w_config(
|
||||
name=config_dataset.name,
|
||||
streaming=streaming,
|
||||
data_files=config_dataset.data_files,
|
||||
token=use_auth_token,
|
||||
token=auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
|
||||
@@ -13,26 +13,3 @@ class DictDefault(Dict):
|
||||
|
||||
def __or__(self, other):
|
||||
return DictDefault(super().__ror__(other))
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
# workaround for pickle/unpickle issues and __frozen not being available
|
||||
try:
|
||||
isFrozen = hasattr( # pylint: disable=invalid-name
|
||||
self, "__frozen"
|
||||
) and object.__getattribute__(self, "__frozen")
|
||||
except AttributeError:
|
||||
isFrozen = False # pylint: disable=invalid-name
|
||||
|
||||
if isFrozen and name not in super().keys():
|
||||
raise KeyError(name)
|
||||
super(Dict, self).__setitem__(name, value) # pylint: disable=bad-super-call
|
||||
try:
|
||||
p = object.__getattribute__(self, "__parent")
|
||||
key = object.__getattribute__(self, "__key")
|
||||
except AttributeError:
|
||||
p = None
|
||||
key = None
|
||||
if p is not None:
|
||||
p[key] = self
|
||||
object.__delattr__(self, "__parent")
|
||||
object.__delattr__(self, "__key")
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
# Copyright 2025 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
module to get the state dict of a merged lora model
|
||||
"""
|
||||
import torch
|
||||
from peft.tuners.tuners_utils import onload_layer
|
||||
from peft.utils import ModulesToSaveWrapper, _get_submodules
|
||||
|
||||
|
||||
def get_lora_merged_state_dict(
|
||||
model: torch.nn.Module,
|
||||
) -> dict:
|
||||
r"""
|
||||
Create and return a state_dict that has the LoRA deltas
|
||||
merged into the base model’s weights, without modifying `model` in place.
|
||||
|
||||
Arguments:
|
||||
model (torch.nn.Module): A model that has LoRA/PEFT adapters attached.
|
||||
|
||||
Returns:
|
||||
dict: A state_dict of the merged parameters.
|
||||
"""
|
||||
|
||||
base_model_prefix = "base_model.model."
|
||||
state_dict = {}
|
||||
key_list = [key for key, _ in model.named_modules() if model.prefix not in key]
|
||||
for key in key_list:
|
||||
try:
|
||||
_, target, _ = _get_submodules(model, key)
|
||||
except AttributeError:
|
||||
continue
|
||||
with onload_layer(target):
|
||||
weight_key = key.replace(base_model_prefix, "") + ".weight"
|
||||
bias_key = key.replace(base_model_prefix, "") + ".bias"
|
||||
if hasattr(target, "base_layer"):
|
||||
target.merge(safe_merge=True, adapter_names=None)
|
||||
# get the state_dict of target.base_layer
|
||||
layer_state_dict = target.base_layer.state_dict()
|
||||
state_dict[weight_key] = layer_state_dict["weight"]
|
||||
elif isinstance(target, ModulesToSaveWrapper):
|
||||
# save any additional trainable modules part of `modules_to_save`
|
||||
new_module = target.modules_to_save[target.active_adapter]
|
||||
if hasattr(new_module, "base_layer"):
|
||||
# check if the module is itself a tuner layer
|
||||
new_module.merge(safe_merge=True, adapter_names=None)
|
||||
layer_state_dict = new_module.state_dict()
|
||||
state_dict[weight_key] = layer_state_dict["weight"]
|
||||
elif hasattr(target, "weight"):
|
||||
if any(
|
||||
skip in key
|
||||
for skip in [
|
||||
".original_module",
|
||||
".modules_to_save",
|
||||
".base_layer",
|
||||
]
|
||||
):
|
||||
continue
|
||||
layer_state_dict = target.state_dict()
|
||||
state_dict[weight_key] = layer_state_dict["weight"]
|
||||
if hasattr(target, "bias") and "bias" in layer_state_dict.keys():
|
||||
state_dict[bias_key] = layer_state_dict["bias"]
|
||||
return state_dict
|
||||
@@ -414,7 +414,6 @@ class ModelLoader:
|
||||
has_remote_code = "AutoModelForCausalLM" in auto_map_config
|
||||
else:
|
||||
has_remote_code = False
|
||||
|
||||
if has_remote_code and self.cfg.trust_remote_code is False:
|
||||
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
|
||||
has_remote_code = self.cfg.trust_remote_code
|
||||
@@ -426,6 +425,10 @@ class ModelLoader:
|
||||
|
||||
if self.cfg.is_llama_derived_model:
|
||||
self.patch_loss_llama()
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora()
|
||||
elif self.cfg.is_llama_derived_model:
|
||||
self.patch_llama_derived_model()
|
||||
|
||||
@@ -469,7 +472,9 @@ class ModelLoader:
|
||||
return importlib.util.find_spec("flash_attn") is not None
|
||||
|
||||
def patch_loss_llama(self) -> None:
|
||||
"""Patch loss functions and other optimizations"""
|
||||
"""
|
||||
Patch loss functions
|
||||
"""
|
||||
if self.has_flash_attn:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
patch_fa_llama_cross_entropy,
|
||||
@@ -489,14 +494,15 @@ class ModelLoader:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
|
||||
|
||||
patch_unsloth_layernorm()
|
||||
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora()
|
||||
|
||||
def patch_llama_derived_model(self) -> None:
|
||||
"""Modify all llama derived models in one block"""
|
||||
"""
|
||||
Modify all llama derived models in one block
|
||||
"""
|
||||
self.patch_loss_llama()
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
@@ -1007,8 +1013,7 @@ class ModelLoader:
|
||||
if hasattr(module, "weight"):
|
||||
module.to(dist_dtype)
|
||||
|
||||
# TODO: Deprecate this.
|
||||
def apply_unsloth_lora_patch(self) -> None:
|
||||
def apply_lora_patch(self) -> None:
|
||||
if self.cfg.unsloth_lora_mlp:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
|
||||
|
||||
@@ -1022,22 +1027,6 @@ class ModelLoader:
|
||||
|
||||
integrate_rope_embeddings()
|
||||
|
||||
def apply_lora_patch(self) -> None:
|
||||
"""Applies patching relevant to LoRA Triton kernels if enabled."""
|
||||
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
|
||||
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora(self.model)
|
||||
|
||||
if (
|
||||
self.cfg.lora_mlp_kernel
|
||||
or self.cfg.lora_qkv_kernel
|
||||
or self.cfg.lora_o_kernel
|
||||
):
|
||||
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
|
||||
|
||||
apply_lora_kernel_patches(self.model, self.cfg)
|
||||
|
||||
def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||
self.apply_patches()
|
||||
self.set_auto_model_loader()
|
||||
@@ -1064,12 +1053,9 @@ class ModelLoader:
|
||||
if self.cfg.resize_token_embeddings_to_32x
|
||||
else len(self.tokenizer)
|
||||
)
|
||||
if hasattr(self.model, "get_input_embeddings") and (
|
||||
self.model.get_input_embeddings().num_embeddings < embeddings_len
|
||||
or (
|
||||
self.model.get_input_embeddings().num_embeddings > embeddings_len
|
||||
and self.cfg.shrink_embeddings
|
||||
)
|
||||
if (
|
||||
hasattr(self.model, "get_input_embeddings")
|
||||
and self.model.get_input_embeddings().num_embeddings != embeddings_len
|
||||
):
|
||||
resize_kwargs = {}
|
||||
if self.cfg.mean_resizing_embeddings is not None:
|
||||
@@ -1182,8 +1168,6 @@ class ModelLoader:
|
||||
if self.cfg.adapter is not None:
|
||||
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
||||
|
||||
# TODO: Deprecate this.
|
||||
self.apply_unsloth_lora_patch()
|
||||
self.apply_lora_patch()
|
||||
|
||||
for _ in range(3):
|
||||
@@ -1203,7 +1187,9 @@ def load_model(
|
||||
reference_model: bool = False,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||
"""Load a model for a given configuration and tokenizer."""
|
||||
"""
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
loader = ModelLoader(
|
||||
cfg,
|
||||
tokenizer,
|
||||
@@ -1323,7 +1309,6 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||
if cfg.peft_use_dora:
|
||||
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
||||
LOG.info("Initializing LoRA weights using dora. This might take longer.")
|
||||
if cfg.peft_use_rslora:
|
||||
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
|
||||
if cfg.peft_layer_replication:
|
||||
|
||||
@@ -396,8 +396,8 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
):
|
||||
total_num_tokens = np.sum(
|
||||
train_dataset.select_columns("input_ids")
|
||||
.to_pandas()["input_ids"]
|
||||
.apply(len)
|
||||
.to_pandas()
|
||||
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
||||
.values
|
||||
)
|
||||
LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True)
|
||||
@@ -576,7 +576,7 @@ def prepare_opinionated_env(cfg):
|
||||
def setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||
):
|
||||
if cfg.rl:
|
||||
if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"):
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||
trainer_builder.model_ref = model[1]
|
||||
trainer_builder.peft_config = model[2]
|
||||
|
||||
@@ -9,7 +9,7 @@ from e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -79,7 +79,6 @@ class TestKnowledgeDistillation:
|
||||
def test_llama_kd(self, temp_dir, kd_min_cfg):
|
||||
cfg = DictDefault(kd_min_cfg)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
@@ -110,7 +109,6 @@ class TestKnowledgeDistillation:
|
||||
| kd_min_cfg
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
"""Tests for GEGLU activation function Triton kernels."""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.kernels.geglu import geglu_backward, geglu_forward
|
||||
|
||||
|
||||
def test_geglu_forward_shape():
|
||||
"""Test that GEGLU forward pass preserves expected shapes."""
|
||||
batch, seq_len, hidden_dim = 2, 3, 64
|
||||
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
||||
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
||||
|
||||
out = geglu_forward(gate, up)
|
||||
assert out.shape == (batch, seq_len, hidden_dim)
|
||||
assert out.dtype == gate.dtype
|
||||
assert out.device == gate.device
|
||||
|
||||
|
||||
def test_geglu_forward_values():
|
||||
"""Test GEGLU forward pass matches PyTorch reference implementation."""
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
# Custom implementation
|
||||
triton_out = geglu_forward(gate.clone(), up.clone())
|
||||
|
||||
# PyTorch reference
|
||||
torch_out = F.gelu(gate) * up
|
||||
|
||||
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
|
||||
|
||||
|
||||
def test_geglu_backward():
|
||||
"""Test GEGLU backward pass matches PyTorch autograd."""
|
||||
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||
grad_output = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
# PyTorch reference - compute intermediates
|
||||
gelu_gate = F.gelu(gate)
|
||||
torch_out = gelu_gate * up
|
||||
torch_out.backward(grad_output)
|
||||
|
||||
# Custom backward pass
|
||||
gate_clone = gate.clone().detach()
|
||||
up_clone = up.clone().detach()
|
||||
grad_output_clone = grad_output.clone()
|
||||
|
||||
h, grad_gate, grad_up = geglu_backward(grad_output_clone, gate_clone, up_clone)
|
||||
|
||||
# Compare outputs and gradients
|
||||
assert torch.allclose(h, torch_out, rtol=1e-3)
|
||||
assert torch.allclose(grad_gate, gate.grad, rtol=1e-3)
|
||||
assert torch.allclose(grad_up, up.grad, rtol=1e-3)
|
||||
|
||||
|
||||
def test_geglu_inplace_preservation():
|
||||
"""Test that GEGLU backward doesn't modify original tensors unexpectedly."""
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
grad_output = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
gate_copy = gate.clone()
|
||||
up_copy = up.clone()
|
||||
grad_copy = grad_output.clone()
|
||||
|
||||
geglu_backward(grad_output, gate, up)
|
||||
|
||||
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
|
||||
assert not torch.equal(up, up_copy), "Up should be modified in-place"
|
||||
assert not torch.equal(
|
||||
grad_output, grad_copy
|
||||
), "Grad output should be modified in-place"
|
||||
@@ -1,531 +0,0 @@
|
||||
"""Tests for LoRA custom autograd."""
|
||||
# pylint: disable=invalid-name,redefined-outer-name
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
from torch import nn
|
||||
|
||||
from axolotl.kernels.geglu import geglu_backward, geglu_forward
|
||||
from axolotl.kernels.lora import (
|
||||
LoRA_MLP,
|
||||
LoRA_O,
|
||||
LoRA_QKV,
|
||||
apply_lora_mlp_geglu,
|
||||
apply_lora_mlp_swiglu,
|
||||
get_lora_parameters,
|
||||
matmul_lora,
|
||||
)
|
||||
from axolotl.kernels.swiglu import swiglu_backward, swiglu_forward
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_quantstate():
|
||||
"""Creates a mock QuantState for testing"""
|
||||
shape = (64, 64)
|
||||
n_blocks = shape[0] # Assuming blockwise quantization along first dimension
|
||||
|
||||
# Create nested state first
|
||||
nested_state = QuantState(
|
||||
absmax=torch.ones(n_blocks, device="cuda"), # One value per block
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape, device="cuda"), # NF4 range is 0-15
|
||||
dtype=torch.float16,
|
||||
blocksize=64,
|
||||
quant_type="nf4",
|
||||
offset=None,
|
||||
state2=None,
|
||||
)
|
||||
|
||||
# Create main state with nested state
|
||||
return QuantState(
|
||||
absmax=torch.ones(n_blocks, device="cuda"),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape, device="cuda"),
|
||||
dtype=torch.float16,
|
||||
blocksize=64,
|
||||
quant_type="nf4",
|
||||
offset=torch.zeros(n_blocks, dtype=torch.int32, device="cuda"),
|
||||
state2=nested_state,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tensors():
|
||||
"""Creates sample tensors for testing"""
|
||||
torch.manual_seed(42)
|
||||
batch_size, seq_len, hidden_dim = 2, 3, 64
|
||||
rank = 8
|
||||
out_dim = hidden_dim
|
||||
|
||||
return {
|
||||
"X": torch.randn(
|
||||
batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16
|
||||
),
|
||||
"W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16),
|
||||
"scale": 0.5,
|
||||
"shapes": {
|
||||
"batch": batch_size,
|
||||
"seq": seq_len,
|
||||
"hidden": hidden_dim,
|
||||
"out": out_dim,
|
||||
"rank": rank,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_proj():
|
||||
"""Creates a mock projection module for testing."""
|
||||
|
||||
class MockProj(nn.Module):
|
||||
"""Mock projection class."""
|
||||
|
||||
def __init__(self, in_features=64, out_features=128, rank=8):
|
||||
super().__init__()
|
||||
self.base_layer = nn.Linear(in_features, out_features)
|
||||
self.base_layer.to("cuda")
|
||||
self.lora_A = nn.ModuleDict(
|
||||
{"default": nn.Linear(in_features, rank, bias=False).to("cuda")}
|
||||
)
|
||||
self.lora_B = nn.ModuleDict(
|
||||
{"default": nn.Linear(rank, out_features, bias=False).to("cuda")}
|
||||
)
|
||||
self.scaling = {"default": 0.5}
|
||||
self.active_adapter = "default"
|
||||
self.disable_adapters = False
|
||||
self.merged = False
|
||||
|
||||
return MockProj()
|
||||
|
||||
|
||||
def test_get_lora_parameters(mock_proj):
|
||||
"""Tests get_lora_parameters function"""
|
||||
# Test with LoRA enabled
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
|
||||
assert isinstance(W, torch.Tensor)
|
||||
assert W.shape == (128, 64)
|
||||
assert A.shape == (8, 64)
|
||||
assert B.shape == (128, 8)
|
||||
assert s == 0.5
|
||||
|
||||
# Test with LoRA disabled
|
||||
mock_proj.disable_adapters = True
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
# Test with merged state
|
||||
mock_proj.disable_adapters = False
|
||||
mock_proj.merged = True
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
|
||||
def test_matmul_lora(sample_tensors):
|
||||
"""Tests matmul_lora function"""
|
||||
X = sample_tensors["X"]
|
||||
W = sample_tensors["W"]
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test base matmul
|
||||
out1 = matmul_lora(X, W, None, None, None, None)
|
||||
expected1 = torch.matmul(X, W.t())
|
||||
assert torch.allclose(out1, expected1, rtol=1e-3)
|
||||
|
||||
# Test with LoRA
|
||||
out2 = matmul_lora(X, W, None, A, B, scale)
|
||||
lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())
|
||||
expected2 = expected1 + lora_term
|
||||
assert torch.allclose(out2, expected2, rtol=1e-3)
|
||||
|
||||
# Test 3D input reshaping
|
||||
X_3d = X.clone()
|
||||
out3 = matmul_lora(X_3d, W, None, A, B, scale)
|
||||
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"activation_forward,activation_backward",
|
||||
[(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)],
|
||||
)
|
||||
def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward):
|
||||
"""Tests LoRA_MLP directly with different activation functions"""
|
||||
X = sample_tensors["X"]
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
|
||||
# Create linear layers
|
||||
gate_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
|
||||
up_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
|
||||
down_proj = nn.Linear(out_dim, hidden_dim).to(device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test SwiGLU path
|
||||
X.requires_grad = True
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
gate_proj.weight,
|
||||
None, # gate_quant
|
||||
None, # gate_A
|
||||
None, # gate_B
|
||||
None, # gate_scale
|
||||
up_proj.weight,
|
||||
None, # up_quant
|
||||
None, # up_A
|
||||
None, # up_B
|
||||
None, # up_scale
|
||||
down_proj.weight,
|
||||
None, # down_quant
|
||||
None, # down_A
|
||||
None, # down_B
|
||||
None, # down_scale
|
||||
activation_forward,
|
||||
activation_backward,
|
||||
True, # inplace
|
||||
)
|
||||
|
||||
assert output.shape == X.shape
|
||||
assert not torch.isnan(output).any()
|
||||
|
||||
# Test backward pass
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
assert X.grad is not None
|
||||
assert not torch.isnan(X.grad).any()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"activation_forward,activation_backward",
|
||||
[(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)],
|
||||
)
|
||||
def test_lora_mlp_with_adapters(
|
||||
sample_tensors, activation_forward, activation_backward
|
||||
):
|
||||
"""Tests LoRA_MLP with LoRA adapters"""
|
||||
X = sample_tensors["X"]
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
# Create LoRA components
|
||||
gate_A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
gate_B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
up_A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
up_B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
down_A = torch.randn(rank, out_dim, device="cuda", dtype=torch.float16)
|
||||
down_B = torch.randn(hidden_dim, rank, device="cuda", dtype=torch.float16)
|
||||
scale = 0.5
|
||||
|
||||
gate_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
|
||||
up_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16)
|
||||
down_proj = nn.Linear(out_dim, hidden_dim).to(device="cuda", dtype=torch.float16)
|
||||
|
||||
X.requires_grad = True
|
||||
gate_A.requires_grad = True
|
||||
gate_B.requires_grad = True
|
||||
up_A.requires_grad = True
|
||||
up_B.requires_grad = True
|
||||
down_A.requires_grad = True
|
||||
down_B.requires_grad = True
|
||||
|
||||
# Forward pass with adapters
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
gate_proj.weight,
|
||||
None,
|
||||
gate_A,
|
||||
gate_B,
|
||||
scale,
|
||||
up_proj.weight,
|
||||
None,
|
||||
up_A,
|
||||
up_B,
|
||||
scale,
|
||||
down_proj.weight,
|
||||
None,
|
||||
down_A,
|
||||
down_B,
|
||||
scale,
|
||||
activation_forward,
|
||||
activation_backward,
|
||||
True,
|
||||
)
|
||||
|
||||
assert output.shape == X.shape
|
||||
assert not torch.isnan(output).any()
|
||||
|
||||
# Test backward pass
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
# Check all gradients
|
||||
assert X.grad is not None
|
||||
assert gate_A.grad is not None
|
||||
assert gate_B.grad is not None
|
||||
assert up_A.grad is not None
|
||||
assert up_B.grad is not None
|
||||
assert down_A.grad is not None
|
||||
assert down_B.grad is not None
|
||||
|
||||
assert not torch.isnan(X.grad).any()
|
||||
assert not torch.isnan(gate_A.grad).any()
|
||||
assert not torch.isnan(gate_B.grad).any()
|
||||
assert not torch.isnan(up_A.grad).any()
|
||||
assert not torch.isnan(up_B.grad).any()
|
||||
assert not torch.isnan(down_A.grad).any()
|
||||
assert not torch.isnan(down_B.grad).any()
|
||||
|
||||
|
||||
def test_lora_qkv(sample_tensors):
|
||||
"""Tests LoRA QKV implementation with and without adapters"""
|
||||
X = sample_tensors["X"]
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
# Create base weights
|
||||
q_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
k_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
v_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Create LoRA matrices
|
||||
q_A = torch.randn(
|
||||
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
q_B = torch.randn(
|
||||
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
k_A = torch.randn(
|
||||
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
k_B = torch.randn(
|
||||
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
v_A = torch.randn(
|
||||
rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
v_B = torch.randn(
|
||||
hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
scale = 0.5
|
||||
|
||||
X.requires_grad = True
|
||||
|
||||
# Test without LoRA adapters
|
||||
Q1, K1, V1 = LoRA_QKV.apply(
|
||||
X,
|
||||
q_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
|
||||
assert Q1.shape == K1.shape == V1.shape == X.shape
|
||||
loss1 = (Q1 + K1 + V1).sum()
|
||||
loss1.backward()
|
||||
assert X.grad is not None
|
||||
|
||||
# Clear gradients
|
||||
X.grad = None
|
||||
|
||||
# Test with LoRA adapters
|
||||
Q2, K2, V2 = LoRA_QKV.apply(
|
||||
X,
|
||||
q_weight,
|
||||
None,
|
||||
q_A,
|
||||
q_B,
|
||||
scale,
|
||||
k_weight,
|
||||
None,
|
||||
k_A,
|
||||
k_B,
|
||||
scale,
|
||||
v_weight,
|
||||
None,
|
||||
v_A,
|
||||
v_B,
|
||||
scale,
|
||||
True,
|
||||
)
|
||||
|
||||
assert Q2.shape == K2.shape == V2.shape == X.shape
|
||||
loss2 = (Q2 + K2 + V2).sum()
|
||||
loss2.backward()
|
||||
|
||||
# Check gradients
|
||||
assert X.grad is not None
|
||||
assert q_A.grad is not None
|
||||
assert q_B.grad is not None
|
||||
assert k_A.grad is not None
|
||||
assert k_B.grad is not None
|
||||
assert v_A.grad is not None
|
||||
assert v_B.grad is not None
|
||||
|
||||
# Check for NaN values
|
||||
assert not torch.isnan(X.grad).any()
|
||||
assert not torch.isnan(q_A.grad).any()
|
||||
assert not torch.isnan(q_B.grad).any()
|
||||
assert not torch.isnan(k_A.grad).any()
|
||||
assert not torch.isnan(k_B.grad).any()
|
||||
assert not torch.isnan(v_A.grad).any()
|
||||
assert not torch.isnan(v_B.grad).any()
|
||||
|
||||
|
||||
def test_lora_o(sample_tensors):
|
||||
"""Tests LoRA output projection"""
|
||||
X = sample_tensors["X"]
|
||||
W = sample_tensors["W"]
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test forward pass
|
||||
X.requires_grad = True
|
||||
output = LoRA_O.apply(X, W, None, A, B, scale)
|
||||
|
||||
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
|
||||
# Test backward pass
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
assert X.grad is not None
|
||||
|
||||
|
||||
def test_with_quantization(sample_tensors, mock_quantstate):
|
||||
"""Tests LoRA with quantized weights"""
|
||||
X = sample_tensors["X"] # [batch, seq, hidden]
|
||||
W = sample_tensors["W"] # [out, hidden]
|
||||
scale = 0.5
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test matmul with quantization
|
||||
out = matmul_lora(X, W, mock_quantstate, A, B, scale)
|
||||
assert out.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
assert not torch.isnan(out).any()
|
||||
|
||||
# Test with different batch sizes
|
||||
X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale)
|
||||
assert out2.shape == (4, 6, W.shape[0])
|
||||
assert not torch.isnan(out2).any()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch,seq,hidden,rank,out",
|
||||
[
|
||||
(1, 1, 32, 4, 64),
|
||||
(2, 3, 64, 8, 128),
|
||||
(4, 5, 128, 16, 256),
|
||||
],
|
||||
)
|
||||
def test_shapes_and_dimensions(batch, seq, hidden, rank, out):
|
||||
"""Tests various input shapes and dimensions"""
|
||||
X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16)
|
||||
W = torch.randn(out, hidden, device="cuda", dtype=torch.float16)
|
||||
A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out, rank, device="cuda", dtype=torch.float16)
|
||||
scale = 0.5
|
||||
|
||||
result = matmul_lora(X, W, None, A, B, scale)
|
||||
assert result.shape == (batch, seq, out)
|
||||
|
||||
|
||||
def test_gradient_flow(sample_tensors):
|
||||
"""Tests gradient flow through LoRA layers"""
|
||||
X = sample_tensors["X"].clone()
|
||||
W = sample_tensors["W"].clone()
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
hidden_dim = shapes["hidden"]
|
||||
out_dim = shapes["out"]
|
||||
rank = shapes["rank"]
|
||||
|
||||
A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
X.requires_grad = True
|
||||
A.requires_grad = True
|
||||
B.requires_grad = True
|
||||
|
||||
# Forward pass
|
||||
out = matmul_lora(X, W, None, A, B, scale)
|
||||
loss = out.sum()
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
assert X.grad is not None
|
||||
assert A.grad is not None
|
||||
assert B.grad is not None
|
||||
assert not torch.isnan(X.grad).any()
|
||||
assert not torch.isnan(A.grad).any()
|
||||
assert not torch.isnan(B.grad).any()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"apply_function",
|
||||
[apply_lora_mlp_swiglu, apply_lora_mlp_geglu],
|
||||
)
|
||||
def test_inplace_operations(sample_tensors, apply_function):
|
||||
"""Tests inplace operation behavior"""
|
||||
X = sample_tensors["X"]
|
||||
shapes = sample_tensors["shapes"]
|
||||
|
||||
# Create MLP with both inplace=True and inplace=False
|
||||
mlp = type(
|
||||
"MLPModule",
|
||||
(),
|
||||
{
|
||||
"gate_proj": nn.Linear(shapes["hidden"], shapes["out"]).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
),
|
||||
"up_proj": nn.Linear(shapes["hidden"], shapes["out"]).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
),
|
||||
"down_proj": nn.Linear(shapes["out"], shapes["hidden"]).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
out1 = apply_function(mlp, X.clone(), inplace=True)
|
||||
out2 = apply_function(mlp, X.clone(), inplace=False)
|
||||
|
||||
assert torch.allclose(out1, out2, rtol=1e-3)
|
||||
@@ -1,103 +0,0 @@
|
||||
"""Tests for quantization utility functions."""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
from axolotl.kernels.quantize import dequantize
|
||||
|
||||
|
||||
def test_dequantize_null_state():
|
||||
"""Test that dequantize returns input unchanged when quant_state is None"""
|
||||
W = torch.randn(32, 32)
|
||||
assert torch.equal(dequantize(W, None), W)
|
||||
|
||||
|
||||
def test_dequantize_shape_preservation():
|
||||
"""Test that dequantization preserves expected shapes"""
|
||||
shape = (32, 32)
|
||||
W = torch.randn(shape, device="cuda")
|
||||
|
||||
quant_state = QuantState(
|
||||
absmax=torch.ones(shape[0], device="cuda"),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape, device="cuda"),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=torch.zeros(shape[0], dtype=torch.int32, device="cuda"),
|
||||
state2=QuantState(
|
||||
absmax=torch.ones(shape[0], device="cuda"),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape, device="cuda"),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=None,
|
||||
state2=None,
|
||||
),
|
||||
)
|
||||
|
||||
result = dequantize(W, quant_state)
|
||||
assert result.shape == shape
|
||||
assert result.dtype == torch.float16
|
||||
assert result.device == W.device
|
||||
|
||||
|
||||
def test_dequantize_transposed():
|
||||
"""Test that transposed input produces transposed output"""
|
||||
shape = (32, 32)
|
||||
W = torch.randn(1, shape[1], device="cuda") # Transposed input
|
||||
|
||||
quant_state = QuantState(
|
||||
absmax=torch.ones(1),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=torch.zeros(1, dtype=torch.int32),
|
||||
state2=QuantState(
|
||||
absmax=torch.ones(1),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=None,
|
||||
state2=None,
|
||||
),
|
||||
)
|
||||
|
||||
result = dequantize(W, quant_state)
|
||||
assert result.shape[0] == shape[0]
|
||||
|
||||
|
||||
def test_dequantize_output_tensor():
|
||||
"""Test dequantization with provided output tensor"""
|
||||
shape = (32, 32)
|
||||
W = torch.randn(shape, device="cuda")
|
||||
out = torch.empty(shape, dtype=torch.float16, device="cuda")
|
||||
|
||||
quant_state = QuantState(
|
||||
absmax=torch.ones(shape[0]),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=torch.zeros(shape[0], dtype=torch.int32),
|
||||
state2=QuantState(
|
||||
absmax=torch.ones(shape[0]),
|
||||
shape=shape,
|
||||
code=torch.randint(0, 15, shape),
|
||||
dtype=torch.float16,
|
||||
blocksize=32,
|
||||
quant_type="nf4",
|
||||
offset=None,
|
||||
state2=None,
|
||||
),
|
||||
)
|
||||
|
||||
result = dequantize(W, quant_state, out=out)
|
||||
assert result is out
|
||||
@@ -1,78 +0,0 @@
|
||||
"""Tests for SwiGLU activation function Triton kernels."""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.kernels.swiglu import swiglu_backward, swiglu_forward
|
||||
|
||||
|
||||
def test_swiglu_forward_shape():
|
||||
"""Test that SwiGLU forward pass preserves expected shapes"""
|
||||
batch, seq_len, hidden_dim = 2, 3, 64
|
||||
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
||||
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
|
||||
|
||||
out = swiglu_forward(gate, up)
|
||||
assert out.shape == (batch, seq_len, hidden_dim)
|
||||
assert out.dtype == gate.dtype
|
||||
assert out.device == gate.device
|
||||
|
||||
|
||||
def test_swiglu_forward_values():
|
||||
"""Test SwiGLU forward pass matches PyTorch reference implementation"""
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
# Custom implementation
|
||||
triton_out = swiglu_forward(gate.clone(), up.clone())
|
||||
|
||||
# PyTorch reference
|
||||
torch_out = F.silu(gate) * up
|
||||
|
||||
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
|
||||
|
||||
|
||||
def test_swiglu_backward():
|
||||
"""Test SwiGLU backward pass matches PyTorch autograd"""
|
||||
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||
grad_output = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
# PyTorch reference - compute intermediates
|
||||
silu_gate = F.silu(gate)
|
||||
torch_out = silu_gate * up
|
||||
torch_out.backward(grad_output)
|
||||
|
||||
# Custom backward pass
|
||||
gate_clone = gate.clone().detach()
|
||||
up_clone = up.clone().detach()
|
||||
grad_output_clone = grad_output.clone()
|
||||
|
||||
h, our_grad_gate, our_grad_up = swiglu_backward(
|
||||
grad_output_clone, gate_clone, up_clone
|
||||
)
|
||||
|
||||
# Compare outputs and gradients
|
||||
assert torch.allclose(h, torch_out, rtol=1e-3)
|
||||
assert torch.allclose(our_grad_gate, gate.grad, rtol=1e-3)
|
||||
assert torch.allclose(our_grad_up, up.grad, rtol=1e-3)
|
||||
|
||||
|
||||
def test_swiglu_inplace_preservation():
|
||||
"""Test that SwiGLU backward doesn't modify original tensors unexpectedly"""
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
grad_output = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
gate_copy = gate.clone()
|
||||
up_copy = up.clone()
|
||||
grad_copy = grad_output.clone()
|
||||
|
||||
swiglu_backward(grad_output, gate, up)
|
||||
|
||||
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
|
||||
assert not torch.equal(up, up_copy), "Up should be modified in-place"
|
||||
assert not torch.equal(
|
||||
grad_output, grad_copy
|
||||
), "Grad output should be modified in-place"
|
||||
@@ -1,173 +0,0 @@
|
||||
"""
|
||||
GRPO test suite
|
||||
"""
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from e2e.utils import require_vllm
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestGRPO:
|
||||
"""
|
||||
Test case for GRPO training using multilpe GPUs
|
||||
"""
|
||||
|
||||
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
|
||||
fout.write(
|
||||
"""import random
|
||||
def rand_reward_func(completions, **kwargs) -> list[float]:
|
||||
return [random.uniform(0, 1) for _ in completions]
|
||||
|
||||
def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
def transform_fn(example, tokenizer=None):
|
||||
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
||||
return {
|
||||
"prompt": [{"role": "user", "content": example["question"]},],
|
||||
"answer": label,
|
||||
}
|
||||
return transform_fn, {"remove_columns": ["question"]}
|
||||
"""
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
)
|
||||
@require_vllm
|
||||
def test_llama_dora(self, temp_dir, num_gpus):
|
||||
rnd_reward_suffix = str(random.randint(1000, 9999))
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"chat_template": "llama3",
|
||||
"rl": "grpo",
|
||||
"trl": {
|
||||
"beta": 0.001,
|
||||
"max_completion_length": 256,
|
||||
"use_vllm": True,
|
||||
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
|
||||
"vllm_gpu_memory_utilization": 0.15,
|
||||
"num_generations": 4,
|
||||
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "openai/gsm8k",
|
||||
"name": "main",
|
||||
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
||||
},
|
||||
],
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"peft_use_dora": True,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"max_steps": 5,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"warmup_steps": 10,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
str(num_gpus),
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
)
|
||||
@require_vllm
|
||||
def test_llama_fft(self, temp_dir, num_gpus):
|
||||
rnd_reward_suffix = str(random.randint(1000, 9999))
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"chat_template": "llama3",
|
||||
"rl": "grpo",
|
||||
"trl": {
|
||||
"beta": 0.001,
|
||||
"max_completion_length": 256,
|
||||
"use_vllm": True,
|
||||
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
|
||||
"vllm_gpu_memory_utilization": 0.15,
|
||||
"num_generations": 4,
|
||||
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "openai/gsm8k",
|
||||
"name": "main",
|
||||
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
||||
},
|
||||
],
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"max_steps": 5,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"warmup_steps": 10,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
str(num_gpus),
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
@@ -9,7 +9,7 @@ from pathlib import Path
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
||||
from e2e.utils import check_tensorboard
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -24,7 +24,6 @@ class TestMultiGPURay:
|
||||
Test cases for AnyScale Ray post training
|
||||
"""
|
||||
|
||||
@require_torch_lt_2_6_0
|
||||
def test_lora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
@@ -81,7 +80,6 @@ class TestMultiGPURay:
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@require_torch_lt_2_6_0
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
|
||||
@@ -1,414 +0,0 @@
|
||||
"""Integration tests for LoRA activation and attention kernels."""
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.state import PartialState
|
||||
from peft import PeftModelForCausalLM, get_peft_config
|
||||
from transformers import AutoModelForCausalLM, LlamaForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.kernels.lora import (
|
||||
apply_lora_mlp_geglu,
|
||||
apply_lora_mlp_swiglu,
|
||||
apply_lora_o,
|
||||
apply_lora_qkv,
|
||||
)
|
||||
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
MODEL_CONFIGS = [
|
||||
{
|
||||
"name": "openaccess-ai-collective/tiny-mistral",
|
||||
"expected_activation": apply_lora_mlp_swiglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
{
|
||||
"name": "Qwen/Qwen2-7B",
|
||||
"expected_activation": apply_lora_mlp_swiglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
{
|
||||
"name": "HuggingFaceTB/SmolLM2-135M",
|
||||
"expected_activation": apply_lora_mlp_swiglu,
|
||||
"dtype": torch.float32,
|
||||
},
|
||||
{
|
||||
"name": "mhenrichsen/gemma-2b",
|
||||
"expected_activation": apply_lora_mlp_geglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def init_accelerate():
|
||||
"""Initialize Accelerate state before tests."""
|
||||
_ = PartialState()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def small_llama_model():
|
||||
"""Create a small LLaMA model for testing."""
|
||||
config = {
|
||||
"vocab_size": 100,
|
||||
"hidden_size": 128,
|
||||
"intermediate_size": 256,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
}
|
||||
|
||||
return LlamaForCausalLM(LlamaConfig(**config))
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
@pytest.fixture
|
||||
def minimal_cfg():
|
||||
"Config of real HuggingFace Hub model"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||
"learning_rate": 0.000001,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"sequence_len": 1024,
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
"lora_o_kernel": True,
|
||||
}
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def test_attention_patching_integration(minimal_cfg):
|
||||
"""Test attention patching in integration context."""
|
||||
# Store the original implementation
|
||||
original_forward = getattr(LlamaAttention, "forward")
|
||||
|
||||
# Load model
|
||||
_, _ = load_model_and_tokenizer(cfg=minimal_cfg)
|
||||
|
||||
# Get the new forward method
|
||||
patched_forward = LlamaAttention.forward
|
||||
|
||||
# Check the forward method was replaced
|
||||
assert original_forward is not patched_forward
|
||||
assert patched_forward.__name__ == "axolotl_attn_forward"
|
||||
|
||||
# Check original implementation was stored
|
||||
assert hasattr(LlamaAttention, "_original_forward")
|
||||
|
||||
# Clean up
|
||||
setattr(LlamaAttention, "forward", original_forward)
|
||||
delattr(LlamaAttention, "_original_forward")
|
||||
|
||||
|
||||
def test_swiglu_mlp_integration(small_llama_model):
|
||||
"""Test SwiGLU activation in LoRA MLP context."""
|
||||
peft_config = get_peft_config(
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0,
|
||||
"bias": "none",
|
||||
}
|
||||
)
|
||||
model = PeftModelForCausalLM(small_llama_model, peft_config).to("cuda")
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
|
||||
# Apply patches
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
|
||||
# Verify patches
|
||||
layer = patched_model.model.model.layers[0]
|
||||
assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu
|
||||
|
||||
# Test forward pass
|
||||
batch_size, seq_len = 2, 10
|
||||
hidden_states = torch.randn(
|
||||
batch_size, seq_len, model.config.hidden_size, device=model.device
|
||||
)
|
||||
position_ids = (
|
||||
torch.arange(seq_len, device=model.device).unsqueeze(0).expand(batch_size, -1)
|
||||
)
|
||||
cos, sin = model.model.model.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
inputs = {
|
||||
"hidden_states": hidden_states,
|
||||
"attention_mask": None,
|
||||
"position_embeddings": (cos, sin),
|
||||
"output_attentions": False,
|
||||
"use_cache": False,
|
||||
"past_key_value": None,
|
||||
}
|
||||
|
||||
# Compare outputs
|
||||
with torch.no_grad():
|
||||
original_output = model.model.model.layers[0](**inputs)[0]
|
||||
patched_output = layer(**inputs)[0]
|
||||
|
||||
assert torch.allclose(original_output, patched_output, rtol=1e-4)
|
||||
|
||||
|
||||
def test_geglu_model_integration():
|
||||
"""Test GeGLU activation with Gemma model."""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda"
|
||||
)
|
||||
peft_config = get_peft_config(
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0,
|
||||
"bias": "none",
|
||||
}
|
||||
)
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
|
||||
# Verify patches
|
||||
layer = patched_model.model.model.layers[0]
|
||||
assert layer.mlp.forward.__func__ is apply_lora_mlp_geglu
|
||||
|
||||
# Test end-to-end
|
||||
inputs = torch.randint(0, 100, (1, 20), device=model.device, dtype=torch.long)
|
||||
with torch.no_grad():
|
||||
original_output = model(inputs).logits
|
||||
patched_output = patched_model(inputs).logits
|
||||
|
||||
assert torch.allclose(original_output, patched_output, rtol=1e-4)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,expected_activation",
|
||||
[
|
||||
("HuggingFaceTB/SmolLM2-135M", apply_lora_mlp_swiglu),
|
||||
("mhenrichsen/gemma-2b", apply_lora_mlp_geglu),
|
||||
],
|
||||
)
|
||||
def test_model_specific_activation(model_name, expected_activation):
|
||||
"""Test that each model type gets the correct activation function."""
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
peft_config = get_peft_config(
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0,
|
||||
"bias": "none",
|
||||
}
|
||||
)
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
layer = patched_model.model.model.layers[0]
|
||||
assert layer.mlp.forward.__func__ is expected_activation
|
||||
|
||||
|
||||
def test_kernel_patch_conditions():
|
||||
"""Test various conditions that should prevent kernel patching."""
|
||||
test_configs = [
|
||||
# Dropout prevents patching
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0.1,
|
||||
"bias": "none",
|
||||
},
|
||||
# Bias prevents patching
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0,
|
||||
"bias": "lora_only",
|
||||
},
|
||||
]
|
||||
|
||||
for config in test_configs:
|
||||
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
|
||||
peft_config = get_peft_config(config)
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
|
||||
# Should not patch
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
layer = patched_model.model.model.layers[0].mlp
|
||||
|
||||
# Verify no patches applied
|
||||
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
|
||||
assert layer.forward.__func__ is not apply_lora_mlp_geglu
|
||||
|
||||
|
||||
def test_kernel_config_options():
|
||||
"""Test that kernel configuration options are respected."""
|
||||
# Test different configurations
|
||||
test_configs = [
|
||||
(
|
||||
{"lora_mlp_kernel": True, "lora_qkv_kernel": False, "lora_o_kernel": False},
|
||||
lambda layer: (
|
||||
layer.mlp.forward.__func__ is apply_lora_mlp_swiglu
|
||||
and layer.self_attn.apply_qkv.__func__ is not apply_lora_qkv
|
||||
and layer.self_attn.apply_o.__func__ is not apply_lora_o
|
||||
),
|
||||
),
|
||||
(
|
||||
{"lora_mlp_kernel": False, "lora_qkv_kernel": True, "lora_o_kernel": False},
|
||||
lambda layer: (
|
||||
layer.mlp.forward.__func__ is not apply_lora_mlp_swiglu
|
||||
and layer.self_attn.apply_qkv.__func__ is apply_lora_qkv
|
||||
and layer.self_attn.apply_o.__func__ is not apply_lora_o
|
||||
),
|
||||
),
|
||||
(
|
||||
{"lora_mlp_kernel": False, "lora_qkv_kernel": False, "lora_o_kernel": True},
|
||||
lambda layer: (
|
||||
layer.mlp.forward.__func__ is not apply_lora_mlp_swiglu
|
||||
and layer.self_attn.apply_qkv.__func__ is not apply_lora_qkv
|
||||
and layer.self_attn.apply_o.__func__ is apply_lora_o
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
for config_dict, check_fn in test_configs:
|
||||
# Create fresh model for each test
|
||||
config = {
|
||||
"vocab_size": 100,
|
||||
"hidden_size": 128,
|
||||
"intermediate_size": 256,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
}
|
||||
small_llama_model = LlamaForCausalLM(LlamaConfig(**config))
|
||||
|
||||
peft_config = get_peft_config(
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
],
|
||||
"lora_dropout": 0,
|
||||
"bias": "none",
|
||||
}
|
||||
)
|
||||
model = PeftModelForCausalLM(small_llama_model, peft_config).to("cuda")
|
||||
cfg = DictDefault(config_dict)
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
|
||||
# Verify only requested optimizations were applied
|
||||
for layer in patched_model.model.model.layers:
|
||||
assert check_fn(layer), f"Failed for config: {config_dict}"
|
||||
|
||||
# Clean up
|
||||
del model
|
||||
del small_llama_model
|
||||
del patched_model
|
||||
|
||||
|
||||
def get_lora_config():
|
||||
"""Get standard LoRA configuration for testing."""
|
||||
return {
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 16,
|
||||
"target_modules": ["gate_proj", "up_proj", "down_proj"],
|
||||
"lora_dropout": 0,
|
||||
"bias": "none",
|
||||
}
|
||||
|
||||
|
||||
def get_test_inputs(model, seq_length=20):
|
||||
"""Generate test inputs for model evaluation."""
|
||||
return torch.randint(
|
||||
0,
|
||||
model.config.vocab_size,
|
||||
(1, seq_length),
|
||||
device=model.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_config", MODEL_CONFIGS)
|
||||
def test_model_architecture(model_config):
|
||||
"""Test LoRA kernel patches across different model architectures."""
|
||||
# Load model with appropriate dtype
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda"
|
||||
)
|
||||
|
||||
# Apply LoRA configuration
|
||||
peft_config = get_peft_config(get_lora_config())
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
|
||||
# Apply kernel patches
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
|
||||
# Verify correct activation function
|
||||
layer = patched_model.model.model.layers[0]
|
||||
assert (
|
||||
layer.mlp.forward.__func__ is model_config["expected_activation"]
|
||||
), f"Wrong activation for {model_config['name']}"
|
||||
|
||||
# Test forward pass
|
||||
inputs = get_test_inputs(model)
|
||||
with torch.no_grad():
|
||||
original_output = model(inputs).logits
|
||||
patched_output = patched_model(inputs).logits
|
||||
|
||||
# Check outputs match
|
||||
assert torch.allclose(
|
||||
original_output, patched_output, rtol=1e-4
|
||||
), f"Outputs don't match for {model_config['name']}"
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def test_kernel_training_integration(minimal_cfg):
|
||||
"""Test model loading with kernel patches enabled."""
|
||||
# Load model
|
||||
model, _ = load_model_and_tokenizer(cfg=minimal_cfg)
|
||||
|
||||
# Verify correct activation function
|
||||
layer = model.model.model.layers[0]
|
||||
assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu
|
||||
@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, check_tensorboard
|
||||
@@ -76,9 +76,7 @@ class TestFAXentropyLlama:
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from pathlib import Path
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
@@ -73,8 +73,6 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -12,7 +12,7 @@ import pytest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_preference_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -63,8 +63,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -110,8 +108,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -157,8 +153,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -204,8 +198,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -250,8 +242,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -299,8 +289,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -365,8 +353,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
@@ -56,8 +56,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -65,8 +65,6 @@ class TestFalcon(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -120,8 +118,6 @@ class TestFalcon(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -161,8 +157,6 @@ class TestFalcon(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -10,7 +10,7 @@ from e2e.utils import check_model_output_exists
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
@@ -56,8 +56,6 @@ class TestLlama:
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -101,8 +99,6 @@ class TestLlama:
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -142,8 +138,6 @@ class TestLlama:
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -10,7 +10,7 @@ import pytest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard
|
||||
@@ -69,8 +69,6 @@ class TestPretrainLlama:
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -62,8 +62,6 @@ class TestLlamaVision(unittest.TestCase):
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -59,8 +59,6 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -11,7 +11,7 @@ import pytest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -59,8 +59,6 @@ class TestMamba(unittest.TestCase):
|
||||
"save_safetensors": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -63,8 +63,6 @@ class TestMistral(unittest.TestCase):
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -108,8 +106,6 @@ class TestMistral(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -69,8 +69,6 @@ class TestMixtral(unittest.TestCase):
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -125,8 +123,6 @@ class TestMixtral(unittest.TestCase):
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -184,8 +180,6 @@ class TestMixtral(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -239,8 +233,6 @@ class TestMixtral(unittest.TestCase):
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
@@ -289,8 +281,6 @@ class TestMixtral(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir
|
||||
@@ -59,8 +59,6 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -105,8 +103,6 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -143,8 +139,6 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_tensorboard, with_temp_dir
|
||||
@@ -59,8 +59,6 @@ class TestPackedLlama(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -61,7 +61,6 @@ class TestPhi(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -40,10 +40,8 @@ class TestE2eQwen:
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_property_mappings": {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
@@ -66,7 +66,6 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -66,36 +66,6 @@ def require_torch_2_5_1(test_case):
|
||||
return unittest.skipUnless(is_min_2_5_1(), "test requires torch>=2.5.1")(test_case)
|
||||
|
||||
|
||||
def require_torch_lt_2_6_0(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.5.1
|
||||
"""
|
||||
|
||||
def is_max_2_6_0():
|
||||
torch_version = version.parse(torch.__version__)
|
||||
return torch_version < version.parse("2.6.0")
|
||||
|
||||
return unittest.skipUnless(is_max_2_6_0(), "test requires torch<2.6.0")(test_case)
|
||||
|
||||
|
||||
def require_vllm(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a vllm to be installed
|
||||
"""
|
||||
|
||||
def is_vllm_installed():
|
||||
try:
|
||||
import vllm # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
return unittest.skipUnless(
|
||||
is_vllm_installed(), "test requires a vllm to be installed"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def is_hopper():
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
return compute_capability == (9, 0)
|
||||
|
||||
@@ -7,7 +7,6 @@ from datasets import Dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
||||
from axolotl.utils.chat_templates import _CHAT_TEMPLATES
|
||||
|
||||
|
||||
@@ -175,32 +174,3 @@ def fixture_llama3_2_vision_with_hardcoded_date() -> str:
|
||||
modified_template = template.replace(old_date_logic, new_date_logic)
|
||||
|
||||
return modified_template
|
||||
|
||||
|
||||
@pytest.fixture(name="chat_template_jinja_with_optional_fields")
|
||||
def fixture_chat_template_jinja_with_optional_fields() -> str:
|
||||
return """{% for message in messages %}
|
||||
{{'<|im_start|>'}}{{ message['role'] }}
|
||||
{% if message['thoughts'] is defined %}[Thoughts: {{ message['thoughts'] }}]{% endif %}
|
||||
{% if message['tool_calls'] is defined %}[Tool: {{ message['tool_calls'][0]['type'] }}]{% endif %}
|
||||
{{ message['content'] }}{{'<|im_end|>'}}
|
||||
{% endfor %}"""
|
||||
|
||||
|
||||
@pytest.fixture(name="basic_jinja_template_analyzer")
|
||||
def basic_jinja_template_analyzer():
|
||||
return JinjaTemplateAnalyzer(
|
||||
"""{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>
|
||||
' + message['content'] + '<|end|>
|
||||
'}}{% elif message['role'] == 'user' %}{{'<|user|>
|
||||
' + message['content'] + '<|end|>
|
||||
'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>
|
||||
' + message['content'] + '<|end|>
|
||||
'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>
|
||||
' }}{% else %}{{ eos_token }}{% endif %}"""
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="mistral_jinja_template_analyzer")
|
||||
def mistral_jinja_template_analyzer(mistralv03_tokenizer_chat_template_jinja):
|
||||
return JinjaTemplateAnalyzer(mistralv03_tokenizer_chat_template_jinja)
|
||||
|
||||
@@ -38,10 +38,6 @@ class TestAssistantChatTemplateLlama3:
|
||||
"chat_template": "llama3",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"message_property_mappings": {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
"roles": {
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
@@ -78,10 +74,8 @@ class TestAssistantChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_property_mappings={
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
roles={
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
@@ -92,7 +86,7 @@ class TestAssistantChatTemplateLlama3:
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
)
|
||||
|
||||
strategy.messages = "messages"
|
||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
# fmt: off
|
||||
@@ -120,10 +114,8 @@ class TestAssistantChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
phi35_tokenizer,
|
||||
chat_template=get_chat_template("phi_35"),
|
||||
message_property_mappings={
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
roles={
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
@@ -134,7 +126,7 @@ class TestAssistantChatTemplateLlama3:
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
)
|
||||
|
||||
strategy.messages = "messages"
|
||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
@@ -178,11 +170,9 @@ class TestAssistantChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
message_field_training="training",
|
||||
message_property_mappings={
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
roles={
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
@@ -195,7 +185,7 @@ class TestAssistantChatTemplateLlama3:
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
|
||||
strategy.messages = "messages"
|
||||
prompt_tokens = strategy.prompter.build_prompt(
|
||||
assistant_dataset[0]["messages"], False
|
||||
)
|
||||
@@ -240,11 +230,8 @@ class TestSharegptChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_property_mappings={
|
||||
"role": "from",
|
||||
"content": "value",
|
||||
},
|
||||
field_messages="conversations",
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -252,7 +239,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
sequence_len=512,
|
||||
roles_to_train=["gpt"],
|
||||
)
|
||||
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
@@ -300,11 +287,8 @@ class TestSharegptChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_property_mappings={
|
||||
"role": "from",
|
||||
"content": "value",
|
||||
},
|
||||
field_messages="conversations",
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -312,7 +296,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
sequence_len=512,
|
||||
roles_to_train=["human"],
|
||||
)
|
||||
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
@@ -360,11 +344,8 @@ class TestSharegptChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_property_mappings={
|
||||
"role": "from",
|
||||
"content": "value",
|
||||
},
|
||||
field_messages="conversations",
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -372,7 +353,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
sequence_len=512,
|
||||
roles_to_train=["system", "human"],
|
||||
)
|
||||
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
@@ -436,7 +417,8 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
|
||||
chat_template=get_chat_template(
|
||||
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "role", "content": "content"},
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -504,7 +486,8 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
|
||||
chat_template=get_chat_template(
|
||||
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "role", "content": "content"},
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
|
||||
@@ -3,6 +3,7 @@ tests for chat_template prompt strategy
|
||||
"""
|
||||
|
||||
import logging
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
@@ -122,15 +123,15 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=True,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
@@ -179,15 +180,15 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
@@ -240,15 +241,20 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant", "human"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
@@ -301,15 +307,15 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=True,
|
||||
sequence_len=512,
|
||||
roles_to_train=["human", "assistant"],
|
||||
)
|
||||
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
@@ -354,8 +360,8 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -363,7 +369,7 @@ class TestChatTemplateConfigurations:
|
||||
roles_to_train=[],
|
||||
train_on_eos="none", # Add this line
|
||||
)
|
||||
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
|
||||
@@ -394,8 +400,8 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -403,7 +409,7 @@ class TestChatTemplateConfigurations:
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="all",
|
||||
)
|
||||
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
@@ -440,8 +446,8 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -449,6 +455,7 @@ class TestChatTemplateConfigurations:
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="turn",
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
@@ -519,8 +526,8 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -528,7 +535,7 @@ class TestChatTemplateConfigurations:
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="last",
|
||||
)
|
||||
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
@@ -571,8 +578,8 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -580,7 +587,7 @@ class TestChatTemplateConfigurations:
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="none",
|
||||
)
|
||||
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
@@ -617,15 +624,15 @@ class TestChatTemplateConfigurations:
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
drop_system_message=True,
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
@@ -661,7 +668,8 @@ class TestChatTemplateConfigurations:
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
roles=custom_roles,
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -733,7 +741,8 @@ class TestChatTemplateConfigurations:
|
||||
),
|
||||
message_field_training="train",
|
||||
message_field_training_detail="train_detail",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -902,64 +911,6 @@ class TestChatTemplateConfigurations:
|
||||
LOG.debug(f"Final labels: {labels}")
|
||||
LOG.debug(f"Final input_ids: {input_ids}")
|
||||
|
||||
def test_get_chat_template_variables(
|
||||
self, tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
):
|
||||
LOG.info("Testing get_chat_template_variables")
|
||||
|
||||
actual_tokenizer, actual_jinja_template = self.setup_tokenizer(
|
||||
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
)
|
||||
|
||||
prompter = ChatTemplatePrompter(
|
||||
actual_tokenizer,
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=actual_jinja_template
|
||||
),
|
||||
message_property_mappings={"from": "role", "value": "content"},
|
||||
)
|
||||
|
||||
variables = prompter.get_chat_template_msg_variables(
|
||||
actual_jinja_template
|
||||
if actual_jinja_template
|
||||
else actual_tokenizer.get_chat_template(),
|
||||
"messages",
|
||||
)
|
||||
|
||||
if chat_template == "llama3":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "chatml":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
|
||||
assert variables == {"role", "content", "tool_call_id", "tool_calls"}, (
|
||||
f"Expected variables: {'role', 'content', 'tool_call_id', 'tool_calls'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "jinja" and tokenizer == "gemma2_tokenizer":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "phi_35":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
else:
|
||||
LOG.warning(
|
||||
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
"""
|
||||
tests for jinja_template_analyzer
|
||||
"""
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
class TestJinjaTemplateAnalyzer:
|
||||
"""
|
||||
tests for jinja_template_analyzer
|
||||
"""
|
||||
|
||||
def test_basic_variable_extraction(self, basic_jinja_template_analyzer):
|
||||
"""Test that all top-level variables are correctly extracted."""
|
||||
LOG.info("Testing with train_on_inputs=True")
|
||||
|
||||
variables = basic_jinja_template_analyzer.get_template_variables()
|
||||
expected_vars = {"messages", "add_generation_prompt", "eos_token", "message"}
|
||||
assert set(variables.keys()) == expected_vars
|
||||
|
||||
def test_mixtral_variable_extraction(self, mistral_jinja_template_analyzer):
|
||||
"""Test that all top-level variables are correctly extracted."""
|
||||
LOG.info("Testing with train_on_inputs=True")
|
||||
|
||||
variables = mistral_jinja_template_analyzer.get_template_variables()
|
||||
expected_vars = {
|
||||
"messages",
|
||||
"content",
|
||||
"eos_token",
|
||||
"message",
|
||||
"tools",
|
||||
"system_message",
|
||||
"loop_messages",
|
||||
"ns",
|
||||
"tool_call",
|
||||
"tool",
|
||||
"loop",
|
||||
"bos_token",
|
||||
"raise_exception",
|
||||
}
|
||||
assert set(variables.keys()) == expected_vars
|
||||
message_vars = variables["message"]
|
||||
assert message_vars == {"role", "content", "tool_calls", "tool_call_id"}
|
||||
|
||||
def test_message_property_access(self, basic_jinja_template_analyzer):
|
||||
"""Test that properties accessed on 'message' variable are correctly identified."""
|
||||
LOG.info("Testing message property access")
|
||||
|
||||
variables = basic_jinja_template_analyzer.get_template_variables()
|
||||
assert "messages" in variables
|
||||
assert "message" in variables
|
||||
assert "role" in variables["message"]
|
||||
assert "content" in variables["message"]
|
||||
|
||||
def test_detailed_analysis(self, basic_jinja_template_analyzer):
|
||||
"""Test the detailed analysis of variable usage."""
|
||||
LOG.info("Testing detailed analysis")
|
||||
|
||||
analysis = basic_jinja_template_analyzer.analyze_template()
|
||||
|
||||
assert analysis["messages"]["is_iterated"] is True
|
||||
assert "role" in analysis["message"]["accessed_properties"]
|
||||
assert "content" in analysis["message"]["accessed_properties"]
|
||||
|
||||
assert analysis["add_generation_prompt"]["is_conditional"] is True
|
||||
assert len(analysis["add_generation_prompt"]["accessed_properties"]) == 0
|
||||
|
||||
assert not analysis["eos_token"]["is_iterated"]
|
||||
assert len(analysis["eos_token"]["accessed_properties"]) == 0
|
||||
|
||||
def test_nested_property_access(self):
|
||||
"""Test handling of nested property access."""
|
||||
LOG.info("Testing nested property access")
|
||||
|
||||
template = """{{ user.profile.name }}{{ user.settings['preference'] }}"""
|
||||
analyzer = JinjaTemplateAnalyzer(template)
|
||||
variables = analyzer.get_template_variables()
|
||||
|
||||
assert "user" in variables
|
||||
assert "profile" in variables["user"]
|
||||
assert "settings" in variables["user"]
|
||||
|
||||
def test_loop_variable_handling(self):
|
||||
"""Test handling of loop variables and their properties."""
|
||||
LOG.info("Testing loop variable handling")
|
||||
|
||||
template = """
|
||||
{% for item in items %}
|
||||
{{ item.name }}
|
||||
{% for subitem in item.subitems %}
|
||||
{{ subitem.value }}
|
||||
{% endfor %}
|
||||
{% endfor %}
|
||||
"""
|
||||
analyzer = JinjaTemplateAnalyzer(template)
|
||||
analysis = analyzer.analyze_template()
|
||||
|
||||
assert analysis["items"]["is_iterated"]
|
||||
assert "name" in analysis["item"]["accessed_properties"]
|
||||
assert "subitems" in analysis["item"]["accessed_properties"]
|
||||
|
||||
def test_conditional_variable_usage(self):
|
||||
"""Test detection of variables used in conditional statements."""
|
||||
LOG.info("Testing conditional variable usage")
|
||||
|
||||
template = """
|
||||
{% if user.is_admin and config.debug_mode %}
|
||||
{{ debug_info }}
|
||||
{% endif %}
|
||||
"""
|
||||
analyzer = JinjaTemplateAnalyzer(template)
|
||||
analysis = analyzer.analyze_template()
|
||||
|
||||
assert analysis["user"]["is_conditional"]
|
||||
assert analysis["config"]["is_conditional"]
|
||||
assert "is_admin" in analysis["user"]["accessed_properties"]
|
||||
assert "debug_mode" in analysis["config"]["accessed_properties"]
|
||||
|
||||
def test_complex_expressions(self):
|
||||
"""Test handling of complex expressions and filters."""
|
||||
LOG.info("Testing complex expressions and filters")
|
||||
|
||||
template = """
|
||||
{{ user.name | upper }}
|
||||
{{ messages | length > 0 and messages[0].content }}
|
||||
{{ data['key'].nested['value'] }}
|
||||
"""
|
||||
analyzer = JinjaTemplateAnalyzer(template)
|
||||
variables = analyzer.get_template_variables()
|
||||
|
||||
assert "user" in variables
|
||||
assert "name" in variables["user"]
|
||||
assert "messages" in variables
|
||||
assert "content" in variables["messages"]
|
||||
assert "data" in variables
|
||||
|
||||
def test_basic_msg_vars(self, basic_jinja_template_analyzer):
|
||||
"""Test that the basic message variables are correctly identified."""
|
||||
LOG.info("Testing basic message variables")
|
||||
|
||||
variables = basic_jinja_template_analyzer.get_message_vars()
|
||||
assert variables == {"role", "content"}
|
||||
|
||||
def test_mixtral_msg_vars(self, mistral_jinja_template_analyzer):
|
||||
"""Test that the mixtral message variables are correctly identified."""
|
||||
LOG.info("Testing mixtral message variables")
|
||||
|
||||
variables = mistral_jinja_template_analyzer.get_message_vars()
|
||||
assert variables == {"role", "content", "tool_calls", "tool_call_id"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@@ -302,22 +302,3 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_message_property_mappings(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
"message_property_mappings": {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
@@ -76,7 +76,7 @@ class TestModelsUtils:
|
||||
mocked_load_model_config.return_value = {}
|
||||
with pytest.raises(ValueError) as exc:
|
||||
# Should error before hitting tokenizer, so we pass in an empty str
|
||||
load_model(cfg, tokenizer="") # type: ignore
|
||||
load_model(cfg, tokenizer="")
|
||||
assert (
|
||||
"shifted-sparse attention does not currently support sample packing"
|
||||
in str(exc.value)
|
||||
@@ -116,79 +116,3 @@ class TestModelsUtils:
|
||||
assert self.model_loader.model_kwargs.get(
|
||||
"quantization_config", BitsAndBytesConfig
|
||||
)
|
||||
|
||||
def test_message_property_mapping(self):
|
||||
"""Test message property mapping configuration validation"""
|
||||
from axolotl.utils.config.models.input.v0_4_1 import SFTDataset
|
||||
|
||||
# Test legacy fields are mapped orrectly
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
message_field_role="role_field",
|
||||
message_field_content="content_field",
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "role_field",
|
||||
"content": "content_field",
|
||||
}
|
||||
|
||||
# Test direct message_property_mapping works
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
message_property_mappings={
|
||||
"role": "custom_role",
|
||||
"content": "custom_content",
|
||||
},
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "custom_role",
|
||||
"content": "custom_content",
|
||||
}
|
||||
|
||||
# Test both legacy and new fields work when they match
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
message_field_role="same_role",
|
||||
message_property_mappings={"role": "same_role"},
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "same_role",
|
||||
"content": "content",
|
||||
}
|
||||
|
||||
# Test both legacy and new fields work when they don't overlap
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
message_field_role="role_field",
|
||||
message_property_mappings={"content": "content_field"},
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "role_field",
|
||||
"content": "content_field",
|
||||
}
|
||||
|
||||
# Test no role or content provided
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
}
|
||||
|
||||
# Test error when legacy and new fields conflict
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
SFTDataset(
|
||||
path="test_path",
|
||||
message_field_role="legacy_role",
|
||||
message_property_mappings={"role": "different_role"},
|
||||
)
|
||||
assert "Conflicting message role fields" in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
SFTDataset(
|
||||
path="test_path",
|
||||
message_field_content="legacy_content",
|
||||
message_property_mappings={"content": "different_content"},
|
||||
)
|
||||
assert "Conflicting message content fields" in str(exc_info.value)
|
||||
|
||||
Reference in New Issue
Block a user