Merge branch 'main' into cj_tokenizer_default_prompt_template

This commit is contained in:
Chirag Jain
2024-07-23 17:16:49 +05:30
committed by GitHub
42 changed files with 1262 additions and 157 deletions

View File

@@ -37,6 +37,11 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.0 pytorch: 2.3.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3

View File

@@ -19,7 +19,6 @@ jobs:
pytorch: 2.1.2 pytorch: 2.1.2
axolotl_extras: axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118" axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
is_latest: true
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.10" python_version: "3.10"
@@ -33,8 +32,9 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.0 pytorch: 2.3.1
axolotl_extras: axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -80,7 +80,6 @@ jobs:
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.2 pytorch: 2.1.2
axolotl_extras: axolotl_extras:
is_latest: true
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.10" python_version: "3.10"
@@ -94,8 +93,9 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.0 pytorch: 2.3.1
axolotl_extras: axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -136,7 +136,7 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.0 pytorch: 2.3.1
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:

View File

@@ -18,7 +18,6 @@ jobs:
pytorch: 2.1.2 pytorch: 2.1.2
axolotl_extras: axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118" axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
is_latest: true
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.10" python_version: "3.10"
@@ -32,8 +31,9 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.0 pytorch: 2.3.1
axolotl_extras: axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -80,7 +80,6 @@ jobs:
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.2 pytorch: 2.1.2
axolotl_extras: axolotl_extras:
is_latest: true
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.10" python_version: "3.10"
@@ -94,8 +93,9 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.0 pytorch: 2.3.1
axolotl_extras: axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -57,6 +57,10 @@ jobs:
run: | run: |
pytest --ignore=tests/e2e/ tests/ pytest --ignore=tests/e2e/ tests/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests: docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud' if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
@@ -87,7 +91,7 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.0 pytorch: 2.3.1
num_gpus: 1 num_gpus: 1
steps: steps:
- name: Checkout - name: Checkout
@@ -99,7 +103,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal jinja2 pip install modal==0.63.64 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -46,6 +46,7 @@ Features:
- [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg> - [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg> - [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg> - [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Unsloth](./docs/unsloth.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Common Errors](#common-errors-) - [Common Errors](#common-errors-)
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training) - [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl) - [Debugging Axolotl](#debugging-axolotl)
@@ -333,7 +334,7 @@ For further and fine-grained use cases, please refer to the official [dstack doc
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field. Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
See [these docs](https://openaccess-ai-collective.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats. See [these docs](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
### Config ### Config

View File

@@ -36,6 +36,7 @@ website:
- docs/nccl.qmd - docs/nccl.qmd
- docs/mac.qmd - docs/mac.qmd
- docs/multi-node.qmd - docs/multi-node.qmd
- docs/unsloth.qmd
- section: "Dataset Formats" - section: "Dataset Formats"
contents: docs/dataset-formats/* contents: docs/dataset-formats/*
- section: "Reference" - section: "Reference"

View File

@@ -24,13 +24,13 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
fi fi
# So we can test the Docker image # So we can test the Docker image
RUN pip install pytest RUN pip install -r requirements-tests.txt
# fix so that git fetch/pull from remote works # fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \ RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \

View File

@@ -2,5 +2,5 @@
set -e set -e
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/ pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest /workspace/axolotl/tests/e2e/patched/ pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/ pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/

View File

@@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \ else \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \ pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
fi fi
# So we can test the Docker image # So we can test the Docker image

19
docs/torchao.qmd Normal file
View File

@@ -0,0 +1,19 @@
---
title: "PyTorch ao"
description: "Custom data types and layouts for training and inference"
---
### Installation
Stable Release from the PyTorch index
```bash
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124
```
Nightly release
```bash
pip install --pre torchao-nightly --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
```

49
docs/unsloth.qmd Normal file
View File

@@ -0,0 +1,49 @@
---
title: "Unsloth"
description: "Hyper-optimized QLoRA finetuning for single GPUs"
---
### Overview
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
standard industry baselines.
### Installation
The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
to date libraries.
```bash
pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps --force-reinstall xformers==0.0.26.post1
```
### Using unsloth w Axolotl
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
Our unsloth integration is currently limited to the following model architectures:
- llama
These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
```yaml
unsloth_lora_mlp: true
unsloth_lora_qkv: true
unsloth_lora_o: true
```
These options are composable and can be used with multi-gpu finetuning
```
unsloth_cross_entropy_loss: true
unsloth_rms_norm: true
unsloth_rope: true
```
### Limitations
- Single GPU only; e.g. no multi-gpu support
- No deepspeed or FSDP support (requires multi-gpu)
- LoRA + QLoRA support only. No full fine tunes or fp8 support.
- Limited model architecture support. Llama, Phi, Gemma, Mistral only
- No MoE support.

View File

@@ -171,7 +171,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Buy using the ! the comand will be executed as a bash command\n", "# By using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml" "!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
] ]
}, },
@@ -188,7 +188,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Buy using the ! the comand will be executed as a bash command\n", "# By using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n", "!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --qlora_model_dir=\"./qlora-out\" --gradio" " --qlora_model_dir=\"./qlora-out\" --gradio"
] ]

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B base_model: NousResearch/Meta-Llama-3-8B
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer

View File

@@ -0,0 +1,81 @@
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
chat_template: llama3
rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
chat_template: llama3
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B-Instruct base_model: NousResearch/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B base_model: NousResearch/Meta-Llama-3-8B
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
@@ -15,6 +15,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true pad_to_sequence_len: true
adapter: lora adapter: lora

View File

@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B base_model: NousResearch/Meta-Llama-3-8B
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer

View File

@@ -1 +1,2 @@
pytest pytest
pytest-xdist

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.11.1 peft==0.11.1
transformers==4.42.3 transformers @ git+https://github.com/huggingface/transformers.git@0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf
tokenizers==0.19.1 tokenizers==0.19.1
bitsandbytes==0.43.1 bitsandbytes==0.43.1
accelerate==0.32.0 accelerate==0.32.0
@@ -12,11 +12,11 @@ fire
PyYAML>=6.0 PyYAML>=6.0
requests requests
datasets==2.19.1 datasets==2.19.1
flash-attn==2.5.8 flash-attn==2.6.1
sentencepiece sentencepiece
wandb wandb
einops einops
xformers==0.0.26.post1 xformers==0.0.27
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
colorama colorama

View File

@@ -29,9 +29,10 @@ def parse_requirements():
_install_requires.append(line) _install_requires.append(line)
try: try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
if "Darwin" in platform.system(): if "Darwin" in platform.system():
# don't install xformers on MacOS # don't install xformers on MacOS
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1")) _install_requires.pop(_install_requires.index(xformers_version))
else: else:
# detect the version of torch already installed # detect the version of torch already installed
# and set it so dependencies don't clobber the torch version # and set it so dependencies don't clobber the torch version
@@ -49,12 +50,14 @@ def parse_requirements():
raise ValueError("Invalid version format") raise ValueError("Invalid version format")
if (major, minor) >= (2, 3): if (major, minor) >= (2, 3):
pass if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
elif (major, minor) >= (2, 2): elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1")) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1") _install_requires.append("xformers>=0.0.25.post1")
else: else:
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1")) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1") _install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError: except PackageNotFoundError:
@@ -77,10 +80,10 @@ setup(
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "flash-attn": [
"flash-attn==2.5.8", "flash-attn==2.6.1",
], ],
"fused-dense-lib": [ "fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib", "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b", "deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b",
@@ -101,5 +104,11 @@ setup(
"galore": [ "galore": [
"galore_torch", "galore_torch",
], ],
"optimizers": [
"galore_torch",
"lion-pytorch==0.1.2",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
}, },
) )

View File

@@ -375,7 +375,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
cfg, cfg,
capabilities={ capabilities={
"bf16": is_torch_bf16_gpu_available(), "bf16": is_torch_bf16_gpu_available(),
"n_gpu": os.environ.get("WORLD_SIZE", 1), "n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version, "compute_capability": gpu_version,
}, },
) )

View File

@@ -0,0 +1,150 @@
"""
helper functions for fixing the embeddings/tokenizer
"""
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import itertools
import numpy as np
import torch
@torch.inference_mode
def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
"""
Many of the newer models have reserved tokens that are not trained.
"""
embedding_matrix = model.get_input_embeddings().weight
lm_head_matrix = model.get_output_embeddings().weight
# Get untrained tokens
indicator_untrained = torch.amax(embedding_matrix, axis=1) <= eps
where_untrained = torch.where(indicator_untrained)[0]
n_untrained = where_untrained.shape[0]
n_trained = embedding_matrix.shape[0] - n_untrained
# Get set and actual tokens
where_untrained = where_untrained.tolist()
if len(where_untrained) == 0:
return False
# Remove untrained indices where it's longer
where_untrained_set = frozenset(where_untrained)
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
# Remove None items in actual_bad_tokens
actual_bad_tokens = [x for x in actual_bad_tokens if x is not None]
# Check if tokenizer and training datasets have bad tokens
if_bad_first = False
if_bad_second = False
# Check tokenizer's chat template for any untrained tokens
chat_template = getattr(tokenizer, "chat_template", None)
if chat_template is not None:
if_bad_first = any(x in chat_template for x in actual_bad_tokens)
# Check the first 250, last 250 input_ids
size_dataset = len(train_dataset)
size = min(size_dataset, 250)
for j in range(size):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
if_bad = any(item in where_untrained_set for item in input_ids)
if if_bad:
if_bad_second = True
break
# Check last 250
if not if_bad_second:
left = max(size_dataset - 250, 0)
for j in range(left, size_dataset):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
if_bad = any(item in where_untrained_set for item in input_ids)
if if_bad:
if_bad_second = True
break
# Check if bad tokens exists!
if not if_bad_first and not if_bad_second:
return False
# Count all the possible bad tokens
final_counts = np.zeros(
max(len(tokenizer), embedding_matrix.shape[0]), dtype=np.int64
)
def mapping(examples):
input_ids = examples["input_ids"]
counter = np.fromiter(itertools.chain.from_iterable(input_ids), dtype=np.int32)
np.add.at(final_counts, counter, 1)
train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
# Get sum of all items
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
# Remove bad tokens
sum_embedding -= torch.sum(
embedding_matrix[where_untrained], dtype=torch.float32, axis=0
)
sum_lm_head -= torch.sum(
lm_head_matrix[where_untrained], dtype=torch.float32, axis=0
)
# Find correct average by dividing by sum of trained tokens
mean_embedding = sum_embedding / n_trained
mean_lm_head = sum_lm_head / n_trained
# Scale each to be equal to 1/max_frequency. Also set some to 0 if none seen
scaling = final_counts[where_untrained] / max(final_counts.max(), 1)
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
mean_embedding = (
mean_embedding.repeat(
(
n_untrained,
1,
)
)
* scaling
)
mean_lm_head = (
mean_lm_head.repeat(
(
n_untrained,
1,
)
)
* scaling
)
where_null = scaling.ravel() == 0
mean_embedding[where_null] = 0
mean_lm_head[where_null] = 0
# Set them to the mean
embedding_matrix[where_untrained] = mean_embedding.to(embedding_matrix.dtype)
lm_head_matrix[where_untrained] = mean_lm_head.to(lm_head_matrix.dtype)
# Clean up
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return True

View File

@@ -226,6 +226,12 @@ class AxolotlTrainingMixins:
default=None, default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"}, metadata={"help": "whether to use sequential sampling for curriculum learning"},
) )
alternate_optimizer: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)
@dataclass @dataclass
@@ -284,26 +290,91 @@ class AxolotlTrainer(Trainer):
if self.args.orpo_alpha: if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer(self): def create_optimizer(self):
if self.args.loraplus_lr_ratio is None: if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
):
return super().create_optimizer() return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args, self.args,
opt_model, opt_model,
) )
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) if self.args.loraplus_lr_ratio is not None:
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init loraplus_lr_embedding = getattr(
opt_model, self.args, "loraplus_lr_embedding", None
optimizer_cls, )
optimizer_kwargs, self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
loraplus_lr_ratio, opt_model,
loraplus_lr_embedding, optimizer_cls,
) optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
elif self.args.alternate_optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -1235,6 +1306,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[ training_arguments_kwargs[
"torch_compile_backend" "torch_compile_backend"
] = self.cfg.torch_compile_backend ] = self.cfg.torch_compile_backend
if self.cfg.torch_compile_mode:
training_arguments_kwargs[
"torch_compile_mode"
] = self.cfg.torch_compile_mode
# DDP Config # DDP Config
if self.cfg.ddp_timeout: if self.cfg.ddp_timeout:
@@ -1396,6 +1471,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {} trainer_kwargs = {}
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
if self.cfg.optimizer == "lion_pytorch": if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion from lion_pytorch import Lion
@@ -1424,6 +1509,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
sys.path.append(self.cfg.torchdistx_path) sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx") importlib.import_module("torchdistx")
if self.cfg.accelerator_config:
training_arguments_kwargs[
"accelerator_config"
] = self.cfg.accelerator_config
training_args = ( training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs, **training_arguments_kwargs,
@@ -1621,6 +1711,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# trl does some odd mapping of alpha to beta to reuse the beta parameter ??? # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha training_args_kwargs["beta"] = self.cfg.orpo_alpha
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_cls = AxolotlDPOConfig training_args_cls = AxolotlDPOConfig
if self.cfg.rpo_alpha is not None: if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
@@ -1688,8 +1779,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["max_target_length"] = None dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = True dpo_trainer_kwargs["generate_during_eval"] = True
if self.cfg.rl == "dpo":
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
elif self.cfg.rl == "orpo": elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model] trainer_cls_args = [self.model]

View File

View File

@@ -78,6 +78,33 @@ def replace_llama_qkv_with_fused(model):
set_module_name(model, name, qkv) set_module_name(model, name, qkv)
def patch_llama_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
def patch_llama_rms_norm():
try:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
def replace_llama_attn_with_flash_attn( def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False, packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False, cross_entropy: Optional[bool] = False,
@@ -104,35 +131,11 @@ def replace_llama_attn_with_flash_attn(
# skip only if explicitly disabled # skip only if explicitly disabled
if cross_entropy: if cross_entropy:
try: patch_llama_cross_entropy()
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.warning(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
# skip only if explicitly disabled # skip only if explicitly disabled
if rms_norm: if rms_norm:
try: patch_llama_rms_norm()
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
class FusedAttention(LlamaAttention): class FusedAttention(LlamaAttention):

View File

@@ -2,6 +2,7 @@
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import logging import logging
from functools import partial
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@@ -45,6 +46,15 @@ def replace_mistral_attn_with_flash_attn(
) )
def patch_mistral_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
@torch.jit.script @torch.jit.script
def _make_sliding_window_causal_mask( def _make_sliding_window_causal_mask(
bsz: int, bsz: int,

View File

@@ -10,6 +10,8 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"llama",
"mistral",
"mixtral", "mixtral",
"qwen2", "qwen2",
"qwen2_moe", "qwen2_moe",
@@ -24,12 +26,35 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
def patch_for_multipack(model_type, model_name=None): def patch_for_multipack(model_type, model_name=None):
if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
elif hasattr(transformers, "modeling_flash_attention_utils"):
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
return
# retain for legacy
if model_type == "mixtral": if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data get_unpad_data
) )
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3() patch_mixtral_moe_forward_zero3()
elif model_type == "llama":
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "mistral":
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2": elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data get_unpad_data
@@ -58,12 +83,6 @@ def patch_for_multipack(model_type, model_name=None):
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data get_unpad_data
) )
elif model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "jamba":
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
def patch_remote(model_name, config_name, modeling_name): def patch_remote(model_name, config_name, modeling_name):

View File

@@ -1,18 +1,20 @@
"""module for patching with unsloth optimizations""" """module for patching with unsloth optimizations"""
import inspect import inspect
import logging
import re import re
import types import types
from typing import Tuple from typing import Tuple
import torch
from accelerate.logging import get_logger
from peft import PeftModelForCausalLM from peft import PeftModelForCausalLM
from torch import nn
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaFlashAttention2, LlamaFlashAttention2,
LlamaForCausalLM, LlamaForCausalLM,
) )
LOG = logging.getLogger("axolotl.monkeypatch.unsloth") LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_CEL_CODE = """ if labels is not None: ORIGINAL_CEL_CODE = """ if labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
@@ -97,48 +99,51 @@ def check_self_attn_is_patchable() -> bool:
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
def integrate_cross_entropy_loss_patch(): def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
forward = get_forward_code() if model_type == "llama":
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access forward = get_forward_code()
forward, _ = detab_code(forward) LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found" forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
forward = forward.replace( forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", "" "@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
) )
forward = forward.replace( forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)", "@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"", "",
) )
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE) forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace( forward = forward.replace(
"def forward(", "def forward(",
"def fast_cross_entropy_loss_forward(", "def fast_cross_entropy_loss_forward(",
1, 1,
) )
# load imports necessary # load imports necessary
import transformers.models.llama.modeling_llama import transformers.models.llama.modeling_llama
items_to_import = [] items_to_import = []
for item in dir(transformers.models.llama.modeling_llama): for item in dir(transformers.models.llama.modeling_llama):
if item in forward: if item in forward:
items_to_import.append(item) items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102 exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss", "from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(), globals(),
) )
exec( # pylint: disable=exec-used # nosec B102 exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import (" "from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import) + ", ".join(x for x in items_to_import)
+ ")", + ")",
globals(), globals(),
) )
exec(forward, globals()) # pylint: disable=exec-used # nosec B102 exec(forward, globals()) # pylint: disable=exec-used # nosec B102
print("patching unsloth fast_cross_entropy_loss") LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821 LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
else:
raise ValueError("Unsupported model type")
def detab_code(code: str) -> Tuple[str, str]: def detab_code(code: str) -> Tuple[str, str]:
@@ -179,12 +184,30 @@ def patch_self_attn_lora():
globals(), globals(),
) )
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
print("patching unsloth attn lora") LOG.info("patching unsloth attn lora", main_process_only=True)
LlamaFlashAttention2.forward = ( LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821 unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
) )
def integrate_rope_embeddings():
import transformers.models.llama.modeling_llama
from unsloth.kernels.rope_embedding import fast_rope_embedding
def apply_rotary_pos_emb( # pylint: disable=unused-argument
q, # pylint: disable=invalid-name
k, # pylint: disable=invalid-name
cos,
sin,
position_ids=None,
unsqueeze_dim=1,
):
return fast_rope_embedding(q, k, cos, sin)
LOG.info("patching unsloth RoPE embeddings", main_process_only=True)
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM): def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if peft_model.base_model.config.model_type in ["llama", "mistral"]: if peft_model.base_model.config.model_type in ["llama", "mistral"]:
from unsloth.kernels import apply_lora_mlp_swiglu from unsloth.kernels import apply_lora_mlp_swiglu
@@ -217,7 +240,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if is_mlp_lora and mlp_no_bias and mlp_not_dora: if is_mlp_lora and mlp_no_bias and mlp_not_dora:
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp) layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
else: else:
logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx) LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
@@ -243,9 +266,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
layer.self_attn.apply_qkv = apply_lora_qkv layer.self_attn.apply_qkv = apply_lora_qkv
else: else:
layer.self_attn.apply_qkv = original_apply_qkv layer.self_attn.apply_qkv = original_apply_qkv
logging.warning( LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx)
"unable to apply unsloth lora qkv patch to layer %d", idx
)
if cfg.unsloth_lora_o: if cfg.unsloth_lora_o:
layer_modules = [ layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"] getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
@@ -264,6 +285,33 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
layer.self_attn.apply_o = apply_lora_o layer.self_attn.apply_o = apply_lora_o
else: else:
layer.self_attn.apply_o = original_apply_o layer.self_attn.apply_o = original_apply_o
logging.warning( LOG.warning(
"unable to apply unsloth lora o_proj patch to layer %d", idx "unable to apply unsloth lora o_proj patch to layer %d", idx
) )
def patch_unsloth_layernorm():
try:
import transformers.models.llama.modeling_llama
from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm
class LlamaRMSNorm(nn.Module):
"""LlamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
return Fast_RMS_Layernorm.apply(
hidden_states, self.weight, self.variance_epsilon, False
)
LOG.info("patching with unsloth.kernels.rms_layernorm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning("missing unsloth library")

View File

@@ -0,0 +1,78 @@
"""
DPO prompt strategies for using tokenizer chat templates.
"""
from axolotl.utils.chat_templates import chat_templates
def default(
cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx]
chat_template_str = chat_templates(cfg.chat_template)
field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
field_message_role = ds_cfg.get("message_field_role", "role")
field_message_content = ds_cfg.get("message_field_content", "content")
role_map_inv = ds_cfg.get(
"roles",
{
"user": ["user"],
"assistant": ["assistant"],
"system": ["system"],
},
)
role_map = {}
for target, sources in role_map_inv.items():
for source in sources:
role_map[source] = target
def transform_fn(sample, tokenizer=None):
messages = sample[field_messages]
messages = [
{
"role": role_map[m[field_message_role]],
"content": m[field_message_content],
}
for m in messages
]
chosen = {
"role": role_map[sample[field_chosen][field_message_role]],
"content": sample[field_chosen][field_message_content],
}
rejected = {
"role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content],
}
result = {}
result["prompt"] = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=False,
)
result["chosen"] = tokenizer.apply_chat_template(
[chosen],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)
chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:]
result["rejected"] = tokenizer.apply_chat_template(
[rejected],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)
rejected_strip_index = result["rejected"].find(rejected["content"])
result["rejected"] = result["rejected"][rejected_strip_index:]
return result
return transform_fn

View File

@@ -19,6 +19,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.core.tokenizer_utils import fix_untrained_tokens
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.freeze import freeze_layers_except
@@ -52,6 +53,15 @@ class TrainDatasetMeta:
def train( def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
# enable expandable segments for cuda allocation to improve VRAM usage
torch_version = torch.__version__.split(".")
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ[
"PYTORCH_CUDA_ALLOC_CONF"
] = "expandable_segments:True,roundup_power2_divisions:16"
# load the tokenizer first # load the tokenizer first
LOG.debug( LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
@@ -114,6 +124,13 @@ def train(
total_num_steps, total_num_steps,
) )
if cfg.fix_untrained_tokens:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
# go ahead and presave, so we have the adapter config available to inspect # go ahead and presave, so we have the adapter config available to inspect
if peft_config: if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")

View File

@@ -9,6 +9,7 @@ import os
from enum import Enum from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
from importlib.metadata import version
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
Field, Field,
@@ -84,6 +85,7 @@ class PretrainingDataset(BaseModel):
split: Optional[str] = "train" split: Optional[str] = "train"
text_column: Optional[str] = "text" text_column: Optional[str] = "text"
type: Optional[str] = "pretrain" type: Optional[str] = "pretrain"
trust_remote_code: Optional[bool] = False
class UserDefinedPrompterType(BaseModel): class UserDefinedPrompterType(BaseModel):
@@ -125,6 +127,8 @@ class SFTDataset(BaseModel):
roles: Optional[Dict[str, List[str]]] = None roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None drop_system_message: Optional[bool] = None
trust_remote_code: Optional[bool] = False
class UserDefinedDPOType(BaseModel): class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO""" """User defined typing for DPO"""
@@ -165,6 +169,7 @@ class KTODataset(BaseModel):
split: Optional[str] = None split: Optional[str] = None
type: Optional[Union[UserDefinedKTOType, str]] = None type: Optional[Union[UserDefinedKTOType, str]] = None
data_files: Optional[List[str]] = None data_files: Optional[List[str]] = None
trust_remote_code: Optional[bool] = False
class RLType(str, Enum): class RLType(str, Enum):
@@ -350,7 +355,16 @@ class HyperparametersConfig(BaseModel):
learning_rate: Union[str, float] learning_rate: Union[str, float]
weight_decay: Optional[float] = 0.0 weight_decay: Optional[float] = 0.0
optimizer: Optional[ optimizer: Optional[
Union[OptimizerNames, Literal["lion_pytorch"]] Union[
OptimizerNames,
Literal[
"lion_pytorch",
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
],
]
] = OptimizerNames.ADAMW_HF.value ] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field( optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."} default=None, metadata={"help": "Optional arguments to supply to optimizer."}
@@ -513,6 +527,8 @@ class AxolotlInputConfig(
dataloader_prefetch_factor: Optional[int] = None dataloader_prefetch_factor: Optional[int] = None
dataloader_drop_last: Optional[bool] = None dataloader_drop_last: Optional[bool] = None
accelerator_config: Optional[Dict[str, Any]] = None
remove_unused_columns: Optional[bool] = None remove_unused_columns: Optional[bool] = None
push_dataset_to_hub: Optional[str] = None push_dataset_to_hub: Optional[str] = None
@@ -599,6 +615,8 @@ class AxolotlInputConfig(
unsloth_lora_mlp: Optional[bool] = None unsloth_lora_mlp: Optional[bool] = None
unsloth_lora_qkv: Optional[bool] = None unsloth_lora_qkv: Optional[bool] = None
unsloth_lora_o: Optional[bool] = None unsloth_lora_o: Optional[bool] = None
unsloth_rms_norm: Optional[bool] = None
unsloth_rope: Optional[bool] = None
deepspeed: Optional[Union[str, Dict[str, Any]]] = None deepspeed: Optional[Union[str, Dict[str, Any]]] = None
fsdp: Optional[List[str]] = None fsdp: Optional[List[str]] = None
@@ -611,6 +629,9 @@ class AxolotlInputConfig(
torch_compile: Optional[bool] = None torch_compile: Optional[bool] = None
torch_compile_backend: Optional[str] = None torch_compile_backend: Optional[str] = None
torch_compile_mode: Optional[
Literal["default", "reduce-overhead", "max-autotune"]
] = None
max_steps: Optional[int] = None max_steps: Optional[int] = None
warmup_steps: Optional[int] = None warmup_steps: Optional[int] = None
@@ -651,6 +672,8 @@ class AxolotlInputConfig(
] = None ] = None
default_system_message: Optional[str] = None default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[bool] = None
# INTERNALS - document for now, generally not set externally # INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None is_preprocess: Optional[bool] = None
@@ -716,6 +739,24 @@ class AxolotlInputConfig(
) )
return data return data
@model_validator(mode="before")
@classmethod
def check_pretraining_split_batches_accelerate(cls, data):
# alternatively set ACCELERATE_SPLIT_BATCHES=False
if data.get("pretraining_dataset"):
accelerator_config = data.get("accelerator_config", {})
if not accelerator_config:
data["accelerator_config"] = {
"split_batches": False,
"dispatch_batches": False,
}
else:
if accelerator_config.get("split_batches") is None:
data["accelerator_config"]["split_batches"] = False
if accelerator_config.get("dispatch_batches") is None:
data["accelerator_config"]["dispatch_batches"] = False
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_gptq_w_revision(cls, data): def check_gptq_w_revision(cls, data):
@@ -834,7 +875,7 @@ class AxolotlInputConfig(
@model_validator(mode="after") @model_validator(mode="after")
def check_adamw_optimizer_params(self): def check_adamw_optimizer_params(self):
if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and ( if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
not self.optimizer or "adamw" not in self.optimizer.value not self.optimizer or "adamw" not in str(self.optimizer).lower()
): ):
LOG.warning("adamw hyperparameters found, but no adamw optimizer set") LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self return self
@@ -1126,6 +1167,55 @@ class AxolotlInputConfig(
raise ValueError("either datasets or pretraining_dataset is required") raise ValueError("either datasets or pretraining_dataset is required")
return data return data
@model_validator(mode="before")
@classmethod
def check_xentropy_patch_conflicts(cls, data):
if data.get("flash_attn_cross_entropy") and data.get(
"unsloth_cross_entropy_loss"
):
raise ValueError(
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
)
return data
@model_validator(mode="before")
@classmethod
def check_qlora_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
)
return data
@model_validator(mode="before")
@classmethod
def check_unsloth_xformers_version(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
xformers_version = version("xformers")
if xformers_version == "0.0.27":
raise ValueError(
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):
if data.get("deepspeed") and data.get("torch_compile"):
raise ValueError(
"torch_compile should be set within your deepspeed config file"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig): class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options""" """wrapper to valdiate gpu capabilities with the configured options"""
@@ -1177,3 +1267,18 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("deepspeed") and data.get("fsdp"): if data.get("deepspeed") and data.get("fsdp"):
raise ValueError("deepspeed and fsdp cannot be used together.") raise ValueError("deepspeed and fsdp cannot be used together.")
return data return data
@model_validator(mode="before")
@classmethod
def check_multigpu_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
capabilities = data.get("capabilities")
if capabilities and capabilities.get("n_gpu", 0) > 1:
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
)
return data

View File

@@ -1,4 +1,5 @@
"""data handling specific to DPO""" """data handling specific to DPO"""
import inspect import inspect
import logging import logging
from functools import partial from functools import partial

View File

@@ -1,7 +1,7 @@
"""Module for models and model loading""" """Module for models and model loading"""
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
import gc
import logging import logging
import math import math
import os import os
@@ -94,7 +94,7 @@ def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDef
"Please make sure to point to a GPTQ model." "Please make sure to point to a GPTQ model."
) )
if not cfg.gptq and quant_config_exists: if not cfg.gptq and quant_config_exists and not cfg.load_in_4bit:
raise ValueError( raise ValueError(
"model_config.quantization_config is set but `gptq` flag is not. " "model_config.quantization_config is set but `gptq` flag is not. "
"Please use the `gptq` flag to train quantized model or point to a non-quantized model." "Please use the `gptq` flag to train quantized model or point to a non-quantized model."
@@ -347,6 +347,31 @@ def load_model(
and cfg.sample_packing and cfg.sample_packing
): ):
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model) patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
if cfg.is_llama_derived_model:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)
if cfg.flash_attn_cross_entropy:
patch_llama_cross_entropy()
if cfg.flash_attn_rms_norm:
patch_llama_rms_norm()
elif cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm()
if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import (
integrate_cross_entropy_loss_patch,
)
integrate_cross_entropy_loss_patch(model_type="llama")
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
elif cfg.is_llama_derived_model: elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block # Modify all llama derived models in one block
@@ -371,6 +396,12 @@ def load_model(
rms_norm=cfg.flash_attn_rms_norm, rms_norm=cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True, use_shifted_sparse_attn=True,
) )
elif cfg.flash_attn_cross_entropy or cfg.flash_attn_rms_norm:
replace_llama_attn_with_flash_attn(
packed=False,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
elif cfg.xformers_attention: elif cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import ( from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention, hijack_llama_attention,
@@ -393,7 +424,7 @@ def load_model(
if cfg.unsloth_cross_entropy_loss: if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch() integrate_cross_entropy_loss_patch(model_type="llama")
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
@@ -401,23 +432,12 @@ def load_model(
patch_self_attn_lora() patch_self_attn_lora()
# Modify mistral derived models # Modify mistral derived models
if ( if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss:
cfg.model_config_type == "mistral"
and cfg.flash_attention
and cfg.sample_packing
):
from axolotl.monkeypatch.mistral_attn_hijack_flash import ( from axolotl.monkeypatch.mistral_attn_hijack_flash import (
replace_mistral_attn_with_flash_attn, patch_mistral_cross_entropy,
) )
LOG.info("patching mistral with flash attention") patch_mistral_cross_entropy()
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
LOG.info("patching _expand_mask")
hijack_expand_mask()
model_kwargs: Dict[str, Any] = {} model_kwargs: Dict[str, Any] = {}
@@ -599,9 +619,12 @@ def load_model(
and not cfg.trust_remote_code and not cfg.trust_remote_code
and not cfg.gptq and not cfg.gptq
): ):
from transformers import LlamaForCausalLM if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
model = LlamaForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=model_config, config=model_config,
**model_kwargs, **model_kwargs,
@@ -634,7 +657,11 @@ def load_model(
base_model, base_model,
**model_kwargs, **model_kwargs,
) )
elif model_type and not cfg.trust_remote_code: elif (
model_type
and model_type != "AutoModelForCausalLM"
and not cfg.trust_remote_code
):
if cfg.gptq: if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
@@ -675,6 +702,7 @@ def load_model(
) )
else: else:
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
# disabling either of these two still leads to VRAM spike before setting back down
skip_move_to_device = True skip_move_to_device = True
if "device_map" in model_kwargs: if "device_map" in model_kwargs:
del model_kwargs["device_map"] del model_kwargs["device_map"]
@@ -849,6 +877,15 @@ def load_model(
integrate_lora_patch(model, cfg) integrate_lora_patch(model, cfg)
if cfg.unsloth_rope:
from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
integrate_rope_embeddings()
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
# TODO resume_from_checkpoint handling # TODO resume_from_checkpoint handling
return model, lora_config return model, lora_config

View File

@@ -62,7 +62,7 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
"""Helper function to process and color tokens.""" """Helper function to process and color tokens."""
colored_tokens = [ colored_tokens = [
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only) color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
for token in tokenizer.encode(tokens) for token in tokenizer.encode(tokens, add_special_tokens=False)
] ]
return colored_tokens return colored_tokens

View File

@@ -189,9 +189,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
max_input_len = np.max(get_dataset_lengths(train_dataset)) max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
if ( if cfg.model_config_type == "mamba":
cfg.is_mistral_derived_model and cfg.flash_attention
) or cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column") LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask") train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset: if eval_dataset:

View File

@@ -0,0 +1,87 @@
"""
E2E tests for lora llama
"""
import logging
import os
import unittest
from importlib import reload
from pathlib import Path
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.fixture(autouse=True)
def reload_transformers():
import transformers.models.llama.modeling_llama
yield
reload(transformers.models.llama.modeling_llama)
class TestFAXentropyLlama(unittest.TestCase):
"""
Test case for Llama models using LoRA w multipack
"""
@with_temp_dir
def test_lora_packing_fa_cross_entropy(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"flash_attn_cross_entropy": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.2,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -4,6 +4,8 @@ E2E smoke tests to check that the monkeypatches are in place for certain configu
import unittest import unittest
import transformers
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -87,9 +89,9 @@ class TestModelPatches(unittest.TestCase):
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) load_model(cfg, tokenizer, inference=cli_args.inference)
assert ( assert (
"axolotl.monkeypatch.mistral_attn_hijack_flash" "torch.jit"
in model.model.layers[0].self_attn.forward.__module__ in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ # pylint: disable=protected-access
) )

View File

@@ -0,0 +1,67 @@
"""
E2E tests for llama pretrain
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPretrainLlama(unittest.TestCase):
"""
Test case for Llama models w pretraining
"""
@with_temp_dir
def test_pretrain_w_sample_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"flash_attention": True,
"sequence_len": 1024,
"sample_packing": True,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"pretraining_dataset": [
{
"path": "allenai/c4",
"name": "en",
"type": "pretrain",
}
],
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -34,8 +34,8 @@ class TestLoraLlama(unittest.TestCase):
"sequence_len": 1024, "sequence_len": 1024,
"load_in_8bit": True, "load_in_8bit": True,
"adapter": "lora", "adapter": "lora",
"lora_r": 32, "lora_r": 8,
"lora_alpha": 64, "lora_alpha": 16,
"lora_dropout": 0.05, "lora_dropout": 0.05,
"lora_target_linear": True, "lora_target_linear": True,
"val_set_size": 0.1, "val_set_size": 0.1,
@@ -50,7 +50,7 @@ class TestLoraLlama(unittest.TestCase):
"type": "alpaca", "type": "alpaca",
}, },
], ],
"num_epochs": 2, "num_epochs": 1,
"micro_batch_size": 8, "micro_batch_size": 8,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": temp_dir, "output_dir": temp_dir,

View File

@@ -0,0 +1,67 @@
"""
E2E tests for custom optimizers using Llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestCustomOptimizers(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_optimi_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "optimi_adamw",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -0,0 +1,156 @@
"""
tests for chat_template prompt strategy
"""
import unittest
import pytest
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.prompt_strategies.dpo.chat_template import default
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"messages": [
{
"role": "user",
"content": "hello",
},
{
"role": "assistant",
"content": "hello",
},
{
"role": "user",
"content": "goodbye",
},
],
"chosen": {
"role": "assistant",
"content": "goodbye",
},
"rejected": {
"role": "assistant",
"content": "party on",
},
}
]
)
@pytest.fixture(name="custom_assistant_dataset")
def fixture_custom_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
"conversation": [
{
"speaker": "human",
"text": "hello",
},
{
"speaker": "agent",
"text": "hello",
},
{
"speaker": "human",
"text": "goodbye",
},
],
"better": {
"speaker": "agent",
"text": "goodbye",
},
"worse": {
"speaker": "agent",
"text": "party on",
},
}
]
)
@pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
tokenizer.eos_token = "<|eot_id|>"
return tokenizer
class TestAssistantDPOChatTemplateLlama3:
"""
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
"""
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "llama3",
"datasets": [
{
"chat_template": "llama3",
}
],
}
)
)
result = transform_fn(assistant_dataset[0], tokenizer=llama3_tokenizer)
assert result["prompt"] == (
"<|begin_of_text|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert result["chosen"] == "goodbye<|eot_id|>"
assert result["rejected"] == "party on<|eot_id|>"
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "llama3",
"datasets": [
{
"chat_template": "llama3",
"field_messages": "conversation",
"field_chosen": "better",
"field_rejected": "worse",
"message_field_role": "speaker",
"message_field_content": "text",
"roles": {
"user": ["human"],
"assistant": ["agent"],
"system": ["sys"],
},
}
],
}
)
)
result = transform_fn(custom_assistant_dataset[0], tokenizer=llama3_tokenizer)
assert result["prompt"] == (
"<|begin_of_text|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
)
assert result["chosen"] == "goodbye<|eot_id|>"
assert result["rejected"] == "party on<|eot_id|>"
if __name__ == "__main__":
unittest.main()

View File

@@ -24,7 +24,7 @@ class TestPretrainingPacking(unittest.TestCase):
def test_packing_stream_dataset(self): def test_packing_stream_dataset(self):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
dataset = load_dataset( dataset = load_dataset(
"c4", "allenai/c4",
"en", "en",
streaming=True, streaming=True,
)["train"] )["train"]
@@ -33,7 +33,7 @@ class TestPretrainingPacking(unittest.TestCase):
{ {
"pretraining_dataset": [ "pretraining_dataset": [
{ {
"path": "c4", "path": "allenai/c4",
"name": "en", "name": "en",
"type": "pretrain", "type": "pretrain",
} }