Compare commits

..

4 Commits

Author SHA1 Message Date
sunny
66a1e209e3 changed yml 2024-10-18 10:46:36 -04:00
sunny
dba033cb5b fixed yml for issue1947 testing 2024-10-17 12:31:11 -04:00
sunny
6a32e9a0df added yml for testing issue 1947 2024-10-17 11:26:32 -04:00
sunny
4afb2656b3 added yml for testing issue 1947 2024-10-17 11:23:47 -04:00
69 changed files with 1258 additions and 2937 deletions

View File

@@ -36,12 +36,6 @@ jobs:
python_version: "3.11"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v3

View File

@@ -29,11 +29,6 @@ jobs:
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -91,11 +86,6 @@ jobs:
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -21,17 +21,10 @@ jobs:
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
- cuda: 124
cuda_version: 12.4.1
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"

View File

@@ -28,11 +28,6 @@ jobs:
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -90,11 +85,6 @@ jobs:
python_version: "3.11"
pytorch: 2.4.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -27,7 +27,7 @@ jobs:
run: |
pip3 install wheel packaging
pip3 install -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pip3 install -r requirements-tests.txt
- name: Extract tag name
id: tag

View File

@@ -25,7 +25,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
pytorch_version: ["2.3.1", "2.4.1"]
timeout-minutes: 20
steps:
@@ -47,14 +47,13 @@ jobs:
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
- name: Install dependencies
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pip3 install -r requirements-tests.txt
- name: Run tests
run: |
@@ -82,17 +81,17 @@ jobs:
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.3.1
num_gpus: 1
axolotl_extras:
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
nightly_build: "true"

View File

@@ -36,7 +36,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
pytorch_version: ["2.3.1", "2.4.1"]
timeout-minutes: 20
steps:
@@ -49,20 +49,16 @@ jobs:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
- name: Install dependencies
run: |
pip3 show torch
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pip3 install -r requirements-tests.txt
- name: Run tests
run: |
@@ -72,17 +68,29 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests-1st:
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
timeout-minutes: 60
needs: [pre-commit, pytest]
strategy:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -111,49 +119,3 @@ jobs:
- name: Run tests job on Modal
run: |
modal run cicd.tests
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest, docker-e2e-tests-1st]
strategy:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.tests

View File

@@ -562,8 +562,7 @@ plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
```

View File

@@ -23,11 +23,11 @@ RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
@@ -37,7 +37,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
fi
# So we can test the Docker image
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
RUN pip install -r requirements-tests.txt
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

View File

@@ -1,6 +1,6 @@
#!/bin/bash
set -e
pytest -n4 --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -64,7 +64,7 @@ def run_cmd(cmd: str, run_folder: str):
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=60 * 60,
timeout=45 * 60,
cpu=8.0,
memory=131072 * N_GPUS,
)

View File

@@ -65,7 +65,7 @@ def run_cmd(cmd: str, run_folder: str):
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=60 * 60,
timeout=45 * 60,
cpu=8.0,
memory=131072,
)

View File

@@ -14,6 +14,15 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",

View File

@@ -24,6 +24,15 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",

View File

@@ -20,6 +20,15 @@
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",

View File

@@ -7,8 +7,8 @@ load_in_8bit: true
load_in_4bit: false
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
- path: philschmid/guanaco-sharegpt-style
type: sharegpt
shards: 10
val_set_size: 0
output_dir: temp_debug/axolotl_outputs/model

View File

@@ -20,6 +20,7 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -83,7 +83,7 @@ lora_on_cpu: true
datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
data_files: # Optional[str] path to source data files
@@ -124,48 +124,6 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column
field:
# Using chat template
- path: ...
# Set type to `chat_template` to use this strategy
type: chat_template
# Specify the name of the chat template to use
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
chat_template: tokenizer_default
# Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`).
chat_template_jinja:
# The key in the data example that contains the messages. Default is "messages".
field_messages: messages
# The key in the message turn that contains the role. Default is "role".
message_field_role: role
# The key in the message turn that contains the content. Default is "content".
message_field_content: content
# Optional[Dict[str, List]]. Roles mapping for the messages.
roles:
user: ["human", "user"]
assistant: ["gpt", "assistant", "ai"]
system: ["system"]
## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on.
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
roles_to_train: ["gpt", "assistant"]
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
# - all: train on all EOS tokens
# - turn: train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
train_on_eos: last
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
# See example at `docs/dataset-formats/conversation.qmd`
message_field_training_detail: train_detail
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
@@ -184,16 +142,9 @@ test_datasets:
# use RL training: 'dpo', 'ipo', 'kto'
rl:
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
# Currently supports chatml and inst (mistral/mixtral)
chat_template: chatml
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so

View File

@@ -6,8 +6,6 @@ order: 3
## sharegpt
UPDATE: ShareGPT is being deprecated in the next release. Please see `chat_template` section below.
conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
```{.json filename="data.jsonl"}
@@ -71,138 +69,3 @@ creates a chat where bot is asked to tell a joke, then explain why the joke is f
```{.json filename="data.jsonl"}
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
```
## chat_template
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "content": "..."}]}
```
See `config.qmd` for full configs and supported templates.
### Migrating from sharegpt
Most configs can be adapted as follows:
```yaml
# old
chat_template: chatml
datasets:
- path: ...
type: sharegpt
conversation: chatml
# new (if using tokenizer's chat_template)
datasets:
- path: ...
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
# new (if setting a new chat_template like chatml, gemma, etc)
chat_template: chatml
datasets:
- path: ...
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
```
We recommend checking the below examples for other usecases.
### Examples
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
```yaml
datasets:
- path: ...
type: chat_template
```
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: gemma # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
```yaml
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
```{.json filename="data.jsonl"}
{
"conversations": [
{"from": "system", "value": "You are an AI assistant.", "train": false},
{"from": "human", "value": "Hello", "train": false},
{"from": "assistant", "value": "Hello", "train": true},
{"from": "human", "value": "How are you?", "train": true},
{
"from": "assistant",
"value": "I'm doing very well, thank you!",
"train_detail": [
{"begin_offset": 0, "end_offset": 8, "train": false},
{"begin_offset": 9, "end_offset": 18, "train": true},
{"begin_offset": 19, "end_offset": 30, "train": false},
],
},
{
"from": "human",
"value": "I'm doing very well, thank you!",
"train": true,
},
{"from": "assistant", "value": "Hi there!", "train": true}
]
}
```
The configuration would look like:
```yaml
datasets:
- path: ...
type: chat_template
chat_template: tokenizer_default
field_messages: conversations
message_field_role: from
message_field_content: value
roles_to_train: []
train_on_eos: turn
message_field_training: train
message_field_training_detail: train_detail
```
Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time.

View File

@@ -51,12 +51,12 @@ While debugging it's helpful to simplify your test scenario as much as possible.
### Background
The below example shows how to configure VSCode to debug data preprocessing of the `chat_template` format. This is the format used when you have the following in your axolotl config:
The below example shows how to configure VSCode to debug data preprocessing of the `sharegpt` format. This is the format used when you have the following in your axolotl config:
```yaml
datasets:
- path: <path to your chat_template formatted dataset> # example on HF Hub: fozziethebeat/alpaca_messages_2k_test
type: chat_template
- path: <path to your sharegpt formatted dataset> # example on HF Hub: philschmid/guanaco-sharegpt-style
type: sharegpt
```
>[!Important]
@@ -83,7 +83,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_sharegpt.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
```jsonc
// .vscode/launch.json
@@ -91,12 +91,12 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
"version": "0.2.0",
"configurations": [
{
"name": "Debug axolotl prompt - chat_template",
"name": "Debug axolotl prompt - sharegpt",
"type": "python",
"module": "accelerate.commands.launch",
"request": "launch",
"args": [
"-m", "axolotl.cli.train", "dev_chat_template.yml",
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
// The flags below simplify debugging by overriding the axolotl config
// with the debugging tips above. Modify as needed.
"--dataset_processes=1", // limits data preprocessing to one process
@@ -240,6 +240,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
</div>
<br>
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/sharegpt.yml`, but this is the same thing.
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).

View File

@@ -16,10 +16,7 @@ chat_template: deepseek_v2
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
split: train
dataset_prepared_path: last_run_prepared
val_set_size: 0.0

View File

@@ -11,11 +11,8 @@ chat_template: gemma
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: gemma
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
val_set_size: 0.0
output_dir: ./outputs/out

View File

@@ -4,15 +4,11 @@ tokenizer_type: AutoTokenizer
load_in_4bit: true
strict: false
use_tensorboard: true
chat_template: jamba
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: jamba
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: jamba-large-fsdp-qlora-ft

View File

@@ -0,0 +1,50 @@
base_model: meta-llama/Llama-3.1-8B-Instruct
save_safetensors: true
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: ./last_run_prepared
output_dir: ./outputs/fft-out
sequence_len: 8192
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
learning_rate: 2e-5
bf16: auto
fp16:
tf32: false
logging_steps: 2
xformers_attention:
flash_attention: true
warmup_steps: 2
evals_per_epoch: 2
save_steps: 2
max_steps: 2
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: true
fsdp_cpu_ram_efficient_loading: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_backward_prefetch: BACKWARD_PRE
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -4,26 +4,28 @@ plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
strict: false
chat_template: llama3
datasets:
- path: tatsu-lab/alpaca
type: alpaca
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
dataset_prepared_path: last_run_prepared
val_set_size: 0
val_set_size: 0.02
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project: check_liger_hf_GA_llama_fix-3
wandb_entity: axolotl-ai
wandb_project:
wandb_entity:
wandb_watch:
wandb_name: pr/fix333-tr4.46.1
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4

View File

@@ -1,77 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

View File

@@ -10,6 +10,7 @@ chat_template: phi_3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
chat_template: phi_3
field_messages: messages
message_field_role: role
message_field_content: content

View File

@@ -2,4 +2,3 @@ pre-commit
black
mypy
types-requests
tbparse

View File

@@ -1,12 +1,12 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.2
transformers==4.46.1
transformers==4.45.2
tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.0.1
datasets==3.0.1
deepspeed==0.15.3
deepspeed==0.14.4
pydantic==2.6.3
addict
fire
@@ -16,7 +16,7 @@ flash-attn==2.6.3
sentencepiece
wandb
einops
xformers>=0.0.23.post1
xformers==0.0.28.post1
optimum==1.16.2
hf_transfer
colorama
@@ -33,8 +33,8 @@ gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq>=0.2.5
triton>=3.1.0
liger-kernel==0.3.1
triton>=2.3.0
liger-kernel==0.3.0
mamba-ssm==1.2.0.post1
@@ -43,7 +43,7 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0
# adlfs
trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924
trl==0.9.6
zstandard==0.22.0
fastcore

View File

@@ -31,8 +31,6 @@ def parse_requirements():
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
@@ -52,16 +50,10 @@ def parse_requirements():
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if (major, minor) >= (2, 4):
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if patch == 0:
@@ -81,6 +73,7 @@ def parse_requirements():
except PackageNotFoundError:
pass
return _install_requires, _dependency_links
@@ -109,7 +102,6 @@ setup(
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",

View File

@@ -30,7 +30,7 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
@@ -272,7 +272,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
chat_template_str = chat_templates(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -462,12 +462,7 @@ def load_datasets(
processor=processor,
)
if (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(

View File

@@ -23,7 +23,7 @@ class TrainerCliArgs:
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
debug_num_examples: int = field(default=5)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)

View File

@@ -7,7 +7,6 @@ import abc
import gc
import importlib
import importlib.util
import inspect
import logging
import math
import os
@@ -28,6 +27,7 @@ from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
EarlyStoppingCallback,
PreTrainedModel,
Trainer,
TrainerCallback,
TrainingArguments,
@@ -48,7 +48,6 @@ from trl import (
)
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available
@@ -64,7 +63,7 @@ from axolotl.utils.callbacks import (
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
@@ -667,9 +666,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
@@ -677,18 +674,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
return super().compute_loss(model, inputs, return_outputs=return_outputs)
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
@@ -784,13 +771,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
).squeeze(2)
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
def orpo_compute_loss(
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument
):
def orpo_compute_loss(self, model, inputs, return_outputs=False):
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs,
label_pad_token=-100,
@@ -917,7 +898,6 @@ class AxolotlMambaTrainer(AxolotlTrainer):
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
@@ -1025,32 +1005,18 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
return super().push_to_hub(*args, **kwargs)
def tokenize_row(
self,
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
) -> Dict:
res = super().tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
res = super().tokenize_row(feature, model=model)
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
return res
def training_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
loss: torch.Tensor = super().training_step(model, inputs)
gc.collect()
torch.cuda.empty_cache()
return loss
@@ -1148,28 +1114,17 @@ class TrainerBuilderBase(abc.ABC):
def get_callbacks(self) -> List[TrainerCallback]:
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from transformers.integrations.integration_utils import MLflowCallback
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.extend(
[
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
MLflowCallback,
]
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_comet and is_comet_available():
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
@@ -1180,17 +1135,11 @@ class TrainerBuilderBase(abc.ABC):
return callbacks
@abstractmethod
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
@@ -1236,7 +1185,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
@@ -1607,9 +1556,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template,
tokenizer=self.tokenizer,
training_arguments_kwargs["chat_template"] = chat_templates(
self.cfg.chat_template
)
if self.cfg.rl == "orpo":
@@ -1714,17 +1662,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return_tensors="pt",
**data_collator_kwargs,
)
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
trainer_kwargs["processing_class"] = self.tokenizer
else:
trainer_kwargs["tokenizer"] = self.tokenizer
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
tokenizer=self.tokenizer,
data_collator=self.build_collator(training_args, **data_collator_kwargs),
callbacks=self.get_callbacks(),
**trainer_kwargs,
@@ -1765,8 +1708,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
]
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
@@ -1804,7 +1745,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
callbacks = []
return callbacks
def build_training_arguments(self, total_num_steps):
@@ -1969,7 +1910,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
dpo_trainer_kwargs["generate_during_eval"] = True
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
@@ -1981,17 +1922,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
sig = inspect.signature(trainer_cls)
if "processing_class" in sig.parameters.keys():
dpo_trainer_kwargs["processing_class"] = self.tokenizer
else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
tokenizer=self.tokenizer,
callbacks=self.get_callbacks(),
**dpo_trainer_kwargs,
)
@@ -2013,11 +1948,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks = []
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
callbacks = []
return callbacks
def build(self, total_num_steps):

View File

@@ -18,10 +18,9 @@ Plugins can be used to integrate third-party models, modify the training process
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
"""
import collections
import importlib
import logging
from typing import OrderedDict
from typing import List
class BasePlugin:
@@ -48,7 +47,7 @@ class BasePlugin:
Initializes the BasePlugin.
"""
def register(self, cfg): # pylint: disable=unused-argument
def register(self, cfg):
"""
Registers the plugin with the given configuration.
@@ -64,7 +63,7 @@ class BasePlugin:
Returns a pydantic model for the plugin's input arguments.
"""
def pre_model_load(self, cfg): # pylint: disable=unused-argument
def pre_model_load(self, cfg):
"""
Performs actions before the model is loaded.
@@ -75,7 +74,7 @@ class BasePlugin:
None
"""
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
def post_model_load(self, cfg, model):
"""
Performs actions after the model is loaded.
@@ -87,7 +86,7 @@ class BasePlugin:
None
"""
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
def pre_lora_load(self, cfg, model):
"""
Performs actions before LoRA weights are loaded.
@@ -99,7 +98,7 @@ class BasePlugin:
None
"""
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
def post_lora_load(self, cfg, model):
"""
Performs actions after LoRA weights are loaded.
@@ -111,7 +110,7 @@ class BasePlugin:
None
"""
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
def create_optimizer(self, cfg, trainer):
"""
Creates and returns an optimizer for training.
@@ -123,9 +122,7 @@ class BasePlugin:
object: The created optimizer.
"""
def create_lr_scheduler(
self, cfg, trainer, optimizer
): # pylint: disable=unused-argument
def create_lr_scheduler(self, cfg, trainer, optimizer):
"""
Creates and returns a learning rate scheduler.
@@ -138,7 +135,7 @@ class BasePlugin:
object: The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
def add_callbacks_pre_trainer(self, cfg, model):
"""
Adds callbacks to the trainer before training.
@@ -149,11 +146,8 @@ class BasePlugin:
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
def add_callbacks_post_trainer(
self, cfg, trainer
): # pylint: disable=unused-argument
def add_callbacks_post_trainer(self, cfg, trainer):
"""
Adds callbacks to the trainer after training.
@@ -164,9 +158,8 @@ class BasePlugin:
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
def post_train(self, cfg, model): # pylint: disable=unused-argument
def post_train(self, cfg, model):
"""
Performs actions after training is complete.
@@ -178,7 +171,7 @@ class BasePlugin:
None
"""
def post_train_unload(self, cfg): # pylint: disable=unused-argument
def post_train_unload(self, cfg):
"""
Performs actions after training is complete and the model is unloaded.
@@ -234,7 +227,7 @@ class PluginManager:
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
"""
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
plugins: List[BasePlugin] = []
_instance = None
@@ -244,7 +237,7 @@ class PluginManager:
"""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins = collections.OrderedDict()
cls._instance.plugins: List[BasePlugin] = []
return cls._instance
@staticmethod
@@ -272,7 +265,7 @@ class PluginManager:
"""
try:
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
self.plugins.append(plugin)
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
@@ -284,7 +277,7 @@ class PluginManager:
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
"""
input_args = []
for plugin in self.plugins.values():
for plugin in self.plugins:
input_args_from_plugin = plugin.get_input_args()
if input_args_from_plugin is not None:
input_args.append(input_args_from_plugin)
@@ -300,7 +293,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
plugin.pre_model_load(cfg)
def post_model_load(self, cfg, model):
@@ -314,7 +307,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
plugin.post_model_load(cfg, model)
def pre_lora_load(self, cfg, model):
@@ -328,7 +321,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
plugin.pre_lora_load(cfg, model)
def post_lora_load(self, cfg, model):
@@ -342,7 +335,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
plugin.post_lora_load(cfg, model)
def create_optimizer(self, cfg, trainer):
@@ -356,7 +349,7 @@ class PluginManager:
Returns:
object: The created optimizer, or None if none was found.
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
optimizer = plugin.create_optimizer(cfg, trainer)
if optimizer is not None:
return optimizer
@@ -374,7 +367,7 @@ class PluginManager:
Returns:
object: The created learning rate scheduler, or None if none was found.
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
if scheduler is not None:
return scheduler
@@ -392,7 +385,7 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins.values():
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
return callbacks
@@ -408,7 +401,7 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins.values():
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
return callbacks
@@ -423,5 +416,5 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
plugin.post_train_unload(cfg)

View File

@@ -18,23 +18,20 @@ Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight.
"""
import inspect
import logging
import sys
from functools import partial
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.liger")
class LigerPlugin(BasePlugin):
"""
@@ -45,31 +42,59 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)
kwargs = {}
if "rope" in liger_fn_sig.parameters:
kwargs["rope"] = cfg.liger_rope
if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs[
"fused_linear_cross_entropy"
] = cfg.liger_fused_linear_cross_entropy
if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters:
kwargs["layer_norm"] = cfg.liger_layer_norm
if "geglu" in liger_fn_sig.parameters:
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
with zero_only():
LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
if cfg.model_config_type == "llama":
from liger_kernel.transformers.model.llama import (
lce_forward as llama_lce_forward,
)
from transformers.models.llama import modeling_llama
if cfg.liger_rope:
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_llama.LlamaRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_llama.LlamaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
elif cfg.liger_fused_linear_cross_entropy:
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
elif cfg.model_config_type == "mistral":
from liger_kernel.transformers.model.mistral import (
lce_forward as mistral_lce_forward,
)
from transformers.models.mistral import modeling_mistral
if cfg.liger_rope:
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_mistral.MistralRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_mistral.MistralMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
elif cfg.model_config_type == "gemma":
from liger_kernel.transformers.model.gemma import (
lce_forward as gemma_lce_forward,
)
from transformers.models.gemma import modeling_gemma
if cfg.liger_rope:
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma.GemmaRMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
apply_liger_fn(**kwargs)
if cfg.liger_swiglu:
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba
@@ -79,12 +104,30 @@ class LigerPlugin(BasePlugin):
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
if cfg.liger_swiglu:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "qwen2":
from liger_kernel.transformers.model.qwen2 import (
lce_forward as qwen2_lce_forward,
)
from transformers.models.qwen2 import modeling_qwen2
if cfg.liger_rope:
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM
@@ -103,9 +146,44 @@ class LigerPlugin(BasePlugin):
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
if cfg.liger_swiglu:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_cross_entropy:
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type == "gemma2":
from transformers.models.gemma2 import modeling_gemma2
if cfg.liger_rope:
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma2.Gemma2RMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
if cfg.liger_swiglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
logging.warning(
"Fused linear cross entropy is not supported for Gemma 2."
)
elif cfg.model_config_type == "phi3":
from liger_kernel.transformers.model.phi3 import (
lce_forward as phi3_lce_forward,
)
from transformers.models.phi3 import modeling_phi3
if cfg.liger_rope:
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward

View File

@@ -15,12 +15,9 @@
"""
Module for handling LIGER input arguments.
"""
import logging
from typing import Optional
from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.liger.args")
from pydantic import BaseModel
class LigerArgs(BaseModel):
@@ -30,24 +27,6 @@ class LigerArgs(BaseModel):
liger_rope: Optional[bool] = None
liger_rms_norm: Optional[bool] = None
liger_layer_norm: Optional[bool] = None
liger_swiglu: Optional[bool] = None
liger_glu_activation: Optional[bool] = None
liger_cross_entropy: Optional[bool] = None
liger_fused_linear_cross_entropy: Optional[bool] = None
@model_validator(mode="before")
@classmethod
def check_deprecated_swiglu(cls, data):
if data.get("liger_swiglu") is not None:
if data.get("liger_glu_activation") is not None:
raise ValueError(
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
)
LOG.warning(
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
"Please use 'liger_glu_activation' instead."
)
data["liger_glu_activation"] = data.pop("liger_swiglu")
return data

View File

@@ -22,6 +22,7 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb,
repeat_kv,
)
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
@@ -43,19 +44,7 @@ except ImportError:
LOG = logging.getLogger("axolotl")
def is_xformers_available() -> bool:
try:
import xformers # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
def is_xformers_swiglu_available() -> bool:
if not is_xformers_available():
return False
from xformers.ops.common import get_xformers_operator
try:
@@ -68,11 +57,6 @@ def is_xformers_swiglu_available() -> bool:
def replace_llama_mlp_with_swiglu(model):
if is_xformers_swiglu_available():
from axolotl.monkeypatch.xformers_ import FusedMLP
else:
raise RuntimeError("xformers SwiGLU not available for this environment")
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
mlp = FusedMLP(
@@ -197,6 +181,49 @@ class FusedAttention(LlamaAttention):
set_module_name(model, name, new_attn)
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(

View File

@@ -16,6 +16,26 @@ from transformers.models.llama.modeling_llama import (
LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
"""
PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
)
"""
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
@@ -60,6 +80,12 @@ def get_forward_code() -> str:
return forward
def check_cel_is_patchable() -> bool:
forward = get_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_CEL_CODE in forward
def get_self_attn_code() -> str:
forward = inspect.getsource(LlamaFlashAttention2.forward)
return forward
@@ -72,31 +98,48 @@ def check_self_attn_is_patchable() -> bool:
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
def UnslothForCausalLMLoss( # pylint: disable=invalid-name
logits,
labels,
vocab_size: int, # pylint: disable=unused-argument
num_items_in_batch: int = None,
ignore_index: int = -100, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
)
return loss
if model_type == "llama":
from transformers.loss import loss_utils
forward = get_forward_code()
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment]
forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
)
forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"",
)
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace(
"def forward(",
"def fast_cross_entropy_loss_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
else:
raise ValueError("Unsupported model type")

View File

@@ -1,51 +0,0 @@
"""
Fused MLP layer for incrementally improved training efficiency
"""
import torch
from transformers.models.llama.modeling_llama import LlamaMLP
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import set_module_name
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)

View File

@@ -6,7 +6,7 @@ import logging
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
LOG = logging.getLogger("axolotl.prompt_strategies")
def load(strategy, tokenizer, cfg, ds_cfg):

View File

@@ -2,18 +2,13 @@
Bradley-Terry model with chat template prompt strategy.
"""
import logging
from typing import Any, Dict, Optional
from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
ChatTemplateStrategy,
)
from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
LOG.setLevel(logging.INFO)
from axolotl.utils.chat_templates import chat_templates
class BTChatTemplateStrategy(ChatTemplateStrategy):
@@ -32,24 +27,18 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)
self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
prompt[self.messages].append({"from": "system", "value": prompt["system"]})
prompt[self.messages].append({"from": "user", "value": prompt["input"]})
prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]})
rejected_tokenized = super().tokenize_prompt(prompt)
return {
@@ -64,18 +53,15 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "from"),
"message_field_content": ds_cfg.get("message_field_content", "value"),
"message_field_training": ds_cfg.get("message_field_training", "training"),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", None
"message_field_training_detail", "train_detail"
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
@@ -88,8 +74,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", []),
"train_on_eos": ds_cfg.get("train_on_eos", None),
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}
strategy = BTChatTemplateStrategy(

View File

@@ -9,7 +9,7 @@ from transformers import ProcessorMixin
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.chat_templates import chat_templates
# Configure the logger
LOG = logging.getLogger("axolotl")
@@ -405,14 +405,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),

View File

@@ -2,16 +2,15 @@
DPO prompt strategies for using tokenizer chat templates.
"""
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
from axolotl.utils.chat_templates import chat_templates
def default(
cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx]
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
chat_template_str = chat_templates(cfg.chat_template)
field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
@@ -31,12 +30,6 @@ def default(
role_map[source] = target
def transform_fn(sample, tokenizer=None):
chat_template_string = get_chat_template(
user_choice=chat_template_choice,
jinja_template=chat_template_jinja,
tokenizer=tokenizer,
)
messages = sample[field_messages]
messages = [
{
@@ -53,29 +46,28 @@ def default(
"role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content],
}
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
result = {}
result["prompt"] = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)
result["chosen"] = tokenizer.apply_chat_template(
[dummy_user_message, chosen],
[chosen],
add_generation_prompt=False,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)
chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template(
[dummy_user_message, rejected],
[rejected],
add_generation_prompt=False,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)
rejected_strip_index = result["rejected"].find(rejected["content"])

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.chat_templates import chat_templates
class Message(BaseModel):
@@ -28,13 +28,18 @@ def load(
"""
chatml transforms for datasets with system, input, chosen, rejected
"""
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
tokenizer.chat_template = chat_template_string
chat_template = chat_templates("chatml")
if ds_cfg and "chat_template" in ds_cfg:
chat_template = ds_cfg["chat_template"]
try:
chat_template = chat_templates(chat_template)
except ValueError:
pass
tokenizer.chat_template = chat_template
return ORPOTokenizingStrategy(
ORPOPrompter(chat_template_string, tokenizer),
ORPOPrompter(chat_template, tokenizer),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
@@ -243,30 +248,28 @@ class ORPOPrompter(Prompter):
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy()
chat_template_str = chat_templates(cfg.chat_template)
def transform_fn(sample, tokenizer=None):
res = {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, tokenizer=tokenizer
)
res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)
prompt_str_len = len(res["prompt"])
res["chosen"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)[prompt_str_len:]
res["rejected"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
add_generation_prompt=False,
chat_template=chat_template_string,
chat_template=chat_template_str,
tokenize=False,
)[prompt_str_len:]

View File

@@ -62,7 +62,7 @@ def build_loader(
):
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
LOG.warning(
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead. https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template",
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead.",
)
conversation = (
ds_cfg["conversation"]

View File

@@ -260,10 +260,8 @@ def train(
if not cfg.hub_model_id:
try:
trainer.create_model_card(
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
)
except (AttributeError, UnicodeDecodeError):
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
except AttributeError:
pass
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated

File diff suppressed because one or more lines are too long

View File

@@ -228,7 +228,6 @@ def normalize_cfg_datasets(cfg):
f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template"
)
cfg.datasets[idx].chat_template = cfg.chat_template
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):

View File

@@ -8,16 +8,9 @@ import logging
import os
from enum import Enum
from importlib.metadata import version
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import (
BaseModel,
Field,
StringConstraints,
conlist,
field_validator,
model_validator,
)
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
from transformers import SchedulerType
from transformers.training_args import OptimizerNames
@@ -28,38 +21,6 @@ LOG = logging.getLogger("axolotl.utils.config.models.input")
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
exaone = "exaone" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""
@@ -144,19 +105,13 @@ class SFTDataset(BaseModel):
input_transform: Optional[str] = None
shards: Optional[int] = None
conversation: Optional[str] = None
# Do not make this too strict or it will break the validator to choose different dataset class
chat_template: Optional[
Union[
ChatTemplate,
str,
]
] = None
chat_template_jinja: Optional[str] = None
chat_template: Optional[str] = None
data_files: Optional[Union[str, List[str]]] = None
input_format: Optional[str] = None
name: Optional[str] = None
ds_type: Optional[str] = None
train_on_split: Optional[str] = None
field: Optional[str] = None
field_human: Optional[str] = None
field_model: Optional[str] = None
@@ -167,32 +122,13 @@ class SFTDataset(BaseModel):
message_field_training_detail: Optional[str] = None
roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
# Set chat_template to tokenizer_default if not set
if data.get("type") == "chat_template" and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.tokenizer_default
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.jinja
return data
class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO"""
@@ -238,6 +174,35 @@ class KTODataset(BaseModel):
revision: Optional[str] = None
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
@@ -584,7 +549,6 @@ class AxolotlInputConfig(
resume_from_checkpoint: Optional[str] = None
auto_resume_from_checkpoints: Optional[bool] = None
resize_token_embeddings_to_32x: Optional[bool] = None
mean_resizing_embeddings: Optional[bool] = False
rl: Optional[RLType] = None
reward_model: Optional[bool] = None
@@ -754,13 +718,7 @@ class AxolotlInputConfig(
gpu_memory_limit: Optional[Union[int, str]] = None
low_cpu_mem_usage: Optional[bool] = None
chat_template: Optional[
Union[
ChatTemplate,
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
]
] = None
chat_template_jinja: Optional[str] = None
chat_template: Optional[ChatTemplate] = None
default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[bool] = None
@@ -869,23 +827,6 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.jinja
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_wo_flash(cls, data):

View File

@@ -2,11 +2,9 @@
import functools
import logging
import time
from pathlib import Path
from typing import List, Optional, Tuple, Union
import requests
from datasets import (
Dataset,
DatasetDict,
@@ -55,28 +53,6 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl")
def retry_on_request_exceptions(max_retries=3, delay=1):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None):
prompters = []
if not cfg.pretraining_dataset:

View File

@@ -16,7 +16,3 @@ def setup_mlflow_env_vars(cfg: DictDefault):
# Enable mlflow if experiment name is present
if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:
cfg.use_mlflow = True
# Enable logging hf artifacts in mlflow if value is truthy
if cfg.hf_mlflow_log_artifacts is True:
os.environ["HF_MLFLOW_LOG_ARTIFACTS"] = "true"

File diff suppressed because it is too large Load Diff

View File

@@ -133,8 +133,6 @@ class MultipackBatchSampler(BatchSampler):
self.eff_total_used = 0
self.eff_total_slots = 0
self.len_across_ranks = None
def set_epoch(self, epoch: int):
self.epoch = epoch
@@ -197,14 +195,15 @@ class MultipackBatchSampler(BatchSampler):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(0.998 * min(estimates))
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
min_len_batches = reduce_and_broadcast(
lambda: num,
calc_min_len,
)
return min_len_batches
def __len__(self):
if not self.len_across_ranks:
len_batches = self.num_batches()
self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks
len_batches = self.num_batches()
return self.gather_len_batches(len_batches)
def _len_est(self):
efficiency = (

View File

@@ -1,155 +0,0 @@
"""
E2E tests for multigpu eval
"""
import logging
import os
import unittest
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
class TestMultiGPUEval(unittest.TestCase):
"""
Test case for MultiGPU Eval Sample Packing
"""
@with_temp_dir
def test_eval_sample_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
"sample_packing": True,
"eval_sample_packing": True,
"pad_to_sequence_len": True,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.1,
"special_tokens": {"pad_token": "<|end_of_text|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"loss_watchdog_threshold": 5.0,
"loss_watchdog_patience": 3,
"bf16": "auto",
"warmup_steps": 1,
"evals_per_epoch": 2,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"logging_steps": 1,
"weight_decay": 0.0,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_eval(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.1,
"special_tokens": {"pad_token": "<|end_of_text|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"loss_watchdog_threshold": 5.0,
"loss_watchdog_patience": 3,
"bf16": "auto",
"warmup_steps": 1,
"evals_per_epoch": 2,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"logging_steps": 1,
"weight_decay": 0.0,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -14,7 +14,7 @@ from huggingface_hub import snapshot_download
from axolotl.utils.dict import DictDefault
from ..utils import is_hopper, with_temp_dir
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
@@ -59,7 +59,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -116,7 +116,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 50,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -144,146 +144,6 @@ class TestMultiGPULlama(unittest.TestCase):
]
)
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
@with_temp_dir
def test_dpo_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_dpo_qlora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM-135M",
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"rl": "dpo",
"chat_template": "chatml",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_fsdp(self, temp_dir):
# pylint: disable=duplicate-code
@@ -305,7 +165,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -371,7 +231,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -413,6 +273,7 @@ class TestMultiGPULlama(unittest.TestCase):
]
)
@pytest.mark.skip("disabled due to upstream issue")
@with_temp_dir
def test_fsdp_qlora_prequant_packed(self, temp_dir):
# pylint: disable=duplicate-code
@@ -421,7 +282,6 @@ class TestMultiGPULlama(unittest.TestCase):
"base_model": "axolotl-ai-co/TinyLlama_v1.1-bnb-nf4-bf16",
"tokenizer_type": "AutoTokenizer",
"adapter": "qlora",
"mean_resizing_embeddings": True,
"load_in_4bit": True,
"lora_r": 8,
"lora_alpha": 16,
@@ -437,7 +297,7 @@ class TestMultiGPULlama(unittest.TestCase):
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "</s>",
"pad_token": "<|end_of_text|>",
},
"datasets": [
{
@@ -447,7 +307,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -513,7 +373,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -572,7 +432,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 100,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,

View File

@@ -47,7 +47,7 @@ class TestMultiGPUQwen2(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 15,
"max_steps": 100,
"warmup_steps": 20,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,

View File

@@ -13,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import require_torch_2_3_1, with_temp_dir
from ..utils import require_torch_2_1_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -24,7 +24,7 @@ class Test4dMultipackLlama(unittest.TestCase):
Test case for Llama models using 4d attention with multipack
"""
@require_torch_2_3_1
@require_torch_2_1_1
@with_temp_dir
def test_sdp_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code

View File

@@ -1,12 +1,22 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
from axolotl.monkeypatch.unsloth_ import (
check_cel_is_patchable,
check_self_attn_is_patchable,
)
class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests."""
def test_is_cel_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_cel_is_patchable(),
"HF transformers loss code has changed and isn't patchable",
)
def test_is_self_attn_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(

View File

@@ -1,95 +0,0 @@
"""Module for testing ModelLoader."""
import shutil
import tempfile
import pytest
import torch
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import ModelLoader, load_model, load_tokenizer
@pytest.fixture(name="temp_dir")
def fixture_temp_dir():
temp_dir = tempfile.mkdtemp()
yield temp_dir
shutil.rmtree(temp_dir)
class TestLoadModelUtils:
"""
Testing module testing ModelLoader.
"""
def setup_method(self):
# load config
self.cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"tokenizer_config": "JackFram/llama-68m",
"sequence_len": 1024,
"load_in_8bit": False,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
ModelLoader(
cfg=self.cfg,
tokenizer="",
)
)
@pytest.mark.parametrize("embedding_modules", ["embed_tokens", "lm_head"])
@pytest.mark.parametrize(
"dist_dtype", [torch.bfloat16, torch.float16, torch.float32]
)
@pytest.mark.parametrize("before_kbit_train_or_finetune", [True, False])
def test_convert_embedding_modules_dtype(
self, temp_dir, embedding_modules, dist_dtype, before_kbit_train_or_finetune
):
self.cfg.output_dir = temp_dir
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
self.model_loader.model, _ = load_model(
self.cfg,
self.model_loader.tokenizer,
inference=False,
reference_model=True,
)
self.model_loader.convert_embedding_modules_dtype(
embedding_modules, dist_dtype, before_kbit_train_or_finetune
)
for name, module in self.model_loader.model.named_modules():
if (
"norm" in name
or (before_kbit_train_or_finetune and name.endswith(".gate"))
or (
any(m in name for m in embedding_modules)
and hasattr(module, "weight")
)
):
for _, param in module.named_parameters():
assert param.dtype == dist_dtype

View File

@@ -1,74 +0,0 @@
"""
E2E tests for packed training
"""
import logging
import os
import unittest
from tbparse import SummaryReader
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import most_recent_subdir, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPackedLlama(unittest.TestCase):
"""
Test case for Packed training of llama models
"""
@with_temp_dir
def test_loss_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM-135M",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.0, "Loss is too high"

View File

@@ -9,8 +9,6 @@ from functools import wraps
from importlib.metadata import version
from pathlib import Path
import torch
def with_temp_dir(test_func):
@wraps(test_func)
@@ -37,18 +35,13 @@ def most_recent_subdir(path):
return subdir
def require_torch_2_3_1(test_case):
def require_torch_2_1_1(test_case):
"""
Decorator marking a test that requires torch >= 2.3.1
Decorator marking a test that requires torch >= 2.1.1
"""
def is_min_2_3_1():
def is_min_2_1_1():
torch_version = version("torch")
return torch_version >= "2.3.1"
return torch_version >= "2.1.1"
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
def is_hopper():
compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0)
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)

View File

@@ -1,125 +0,0 @@
"""
Tests for utils in axolotl.utils.chat_templates
"""
import unittest
import pytest
from transformers import AutoTokenizer
from axolotl.utils.chat_templates import (
_CHAT_TEMPLATES,
extract_chat_template_args,
get_chat_template,
)
@pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
return tokenizer
class TestGetChatTemplateUtils:
"""
Tests the get_chat_template function.
"""
def test_known_chat_template(self):
chat_template_str = get_chat_template("llama3")
assert chat_template_str == _CHAT_TEMPLATES["llama3"]
def test_invalid_chat_template(self):
with pytest.raises(ValueError) as exc:
get_chat_template("invalid_template")
assert str(exc) == "Template 'invalid_template' not found."
def test_tokenizer_default_no_tokenizer(self):
with pytest.raises(ValueError):
get_chat_template("tokenizer_default", tokenizer=None)
def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
with pytest.raises(ValueError):
get_chat_template("tokenizer_default", tokenizer=llama3_tokenizer)
def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
llama3_tokenizer.chat_template = "test_template"
chat_template_str = get_chat_template(
"tokenizer_default", tokenizer=llama3_tokenizer
)
assert chat_template_str == "test_template"
def test_tokenizer_default_fallback_no_tokenizer(self):
with pytest.raises(ValueError):
get_chat_template("tokenizer_default_fallback_test", tokenizer=None)
def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
self, llama3_tokenizer
):
chat_template_str = get_chat_template(
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
)
assert chat_template_str == get_chat_template("chatml")
def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
self, llama3_tokenizer
):
llama3_tokenizer.chat_template = "test_template"
chat_template_str = get_chat_template(
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
)
assert chat_template_str == "test_template"
def test_jinja_template_mode(self):
jinja_template = "example_jinja_template"
chat_template_str = get_chat_template("jinja", jinja_template=jinja_template)
assert chat_template_str == jinja_template
def test_jinja_template_mode_no_jinja_template(self):
with pytest.raises(ValueError):
get_chat_template("jinja", jinja_template=None)
def test_extract_chat_template_args(self):
# No ds_cfg
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={"chat_template": "chatml"},
)
assert chat_template_choice == "chatml"
assert chat_template_jinja is None
# ds_cfg provided
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={
"chat_template": "jinja",
"chat_template_jinja": "global_jinja_template",
},
ds_cfg={"chat_template": "llama3", "chat_template_jinja": None},
)
assert chat_template_choice == "llama3"
assert chat_template_jinja is None
# ds_cfg provided with jinja template
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={"chat_template": "chatml", "chat_template_jinja": None},
ds_cfg={
"chat_template": "jinja",
"chat_template_jinja": "ds_jinja_template",
},
)
assert chat_template_choice == "jinja"
assert chat_template_jinja == "ds_jinja_template"
# ds_cfg provided with no chat_template
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={
"chat_template": "jinja",
"chat_template_jinja": "global_jinja_template",
},
ds_cfg={"chat_template": None, "chat_template_jinja": "ds_jinja_template"},
)
assert chat_template_choice == "jinja"
assert chat_template_jinja == "global_jinja_template"
if __name__ == "__main__":
unittest.main()

View File

@@ -11,7 +11,7 @@ from axolotl.prompt_strategies.chat_template import (
load,
)
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.dict import DictDefault
logging.basicConfig(level=logging.DEBUG)
@@ -73,7 +73,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
chat_template=chat_templates("llama3"),
message_field_role="role",
message_field_content="content",
roles={
@@ -113,7 +113,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
phi35_tokenizer,
chat_template=get_chat_template("phi_35"),
chat_template=chat_templates("phi_35"),
message_field_role="role",
message_field_content="content",
roles={
@@ -171,7 +171,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
chat_template=chat_templates("llama3"),
message_field_role="role",
message_field_content="content",
message_field_training="training",
@@ -230,7 +230,7 @@ class TestSharegptChatTemplateLlama3:
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3")
llama3_tokenizer, chat_template=chat_templates("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -283,7 +283,7 @@ class TestSharegptChatTemplateLlama3:
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3")
llama3_tokenizer, chat_template=chat_templates("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -336,7 +336,7 @@ class TestSharegptChatTemplateLlama3:
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3")
llama3_tokenizer, chat_template=chat_templates("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,

View File

@@ -12,7 +12,7 @@ from axolotl.prompt_strategies.chat_template import (
ChatTemplateStrategy,
)
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.chat_templates import chat_templates
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
@@ -35,7 +35,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_inputs=True")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3")
llama3_tokenizer, chat_template=chat_templates("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
@@ -80,7 +80,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_inputs=False")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3")
llama3_tokenizer, chat_template=chat_templates("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -123,7 +123,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing roles_to_train with assistant only")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3")
llama3_tokenizer, chat_template=chat_templates("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -151,7 +151,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing roles_to_train with all roles")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3")
llama3_tokenizer, chat_template=chat_templates("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=True,
@@ -184,7 +184,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with empty roles_to_train")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3")
llama3_tokenizer, chat_template=chat_templates("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -205,7 +205,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='all'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3")
llama3_tokenizer, chat_template=chat_templates("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -232,7 +232,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='turn'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3")
llama3_tokenizer, chat_template=chat_templates("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -282,7 +282,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='last'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3")
llama3_tokenizer, chat_template=chat_templates("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -315,7 +315,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='none'")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer, chat_template=get_chat_template("llama3")
llama3_tokenizer, chat_template=chat_templates("llama3")
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -343,7 +343,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
chat_template=chat_templates("llama3"),
drop_system_message=True,
),
tokenizer=llama3_tokenizer,
@@ -371,7 +371,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
chat_template=chat_templates("llama3"),
roles=custom_roles,
),
tokenizer=llama3_tokenizer,
@@ -424,7 +424,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
chat_template=chat_templates("llama3"),
message_field_training="train",
message_field_training_detail="train_detail",
),

View File

@@ -86,20 +86,6 @@ def fixture_llama3_tokenizer():
return tokenizer
@pytest.fixture(name="phi3_tokenizer")
def fixture_phi3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
return tokenizer
@pytest.fixture(name="gemma_tokenizer")
def fixture_gemma_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")
return tokenizer
class TestAssistantDPOChatTemplateLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
@@ -113,7 +99,7 @@ class TestAssistantDPOChatTemplateLlama3:
"chat_template": "llama3",
"datasets": [
{
"type": "chat_template",
"chat_template": "llama3",
}
],
}
@@ -138,7 +124,7 @@ class TestAssistantDPOChatTemplateLlama3:
"chat_template": "llama3",
"datasets": [
{
"type": "chat_template",
"chat_template": "llama3",
"field_messages": "conversation",
"field_chosen": "better",
"field_rejected": "worse",
@@ -166,65 +152,5 @@ class TestAssistantDPOChatTemplateLlama3:
assert result["rejected"] == "party on<|eot_id|>"
class TestAssistantDPOChatTemplatePhi3:
"""
Test class for assistant style datasets with phi-3 prompts using the tokenizer's chat_template strategy.
"""
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "tokenizer_default",
"datasets": [
{
"type": "chat_template",
}
],
}
)
)
result = transform_fn(assistant_dataset[0], tokenizer=phi3_tokenizer)
assert result["prompt"] == (
"<|user|>\nhello<|end|>\n"
+ "<|assistant|>\nhello<|end|>\n"
+ "<|user|>\ngoodbye<|end|>\n"
+ "<|assistant|>\n"
)
assert result["chosen"] == "goodbye<|end|>"
assert result["rejected"] == "party on<|end|>"
class TestAssistantDPOChatTemplateGemma:
"""
Test class for assistant style datasets with gemma prompts using the tokenizer's chat_template strategy.
"""
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "tokenizer_default",
"datasets": [
{
"type": "chat_template",
}
],
}
)
)
result = transform_fn(assistant_dataset[0], tokenizer=gemma_tokenizer)
assert result["prompt"] == (
"<bos><start_of_turn>user\nhello<end_of_turn>\n"
+ "<start_of_turn>model\nhello<end_of_turn>\n"
+ "<start_of_turn>user\ngoodbye<end_of_turn>\n"
+ "<start_of_turn>model\n"
)
assert result["chosen"] == "goodbye<end_of_turn>"
assert result["rejected"] == "party on<end_of_turn>"
if __name__ == "__main__":
unittest.main()

View File

@@ -367,44 +367,43 @@ class TestDatasetPreparation(unittest.TestCase):
def test_load_local_hub_with_revision(self):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
with tempfile.TemporaryDirectory() as tmp_dir2:
tmp_ds_path = Path(tmp_dir2) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
f"{tmp_ds_path}/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
if __name__ == "__main__":

View File

@@ -13,7 +13,6 @@ from axolotl.utils import is_comet_available
from axolotl.utils.config import validate_config
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
from axolotl.utils.dict import DictDefault
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import check_model_config
from axolotl.utils.wandb_ import setup_wandb_env_vars
@@ -1433,58 +1432,3 @@ class TestValidationComet(BaseValidation):
for key in comet_env.keys():
os.environ.pop(key, None)
class TestValidationMLflow(BaseValidation):
"""
Validation test for MLflow
"""
def test_hf_mlflow_artifacts_config_sets_env(self, minimal_cfg):
cfg = (
DictDefault(
{
"hf_mlflow_log_artifacts": True,
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
assert new_cfg.hf_mlflow_log_artifacts is True
# Check it's not already present in env
assert "HF_MLFLOW_LOG_ARTIFACTS" not in os.environ
setup_mlflow_env_vars(new_cfg)
assert os.environ.get("HF_MLFLOW_LOG_ARTIFACTS") == "true"
os.environ.pop("HF_MLFLOW_LOG_ARTIFACTS", None)
def test_mlflow_not_used_by_default(self, minimal_cfg):
cfg = DictDefault({}) | minimal_cfg
new_cfg = validate_config(cfg)
setup_mlflow_env_vars(new_cfg)
assert cfg.use_mlflow is not True
cfg = (
DictDefault(
{
"mlflow_experiment_name": "foo",
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
setup_mlflow_env_vars(new_cfg)
assert new_cfg.use_mlflow is True
os.environ.pop("MLFLOW_EXPERIMENT_NAME", None)

View File

@@ -1,238 +0,0 @@
"""Module for testing the validation module for the dataset config"""
import warnings
from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate
from axolotl.utils.dict import DictDefault
warnings.filterwarnings("error")
@pytest.fixture(name="minimal_cfg")
def fixture_cfg():
return DictDefault(
{
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
"learning_rate": 0.000001,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)
# pylint: disable=too-many-public-methods (duplicate-code)
class BaseValidation:
"""
Base validation module to setup the log capture
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
class TestValidationCheckDatasetConfig(BaseValidation):
"""
Test the validation for the dataset config to ensure no correct parameters are dropped
"""
def test_dataset_config_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "sharegpt",
"conversation": "chatml",
"shards": 10,
}
]
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.datasets[0].conversation == cfg.datasets[0].conversation
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
)
_check_config()
def test_dataset_default_chat_template_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"field_messages": "conversations",
"shards": 10,
"message_field_role": "from",
"message_field_content": "value",
}
],
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.chat_template is None
assert (
checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
)
assert (
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
)
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
assert (
checked_cfg.datasets[0].message_field_role
== cfg.datasets[0].message_field_role
)
assert (
checked_cfg.datasets[0].message_field_content
== cfg.datasets[0].message_field_content
)
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
)
_check_config()
def test_dataset_partial_default_chat_template_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"chat_template": "chatml",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"field_messages": "conversations",
"shards": 10,
"message_field_role": "from",
"message_field_content": "value",
}
],
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.chat_template == ChatTemplate.chatml
assert (
checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
)
assert (
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
)
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
assert (
checked_cfg.datasets[0].message_field_role
== cfg.datasets[0].message_field_role
)
assert (
checked_cfg.datasets[0].message_field_content
== cfg.datasets[0].message_field_content
)
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
)
_check_config()
def test_dataset_chatml_chat_template_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"chat_template": "chatml",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"chat_template": "gemma",
"field_messages": "conversations",
"shards": 10,
"message_field_role": "from",
"message_field_content": "value",
}
],
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.chat_template == cfg.chat_template
assert (
checked_cfg.datasets[0].chat_template == cfg.datasets[0].chat_template
)
assert (
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
)
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
assert (
checked_cfg.datasets[0].message_field_role
== cfg.datasets[0].message_field_role
)
assert (
checked_cfg.datasets[0].message_field_content
== cfg.datasets[0].message_field_content
)
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
)
_check_config()

View File

@@ -1,64 +1,18 @@
"""Module for testing models utils file."""
from unittest.mock import MagicMock, patch
import unittest
from unittest.mock import patch
import pytest
from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils.import_utils import is_torch_mps_available
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import ModelLoader, load_model
from axolotl.utils.models import load_model
class TestModelsUtils:
class ModelsUtilsTest(unittest.TestCase):
"""Testing module for models utils."""
def setup_method(self) -> None:
# load config
self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init
{
"base_model": "JackFram/llama-68m",
"model_type": "LlamaForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"load_in_8bit": True,
"load_in_4bit": False,
"adapter": "lora",
"flash_attention": False,
"sample_packing": True,
"device_map": "auto",
}
)
self.tokenizer = MagicMock( # pylint: disable=attribute-defined-outside-init
spec=PreTrainedTokenizerBase
)
self.inference = False # pylint: disable=attribute-defined-outside-init
self.reference_model = True # pylint: disable=attribute-defined-outside-init
# init ModelLoader
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
ModelLoader(
cfg=self.cfg,
tokenizer=self.tokenizer,
inference=self.inference,
reference_model=self.reference_model,
)
)
def test_set_device_map_config(self):
# check device_map
device_map = self.cfg.device_map
if is_torch_mps_available():
device_map = "mps"
self.model_loader.set_device_map_config()
if is_deepspeed_zero3_enabled():
assert "device_map" not in self.model_loader.model_kwargs
else:
assert device_map in self.model_loader.model_kwargs["device_map"]
# check torch_dtype
assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"]
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
cfg = DictDefault(
{
@@ -81,38 +35,3 @@ class TestModelsUtils:
"shifted-sparse attention does not currently support sample packing"
in str(exc.value)
)
@pytest.mark.parametrize("adapter", ["lora", "qlora", None])
@pytest.mark.parametrize("load_in_8bit", [True, False])
@pytest.mark.parametrize("load_in_4bit", [True, False])
@pytest.mark.parametrize("gptq", [True, False])
def test_set_quantization_config(
self,
adapter,
load_in_8bit,
load_in_4bit,
gptq,
):
# init cfg as args
self.cfg.load_in_8bit = load_in_8bit
self.cfg.load_in_4bit = load_in_4bit
self.cfg.gptq = gptq
self.cfg.adapter = adapter
self.model_loader.set_quantization_config()
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
assert not (
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
)
elif load_in_8bit and self.cfg.adapter is not None:
assert self.model_loader.model_kwargs["load_in_8bit"]
elif load_in_4bit and self.cfg.adapter is not None:
assert self.model_loader.model_kwargs["load_in_4bit"]
if (self.cfg.adapter == "qlora" and load_in_4bit) or (
self.cfg.adapter == "lora" and load_in_8bit
):
assert self.model_loader.model_kwargs.get(
"quantization_config", BitsAndBytesConfig
)